diff options
Diffstat (limited to 'internal/api/client')
| -rw-r--r-- | internal/api/client/auth/auth.go | 13 | ||||
| -rw-r--r-- | internal/api/client/auth/auth_test.go | 94 | ||||
| -rw-r--r-- | internal/api/client/auth/authorize.go | 73 | ||||
| -rw-r--r-- | internal/api/client/auth/authorize_test.go | 113 | 
4 files changed, 279 insertions, 14 deletions
diff --git a/internal/api/client/auth/auth.go b/internal/api/client/auth/auth.go index 67643244b..717d997a3 100644 --- a/internal/api/client/auth/auth.go +++ b/internal/api/client/auth/auth.go @@ -32,10 +32,23 @@ import (  const (  	// AuthSignInPath is the API path for users to sign in through  	AuthSignInPath = "/auth/sign_in" + +	// CheckYourEmailPath users land here after registering a new account, instructs them to confirm thier email +	CheckYourEmailPath = "/check_your_email" + +	// WaitForApprovalPath users land here after confirming thier email but before an admin approves thier account +	// (if such is required) +	WaitForApprovalPath = "/wait_for_approval" + +	// AccountDisabledPath users land here when thier account is suspended by an admin +	AccountDisabledPath = "/account_disabled" +  	// OauthTokenPath is the API path to use for granting token requests to users with valid credentials  	OauthTokenPath = "/oauth/token" +  	// OauthAuthorizePath is the API path for authorization requests (eg., authorize this app to act on my behalf as a user)  	OauthAuthorizePath = "/oauth/authorize" +  	// CallbackPath is the API path for receiving callback tokens from external OIDC providers  	CallbackPath = oidc.CallbackPath diff --git a/internal/api/client/auth/auth_test.go b/internal/api/client/auth/auth_test.go index a0ee8892d..fdf1b6baf 100644 --- a/internal/api/client/auth/auth_test.go +++ b/internal/api/client/auth/auth_test.go @@ -18,4 +18,96 @@  package auth_test -// TODO +import ( +	"context" +	"fmt" +	"net/http/httptest" + +	"github.com/gin-contrib/sessions" +	"github.com/gin-contrib/sessions/memstore" +	"github.com/gin-gonic/gin" +	"github.com/spf13/viper" +	"github.com/stretchr/testify/suite" +	"github.com/superseriousbusiness/gotosocial/internal/api/client/auth" +	"github.com/superseriousbusiness/gotosocial/internal/config" +	"github.com/superseriousbusiness/gotosocial/internal/db" +	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +	"github.com/superseriousbusiness/gotosocial/internal/oauth" +	"github.com/superseriousbusiness/gotosocial/internal/oidc" +	"github.com/superseriousbusiness/gotosocial/internal/router" +	"github.com/superseriousbusiness/gotosocial/testrig" +) + +type AuthStandardTestSuite struct { +	suite.Suite +	db          db.DB +	idp         oidc.IDP +	oauthServer oauth.Server + +	// standard suite models +	testTokens       map[string]*gtsmodel.Token +	testClients      map[string]*gtsmodel.Client +	testApplications map[string]*gtsmodel.Application +	testUsers        map[string]*gtsmodel.User +	testAccounts     map[string]*gtsmodel.Account + +	// module being tested +	authModule *auth.Module +} + +const ( +	sessionUserID   = "userid" +	sessionClientID = "client_id" +) + +func (suite *AuthStandardTestSuite) SetupSuite() { +	suite.testTokens = testrig.NewTestTokens() +	suite.testClients = testrig.NewTestClients() +	suite.testApplications = testrig.NewTestApplications() +	suite.testUsers = testrig.NewTestUsers() +	suite.testAccounts = testrig.NewTestAccounts() +} + +func (suite *AuthStandardTestSuite) SetupTest() { +	testrig.InitTestConfig() +	suite.db = testrig.NewTestDB() +	testrig.InitTestLog() + +	suite.oauthServer = testrig.NewTestOauthServer(suite.db) +	var err error +	suite.idp, err = oidc.NewIDP(context.Background()) +	if err != nil { +		panic(err) +	} +	suite.authModule = auth.New(suite.db, suite.oauthServer, suite.idp).(*auth.Module) +	testrig.StandardDBSetup(suite.db, nil) +} + +func (suite *AuthStandardTestSuite) TearDownTest() { +	testrig.StandardDBTeardown(suite.db) +} + +func (suite *AuthStandardTestSuite) newContext(requestMethod string, requestPath string) (*gin.Context, *httptest.ResponseRecorder) { +	// create the recorder and gin test context +	recorder := httptest.NewRecorder() +	ctx, engine := gin.CreateTestContext(recorder) + +	// load templates into the engine +	testrig.ConfigureTemplatesWithGin(engine) + +	// create the request +	protocol := viper.GetString(config.Keys.Protocol) +	host := viper.GetString(config.Keys.Host) +	baseURI := fmt.Sprintf("%s://%s", protocol, host) +	requestURI := fmt.Sprintf("%s/%s", baseURI, requestPath) +	ctx.Request = httptest.NewRequest(requestMethod, requestURI, nil) // the endpoint we're hitting +	ctx.Request.Header.Set("accept", "text/html") + +	// trigger the session middleware on the context +	store := memstore.NewStore(make([]byte, 32), make([]byte, 32)) +	store.Options(router.SessionOptions()) +	sessionMiddleware := sessions.Sessions("gotosocial-localhost", store) +	sessionMiddleware(ctx) + +	return ctx, recorder +} diff --git a/internal/api/client/auth/authorize.go b/internal/api/client/auth/authorize.go index 99f3cca68..387b83c1e 100644 --- a/internal/api/client/auth/authorize.go +++ b/internal/api/client/auth/authorize.go @@ -44,7 +44,7 @@ func (m *Module) AuthorizeGETHandler(c *gin.Context) {  	s := sessions.Default(c)  	if _, err := api.NegotiateAccept(c, api.HTMLAcceptHeaders...); err != nil { -		c.JSON(http.StatusNotAcceptable, gin.H{"error": err.Error()}) +		c.HTML(http.StatusNotAcceptable, "error.tmpl", gin.H{"error": err.Error()})  		return  	} @@ -57,7 +57,7 @@ func (m *Module) AuthorizeGETHandler(c *gin.Context) {  		if err := c.Bind(form); err != nil {  			l.Debugf("invalid auth form: %s", err)  			m.clearSession(s) -			c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) +			c.HTML(http.StatusBadRequest, "error.tmpl", gin.H{"error": err.Error()})  			return  		}  		l.Debugf("parsed auth form: %+v", form) @@ -65,7 +65,7 @@ func (m *Module) AuthorizeGETHandler(c *gin.Context) {  		if err := extractAuthForm(s, form); err != nil {  			l.Debugf(fmt.Sprintf("error parsing form at /oauth/authorize: %s", err))  			m.clearSession(s) -			c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) +			c.HTML(http.StatusBadRequest, "error.tmpl", gin.H{"error": err.Error()})  			return  		}  		c.Redirect(http.StatusSeeOther, AuthSignInPath) @@ -75,28 +75,33 @@ func (m *Module) AuthorizeGETHandler(c *gin.Context) {  	// We can use the client_id on the session to retrieve info about the app associated with the client_id  	clientID, ok := s.Get(sessionClientID).(string)  	if !ok || clientID == "" { -		c.JSON(http.StatusInternalServerError, gin.H{"error": "no client_id found in session"}) +		c.HTML(http.StatusInternalServerError, "error.tmpl", gin.H{"error": "no client_id found in session"})  		return  	}  	app := >smodel.Application{}  	if err := m.db.GetWhere(c.Request.Context(), []db.Where{{Key: sessionClientID, Value: clientID}}, app); err != nil {  		m.clearSession(s) -		c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("no application found for client id %s", clientID)}) +		c.HTML(http.StatusInternalServerError, "error.tmpl", gin.H{ +			"error": fmt.Sprintf("no application found for client id %s", clientID), +		})  		return  	} -	// we can also use the userid of the user to fetch their username from the db to greet them nicely <3 +	// redirect the user if they have not confirmed their email yet, thier account has not been approved yet, +	// or thier account has been disabled.  	user := >smodel.User{}  	if err := m.db.GetByID(c.Request.Context(), userID, user); err != nil {  		m.clearSession(s) -		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) +		c.HTML(http.StatusInternalServerError, "error.tmpl", gin.H{"error": err.Error()})  		return  	} -  	acct, err := m.db.GetAccountByID(c.Request.Context(), user.AccountID)  	if err != nil {  		m.clearSession(s) -		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) +		c.HTML(http.StatusInternalServerError, "error.tmpl", gin.H{"error": err.Error()}) +		return +	} +	if !ensureUserIsAuthorizedOrRedirect(c, user, acct) {  		return  	} @@ -104,13 +109,13 @@ func (m *Module) AuthorizeGETHandler(c *gin.Context) {  	redirect, ok := s.Get(sessionRedirectURI).(string)  	if !ok || redirect == "" {  		m.clearSession(s) -		c.JSON(http.StatusInternalServerError, gin.H{"error": "no redirect_uri found in session"}) +		c.HTML(http.StatusInternalServerError, "error.tmpl", gin.H{"error": "no redirect_uri found in session"})  		return  	}  	scope, ok := s.Get(sessionScope).(string)  	if !ok || scope == "" {  		m.clearSession(s) -		c.JSON(http.StatusInternalServerError, gin.H{"error": "no scope found in session"}) +		c.HTML(http.StatusInternalServerError, "error.tmpl", gin.H{"error": "no scope found in session"})  		return  	} @@ -170,10 +175,28 @@ func (m *Module) AuthorizePOSTHandler(c *gin.Context) {  		errs = append(errs, "session missing userid")  	} +	// redirect the user if they have not confirmed their email yet, thier account has not been approved yet, +	// or thier account has been disabled. +	user := >smodel.User{} +	if err := m.db.GetByID(c.Request.Context(), userID, user); err != nil { +		m.clearSession(s) +		c.HTML(http.StatusInternalServerError, "error.tmpl", gin.H{"error": err.Error()}) +		return +	} +	acct, err := m.db.GetAccountByID(c.Request.Context(), user.AccountID) +	if err != nil { +		m.clearSession(s) +		c.HTML(http.StatusInternalServerError, "error.tmpl", gin.H{"error": err.Error()}) +		return +	} +	if !ensureUserIsAuthorizedOrRedirect(c, user, acct) { +		return +	} +  	m.clearSession(s)  	if len(errs) != 0 { -		c.JSON(http.StatusBadRequest, gin.H{"error": strings.Join(errs, ": ")}) +		c.HTML(http.StatusBadRequest, "error.tmpl", gin.H{"error": strings.Join(errs, ": ")})  		return  	} @@ -190,7 +213,7 @@ func (m *Module) AuthorizePOSTHandler(c *gin.Context) {  	// and proceed with authorization using the oauth2 library  	if err := m.server.HandleAuthorizeRequest(c.Writer, c.Request); err != nil { -		c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) +		c.HTML(http.StatusBadRequest, "error.tmpl", gin.H{"error": err.Error()})  	}  } @@ -216,3 +239,27 @@ func extractAuthForm(s sessions.Session, form *model.OAuthAuthorize) error {  	s.Set(sessionState, uuid.NewString())  	return s.Save()  } + +func ensureUserIsAuthorizedOrRedirect(ctx *gin.Context, user *gtsmodel.User, account *gtsmodel.Account) bool { +	if user.ConfirmedAt.IsZero() { +		ctx.Redirect(http.StatusSeeOther, CheckYourEmailPath) +		return false +	} + +	if !user.Approved { +		ctx.Redirect(http.StatusSeeOther, WaitForApprovalPath) +		return false +	} + +	if user.Disabled { +		ctx.Redirect(http.StatusSeeOther, AccountDisabledPath) +		return false +	} + +	if !account.SuspendedAt.IsZero() { +		ctx.Redirect(http.StatusSeeOther, AccountDisabledPath) +		return false +	} + +	return true +} diff --git a/internal/api/client/auth/authorize_test.go b/internal/api/client/auth/authorize_test.go new file mode 100644 index 000000000..8f16702da --- /dev/null +++ b/internal/api/client/auth/authorize_test.go @@ -0,0 +1,113 @@ +package auth_test + +import ( +	"context" +	"fmt" +	"net/http" +	"testing" +	"time" + +	"codeberg.org/gruf/go-errors" +	"github.com/gin-contrib/sessions" +	"github.com/stretchr/testify/suite" +	"github.com/superseriousbusiness/gotosocial/internal/api/client/auth" +	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) + +type AuthAuthorizeTestSuite struct { +	AuthStandardTestSuite +} + +type authorizeHandlerTestCase struct { +	description            string +	mutateUserAccount      func(*gtsmodel.User, *gtsmodel.Account) +	expectedStatusCode     int +	expectedLocationHeader string +} + +func (suite *AuthAuthorizeTestSuite) TestAccountAuthorizeHandler() { + +	var tests = []authorizeHandlerTestCase{ +		{ +			description: "user has their email unconfirmed", +			mutateUserAccount: func(user *gtsmodel.User, account *gtsmodel.Account) { +				// nothing to do, weed_lord420 already has their email unconfirmed +			}, +			expectedStatusCode:     http.StatusSeeOther, +			expectedLocationHeader: auth.CheckYourEmailPath, +		}, +		{ +			description: "user has their email confirmed but is not approved", +			mutateUserAccount: func(user *gtsmodel.User, account *gtsmodel.Account) { +				user.ConfirmedAt = time.Now() +				user.Email = user.UnconfirmedEmail +			}, +			expectedStatusCode:     http.StatusSeeOther, +			expectedLocationHeader: auth.WaitForApprovalPath, +		}, +		{ +			description: "user has their email confirmed and is approved, but User entity has been disabled", +			mutateUserAccount: func(user *gtsmodel.User, account *gtsmodel.Account) { +				user.ConfirmedAt = time.Now() +				user.Email = user.UnconfirmedEmail +				user.Approved = true +				user.Disabled = true +			}, +			expectedStatusCode:     http.StatusSeeOther, +			expectedLocationHeader: auth.AccountDisabledPath, +		}, +		{ +			description: "user has their email confirmed and is approved, but Account entity has been suspended", +			mutateUserAccount: func(user *gtsmodel.User, account *gtsmodel.Account) { +				user.ConfirmedAt = time.Now() +				user.Email = user.UnconfirmedEmail +				user.Approved = true +				user.Disabled = false +				account.SuspendedAt = time.Now() +			}, +			expectedStatusCode:     http.StatusSeeOther, +			expectedLocationHeader: auth.AccountDisabledPath, +		}, +	} + +	doTest := func(testCase authorizeHandlerTestCase) { +		ctx, recorder := suite.newContext(http.MethodGet, auth.OauthAuthorizePath) + +		user := suite.testUsers["unconfirmed_account"] +		account := suite.testAccounts["unconfirmed_account"] + +		testSession := sessions.Default(ctx) +		testSession.Set(sessionUserID, user.ID) +		testSession.Set(sessionClientID, suite.testApplications["application_1"].ClientID) +		if err := testSession.Save(); err != nil { +			panic(errors.WrapMsgf(err, "failed on case: %s", testCase.description)) +		} + +		testCase.mutateUserAccount(user, account) + +		testCase.description = fmt.Sprintf("%s, %t, %s", user.Email, user.Disabled, account.SuspendedAt) + +		user.UpdatedAt = time.Now() +		err := suite.db.UpdateByPrimaryKey(context.Background(), user) +		suite.NoError(err) +		_, err = suite.db.UpdateAccount(context.Background(), account) +		suite.NoError(err) + +		// call the handler +		suite.authModule.AuthorizeGETHandler(ctx) + +		// 1. we should have a redirect +		suite.Equal(testCase.expectedStatusCode, recorder.Code, fmt.Sprintf("failed on case: %s", testCase.description)) + +		// 2. we should have a redirect to the check your email path, as this user has not confirmed their email yet. +		suite.Equal(testCase.expectedLocationHeader, recorder.Header().Get("Location"), fmt.Sprintf("failed on case: %s", testCase.description)) +	} + +	for _, testCase := range tests { +		doTest(testCase) +	} +} + +func TestAccountUpdateTestSuite(t *testing.T) { +	suite.Run(t, new(AuthAuthorizeTestSuite)) +}  | 
