diff options
author | 2023-01-02 13:10:50 +0100 | |
---|---|---|
committer | 2023-01-02 12:10:50 +0000 | |
commit | 941893a774c83802afdc4cc76e1d30c59b6c5585 (patch) | |
tree | 6e7296146dedfeac8e83655157270f41e190724b /internal/api/auth | |
parent | [chore]: Bump github.com/abema/go-mp4 from 0.8.0 to 0.9.0 (#1287) (diff) | |
download | gotosocial-941893a774c83802afdc4cc76e1d30c59b6c5585.tar.xz |
[chore] The Big Middleware and API Refactor (tm) (#1250)
* interim commit: start refactoring middlewares into package under router
* another interim commit, this is becoming a big job
* another fucking massive interim commit
* refactor bookmarks to new style
* ambassador, wiz zeze commits you are spoiling uz
* she compiles, we're getting there
* we're just normal men; we're just innocent men
* apiutil
* whoopsie
* i'm glad noone reads commit msgs haha :blob_sweat:
* use that weirdo go-bytesize library for maxMultipartMemory
* fix media module paths
Diffstat (limited to 'internal/api/auth')
-rw-r--r-- | internal/api/auth/auth.go | 117 | ||||
-rw-r--r-- | internal/api/auth/auth_test.go | 129 | ||||
-rw-r--r-- | internal/api/auth/authorize.go | 335 | ||||
-rw-r--r-- | internal/api/auth/authorize_test.go | 118 | ||||
-rw-r--r-- | internal/api/auth/callback.go | 317 | ||||
-rw-r--r-- | internal/api/auth/oob.go | 113 | ||||
-rw-r--r-- | internal/api/auth/signin.go | 145 | ||||
-rw-r--r-- | internal/api/auth/token.go | 115 | ||||
-rw-r--r-- | internal/api/auth/token_test.go | 215 |
9 files changed, 1604 insertions, 0 deletions
diff --git a/internal/api/auth/auth.go b/internal/api/auth/auth.go new file mode 100644 index 000000000..7ce992466 --- /dev/null +++ b/internal/api/auth/auth.go @@ -0,0 +1,117 @@ +/* + GoToSocial + Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package auth + +import ( + "net/http" + + "github.com/gin-contrib/sessions" + "github.com/gin-gonic/gin" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/oidc" + "github.com/superseriousbusiness/gotosocial/internal/processing" +) + +const ( + /* + paths prefixed with 'auth' + */ + + // AuthSignInPath is the API path for users to sign in through + AuthSignInPath = "/sign_in" + // AuthCheckYourEmailPath users land here after registering a new account, instructs them to confirm their email + AuthCheckYourEmailPath = "/check_your_email" + // AuthWaitForApprovalPath users land here after confirming their email + // but before an admin approves their account (if such is required) + AuthWaitForApprovalPath = "/wait_for_approval" + // AuthAccountDisabledPath users land here when their account is suspended by an admin + AuthAccountDisabledPath = "/account_disabled" + // AuthCallbackPath is the API path for receiving callback tokens from external OIDC providers + AuthCallbackPath = "/callback" + + /* + paths prefixed with 'oauth' + */ + + // OauthTokenPath is the API path to use for granting token requests to users with valid credentials + OauthTokenPath = "/token" // #nosec G101 else we get a hardcoded credentials warning + // OauthAuthorizePath is the API path for authorization requests (eg., authorize this app to act on my behalf as a user) + OauthAuthorizePath = "/authorize" + // OauthFinalizePath is the API path for completing user registration with additional user details + OauthFinalizePath = "/finalize" + // OauthOobTokenPath is the path for serving an html representation of an oob token page. + OauthOobTokenPath = "/oob" // #nosec G101 else we get a hardcoded credentials warning + + /* + params / session keys + */ + + callbackStateParam = "state" + callbackCodeParam = "code" + sessionUserID = "userid" + sessionClientID = "client_id" + sessionRedirectURI = "redirect_uri" + sessionForceLogin = "force_login" + sessionResponseType = "response_type" + sessionScope = "scope" + sessionInternalState = "internal_state" + sessionClientState = "client_state" + sessionClaims = "claims" + sessionAppID = "app_id" +) + +type Module struct { + db db.DB + processor processing.Processor + idp oidc.IDP +} + +// New returns an Auth module which provides both 'oauth' and 'auth' endpoints. +// +// It is safe to pass a nil idp if oidc is disabled. +func New(db db.DB, processor processing.Processor, idp oidc.IDP) *Module { + return &Module{ + db: db, + processor: processor, + idp: idp, + } +} + +// RouteAuth routes all paths that should have an 'auth' prefix +func (m *Module) RouteAuth(attachHandler func(method string, path string, f ...gin.HandlerFunc) gin.IRoutes) { + attachHandler(http.MethodGet, AuthSignInPath, m.SignInGETHandler) + attachHandler(http.MethodPost, AuthSignInPath, m.SignInPOSTHandler) + attachHandler(http.MethodGet, AuthCallbackPath, m.CallbackGETHandler) +} + +// RouteOauth routes all paths that should have an 'oauth' prefix +func (m *Module) RouteOauth(attachHandler func(method string, path string, f ...gin.HandlerFunc) gin.IRoutes) { + attachHandler(http.MethodPost, OauthTokenPath, m.TokenPOSTHandler) + attachHandler(http.MethodGet, OauthAuthorizePath, m.AuthorizeGETHandler) + attachHandler(http.MethodPost, OauthAuthorizePath, m.AuthorizePOSTHandler) + attachHandler(http.MethodPost, OauthFinalizePath, m.FinalizePOSTHandler) + attachHandler(http.MethodGet, OauthOobTokenPath, m.OobHandler) +} + +func (m *Module) clearSession(s sessions.Session) { + s.Clear() + if err := s.Save(); err != nil { + panic(err) + } +} diff --git a/internal/api/auth/auth_test.go b/internal/api/auth/auth_test.go new file mode 100644 index 000000000..cb92850d0 --- /dev/null +++ b/internal/api/auth/auth_test.go @@ -0,0 +1,129 @@ +/* + GoToSocial + Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package auth_test + +import ( + "bytes" + "fmt" + "net/http/httptest" + + "github.com/gin-contrib/sessions" + "github.com/gin-contrib/sessions/memstore" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/suite" + "github.com/superseriousbusiness/gotosocial/internal/api/auth" + "github.com/superseriousbusiness/gotosocial/internal/concurrency" + "github.com/superseriousbusiness/gotosocial/internal/config" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/email" + "github.com/superseriousbusiness/gotosocial/internal/federation" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/media" + "github.com/superseriousbusiness/gotosocial/internal/messages" + "github.com/superseriousbusiness/gotosocial/internal/middleware" + "github.com/superseriousbusiness/gotosocial/internal/oidc" + "github.com/superseriousbusiness/gotosocial/internal/processing" + "github.com/superseriousbusiness/gotosocial/internal/storage" + "github.com/superseriousbusiness/gotosocial/testrig" +) + +type AuthStandardTestSuite struct { + suite.Suite + db db.DB + storage *storage.Driver + mediaManager media.Manager + federator federation.Federator + processor processing.Processor + emailSender email.Sender + idp oidc.IDP + + // standard suite models + testTokens map[string]*gtsmodel.Token + testClients map[string]*gtsmodel.Client + testApplications map[string]*gtsmodel.Application + testUsers map[string]*gtsmodel.User + testAccounts map[string]*gtsmodel.Account + + // module being tested + authModule *auth.Module +} + +const ( + sessionUserID = "userid" + sessionClientID = "client_id" +) + +func (suite *AuthStandardTestSuite) SetupSuite() { + suite.testTokens = testrig.NewTestTokens() + suite.testClients = testrig.NewTestClients() + suite.testApplications = testrig.NewTestApplications() + suite.testUsers = testrig.NewTestUsers() + suite.testAccounts = testrig.NewTestAccounts() +} + +func (suite *AuthStandardTestSuite) SetupTest() { + testrig.InitTestConfig() + testrig.InitTestLog() + + fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) + clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) + + suite.db = testrig.NewTestDB() + suite.storage = testrig.NewInMemoryStorage() + suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage) + suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker) + suite.emailSender = testrig.NewEmailSender("../../../web/template/", nil) + suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker) + suite.authModule = auth.New(suite.db, suite.processor, suite.idp) + testrig.StandardDBSetup(suite.db, suite.testAccounts) +} + +func (suite *AuthStandardTestSuite) TearDownTest() { + testrig.StandardDBTeardown(suite.db) +} + +func (suite *AuthStandardTestSuite) newContext(requestMethod string, requestPath string, requestBody []byte, bodyContentType string) (*gin.Context, *httptest.ResponseRecorder) { + // create the recorder and gin test context + recorder := httptest.NewRecorder() + ctx, engine := testrig.CreateGinTestContext(recorder, nil) + + // load templates into the engine + testrig.ConfigureTemplatesWithGin(engine, "../../../web/template") + + // create the request + protocol := config.GetProtocol() + host := config.GetHost() + baseURI := fmt.Sprintf("%s://%s", protocol, host) + requestURI := fmt.Sprintf("%s/%s", baseURI, requestPath) + + ctx.Request = httptest.NewRequest(requestMethod, requestURI, bytes.NewReader(requestBody)) // the endpoint we're hitting + ctx.Request.Header.Set("accept", "text/html") + + if bodyContentType != "" { + ctx.Request.Header.Set("Content-Type", bodyContentType) + } + + // trigger the session middleware on the context + store := memstore.NewStore(make([]byte, 32), make([]byte, 32)) + store.Options(middleware.SessionOptions()) + sessionMiddleware := sessions.Sessions("gotosocial-localhost", store) + sessionMiddleware(ctx) + + return ctx, recorder +} diff --git a/internal/api/auth/authorize.go b/internal/api/auth/authorize.go new file mode 100644 index 000000000..e504f6be2 --- /dev/null +++ b/internal/api/auth/authorize.go @@ -0,0 +1,335 @@ +/* + GoToSocial + Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package auth + +import ( + "errors" + "fmt" + "net/http" + "net/url" + + "github.com/gin-contrib/sessions" + "github.com/gin-gonic/gin" + "github.com/google/uuid" + apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" + apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" + "github.com/superseriousbusiness/gotosocial/internal/config" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/oauth" +) + +// AuthorizeGETHandler should be served as GET at https://example.org/oauth/authorize +// The idea here is to present an oauth authorize page to the user, with a button +// that they have to click to accept. +func (m *Module) AuthorizeGETHandler(c *gin.Context) { + s := sessions.Default(c) + + if _, err := apiutil.NegotiateAccept(c, apiutil.HTMLAcceptHeaders...); err != nil { + apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet) + return + } + + // UserID will be set in the session by AuthorizePOSTHandler if the caller has already gone through the authentication flow + // If it's not set, then we don't know yet who the user is, so we need to redirect them to the sign in page. + userID, ok := s.Get(sessionUserID).(string) + if !ok || userID == "" { + form := &apimodel.OAuthAuthorize{} + if err := c.ShouldBind(form); err != nil { + m.clearSession(s) + apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, oauth.HelpfulAdvice), m.processor.InstanceGet) + return + } + + if errWithCode := saveAuthFormToSession(s, form); errWithCode != nil { + m.clearSession(s) + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet) + return + } + + c.Redirect(http.StatusSeeOther, "/auth"+AuthSignInPath) + return + } + + // use session information to validate app, user, and account for this request + clientID, ok := s.Get(sessionClientID).(string) + if !ok || clientID == "" { + m.clearSession(s) + err := fmt.Errorf("key %s was not found in session", sessionClientID) + apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, oauth.HelpfulAdvice), m.processor.InstanceGet) + return + } + + app := >smodel.Application{} + if err := m.db.GetWhere(c.Request.Context(), []db.Where{{Key: sessionClientID, Value: clientID}}, app); err != nil { + m.clearSession(s) + safe := fmt.Sprintf("application for %s %s could not be retrieved", sessionClientID, clientID) + var errWithCode gtserror.WithCode + if err == db.ErrNoEntries { + errWithCode = gtserror.NewErrorBadRequest(err, safe, oauth.HelpfulAdvice) + } else { + errWithCode = gtserror.NewErrorInternalError(err, safe, oauth.HelpfulAdvice) + } + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet) + return + } + + user, err := m.db.GetUserByID(c.Request.Context(), userID) + if err != nil { + m.clearSession(s) + safe := fmt.Sprintf("user with id %s could not be retrieved", userID) + var errWithCode gtserror.WithCode + if err == db.ErrNoEntries { + errWithCode = gtserror.NewErrorBadRequest(err, safe, oauth.HelpfulAdvice) + } else { + errWithCode = gtserror.NewErrorInternalError(err, safe, oauth.HelpfulAdvice) + } + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet) + return + } + + acct, err := m.db.GetAccountByID(c.Request.Context(), user.AccountID) + if err != nil { + m.clearSession(s) + safe := fmt.Sprintf("account with id %s could not be retrieved", user.AccountID) + var errWithCode gtserror.WithCode + if err == db.ErrNoEntries { + errWithCode = gtserror.NewErrorBadRequest(err, safe, oauth.HelpfulAdvice) + } else { + errWithCode = gtserror.NewErrorInternalError(err, safe, oauth.HelpfulAdvice) + } + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet) + return + } + + if ensureUserIsAuthorizedOrRedirect(c, user, acct) { + return + } + + // Finally we should also get the redirect and scope of this particular request, as stored in the session. + redirect, ok := s.Get(sessionRedirectURI).(string) + if !ok || redirect == "" { + m.clearSession(s) + err := fmt.Errorf("key %s was not found in session", sessionRedirectURI) + apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, oauth.HelpfulAdvice), m.processor.InstanceGet) + return + } + + scope, ok := s.Get(sessionScope).(string) + if !ok || scope == "" { + m.clearSession(s) + err := fmt.Errorf("key %s was not found in session", sessionScope) + apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, oauth.HelpfulAdvice), m.processor.InstanceGet) + return + } + + instance, errWithCode := m.processor.InstanceGet(c.Request.Context(), config.GetHost()) + if errWithCode != nil { + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet) + return + } + + // the authorize template will display a form to the user where they can get some information + // about the app that's trying to authorize, and the scope of the request. + // They can then approve it if it looks OK to them, which will POST to the AuthorizePOSTHandler + c.HTML(http.StatusOK, "authorize.tmpl", gin.H{ + "appname": app.Name, + "appwebsite": app.Website, + "redirect": redirect, + "scope": scope, + "user": acct.Username, + "instance": instance, + }) +} + +// AuthorizePOSTHandler should be served as POST at https://example.org/oauth/authorize +// At this point we assume that the user has A) logged in and B) accepted that the app should act for them, +// so we should proceed with the authentication flow and generate an oauth token for them if we can. +func (m *Module) AuthorizePOSTHandler(c *gin.Context) { + s := sessions.Default(c) + + // We need to retrieve the original form submitted to the authorizeGEThandler, and + // recreate it on the request so that it can be used further by the oauth2 library. + errs := []string{} + + forceLogin, ok := s.Get(sessionForceLogin).(string) + if !ok { + forceLogin = "false" + } + + responseType, ok := s.Get(sessionResponseType).(string) + if !ok || responseType == "" { + errs = append(errs, fmt.Sprintf("key %s was not found in session", sessionResponseType)) + } + + clientID, ok := s.Get(sessionClientID).(string) + if !ok || clientID == "" { + errs = append(errs, fmt.Sprintf("key %s was not found in session", sessionClientID)) + } + + redirectURI, ok := s.Get(sessionRedirectURI).(string) + if !ok || redirectURI == "" { + errs = append(errs, fmt.Sprintf("key %s was not found in session", sessionRedirectURI)) + } + + scope, ok := s.Get(sessionScope).(string) + if !ok { + errs = append(errs, fmt.Sprintf("key %s was not found in session", sessionScope)) + } + + var clientState string + if s, ok := s.Get(sessionClientState).(string); ok { + clientState = s + } + + userID, ok := s.Get(sessionUserID).(string) + if !ok { + errs = append(errs, fmt.Sprintf("key %s was not found in session", sessionUserID)) + } + + if len(errs) != 0 { + errs = append(errs, oauth.HelpfulAdvice) + apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(errors.New("one or more missing keys on session during AuthorizePOSTHandler"), errs...), m.processor.InstanceGet) + return + } + + user, err := m.db.GetUserByID(c.Request.Context(), userID) + if err != nil { + m.clearSession(s) + safe := fmt.Sprintf("user with id %s could not be retrieved", userID) + var errWithCode gtserror.WithCode + if err == db.ErrNoEntries { + errWithCode = gtserror.NewErrorBadRequest(err, safe, oauth.HelpfulAdvice) + } else { + errWithCode = gtserror.NewErrorInternalError(err, safe, oauth.HelpfulAdvice) + } + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet) + return + } + + acct, err := m.db.GetAccountByID(c.Request.Context(), user.AccountID) + if err != nil { + m.clearSession(s) + safe := fmt.Sprintf("account with id %s could not be retrieved", user.AccountID) + var errWithCode gtserror.WithCode + if err == db.ErrNoEntries { + errWithCode = gtserror.NewErrorBadRequest(err, safe, oauth.HelpfulAdvice) + } else { + errWithCode = gtserror.NewErrorInternalError(err, safe, oauth.HelpfulAdvice) + } + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet) + return + } + + if ensureUserIsAuthorizedOrRedirect(c, user, acct) { + return + } + + if redirectURI != oauth.OOBURI { + // we're done with the session now, so just clear it out + m.clearSession(s) + } + + // we have to set the values on the request form + // so that they're picked up by the oauth server + c.Request.Form = url.Values{ + sessionForceLogin: {forceLogin}, + sessionResponseType: {responseType}, + sessionClientID: {clientID}, + sessionRedirectURI: {redirectURI}, + sessionScope: {scope}, + sessionUserID: {userID}, + } + + if clientState != "" { + c.Request.Form.Set("state", clientState) + } + + if errWithCode := m.processor.OAuthHandleAuthorizeRequest(c.Writer, c.Request); errWithCode != nil { + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet) + } +} + +// saveAuthFormToSession checks the given OAuthAuthorize form, +// and stores the values in the form into the session. +func saveAuthFormToSession(s sessions.Session, form *apimodel.OAuthAuthorize) gtserror.WithCode { + if form == nil { + err := errors.New("OAuthAuthorize form was nil") + return gtserror.NewErrorBadRequest(err, err.Error(), oauth.HelpfulAdvice) + } + + if form.ResponseType == "" { + err := errors.New("field response_type was not set on OAuthAuthorize form") + return gtserror.NewErrorBadRequest(err, err.Error(), oauth.HelpfulAdvice) + } + + if form.ClientID == "" { + err := errors.New("field client_id was not set on OAuthAuthorize form") + return gtserror.NewErrorBadRequest(err, err.Error(), oauth.HelpfulAdvice) + } + + if form.RedirectURI == "" { + err := errors.New("field redirect_uri was not set on OAuthAuthorize form") + return gtserror.NewErrorBadRequest(err, err.Error(), oauth.HelpfulAdvice) + } + + // set default scope to read + if form.Scope == "" { + form.Scope = "read" + } + + // save these values from the form so we can use them elsewhere in the session + s.Set(sessionForceLogin, form.ForceLogin) + s.Set(sessionResponseType, form.ResponseType) + s.Set(sessionClientID, form.ClientID) + s.Set(sessionRedirectURI, form.RedirectURI) + s.Set(sessionScope, form.Scope) + s.Set(sessionInternalState, uuid.NewString()) + s.Set(sessionClientState, form.State) + + if err := s.Save(); err != nil { + err := fmt.Errorf("error saving form values onto session: %s", err) + return gtserror.NewErrorInternalError(err, oauth.HelpfulAdvice) + } + + return nil +} + +func ensureUserIsAuthorizedOrRedirect(ctx *gin.Context, user *gtsmodel.User, account *gtsmodel.Account) (redirected bool) { + if user.ConfirmedAt.IsZero() { + ctx.Redirect(http.StatusSeeOther, "/auth"+AuthCheckYourEmailPath) + redirected = true + return + } + + if !*user.Approved { + ctx.Redirect(http.StatusSeeOther, "/auth"+AuthWaitForApprovalPath) + redirected = true + return + } + + if *user.Disabled || !account.SuspendedAt.IsZero() { + ctx.Redirect(http.StatusSeeOther, "/auth"+AuthAccountDisabledPath) + redirected = true + return + } + + return +} diff --git a/internal/api/auth/authorize_test.go b/internal/api/auth/authorize_test.go new file mode 100644 index 000000000..ff65d041b --- /dev/null +++ b/internal/api/auth/authorize_test.go @@ -0,0 +1,118 @@ +package auth_test + +import ( + "context" + "fmt" + "net/http" + "testing" + "time" + + "github.com/gin-contrib/sessions" + "github.com/stretchr/testify/suite" + "github.com/superseriousbusiness/gotosocial/internal/api/auth" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/testrig" +) + +type AuthAuthorizeTestSuite struct { + AuthStandardTestSuite +} + +type authorizeHandlerTestCase struct { + description string + mutateUserAccount func(*gtsmodel.User, *gtsmodel.Account) []string + expectedStatusCode int + expectedLocationHeader string +} + +func (suite *AuthAuthorizeTestSuite) TestAccountAuthorizeHandler() { + tests := []authorizeHandlerTestCase{ + { + description: "user has their email unconfirmed", + mutateUserAccount: func(user *gtsmodel.User, account *gtsmodel.Account) []string { + user.ConfirmedAt = time.Time{} + return []string{"confirmed_at"} + }, + expectedStatusCode: http.StatusSeeOther, + expectedLocationHeader: "/auth" + auth.AuthCheckYourEmailPath, + }, + { + description: "user has their email confirmed but is not approved", + mutateUserAccount: func(user *gtsmodel.User, account *gtsmodel.Account) []string { + user.ConfirmedAt = time.Now() + user.Email = user.UnconfirmedEmail + return []string{"confirmed_at", "email"} + }, + expectedStatusCode: http.StatusSeeOther, + expectedLocationHeader: "/auth" + auth.AuthWaitForApprovalPath, + }, + { + description: "user has their email confirmed and is approved, but User entity has been disabled", + mutateUserAccount: func(user *gtsmodel.User, account *gtsmodel.Account) []string { + user.ConfirmedAt = time.Now() + user.Email = user.UnconfirmedEmail + user.Approved = testrig.TrueBool() + user.Disabled = testrig.TrueBool() + return []string{"confirmed_at", "email", "approved", "disabled"} + }, + expectedStatusCode: http.StatusSeeOther, + expectedLocationHeader: "/auth" + auth.AuthAccountDisabledPath, + }, + { + description: "user has their email confirmed and is approved, but Account entity has been suspended", + mutateUserAccount: func(user *gtsmodel.User, account *gtsmodel.Account) []string { + user.ConfirmedAt = time.Now() + user.Email = user.UnconfirmedEmail + user.Approved = testrig.TrueBool() + user.Disabled = testrig.FalseBool() + account.SuspendedAt = time.Now() + return []string{"confirmed_at", "email", "approved", "disabled"} + }, + expectedStatusCode: http.StatusSeeOther, + expectedLocationHeader: "/auth" + auth.AuthAccountDisabledPath, + }, + } + + doTest := func(testCase authorizeHandlerTestCase) { + ctx, recorder := suite.newContext(http.MethodGet, auth.OauthAuthorizePath, nil, "") + + user := >smodel.User{} + account := >smodel.Account{} + + *user = *suite.testUsers["unconfirmed_account"] + *account = *suite.testAccounts["unconfirmed_account"] + + testSession := sessions.Default(ctx) + testSession.Set(sessionUserID, user.ID) + testSession.Set(sessionClientID, suite.testApplications["application_1"].ClientID) + if err := testSession.Save(); err != nil { + panic(fmt.Errorf("failed on case %s: %w", testCase.description, err)) + } + + columns := testCase.mutateUserAccount(user, account) + + testCase.description = fmt.Sprintf("%s, %t, %s", user.Email, *user.Disabled, account.SuspendedAt) + + err := suite.db.UpdateUser(context.Background(), user, columns...) + suite.NoError(err) + err = suite.db.UpdateAccount(context.Background(), account) + suite.NoError(err) + + // call the handler + suite.authModule.AuthorizeGETHandler(ctx) + + // 1. we should have a redirect + suite.Equal(testCase.expectedStatusCode, recorder.Code, fmt.Sprintf("failed on case: %s", testCase.description)) + + // 2. we should have a redirect to the check your email path, as this user has not confirmed their email yet. + suite.Equal(testCase.expectedLocationHeader, recorder.Header().Get("Location"), fmt.Sprintf("failed on case: %s", testCase.description)) + } + + for _, testCase := range tests { + doTest(testCase) + } +} + +func TestAccountUpdateTestSuite(t *testing.T) { + suite.Run(t, new(AuthAuthorizeTestSuite)) +} diff --git a/internal/api/auth/callback.go b/internal/api/auth/callback.go new file mode 100644 index 000000000..d344b5d5f --- /dev/null +++ b/internal/api/auth/callback.go @@ -0,0 +1,317 @@ +/* + GoToSocial + Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package auth + +import ( + "context" + "errors" + "fmt" + "net" + "net/http" + "strings" + + "github.com/gin-contrib/sessions" + "github.com/gin-gonic/gin" + "github.com/google/uuid" + apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" + "github.com/superseriousbusiness/gotosocial/internal/config" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/oauth" + "github.com/superseriousbusiness/gotosocial/internal/oidc" + "github.com/superseriousbusiness/gotosocial/internal/validate" +) + +// extraInfo wraps a form-submitted username and transmitted name +type extraInfo struct { + Username string `form:"username"` + Name string `form:"name"` // note that this is only used for re-rendering the page in case of an error +} + +// CallbackGETHandler parses a token from an external auth provider. +func (m *Module) CallbackGETHandler(c *gin.Context) { + if !config.GetOIDCEnabled() { + err := errors.New("oidc is not enabled for this server") + apiutil.ErrorHandler(c, gtserror.NewErrorNotFound(err, err.Error()), m.processor.InstanceGet) + return + } + + s := sessions.Default(c) + + // check the query vs session state parameter to mitigate csrf + // https://auth0.com/docs/secure/attack-protection/state-parameters + + returnedInternalState := c.Query(callbackStateParam) + if returnedInternalState == "" { + m.clearSession(s) + err := fmt.Errorf("%s parameter not found on callback query", callbackStateParam) + apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) + return + } + + savedInternalStateI := s.Get(sessionInternalState) + savedInternalState, ok := savedInternalStateI.(string) + if !ok { + m.clearSession(s) + err := fmt.Errorf("key %s was not found in session", sessionInternalState) + apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) + return + } + + if returnedInternalState != savedInternalState { + m.clearSession(s) + err := errors.New("mismatch between callback state and saved state") + apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet) + return + } + + // retrieve stored claims using code + code := c.Query(callbackCodeParam) + if code == "" { + m.clearSession(s) + err := fmt.Errorf("%s parameter not found on callback query", callbackCodeParam) + apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) + return + } + + claims, errWithCode := m.idp.HandleCallback(c.Request.Context(), code) + if errWithCode != nil { + m.clearSession(s) + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet) + return + } + + // We can use the client_id on the session to retrieve + // info about the app associated with the client_id + clientID, ok := s.Get(sessionClientID).(string) + if !ok || clientID == "" { + m.clearSession(s) + err := fmt.Errorf("key %s was not found in session", sessionClientID) + apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, oauth.HelpfulAdvice), m.processor.InstanceGet) + return + } + + app := >smodel.Application{} + if err := m.db.GetWhere(c.Request.Context(), []db.Where{{Key: sessionClientID, Value: clientID}}, app); err != nil { + m.clearSession(s) + safe := fmt.Sprintf("application for %s %s could not be retrieved", sessionClientID, clientID) + var errWithCode gtserror.WithCode + if err == db.ErrNoEntries { + errWithCode = gtserror.NewErrorBadRequest(err, safe, oauth.HelpfulAdvice) + } else { + errWithCode = gtserror.NewErrorInternalError(err, safe, oauth.HelpfulAdvice) + } + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet) + return + } + + user, errWithCode := m.fetchUserForClaims(c.Request.Context(), claims, net.IP(c.ClientIP()), app.ID) + if errWithCode != nil { + m.clearSession(s) + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet) + return + } + if user == nil { + // no user exists yet - let's ask them for their preferred username + instance, errWithCode := m.processor.InstanceGet(c.Request.Context(), config.GetHost()) + if errWithCode != nil { + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet) + return + } + + // store the claims in the session - that way we know the user is authenticated when processing the form later + s.Set(sessionClaims, claims) + s.Set(sessionAppID, app.ID) + if err := s.Save(); err != nil { + m.clearSession(s) + apiutil.ErrorHandler(c, gtserror.NewErrorInternalError(err), m.processor.InstanceGet) + return + } + c.HTML(http.StatusOK, "finalize.tmpl", gin.H{ + "instance": instance, + "name": claims.Name, + "preferredUsername": claims.PreferredUsername, + }) + return + } + s.Set(sessionUserID, user.ID) + if err := s.Save(); err != nil { + m.clearSession(s) + apiutil.ErrorHandler(c, gtserror.NewErrorInternalError(err), m.processor.InstanceGet) + return + } + c.Redirect(http.StatusFound, "/oauth"+OauthAuthorizePath) +} + +// FinalizePOSTHandler registers the user after additional data has been provided +func (m *Module) FinalizePOSTHandler(c *gin.Context) { + s := sessions.Default(c) + + form := &extraInfo{} + if err := c.ShouldBind(form); err != nil { + m.clearSession(s) + apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, oauth.HelpfulAdvice), m.processor.InstanceGet) + return + } + + // since we have multiple possible validation error, `validationError` is a shorthand for rendering them + validationError := func(err error) { + instance, errWithCode := m.processor.InstanceGet(c.Request.Context(), config.GetHost()) + if errWithCode != nil { + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet) + return + } + c.HTML(http.StatusOK, "finalize.tmpl", gin.H{ + "instance": instance, + "name": form.Name, + "preferredUsername": form.Username, + "error": err, + }) + } + + // check if the username conforms to the spec + if err := validate.Username(form.Username); err != nil { + validationError(err) + return + } + + // see if the username is still available + usernameAvailable, err := m.db.IsUsernameAvailable(c.Request.Context(), form.Username) + if err != nil { + apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, oauth.HelpfulAdvice), m.processor.InstanceGet) + return + } + if !usernameAvailable { + validationError(fmt.Errorf("Username %s is already taken", form.Username)) + return + } + + // retrieve the information previously set by the oidc logic + appID, ok := s.Get(sessionAppID).(string) + if !ok { + err := fmt.Errorf("key %s was not found in session", sessionAppID) + apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, oauth.HelpfulAdvice), m.processor.InstanceGet) + return + } + + // retrieve the claims returned by the IDP. Having this present means that we previously already verified these claims + claims, ok := s.Get(sessionClaims).(*oidc.Claims) + if !ok { + err := fmt.Errorf("key %s was not found in session", sessionClaims) + apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, oauth.HelpfulAdvice), m.processor.InstanceGet) + return + } + + // we're now ready to actually create the user + user, errWithCode := m.createUserFromOIDC(c.Request.Context(), claims, form, net.IP(c.ClientIP()), appID) + if errWithCode != nil { + m.clearSession(s) + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet) + return + } + s.Delete(sessionClaims) + s.Delete(sessionAppID) + s.Set(sessionUserID, user.ID) + if err := s.Save(); err != nil { + m.clearSession(s) + apiutil.ErrorHandler(c, gtserror.NewErrorInternalError(err), m.processor.InstanceGet) + return + } + c.Redirect(http.StatusFound, "/oauth"+OauthAuthorizePath) +} + +func (m *Module) fetchUserForClaims(ctx context.Context, claims *oidc.Claims, ip net.IP, appID string) (*gtsmodel.User, gtserror.WithCode) { + if claims.Sub == "" { + err := errors.New("no sub claim found - is your provider OIDC compliant?") + return nil, gtserror.NewErrorBadRequest(err, err.Error()) + } + user, err := m.db.GetUserByExternalID(ctx, claims.Sub) + if err == nil { + return user, nil + } + if err != db.ErrNoEntries { + err := fmt.Errorf("error checking database for externalID %s: %s", claims.Sub, err) + return nil, gtserror.NewErrorInternalError(err) + } + if !config.GetOIDCLinkExisting() { + return nil, nil + } + // fallback to email if we want to link existing users + user, err = m.db.GetUserByEmailAddress(ctx, claims.Email) + if err == db.ErrNoEntries { + return nil, nil + } else if err != nil { + err := fmt.Errorf("error checking database for email %s: %s", claims.Email, err) + return nil, gtserror.NewErrorInternalError(err) + } + // at this point we have found a matching user but still need to link the newly received external ID + + user.ExternalID = claims.Sub + err = m.db.UpdateUser(ctx, user, "external_id") + if err != nil { + err := fmt.Errorf("error linking existing user %s: %s", claims.Email, err) + return nil, gtserror.NewErrorInternalError(err) + } + return user, nil +} + +func (m *Module) createUserFromOIDC(ctx context.Context, claims *oidc.Claims, extraInfo *extraInfo, ip net.IP, appID string) (*gtsmodel.User, gtserror.WithCode) { + // check if the email address is available for use; if it's not there's nothing we can so + emailAvailable, err := m.db.IsEmailAvailable(ctx, claims.Email) + if err != nil { + return nil, gtserror.NewErrorBadRequest(err) + } + if !emailAvailable { + help := "The email address given to us by your authentication provider already exists in our records and the server administrator has not enabled account migration" + return nil, gtserror.NewErrorConflict(fmt.Errorf("email address %s is not available", claims.Email), help) + } + + // check if the user is in any recognised admin groups + var admin bool + for _, g := range claims.Groups { + if strings.EqualFold(g, "admin") || strings.EqualFold(g, "admins") { + admin = true + } + } + + // We still need to set *a* password even if it's not a password the user will end up using, so set something random. + // We'll just set two uuids on top of each other, which should be long + random enough to baffle any attempts to crack. + // + // If the user ever wants to log in using gts password rather than oidc flow, they'll have to request a password reset, which is fine + password := uuid.NewString() + uuid.NewString() + + // Since this user is created via oidc, which has been set up by the admin, we can assume that the account is already + // implicitly approved, and that the email address has already been verified: otherwise, we end up in situations where + // the admin first approves the user in OIDC, and then has to approve them again in GoToSocial, which doesn't make sense. + // + // In other words, if a user logs in via OIDC, they should be able to use their account straight away. + // + // See: https://github.com/superseriousbusiness/gotosocial/issues/357 + requireApproval := false + emailVerified := true + + // create the user! this will also create an account and store it in the database so we don't need to do that here + user, err := m.db.NewSignup(ctx, extraInfo.Username, "", requireApproval, claims.Email, password, ip, "", appID, emailVerified, claims.Sub, admin) + if err != nil { + return nil, gtserror.NewErrorInternalError(err) + } + + return user, nil +} diff --git a/internal/api/auth/oob.go b/internal/api/auth/oob.go new file mode 100644 index 000000000..97f9c0f8c --- /dev/null +++ b/internal/api/auth/oob.go @@ -0,0 +1,113 @@ +/* + GoToSocial + Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package auth + +import ( + "context" + "errors" + "fmt" + "net/http" + + "github.com/gin-contrib/sessions" + "github.com/gin-gonic/gin" + apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" + apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" + "github.com/superseriousbusiness/gotosocial/internal/config" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/oauth" +) + +func (m *Module) OobHandler(c *gin.Context) { + host := config.GetHost() + instance, errWithCode := m.processor.InstanceGet(c.Request.Context(), host) + if errWithCode != nil { + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet) + return + } + + instanceGet := func(ctx context.Context, domain string) (*apimodel.Instance, gtserror.WithCode) { + return instance, nil + } + + oobToken := c.Query("code") + if oobToken == "" { + err := errors.New("no 'code' query value provided in callback redirect") + apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error(), oauth.HelpfulAdvice), instanceGet) + return + } + + s := sessions.Default(c) + + errs := []string{} + + scope, ok := s.Get(sessionScope).(string) + if !ok { + errs = append(errs, fmt.Sprintf("key %s was not found in session", sessionScope)) + } + + userID, ok := s.Get(sessionUserID).(string) + if !ok { + errs = append(errs, fmt.Sprintf("key %s was not found in session", sessionUserID)) + } + + if len(errs) != 0 { + errs = append(errs, oauth.HelpfulAdvice) + apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(errors.New("one or more missing keys on session during OobHandler"), errs...), m.processor.InstanceGet) + return + } + + user, err := m.db.GetUserByID(c.Request.Context(), userID) + if err != nil { + m.clearSession(s) + safe := fmt.Sprintf("user with id %s could not be retrieved", userID) + var errWithCode gtserror.WithCode + if err == db.ErrNoEntries { + errWithCode = gtserror.NewErrorBadRequest(err, safe, oauth.HelpfulAdvice) + } else { + errWithCode = gtserror.NewErrorInternalError(err, safe, oauth.HelpfulAdvice) + } + apiutil.ErrorHandler(c, errWithCode, instanceGet) + return + } + + acct, err := m.db.GetAccountByID(c.Request.Context(), user.AccountID) + if err != nil { + m.clearSession(s) + safe := fmt.Sprintf("account with id %s could not be retrieved", user.AccountID) + var errWithCode gtserror.WithCode + if err == db.ErrNoEntries { + errWithCode = gtserror.NewErrorBadRequest(err, safe, oauth.HelpfulAdvice) + } else { + errWithCode = gtserror.NewErrorInternalError(err, safe, oauth.HelpfulAdvice) + } + apiutil.ErrorHandler(c, errWithCode, instanceGet) + return + } + + // we're done with the session now, so just clear it out + m.clearSession(s) + + c.HTML(http.StatusOK, "oob.tmpl", gin.H{ + "instance": instance, + "user": acct.Username, + "oobToken": oobToken, + "scope": scope, + }) +} diff --git a/internal/api/auth/signin.go b/internal/api/auth/signin.go new file mode 100644 index 000000000..bae33a43b --- /dev/null +++ b/internal/api/auth/signin.go @@ -0,0 +1,145 @@ +/* + GoToSocial + Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package auth + +import ( + "context" + "errors" + "fmt" + "net/http" + + "github.com/gin-contrib/sessions" + "github.com/gin-gonic/gin" + apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" + "github.com/superseriousbusiness/gotosocial/internal/config" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/oauth" + "golang.org/x/crypto/bcrypt" +) + +// login just wraps a form-submitted username (we want an email) and password +type login struct { + Email string `form:"username"` + Password string `form:"password"` +} + +// SignInGETHandler should be served at https://example.org/auth/sign_in. +// The idea is to present a sign in page to the user, where they can enter their username and password. +// The form will then POST to the sign in page, which will be handled by SignInPOSTHandler. +// If an idp provider is set, then the user will be redirected to that to do their sign in. +func (m *Module) SignInGETHandler(c *gin.Context) { + if _, err := apiutil.NegotiateAccept(c, apiutil.HTMLAcceptHeaders...); err != nil { + apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet) + return + } + + if !config.GetOIDCEnabled() { + instance, errWithCode := m.processor.InstanceGet(c.Request.Context(), config.GetHost()) + if errWithCode != nil { + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet) + return + } + + // no idp provider, use our own funky little sign in page + c.HTML(http.StatusOK, "sign-in.tmpl", gin.H{ + "instance": instance, + }) + return + } + + // idp provider is in use, so redirect to it + s := sessions.Default(c) + + internalStateI := s.Get(sessionInternalState) + internalState, ok := internalStateI.(string) + if !ok { + m.clearSession(s) + err := fmt.Errorf("key %s was not found in session", sessionInternalState) + apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) + return + } + + c.Redirect(http.StatusSeeOther, m.idp.AuthCodeURL(internalState)) +} + +// SignInPOSTHandler should be served at https://example.org/auth/sign_in. +// The idea is to present a sign in page to the user, where they can enter their username and password. +// The handler will then redirect to the auth handler served at /auth +func (m *Module) SignInPOSTHandler(c *gin.Context) { + s := sessions.Default(c) + + form := &login{} + if err := c.ShouldBind(form); err != nil { + m.clearSession(s) + apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, oauth.HelpfulAdvice), m.processor.InstanceGet) + return + } + + userid, errWithCode := m.ValidatePassword(c.Request.Context(), form.Email, form.Password) + if errWithCode != nil { + // don't clear session here, so the user can just press back and try again + // if they accidentally gave the wrong password or something + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet) + return + } + + s.Set(sessionUserID, userid) + if err := s.Save(); err != nil { + err := fmt.Errorf("error saving user id onto session: %s", err) + apiutil.ErrorHandler(c, gtserror.NewErrorInternalError(err, oauth.HelpfulAdvice), m.processor.InstanceGet) + } + + c.Redirect(http.StatusFound, "/oauth"+OauthAuthorizePath) +} + +// ValidatePassword takes an email address and a password. +// The goal is to authenticate the password against the one for that email +// address stored in the database. If OK, we return the userid (a ulid) for that user, +// so that it can be used in further Oauth flows to generate a token/retreieve an oauth client from the db. +func (m *Module) ValidatePassword(ctx context.Context, email string, password string) (string, gtserror.WithCode) { + if email == "" || password == "" { + err := errors.New("email or password was not provided") + return incorrectPassword(err) + } + + user, err := m.db.GetUserByEmailAddress(ctx, email) + if err != nil { + err := fmt.Errorf("user %s was not retrievable from db during oauth authorization attempt: %s", email, err) + return incorrectPassword(err) + } + + if user.EncryptedPassword == "" { + err := fmt.Errorf("encrypted password for user %s was empty for some reason", user.Email) + return incorrectPassword(err) + } + + if err := bcrypt.CompareHashAndPassword([]byte(user.EncryptedPassword), []byte(password)); err != nil { + err := fmt.Errorf("password hash didn't match for user %s during login attempt: %s", user.Email, err) + return incorrectPassword(err) + } + + return user.ID, nil +} + +// incorrectPassword wraps the given error in a gtserror.WithCode, and returns +// only a generic 'safe' error message to the user, to not give any info away. +func incorrectPassword(err error) (string, gtserror.WithCode) { + safeErr := fmt.Errorf("password/email combination was incorrect") + return "", gtserror.NewErrorUnauthorized(err, safeErr.Error(), oauth.HelpfulAdvice) +} diff --git a/internal/api/auth/token.go b/internal/api/auth/token.go new file mode 100644 index 000000000..17c4d8d8b --- /dev/null +++ b/internal/api/auth/token.go @@ -0,0 +1,115 @@ +/* + GoToSocial + Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package auth + +import ( + "net/http" + "net/url" + + apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/oauth" + + "github.com/gin-gonic/gin" +) + +type tokenRequestForm struct { + GrantType *string `form:"grant_type" json:"grant_type" xml:"grant_type"` + Code *string `form:"code" json:"code" xml:"code"` + RedirectURI *string `form:"redirect_uri" json:"redirect_uri" xml:"redirect_uri"` + ClientID *string `form:"client_id" json:"client_id" xml:"client_id"` + ClientSecret *string `form:"client_secret" json:"client_secret" xml:"client_secret"` + Scope *string `form:"scope" json:"scope" xml:"scope"` +} + +// TokenPOSTHandler should be served as a POST at https://example.org/oauth/token +// The idea here is to serve an oauth access token to a user, which can be used for authorizing against non-public APIs. +func (m *Module) TokenPOSTHandler(c *gin.Context) { + if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil { + apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGet) + return + } + + help := []string{} + + form := &tokenRequestForm{} + if err := c.ShouldBind(form); err != nil { + apiutil.OAuthErrorHandler(c, gtserror.NewErrorBadRequest(oauth.InvalidRequest(), err.Error())) + return + } + + c.Request.Form = url.Values{} + + var grantType string + if form.GrantType != nil { + grantType = *form.GrantType + c.Request.Form.Set("grant_type", grantType) + } else { + help = append(help, "grant_type was not set in the token request form, but must be set to authorization_code or client_credentials") + } + + if form.ClientID != nil { + c.Request.Form.Set("client_id", *form.ClientID) + } else { + help = append(help, "client_id was not set in the token request form") + } + + if form.ClientSecret != nil { + c.Request.Form.Set("client_secret", *form.ClientSecret) + } else { + help = append(help, "client_secret was not set in the token request form") + } + + if form.RedirectURI != nil { + c.Request.Form.Set("redirect_uri", *form.RedirectURI) + } else { + help = append(help, "redirect_uri was not set in the token request form") + } + + var code string + if form.Code != nil { + if grantType != "authorization_code" { + help = append(help, "a code was provided in the token request form, but grant_type was not set to authorization_code") + } else { + code = *form.Code + c.Request.Form.Set("code", code) + } + } else if grantType == "authorization_code" { + help = append(help, "code was not set in the token request form, but must be set since grant_type is authorization_code") + } + + if form.Scope != nil { + c.Request.Form.Set("scope", *form.Scope) + } + + if len(help) != 0 { + apiutil.OAuthErrorHandler(c, gtserror.NewErrorBadRequest(oauth.InvalidRequest(), help...)) + return + } + + token, errWithCode := m.processor.OAuthHandleTokenRequest(c.Request) + if errWithCode != nil { + apiutil.OAuthErrorHandler(c, errWithCode) + return + } + + c.Header("Cache-Control", "no-store") + c.Header("Pragma", "no-cache") + c.JSON(http.StatusOK, token) +} diff --git a/internal/api/auth/token_test.go b/internal/api/auth/token_test.go new file mode 100644 index 000000000..50bbd6918 --- /dev/null +++ b/internal/api/auth/token_test.go @@ -0,0 +1,215 @@ +/* + GoToSocial + Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package auth_test + +import ( + "context" + "encoding/json" + "io/ioutil" + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/suite" + apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/testrig" +) + +type TokenTestSuite struct { + AuthStandardTestSuite +} + +func (suite *TokenTestSuite) TestPOSTTokenEmptyForm() { + ctx, recorder := suite.newContext(http.MethodPost, "oauth/token", []byte{}, "") + ctx.Request.Header.Set("accept", "application/json") + + suite.authModule.TokenPOSTHandler(ctx) + + suite.Equal(http.StatusBadRequest, recorder.Code) + + result := recorder.Result() + defer result.Body.Close() + + b, err := ioutil.ReadAll(result.Body) + suite.NoError(err) + + suite.Equal(`{"error":"invalid_request","error_description":"Bad Request: grant_type was not set in the token request form, but must be set to authorization_code or client_credentials: client_id was not set in the token request form: client_secret was not set in the token request form: redirect_uri was not set in the token request form"}`, string(b)) +} + +func (suite *TokenTestSuite) TestRetrieveClientCredentialsOK() { + testClient := suite.testClients["local_account_1"] + + requestBody, w, err := testrig.CreateMultipartFormData( + "", "", + map[string]string{ + "grant_type": "client_credentials", + "client_id": testClient.ID, + "client_secret": testClient.Secret, + "redirect_uri": "http://localhost:8080", + }) + if err != nil { + panic(err) + } + bodyBytes := requestBody.Bytes() + + ctx, recorder := suite.newContext(http.MethodPost, "oauth/token", bodyBytes, w.FormDataContentType()) + ctx.Request.Header.Set("accept", "application/json") + + suite.authModule.TokenPOSTHandler(ctx) + + suite.Equal(http.StatusOK, recorder.Code) + + result := recorder.Result() + defer result.Body.Close() + + b, err := ioutil.ReadAll(result.Body) + suite.NoError(err) + + t := &apimodel.Token{} + err = json.Unmarshal(b, t) + suite.NoError(err) + + suite.Equal("Bearer", t.TokenType) + suite.NotEmpty(t.AccessToken) + suite.NotEmpty(t.CreatedAt) + suite.WithinDuration(time.Now(), time.Unix(t.CreatedAt, 0), 1*time.Minute) + + // there should be a token in the database now too + dbToken := >smodel.Token{} + err = suite.db.GetWhere(context.Background(), []db.Where{{Key: "access", Value: t.AccessToken}}, dbToken) + suite.NoError(err) + suite.NotNil(dbToken) +} + +func (suite *TokenTestSuite) TestRetrieveAuthorizationCodeOK() { + testClient := suite.testClients["local_account_1"] + testUserAuthorizationToken := suite.testTokens["local_account_1_user_authorization_token"] + + requestBody, w, err := testrig.CreateMultipartFormData( + "", "", + map[string]string{ + "grant_type": "authorization_code", + "client_id": testClient.ID, + "client_secret": testClient.Secret, + "redirect_uri": "http://localhost:8080", + "code": testUserAuthorizationToken.Code, + }) + if err != nil { + panic(err) + } + bodyBytes := requestBody.Bytes() + + ctx, recorder := suite.newContext(http.MethodPost, "oauth/token", bodyBytes, w.FormDataContentType()) + ctx.Request.Header.Set("accept", "application/json") + + suite.authModule.TokenPOSTHandler(ctx) + + suite.Equal(http.StatusOK, recorder.Code) + + result := recorder.Result() + defer result.Body.Close() + + b, err := ioutil.ReadAll(result.Body) + suite.NoError(err) + + t := &apimodel.Token{} + err = json.Unmarshal(b, t) + suite.NoError(err) + + suite.Equal("Bearer", t.TokenType) + suite.NotEmpty(t.AccessToken) + suite.NotEmpty(t.CreatedAt) + suite.WithinDuration(time.Now(), time.Unix(t.CreatedAt, 0), 1*time.Minute) + + dbToken := >smodel.Token{} + err = suite.db.GetWhere(context.Background(), []db.Where{{Key: "access", Value: t.AccessToken}}, dbToken) + suite.NoError(err) + suite.NotNil(dbToken) +} + +func (suite *TokenTestSuite) TestRetrieveAuthorizationCodeNoCode() { + testClient := suite.testClients["local_account_1"] + + requestBody, w, err := testrig.CreateMultipartFormData( + "", "", + map[string]string{ + "grant_type": "authorization_code", + "client_id": testClient.ID, + "client_secret": testClient.Secret, + "redirect_uri": "http://localhost:8080", + }) + if err != nil { + panic(err) + } + bodyBytes := requestBody.Bytes() + + ctx, recorder := suite.newContext(http.MethodPost, "oauth/token", bodyBytes, w.FormDataContentType()) + ctx.Request.Header.Set("accept", "application/json") + + suite.authModule.TokenPOSTHandler(ctx) + + suite.Equal(http.StatusBadRequest, recorder.Code) + + result := recorder.Result() + defer result.Body.Close() + + b, err := ioutil.ReadAll(result.Body) + suite.NoError(err) + + suite.Equal(`{"error":"invalid_request","error_description":"Bad Request: code was not set in the token request form, but must be set since grant_type is authorization_code"}`, string(b)) +} + +func (suite *TokenTestSuite) TestRetrieveAuthorizationCodeWrongGrantType() { + testClient := suite.testClients["local_account_1"] + + requestBody, w, err := testrig.CreateMultipartFormData( + "", "", + map[string]string{ + "grant_type": "client_credentials", + "client_id": testClient.ID, + "client_secret": testClient.Secret, + "redirect_uri": "http://localhost:8080", + "code": "peepeepoopoo", + }) + if err != nil { + panic(err) + } + bodyBytes := requestBody.Bytes() + + ctx, recorder := suite.newContext(http.MethodPost, "oauth/token", bodyBytes, w.FormDataContentType()) + ctx.Request.Header.Set("accept", "application/json") + + suite.authModule.TokenPOSTHandler(ctx) + + suite.Equal(http.StatusBadRequest, recorder.Code) + + result := recorder.Result() + defer result.Body.Close() + + b, err := ioutil.ReadAll(result.Body) + suite.NoError(err) + + suite.Equal(`{"error":"invalid_request","error_description":"Bad Request: a code was provided in the token request form, but grant_type was not set to authorization_code"}`, string(b)) +} + +func TestTokenTestSuite(t *testing.T) { + suite.Run(t, &TokenTestSuite{}) +} |