From 365b5753419238bb96bc3f9b744d380ff20cbafc Mon Sep 17 00:00:00 2001 From: tobi <31960611+tsmethurst@users.noreply.github.com> Date: Mon, 7 Apr 2025 16:14:41 +0200 Subject: [feature] add TOTP two-factor authentication (2FA) (#3960) * [feature] add TOTP two-factor authentication (2FA) * use byteutil.S2B to avoid allocations when comparing + generating password hashes * don't bother with string conversion for consts * use io.ReadFull * use MustGenerateSecret for backup codes * rename util functions --- internal/api/auth/auth.go | 79 ++++----- internal/api/auth/auth_test.go | 2 +- internal/api/auth/authorize.go | 368 ++++++++++++++++++----------------------- internal/api/auth/callback.go | 40 ++--- internal/api/auth/oob.go | 89 +++------- internal/api/auth/signin.go | 271 ++++++++++++++++++++++++------ internal/api/auth/util.go | 152 +++++++++++++++++ 7 files changed, 611 insertions(+), 390 deletions(-) create mode 100644 internal/api/auth/util.go (limited to 'internal/api/auth') diff --git a/internal/api/auth/auth.go b/internal/api/auth/auth.go index e0e8058d6..f9dcb87ea 100644 --- a/internal/api/auth/auth.go +++ b/internal/api/auth/auth.go @@ -20,11 +20,10 @@ package auth import ( "net/http" - "github.com/gin-contrib/sessions" "github.com/gin-gonic/gin" - "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/oidc" "github.com/superseriousbusiness/gotosocial/internal/processing" + "github.com/superseriousbusiness/gotosocial/internal/state" ) const ( @@ -32,61 +31,58 @@ const ( paths prefixed with 'auth' */ - // AuthSignInPath is the API path for users to sign in through - AuthSignInPath = "/sign_in" - // AuthCheckYourEmailPath users land here after registering a new account, instructs them to confirm their email - AuthCheckYourEmailPath = "/check_your_email" - // AuthWaitForApprovalPath users land here after confirming their email - // but before an admin approves their account (if such is required) + AuthSignInPath = "/sign_in" + Auth2FAPath = "/2fa" + AuthCheckYourEmailPath = "/check_your_email" AuthWaitForApprovalPath = "/wait_for_approval" - // AuthAccountDisabledPath users land here when their account is suspended by an admin AuthAccountDisabledPath = "/account_disabled" - // AuthCallbackPath is the API path for receiving callback tokens from external OIDC providers - AuthCallbackPath = "/callback" + AuthCallbackPath = "/callback" /* paths prefixed with 'oauth' */ - // OauthTokenPath is the API path to use for granting token requests to users with valid credentials - OauthTokenPath = "/token" // #nosec G101 else we get a hardcoded credentials warning - // OauthAuthorizePath is the API path for authorization requests (eg., authorize this app to act on my behalf as a user) OauthAuthorizePath = "/authorize" - // OauthFinalizePath is the API path for completing user registration with additional user details - OauthFinalizePath = "/finalize" - // OauthOobTokenPath is the path for serving an html representation of an oob token page. - OauthOobTokenPath = "/oob" // #nosec G101 else we get a hardcoded credentials warning + OauthFinalizePath = "/finalize" + OauthOOBTokenPath = "/oob" // #nosec G101 else we get a hardcoded credentials warning + OauthTokenPath = "/token" // #nosec G101 else we get a hardcoded credentials warning /* params / session keys */ - callbackStateParam = "state" - callbackCodeParam = "code" - sessionUserID = "userid" - sessionClientID = "client_id" - sessionRedirectURI = "redirect_uri" - sessionForceLogin = "force_login" - sessionResponseType = "response_type" - sessionScope = "scope" - sessionInternalState = "internal_state" - sessionClientState = "client_state" - sessionClaims = "claims" - sessionAppID = "app_id" + callbackStateParam = "state" + callbackCodeParam = "code" + sessionUserID = "userid" + sessionUserIDAwaiting2FA = "userid_awaiting_2fa" + sessionClientID = "client_id" + sessionRedirectURI = "redirect_uri" + sessionForceLogin = "force_login" + sessionResponseType = "response_type" + sessionScope = "scope" + sessionInternalState = "internal_state" + sessionClientState = "client_state" + sessionClaims = "claims" + sessionAppID = "app_id" ) type Module struct { - db db.DB + state *state.State processor *processing.Processor idp oidc.IDP } -// New returns an Auth module which provides both 'oauth' and 'auth' endpoints. +// New returns an Auth module which provides +// both 'oauth' and 'auth' endpoints. // // It is safe to pass a nil idp if oidc is disabled. -func New(db db.DB, processor *processing.Processor, idp oidc.IDP) *Module { +func New( + state *state.State, + processor *processing.Processor, + idp oidc.IDP, +) *Module { return &Module{ - db: db, + state: state, processor: processor, idp: idp, } @@ -96,21 +92,16 @@ func New(db db.DB, processor *processing.Processor, idp oidc.IDP) *Module { func (m *Module) RouteAuth(attachHandler func(method string, path string, f ...gin.HandlerFunc) gin.IRoutes) { attachHandler(http.MethodGet, AuthSignInPath, m.SignInGETHandler) attachHandler(http.MethodPost, AuthSignInPath, m.SignInPOSTHandler) + attachHandler(http.MethodGet, Auth2FAPath, m.TwoFactorCodeGETHandler) + attachHandler(http.MethodPost, Auth2FAPath, m.TwoFactorCodePOSTHandler) attachHandler(http.MethodGet, AuthCallbackPath, m.CallbackGETHandler) } -// RouteOauth routes all paths that should have an 'oauth' prefix -func (m *Module) RouteOauth(attachHandler func(method string, path string, f ...gin.HandlerFunc) gin.IRoutes) { +// RouteOAuth routes all paths that should have an 'oauth' prefix +func (m *Module) RouteOAuth(attachHandler func(method string, path string, f ...gin.HandlerFunc) gin.IRoutes) { attachHandler(http.MethodPost, OauthTokenPath, m.TokenPOSTHandler) attachHandler(http.MethodGet, OauthAuthorizePath, m.AuthorizeGETHandler) attachHandler(http.MethodPost, OauthAuthorizePath, m.AuthorizePOSTHandler) attachHandler(http.MethodPost, OauthFinalizePath, m.FinalizePOSTHandler) - attachHandler(http.MethodGet, OauthOobTokenPath, m.OobHandler) -} - -func (m *Module) clearSession(s sessions.Session) { - s.Clear() - if err := s.Save(); err != nil { - panic(err) - } + attachHandler(http.MethodGet, OauthOOBTokenPath, m.OOBTokenGETHandler) } diff --git a/internal/api/auth/auth_test.go b/internal/api/auth/auth_test.go index 3bf3ec593..4b7ea2f5f 100644 --- a/internal/api/auth/auth_test.go +++ b/internal/api/auth/auth_test.go @@ -96,7 +96,7 @@ func (suite *AuthStandardTestSuite) SetupTest() { testrig.NewNoopWebPushSender(), suite.mediaManager, ) - suite.authModule = auth.New(suite.db, suite.processor, suite.idp) + suite.authModule = auth.New(&suite.state, suite.processor, suite.idp) testrig.StandardDBSetup(suite.db, suite.testAccounts) testrig.StartNoopWorkers(&suite.state) diff --git a/internal/api/auth/authorize.go b/internal/api/auth/authorize.go index e4694de57..3676fd417 100644 --- a/internal/api/auth/authorize.go +++ b/internal/api/auth/authorize.go @@ -18,8 +18,6 @@ package auth import ( - "errors" - "fmt" "net/http" "net/url" @@ -28,119 +26,79 @@ import ( "github.com/google/uuid" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" - "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/oauth" ) -// AuthorizeGETHandler should be served as GET at https://example.org/oauth/authorize -// The idea here is to present an oauth authorize page to the user, with a button -// that they have to click to accept. +// AuthorizeGETHandler should be served as +// GET at https://example.org/oauth/authorize. +// +// The idea here is to present an authorization +// page to the user, informing them of the scopes +// the application is requesting, with a button +// that they have to click to give it permission. func (m *Module) AuthorizeGETHandler(c *gin.Context) { - s := sessions.Default(c) - if _, err := apiutil.NegotiateAccept(c, apiutil.HTMLAcceptHeaders...); err != nil { apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGetV1) return } - // UserID will be set in the session by AuthorizePOSTHandler if the caller has already gone through the authentication flow - // If it's not set, then we don't know yet who the user is, so we need to redirect them to the sign in page. - userID, ok := s.Get(sessionUserID).(string) - if !ok || userID == "" { - form := &apimodel.OAuthAuthorize{} - if err := c.ShouldBind(form); err != nil { - m.clearSession(s) - apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, oauth.HelpfulAdvice), m.processor.InstanceGetV1) - return - } - - if errWithCode := saveAuthFormToSession(s, form); errWithCode != nil { - m.clearSession(s) - apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) - return - } - - c.Redirect(http.StatusSeeOther, "/auth"+AuthSignInPath) - return - } + s := sessions.Default(c) - // use session information to validate app, user, and account for this request - clientID, ok := s.Get(sessionClientID).(string) - if !ok || clientID == "" { - m.clearSession(s) - err := fmt.Errorf("key %s was not found in session", sessionClientID) - apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, oauth.HelpfulAdvice), m.processor.InstanceGetV1) + // UserID will be set in the session by + // AuthorizePOSTHandler if the caller has + // already gone through the auth flow. + // + // If it's not set, then we don't yet know + // yet who the user is, so send them to the + // sign in page first. + if userID, ok := s.Get(sessionUserID).(string); !ok || userID == "" { + m.redirectAuthFormToSignIn(c) return } - app, err := m.db.GetApplicationByClientID(c.Request.Context(), clientID) - if err != nil { - m.clearSession(s) - safe := fmt.Sprintf("application for %s %s could not be retrieved", sessionClientID, clientID) - var errWithCode gtserror.WithCode - if err == db.ErrNoEntries { - errWithCode = gtserror.NewErrorBadRequest(err, safe, oauth.HelpfulAdvice) - } else { - errWithCode = gtserror.NewErrorInternalError(err, safe, oauth.HelpfulAdvice) - } - apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + user := m.mustUserFromSession(c, s) + if user == nil { + // Error already + // written. return } - user, err := m.db.GetUserByID(c.Request.Context(), userID) - if err != nil { - m.clearSession(s) - safe := fmt.Sprintf("user with id %s could not be retrieved", userID) - var errWithCode gtserror.WithCode - if err == db.ErrNoEntries { - errWithCode = gtserror.NewErrorBadRequest(err, safe, oauth.HelpfulAdvice) - } else { - errWithCode = gtserror.NewErrorInternalError(err, safe, oauth.HelpfulAdvice) - } - apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + // If the user is unconfirmed, waiting approval, + // or suspended, redirect to an appropriate help page. + if !m.validateUser(c, user) { + // Already + // redirected. return } - acct, err := m.db.GetAccountByID(c.Request.Context(), user.AccountID) - if err != nil { - m.clearSession(s) - safe := fmt.Sprintf("account with id %s could not be retrieved", user.AccountID) - var errWithCode gtserror.WithCode - if err == db.ErrNoEntries { - errWithCode = gtserror.NewErrorBadRequest(err, safe, oauth.HelpfulAdvice) - } else { - errWithCode = gtserror.NewErrorInternalError(err, safe, oauth.HelpfulAdvice) - } + // Everything looks OK. + // Start preparing to render the html template. + instance, errWithCode := m.processor.InstanceGetV1(c.Request.Context()) + if errWithCode != nil { apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) return } - if ensureUserIsAuthorizedOrRedirect(c, user, acct) { - return - } - - // Finally we should also get the redirect and scope of this particular request, as stored in the session. - redirect, ok := s.Get(sessionRedirectURI).(string) - if !ok || redirect == "" { - m.clearSession(s) - err := fmt.Errorf("key %s was not found in session", sessionRedirectURI) - apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, oauth.HelpfulAdvice), m.processor.InstanceGetV1) + redirectURI := m.mustStringFromSession(c, s, sessionRedirectURI) + if redirectURI == "" { + // Error already + // written. return } - scope, ok := s.Get(sessionScope).(string) - if !ok || scope == "" { - m.clearSession(s) - err := fmt.Errorf("key %s was not found in session", sessionScope) - apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, oauth.HelpfulAdvice), m.processor.InstanceGetV1) + scope := m.mustStringFromSession(c, s, sessionScope) + if scope == "" { + // Error already + // written. return } - instance, errWithCode := m.processor.InstanceGetV1(c.Request.Context()) - if errWithCode != nil { - apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + app := m.mustAppFromSession(c, s) + if app == nil { + // Error already + // written. return } @@ -150,158 +108,145 @@ func (m *Module) AuthorizeGETHandler(c *gin.Context) { // and the scope of the request. They can then // approve it if it looks OK to them, which // will POST to the AuthorizePOSTHandler. - page := apiutil.WebPage{ + apiutil.TemplateWebPage(c, apiutil.WebPage{ Template: "authorize.tmpl", Instance: instance, Extra: map[string]any{ "appname": app.Name, "appwebsite": app.Website, - "redirect": redirect, + "redirect": redirectURI, "scope": scope, - "user": acct.Username, + "user": user.Account.Username, }, - } - - apiutil.TemplateWebPage(c, page) + }) } -// AuthorizePOSTHandler should be served as POST at https://example.org/oauth/authorize -// At this point we assume that the user has A) logged in and B) accepted that the app should act for them, -// so we should proceed with the authentication flow and generate an oauth token for them if we can. +// AuthorizePOSTHandler should be served as +// POST at https://example.org/oauth/authorize. +// +// At this point we assume that the user has signed +// in and permitted the app to act on their behalf. +// We should proceed with the authentication flow +// and generate an oauth code at the redirect URI. func (m *Module) AuthorizePOSTHandler(c *gin.Context) { - s := sessions.Default(c) - - // We need to retrieve the original form submitted to the authorizeGEThandler, and - // recreate it on the request so that it can be used further by the oauth2 library. - errs := []string{} - - forceLogin, ok := s.Get(sessionForceLogin).(string) - if !ok { - forceLogin = "false" - } - - responseType, ok := s.Get(sessionResponseType).(string) - if !ok || responseType == "" { - errs = append(errs, fmt.Sprintf("key %s was not found in session", sessionResponseType)) - } - clientID, ok := s.Get(sessionClientID).(string) - if !ok || clientID == "" { - errs = append(errs, fmt.Sprintf("key %s was not found in session", sessionClientID)) - } + // We need to use the session cookie to + // recreate the original form submitted + // to the authorizeGEThandler so that it + // can be validated by the oauth2 library. + s := sessions.Default(c) - redirectURI, ok := s.Get(sessionRedirectURI).(string) - if !ok || redirectURI == "" { - errs = append(errs, fmt.Sprintf("key %s was not found in session", sessionRedirectURI)) + responseType := m.mustStringFromSession(c, s, sessionResponseType) + if responseType == "" { + // Error already + // written. + return } - scope, ok := s.Get(sessionScope).(string) - if !ok { - errs = append(errs, fmt.Sprintf("key %s was not found in session", sessionScope)) + clientID := m.mustStringFromSession(c, s, sessionClientID) + if clientID == "" { + // Error already + // written. + return } - var clientState string - if s, ok := s.Get(sessionClientState).(string); ok { - clientState = s + redirectURI := m.mustStringFromSession(c, s, sessionRedirectURI) + if redirectURI == "" { + // Error already + // written. + return } - userID, ok := s.Get(sessionUserID).(string) - if !ok { - errs = append(errs, fmt.Sprintf("key %s was not found in session", sessionUserID)) + scope := m.mustStringFromSession(c, s, sessionScope) + if scope == "" { + // Error already + // written. + return } - if len(errs) != 0 { - errs = append(errs, oauth.HelpfulAdvice) - apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(errors.New("one or more missing keys on session during AuthorizePOSTHandler"), errs...), m.processor.InstanceGetV1) + user := m.mustUserFromSession(c, s) + if user == nil { + // Error already + // written. return } - user, err := m.db.GetUserByID(c.Request.Context(), userID) - if err != nil { - m.clearSession(s) - safe := fmt.Sprintf("user with id %s could not be retrieved", userID) - var errWithCode gtserror.WithCode - if err == db.ErrNoEntries { - errWithCode = gtserror.NewErrorBadRequest(err, safe, oauth.HelpfulAdvice) - } else { - errWithCode = gtserror.NewErrorInternalError(err, safe, oauth.HelpfulAdvice) - } - apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) - return + // Force login is optional with default of "false". + forceLogin, ok := s.Get(sessionForceLogin).(string) + if !ok || forceLogin == "" { + forceLogin = "false" } - acct, err := m.db.GetAccountByID(c.Request.Context(), user.AccountID) - if err != nil { - m.clearSession(s) - safe := fmt.Sprintf("account with id %s could not be retrieved", user.AccountID) - var errWithCode gtserror.WithCode - if err == db.ErrNoEntries { - errWithCode = gtserror.NewErrorBadRequest(err, safe, oauth.HelpfulAdvice) - } else { - errWithCode = gtserror.NewErrorInternalError(err, safe, oauth.HelpfulAdvice) - } - apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) - return + // Client state is optional with default of "". + var clientState string + if cs, ok := s.Get(sessionClientState).(string); ok { + clientState = cs } - if ensureUserIsAuthorizedOrRedirect(c, user, acct) { + // If the user is unconfirmed, waiting approval, + // or suspended, redirect to an appropriate help page. + if !m.validateUser(c, user) { + // Already + // redirected. return } + // If we're redirecting to our OOB token handler, + // we need to keep the session around so the OOB + // handler can extract values from it. Otherwise, + // we're going to be redirecting somewhere else + // so we can safely clear the session now. if redirectURI != oauth.OOBURI { - // we're done with the session now, so just clear it out - m.clearSession(s) + m.mustClearSession(s) } - // we have to set the values on the request form - // so that they're picked up by the oauth server + // Set values on the request form so that + // they're picked up by the oauth server. c.Request.Form = url.Values{ - sessionForceLogin: {forceLogin}, sessionResponseType: {responseType}, sessionClientID: {clientID}, sessionRedirectURI: {redirectURI}, sessionScope: {scope}, - sessionUserID: {userID}, + sessionUserID: {user.ID}, + sessionForceLogin: {forceLogin}, } if clientState != "" { + // If client state was submitted, + // set it on the form so it can be + // fed back to the client via a query + // param at the eventual redirect URL. c.Request.Form.Set("state", clientState) } - if errWithCode := m.processor.OAuthHandleAuthorizeRequest(c.Writer, c.Request); errWithCode != nil { + // If OAuthHandleAuthorizeRequest is successful, + // it'll handle any further redirects for us, + // but we do still need to handle any errors. + errWithCode := m.processor.OAuthHandleAuthorizeRequest(c.Writer, c.Request) + if errWithCode != nil { apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) } } -// saveAuthFormToSession checks the given OAuthAuthorize form, -// and stores the values in the form into the session. -func saveAuthFormToSession(s sessions.Session, form *apimodel.OAuthAuthorize) gtserror.WithCode { - if form == nil { - err := errors.New("OAuthAuthorize form was nil") - return gtserror.NewErrorBadRequest(err, err.Error(), oauth.HelpfulAdvice) - } - - if form.ResponseType == "" { - err := errors.New("field response_type was not set on OAuthAuthorize form") - return gtserror.NewErrorBadRequest(err, err.Error(), oauth.HelpfulAdvice) - } - - if form.ClientID == "" { - err := errors.New("field client_id was not set on OAuthAuthorize form") - return gtserror.NewErrorBadRequest(err, err.Error(), oauth.HelpfulAdvice) - } +// redirectAuthFormToSignIn binds an OAuthAuthorize form, +// stores the values in the form into the session, and +// redirects the user to the sign in page. +func (m *Module) redirectAuthFormToSignIn(c *gin.Context) { + s := sessions.Default(c) - if form.RedirectURI == "" { - err := errors.New("field redirect_uri was not set on OAuthAuthorize form") - return gtserror.NewErrorBadRequest(err, err.Error(), oauth.HelpfulAdvice) + form := &apimodel.OAuthAuthorize{} + if err := c.ShouldBind(form); err != nil { + m.clearSessionWithBadRequest(c, s, err, err.Error(), oauth.HelpfulAdvice) + return } - // set default scope to read + // Set default scope to read. if form.Scope == "" { form.Scope = "read" } - // save these values from the form so we can use them elsewhere in the session + // Save these values from the form so we + // can use them elsewhere in the session. s.Set(sessionForceLogin, form.ForceLogin) s.Set(sessionResponseType, form.ResponseType) s.Set(sessionClientID, form.ClientID) @@ -310,32 +255,43 @@ func saveAuthFormToSession(s sessions.Session, form *apimodel.OAuthAuthorize) gt s.Set(sessionInternalState, uuid.NewString()) s.Set(sessionClientState, form.State) - if err := s.Save(); err != nil { - err := fmt.Errorf("error saving form values onto session: %s", err) - return gtserror.NewErrorInternalError(err, oauth.HelpfulAdvice) - } - - return nil + m.mustSaveSession(s) + c.Redirect(http.StatusSeeOther, "/auth"+AuthSignInPath) } -func ensureUserIsAuthorizedOrRedirect(ctx *gin.Context, user *gtsmodel.User, account *gtsmodel.Account) (redirected bool) { - if user.ConfirmedAt.IsZero() { - ctx.Redirect(http.StatusSeeOther, "/auth"+AuthCheckYourEmailPath) - redirected = true - return - } - - if !*user.Approved { - ctx.Redirect(http.StatusSeeOther, "/auth"+AuthWaitForApprovalPath) - redirected = true - return - } - - if *user.Disabled || !account.SuspendedAt.IsZero() { - ctx.Redirect(http.StatusSeeOther, "/auth"+AuthAccountDisabledPath) - redirected = true - return +// validateUser checks if the given user: +// +// 1. Has a confirmed email address. +// 2. Has been approved. +// 3. Is not disabled or suspended. +// +// If all looks OK, returns true. Otherwise, +// redirects to a help page and returns false. +func (m *Module) validateUser( + c *gin.Context, + user *gtsmodel.User, +) bool { + switch { + case user.ConfirmedAt.IsZero(): + // User email not confirmed yet. + const redirectTo = "/auth" + AuthCheckYourEmailPath + c.Redirect(http.StatusSeeOther, redirectTo) + return false + + case !*user.Approved: + // User signup not approved yet. + const redirectTo = "/auth" + AuthWaitForApprovalPath + c.Redirect(http.StatusSeeOther, redirectTo) + return false + + case *user.Disabled || !user.Account.SuspendedAt.IsZero(): + // User disabled or suspended. + const redirectTo = "/auth" + AuthAccountDisabledPath + c.Redirect(http.StatusSeeOther, redirectTo) + return false + + default: + // All good. + return true } - - return } diff --git a/internal/api/auth/callback.go b/internal/api/auth/callback.go index 2dc36fac8..5003910e9 100644 --- a/internal/api/auth/callback.go +++ b/internal/api/auth/callback.go @@ -60,7 +60,7 @@ func (m *Module) CallbackGETHandler(c *gin.Context) { returnedInternalState := c.Query(callbackStateParam) if returnedInternalState == "" { - m.clearSession(s) + m.mustClearSession(s) err := fmt.Errorf("%s parameter not found on callback query", callbackStateParam) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1) return @@ -69,14 +69,14 @@ func (m *Module) CallbackGETHandler(c *gin.Context) { savedInternalStateI := s.Get(sessionInternalState) savedInternalState, ok := savedInternalStateI.(string) if !ok { - m.clearSession(s) + m.mustClearSession(s) err := fmt.Errorf("key %s was not found in session", sessionInternalState) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1) return } if returnedInternalState != savedInternalState { - m.clearSession(s) + m.mustClearSession(s) err := errors.New("mismatch between callback state and saved state") apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGetV1) return @@ -85,7 +85,7 @@ func (m *Module) CallbackGETHandler(c *gin.Context) { // retrieve stored claims using code code := c.Query(callbackCodeParam) if code == "" { - m.clearSession(s) + m.mustClearSession(s) err := fmt.Errorf("%s parameter not found on callback query", callbackCodeParam) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1) return @@ -93,7 +93,7 @@ func (m *Module) CallbackGETHandler(c *gin.Context) { claims, errWithCode := m.idp.HandleCallback(c.Request.Context(), code) if errWithCode != nil { - m.clearSession(s) + m.mustClearSession(s) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) return } @@ -102,15 +102,15 @@ func (m *Module) CallbackGETHandler(c *gin.Context) { // info about the app associated with the client_id clientID, ok := s.Get(sessionClientID).(string) if !ok || clientID == "" { - m.clearSession(s) + m.mustClearSession(s) err := fmt.Errorf("key %s was not found in session", sessionClientID) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, oauth.HelpfulAdvice), m.processor.InstanceGetV1) return } - app, err := m.db.GetApplicationByClientID(c.Request.Context(), clientID) + app, err := m.state.DB.GetApplicationByClientID(c.Request.Context(), clientID) if err != nil { - m.clearSession(s) + m.mustClearSession(s) safe := fmt.Sprintf("application for %s %s could not be retrieved", sessionClientID, clientID) var errWithCode gtserror.WithCode if err == db.ErrNoEntries { @@ -124,7 +124,7 @@ func (m *Module) CallbackGETHandler(c *gin.Context) { user, errWithCode := m.fetchUserForClaims(c.Request.Context(), claims, net.IP(c.ClientIP()), app.ID) if errWithCode != nil { - m.clearSession(s) + m.mustClearSession(s) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) return } @@ -140,7 +140,7 @@ func (m *Module) CallbackGETHandler(c *gin.Context) { s.Set(sessionClaims, claims) s.Set(sessionAppID, app.ID) if err := s.Save(); err != nil { - m.clearSession(s) + m.mustClearSession(s) apiutil.ErrorHandler(c, gtserror.NewErrorInternalError(err), m.processor.InstanceGetV1) return } @@ -173,7 +173,7 @@ func (m *Module) CallbackGETHandler(c *gin.Context) { s.Set(sessionUserID, user.ID) if err := s.Save(); err != nil { - m.clearSession(s) + m.mustClearSession(s) apiutil.ErrorHandler(c, gtserror.NewErrorInternalError(err), m.processor.InstanceGetV1) return } @@ -186,7 +186,7 @@ func (m *Module) FinalizePOSTHandler(c *gin.Context) { form := &extraInfo{} if err := c.ShouldBind(form); err != nil { - m.clearSession(s) + m.mustClearSession(s) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, oauth.HelpfulAdvice), m.processor.InstanceGetV1) return } @@ -219,7 +219,7 @@ func (m *Module) FinalizePOSTHandler(c *gin.Context) { } // see if the username is still available - usernameAvailable, err := m.db.IsUsernameAvailable(c.Request.Context(), form.Username) + usernameAvailable, err := m.state.DB.IsUsernameAvailable(c.Request.Context(), form.Username) if err != nil { apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, oauth.HelpfulAdvice), m.processor.InstanceGetV1) return @@ -248,7 +248,7 @@ func (m *Module) FinalizePOSTHandler(c *gin.Context) { // we're now ready to actually create the user user, errWithCode := m.createUserFromOIDC(c.Request.Context(), claims, form, net.IP(c.ClientIP()), appID) if errWithCode != nil { - m.clearSession(s) + m.mustClearSession(s) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) return } @@ -256,7 +256,7 @@ func (m *Module) FinalizePOSTHandler(c *gin.Context) { s.Delete(sessionAppID) s.Set(sessionUserID, user.ID) if err := s.Save(); err != nil { - m.clearSession(s) + m.mustClearSession(s) apiutil.ErrorHandler(c, gtserror.NewErrorInternalError(err), m.processor.InstanceGetV1) return } @@ -268,7 +268,7 @@ func (m *Module) fetchUserForClaims(ctx context.Context, claims *oidc.Claims, ip err := errors.New("no sub claim found - is your provider OIDC compliant?") return nil, gtserror.NewErrorBadRequest(err, err.Error()) } - user, err := m.db.GetUserByExternalID(ctx, claims.Sub) + user, err := m.state.DB.GetUserByExternalID(ctx, claims.Sub) if err == nil { return user, nil } @@ -280,7 +280,7 @@ func (m *Module) fetchUserForClaims(ctx context.Context, claims *oidc.Claims, ip return nil, nil } // fallback to email if we want to link existing users - user, err = m.db.GetUserByEmailAddress(ctx, claims.Email) + user, err = m.state.DB.GetUserByEmailAddress(ctx, claims.Email) if err == db.ErrNoEntries { return nil, nil } else if err != nil { @@ -290,7 +290,7 @@ func (m *Module) fetchUserForClaims(ctx context.Context, claims *oidc.Claims, ip // at this point we have found a matching user but still need to link the newly received external ID user.ExternalID = claims.Sub - err = m.db.UpdateUser(ctx, user, "external_id") + err = m.state.DB.UpdateUser(ctx, user, "external_id") if err != nil { err := fmt.Errorf("error linking existing user %s: %s", claims.Email, err) return nil, gtserror.NewErrorInternalError(err) @@ -300,7 +300,7 @@ func (m *Module) fetchUserForClaims(ctx context.Context, claims *oidc.Claims, ip func (m *Module) createUserFromOIDC(ctx context.Context, claims *oidc.Claims, extraInfo *extraInfo, ip net.IP, appID string) (*gtsmodel.User, gtserror.WithCode) { // Check if the claimed email address is available for use. - emailAvailable, err := m.db.IsEmailAvailable(ctx, claims.Email) + emailAvailable, err := m.state.DB.IsEmailAvailable(ctx, claims.Email) if err != nil { err := gtserror.Newf("db error checking email availability: %w", err) return nil, gtserror.NewErrorInternalError(err) @@ -354,7 +354,7 @@ func (m *Module) createUserFromOIDC(ctx context.Context, claims *oidc.Claims, ex // Create the user! This will also create an account and // store it in the database, so we don't need to do that. - user, err := m.db.NewSignup(ctx, gtsmodel.NewSignup{ + user, err := m.state.DB.NewSignup(ctx, gtsmodel.NewSignup{ Username: extraInfo.Username, Email: claims.Email, Password: password, diff --git a/internal/api/auth/oob.go b/internal/api/auth/oob.go index 8c7b1f2a5..c723a1cb5 100644 --- a/internal/api/auth/oob.go +++ b/internal/api/auth/oob.go @@ -18,97 +18,56 @@ package auth import ( - "context" "errors" - "fmt" "github.com/gin-contrib/sessions" "github.com/gin-gonic/gin" - apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" - "github.com/superseriousbusiness/gotosocial/internal/db" - "github.com/superseriousbusiness/gotosocial/internal/gtserror" - "github.com/superseriousbusiness/gotosocial/internal/oauth" ) -func (m *Module) OobHandler(c *gin.Context) { - instance, errWithCode := m.processor.InstanceGetV1(c.Request.Context()) - if errWithCode != nil { - apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) - return - } - - instanceGet := func(ctx context.Context) (*apimodel.InstanceV1, gtserror.WithCode) { - return instance, nil - } +// OOBTokenGETHandler parses the OAuth code from the query +// params and serves a nice little HTML page showing the code. +func (m *Module) OOBTokenGETHandler(c *gin.Context) { + s := sessions.Default(c) oobToken := c.Query("code") if oobToken == "" { - err := errors.New("no 'code' query value provided in callback redirect") - apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error(), oauth.HelpfulAdvice), instanceGet) + const errText = "no 'code' query value provided in callback redirect" + m.clearSessionWithBadRequest(c, s, errors.New(errText), errText) return } - s := sessions.Default(c) - - errs := []string{} - - scope, ok := s.Get(sessionScope).(string) - if !ok { - errs = append(errs, fmt.Sprintf("key %s was not found in session", sessionScope)) - } - - userID, ok := s.Get(sessionUserID).(string) - if !ok { - errs = append(errs, fmt.Sprintf("key %s was not found in session", sessionUserID)) - } - - if len(errs) != 0 { - errs = append(errs, oauth.HelpfulAdvice) - apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(errors.New("one or more missing keys on session during OobHandler"), errs...), m.processor.InstanceGetV1) + user := m.mustUserFromSession(c, s) + if user == nil { + // Error already + // written. return } - user, err := m.db.GetUserByID(c.Request.Context(), userID) - if err != nil { - m.clearSession(s) - safe := fmt.Sprintf("user with id %s could not be retrieved", userID) - var errWithCode gtserror.WithCode - if err == db.ErrNoEntries { - errWithCode = gtserror.NewErrorBadRequest(err, safe, oauth.HelpfulAdvice) - } else { - errWithCode = gtserror.NewErrorInternalError(err, safe, oauth.HelpfulAdvice) - } - apiutil.ErrorHandler(c, errWithCode, instanceGet) + scope := m.mustStringFromSession(c, s, sessionScope) + if scope == "" { + // Error already + // written. return } - acct, err := m.db.GetAccountByID(c.Request.Context(), user.AccountID) - if err != nil { - m.clearSession(s) - safe := fmt.Sprintf("account with id %s could not be retrieved", user.AccountID) - var errWithCode gtserror.WithCode - if err == db.ErrNoEntries { - errWithCode = gtserror.NewErrorBadRequest(err, safe, oauth.HelpfulAdvice) - } else { - errWithCode = gtserror.NewErrorInternalError(err, safe, oauth.HelpfulAdvice) - } - apiutil.ErrorHandler(c, errWithCode, instanceGet) + // We're done with + // the session now. + m.mustClearSession(s) + + instance, errWithCode := m.processor.InstanceGetV1(c.Request.Context()) + if errWithCode != nil { + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) return } - // we're done with the session now, so just clear it out - m.clearSession(s) - - page := apiutil.WebPage{ + apiutil.TemplateWebPage(c, apiutil.WebPage{ Template: "oob.tmpl", Instance: instance, Extra: map[string]any{ - "user": acct.Username, + "user": user.Account.Username, "oobToken": oobToken, "scope": scope, }, - } - - apiutil.TemplateWebPage(c, page) + }) } diff --git a/internal/api/auth/signin.go b/internal/api/auth/signin.go index a8713d05f..2820255db 100644 --- a/internal/api/auth/signin.go +++ b/internal/api/auth/signin.go @@ -22,104 +22,143 @@ import ( "errors" "fmt" "net/http" + "slices" + "strings" + "codeberg.org/gruf/go-byteutil" "github.com/gin-contrib/sessions" "github.com/gin-gonic/gin" + "github.com/pquerna/otp/totp" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/oauth" "golang.org/x/crypto/bcrypt" ) -// signIn just wraps a form-submitted username (we want an email) and password -type signIn struct { - Email string `form:"username"` - Password string `form:"password"` -} - -// SignInGETHandler should be served at https://example.org/auth/sign_in. -// The idea is to present a sign in page to the user, where they can enter their username and password. -// The form will then POST to the sign in page, which will be handled by SignInPOSTHandler. -// If an idp provider is set, then the user will be redirected to that to do their sign in. +// SignInGETHandler should be served at +// GET https://example.org/auth/sign_in. +// +// The idea is to present a friendly sign-in +// page to the user, where they can enter their +// username and password. +// +// When submitted, the form will POST to the sign- +// in page, which will be handled by SignInPOSTHandler. +// +// If an idp provider is set, then the user will +// be redirected to that to do their sign in. func (m *Module) SignInGETHandler(c *gin.Context) { if _, err := apiutil.NegotiateAccept(c, apiutil.HTMLAcceptHeaders...); err != nil { apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGetV1) return } - if !config.GetOIDCEnabled() { - instance, errWithCode := m.processor.InstanceGetV1(c.Request.Context()) - if errWithCode != nil { - apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + if config.GetOIDCEnabled() { + // IDP provider is in use, so redirect to it + // instead of serving our own sign in page. + // + // We need the internal state to know where + // to redirect to. + internalState := m.mustStringFromSession( + c, + sessions.Default(c), + sessionInternalState, + ) + if internalState == "" { + // Error already + // written. return } - page := apiutil.WebPage{ - Template: "sign-in.tmpl", - Instance: instance, - } - - apiutil.TemplateWebPage(c, page) + c.Redirect(http.StatusSeeOther, m.idp.AuthCodeURL(internalState)) return } - // idp provider is in use, so redirect to it - s := sessions.Default(c) - - internalStateI := s.Get(sessionInternalState) - internalState, ok := internalStateI.(string) - if !ok { - m.clearSession(s) - err := fmt.Errorf("key %s was not found in session", sessionInternalState) - apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1) + // IDP provider is not in use. + // Render our own cute little page. + instance, errWithCode := m.processor.InstanceGetV1(c.Request.Context()) + if errWithCode != nil { + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) return } - c.Redirect(http.StatusSeeOther, m.idp.AuthCodeURL(internalState)) + apiutil.TemplateWebPage(c, apiutil.WebPage{ + Template: "sign-in.tmpl", + Instance: instance, + }) } -// SignInPOSTHandler should be served at https://example.org/auth/sign_in. -// The idea is to present a sign in page to the user, where they can enter their username and password. -// The handler will then redirect to the auth handler served at /auth +// SignInPOSTHandler should be served at +// POST https://example.org/auth/sign_in. +// +// The handler will check the submitted credentials, +// then redirect either to the 2fa form, or straight +// to the authorize page served at /oauth/authorize. func (m *Module) SignInPOSTHandler(c *gin.Context) { s := sessions.Default(c) - form := &signIn{} + // Parse email + password. + form := &struct { + Email string `form:"username" validate:"required"` + Password string `form:"password" validate:"required"` + }{} if err := c.ShouldBind(form); err != nil { - m.clearSession(s) - apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, oauth.HelpfulAdvice), m.processor.InstanceGetV1) + m.clearSessionWithBadRequest(c, s, err, oauth.HelpfulAdvice) return } - userid, errWithCode := m.ValidatePassword(c.Request.Context(), form.Email, form.Password) + user, errWithCode := m.validatePassword( + c.Request.Context(), + form.Email, + form.Password, + ) if errWithCode != nil { - // don't clear session here, so the user can just press back and try again - // if they accidentally gave the wrong password or something + // Don't clear session here yet, so the user + // can just press back and try again if they + // accidentally gave the wrong password, without + // having to do the whole sign in flow again! apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) return } - s.Set(sessionUserID, userid) - if err := s.Save(); err != nil { - err := fmt.Errorf("error saving user id onto session: %s", err) - apiutil.ErrorHandler(c, gtserror.NewErrorInternalError(err, oauth.HelpfulAdvice), m.processor.InstanceGetV1) + // Whether or not 2fa is enabled, we want + // to save the session when we're done here. + defer m.mustSaveSession(s) + + if user.TwoFactorEnabled() { + // If this user has 2FA enabled, redirect + // to the 2FA page and have them submit + // a code from their authenticator app. + s.Set(sessionUserIDAwaiting2FA, user.ID) + c.Redirect(http.StatusFound, "/auth"+Auth2FAPath) + return } + // If the user doesn't have 2fa enabled, + // redirect straight to the OAuth authorize page. + s.Set(sessionUserID, user.ID) c.Redirect(http.StatusFound, "/oauth"+OauthAuthorizePath) } -// ValidatePassword takes an email address and a password. -// The goal is to authenticate the password against the one for that email -// address stored in the database. If OK, we return the userid (a ulid) for that user, -// so that it can be used in further Oauth flows to generate a token/retreieve an oauth client from the db. -func (m *Module) ValidatePassword(ctx context.Context, email string, password string) (string, gtserror.WithCode) { +// validatePassword takes an email address and a password. +// The func authenticates the password against the one for +// that email address stored in the database. +// +// If OK, it returns the user, so that it can be used in +// further OAuth flows to generate a token etc. +func (m *Module) validatePassword( + ctx context.Context, + email string, + password string, +) (*gtsmodel.User, gtserror.WithCode) { if email == "" || password == "" { err := errors.New("email or password was not provided") return incorrectPassword(err) } - user, err := m.db.GetUserByEmailAddress(ctx, email) + user, err := m.state.DB.GetUserByEmailAddress(ctx, email) if err != nil { err := fmt.Errorf("user %s was not retrievable from db during oauth authorization attempt: %s", email, err) return incorrectPassword(err) @@ -130,17 +169,141 @@ func (m *Module) ValidatePassword(ctx context.Context, email string, password st return incorrectPassword(err) } - if err := bcrypt.CompareHashAndPassword([]byte(user.EncryptedPassword), []byte(password)); err != nil { + if err := bcrypt.CompareHashAndPassword( + byteutil.S2B(user.EncryptedPassword), + byteutil.S2B(password), + ); err != nil { err := fmt.Errorf("password hash didn't match for user %s during sign in attempt: %s", user.Email, err) return incorrectPassword(err) } - return user.ID, nil + return user, nil } // incorrectPassword wraps the given error in a gtserror.WithCode, and returns // only a generic 'safe' error message to the user, to not give any info away. -func incorrectPassword(err error) (string, gtserror.WithCode) { - safeErr := fmt.Errorf("password/email combination was incorrect") - return "", gtserror.NewErrorUnauthorized(err, safeErr.Error(), oauth.HelpfulAdvice) +func incorrectPassword(err error) (*gtsmodel.User, gtserror.WithCode) { + const errText = "password/email combination was incorrect" + return nil, gtserror.NewErrorUnauthorized(err, errText, oauth.HelpfulAdvice) +} + +// TwoFactorCodeGETHandler should be served at +// GET https://example.org/auth/2fa. +// +// The 2fa template displays a simple form asking the +// user to input a code from their authenticator app. +func (m *Module) TwoFactorCodeGETHandler(c *gin.Context) { + s := sessions.Default(c) + + user := m.mustUserFromSession(c, s) + if user == nil { + // Error already + // written. + return + } + + instance, errWithCode := m.processor.InstanceGetV1(c.Request.Context()) + if errWithCode != nil { + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + return + } + + apiutil.TemplateWebPage(c, apiutil.WebPage{ + Template: "2fa.tmpl", + Instance: instance, + Extra: map[string]any{ + "user": user.Account.Username, + }, + }) +} + +// TwoFactorCodePOSTHandler should be served at +// POST https://example.org/auth/2fa. +// +// The idea is to handle a submitted 2fa code, validate it, +// and if valid redirect to the /oauth/authorize page that +// the user would get to if they didn't have 2fa enabled. +func (m *Module) TwoFactorCodePOSTHandler(c *gin.Context) { + s := sessions.Default(c) + + user := m.mustUserFromSession(c, s) + if user == nil { + // Error already + // written. + return + } + + // Parse 2fa code. + form := &struct { + Code string `form:"code" validate:"required"` + }{} + if err := c.ShouldBind(form); err != nil { + m.clearSessionWithBadRequest(c, s, err, oauth.HelpfulAdvice) + return + } + + valid, err := m.validate2FACode(c, user, form.Code) + if err != nil { + m.clearSessionWithInternalError(c, s, err, oauth.HelpfulAdvice) + return + } + + if !valid { + // Don't clear session here yet, so the user + // can just press back and try again if they + // accidentally gave the wrong code, without + // having to do the whole sign in flow again! + const errText = "2fa code invalid or timed out, press back and try again; " + + "if issues persist, pester your instance admin to check the server clock" + errWithCode := gtserror.NewErrorBadRequest(errors.New(errText), errText) + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + return + } + + // Code looks good! Redirect + // to the OAuth authorize page. + s.Set(sessionUserID, user.ID) + m.mustSaveSession(s) + c.Redirect(http.StatusFound, "/oauth"+OauthAuthorizePath) +} + +func (m *Module) validate2FACode(c *gin.Context, user *gtsmodel.User, code string) (bool, error) { + code = strings.TrimSpace(code) + if len(code) <= 6 { + // This is a normal authenticator + // app code, just try to validate it. + return totp.Validate(code, user.TwoFactorSecret), nil + } + + // This is a one-time recovery code. + // Check against the user's stored codes. + for i := 0; i < len(user.TwoFactorBackups); i++ { + err := bcrypt.CompareHashAndPassword( + byteutil.S2B(user.TwoFactorBackups[i]), + byteutil.S2B(code), + ) + if err != nil { + // Doesn't match, + // try next. + continue + } + + // We have a match. + // Remove this one-time code from the user's backups. + user.TwoFactorBackups = slices.Delete(user.TwoFactorBackups, i, i+1) + if err := m.state.DB.UpdateUser( + c.Request.Context(), + user, + "two_factor_backups", + ); err != nil { + return false, err + } + + // So valid bestie! + return true, nil + } + + // Not a valid one-time + // recovery code. + return false, nil } diff --git a/internal/api/auth/util.go b/internal/api/auth/util.go new file mode 100644 index 000000000..f1aed0bc3 --- /dev/null +++ b/internal/api/auth/util.go @@ -0,0 +1,152 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package auth + +import ( + "errors" + + "github.com/gin-contrib/sessions" + "github.com/gin-gonic/gin" + apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/oauth" +) + +func (m *Module) mustClearSession(s sessions.Session) { + s.Clear() + m.mustSaveSession(s) +} + +func (m *Module) mustSaveSession(s sessions.Session) { + if err := s.Save(); err != nil { + panic(err) + } +} + +// mustUserFromSession returns a *gtsmodel.User by checking the +// session for a user id and fetching the user from the database. +// +// On failure, the function clears session state, writes an internal +// error to the response writer, and returns nil. Callers should always +// return immediately if receiving nil back from this function! +func (m *Module) mustUserFromSession( + c *gin.Context, + s sessions.Session, +) *gtsmodel.User { + // Try "userid" key first, fall + // back to "userid_awaiting_2fa". + var userID string + for _, key := range [2]string{ + sessionUserID, + sessionUserIDAwaiting2FA, + } { + var ok bool + userID, ok = s.Get(key).(string) + if ok && userID != "" { + // Got it. + break + } + } + + if userID == "" { + const safe = "neither userid nor userid_awaiting_2fa keys found in session" + m.clearSessionWithInternalError(c, s, errors.New(safe), safe, oauth.HelpfulAdvice) + return nil + } + + user, err := m.state.DB.GetUserByID(c.Request.Context(), userID) + if err != nil { + safe := "db error getting user " + userID + m.clearSessionWithInternalError(c, s, err, safe, oauth.HelpfulAdvice) + return nil + } + + return user +} + +// mustAppFromSession returns a *gtsmodel.Application by checking the +// session for an application keyid and fetching the app from the database. +// +// On failure, the function clears session state, writes an internal +// error to the response writer, and returns nil. Callers should always +// return immediately if receiving nil back from this function! +func (m *Module) mustAppFromSession( + c *gin.Context, + s sessions.Session, +) *gtsmodel.Application { + clientID, ok := s.Get(sessionClientID).(string) + if !ok { + const safe = "key client_id not found in session" + m.clearSessionWithInternalError(c, s, errors.New(safe), safe, oauth.HelpfulAdvice) + return nil + } + + app, err := m.state.DB.GetApplicationByClientID(c.Request.Context(), clientID) + if err != nil { + safe := "db error getting app for clientID " + clientID + m.clearSessionWithInternalError(c, s, err, safe, oauth.HelpfulAdvice) + return nil + } + + return app +} + +// mustStringFromSession returns the string value +// corresponding to the given session key, if any is set. +// +// On failure (nothing set), the function clears session +// state, writes an internal error to the response writer, +// and returns nil. Callers should always return immediately +// if receiving nil back from this function! +func (m *Module) mustStringFromSession( + c *gin.Context, + s sessions.Session, + key string, +) string { + v, ok := s.Get(key).(string) + if !ok { + safe := "key " + key + " not found in session" + m.clearSessionWithInternalError(c, s, errors.New(safe), safe, oauth.HelpfulAdvice) + return "" + } + + return v +} + +func (m *Module) clearSessionWithInternalError( + c *gin.Context, + s sessions.Session, + err error, + helpText ...string, +) { + m.mustClearSession(s) + errWithCode := gtserror.NewErrorInternalError(err, helpText...) + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) +} + +func (m *Module) clearSessionWithBadRequest( + c *gin.Context, + s sessions.Session, + err error, + helpText ...string, +) { + m.mustClearSession(s) + errWithCode := gtserror.NewErrorBadRequest(err, helpText...) + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) +} -- cgit v1.2.3