diff options
Diffstat (limited to 'internal/api/auth/callback.go')
| -rw-r--r-- | internal/api/auth/callback.go | 40 |
1 files changed, 20 insertions, 20 deletions
diff --git a/internal/api/auth/callback.go b/internal/api/auth/callback.go index 2dc36fac8..5003910e9 100644 --- a/internal/api/auth/callback.go +++ b/internal/api/auth/callback.go @@ -60,7 +60,7 @@ func (m *Module) CallbackGETHandler(c *gin.Context) { returnedInternalState := c.Query(callbackStateParam) if returnedInternalState == "" { - m.clearSession(s) + m.mustClearSession(s) err := fmt.Errorf("%s parameter not found on callback query", callbackStateParam) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1) return @@ -69,14 +69,14 @@ func (m *Module) CallbackGETHandler(c *gin.Context) { savedInternalStateI := s.Get(sessionInternalState) savedInternalState, ok := savedInternalStateI.(string) if !ok { - m.clearSession(s) + m.mustClearSession(s) err := fmt.Errorf("key %s was not found in session", sessionInternalState) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1) return } if returnedInternalState != savedInternalState { - m.clearSession(s) + m.mustClearSession(s) err := errors.New("mismatch between callback state and saved state") apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGetV1) return @@ -85,7 +85,7 @@ func (m *Module) CallbackGETHandler(c *gin.Context) { // retrieve stored claims using code code := c.Query(callbackCodeParam) if code == "" { - m.clearSession(s) + m.mustClearSession(s) err := fmt.Errorf("%s parameter not found on callback query", callbackCodeParam) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1) return @@ -93,7 +93,7 @@ func (m *Module) CallbackGETHandler(c *gin.Context) { claims, errWithCode := m.idp.HandleCallback(c.Request.Context(), code) if errWithCode != nil { - m.clearSession(s) + m.mustClearSession(s) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) return } @@ -102,15 +102,15 @@ func (m *Module) CallbackGETHandler(c *gin.Context) { // info about the app associated with the client_id clientID, ok := s.Get(sessionClientID).(string) if !ok || clientID == "" { - m.clearSession(s) + m.mustClearSession(s) err := fmt.Errorf("key %s was not found in session", sessionClientID) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, oauth.HelpfulAdvice), m.processor.InstanceGetV1) return } - app, err := m.db.GetApplicationByClientID(c.Request.Context(), clientID) + app, err := m.state.DB.GetApplicationByClientID(c.Request.Context(), clientID) if err != nil { - m.clearSession(s) + m.mustClearSession(s) safe := fmt.Sprintf("application for %s %s could not be retrieved", sessionClientID, clientID) var errWithCode gtserror.WithCode if err == db.ErrNoEntries { @@ -124,7 +124,7 @@ func (m *Module) CallbackGETHandler(c *gin.Context) { user, errWithCode := m.fetchUserForClaims(c.Request.Context(), claims, net.IP(c.ClientIP()), app.ID) if errWithCode != nil { - m.clearSession(s) + m.mustClearSession(s) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) return } @@ -140,7 +140,7 @@ func (m *Module) CallbackGETHandler(c *gin.Context) { s.Set(sessionClaims, claims) s.Set(sessionAppID, app.ID) if err := s.Save(); err != nil { - m.clearSession(s) + m.mustClearSession(s) apiutil.ErrorHandler(c, gtserror.NewErrorInternalError(err), m.processor.InstanceGetV1) return } @@ -173,7 +173,7 @@ func (m *Module) CallbackGETHandler(c *gin.Context) { s.Set(sessionUserID, user.ID) if err := s.Save(); err != nil { - m.clearSession(s) + m.mustClearSession(s) apiutil.ErrorHandler(c, gtserror.NewErrorInternalError(err), m.processor.InstanceGetV1) return } @@ -186,7 +186,7 @@ func (m *Module) FinalizePOSTHandler(c *gin.Context) { form := &extraInfo{} if err := c.ShouldBind(form); err != nil { - m.clearSession(s) + m.mustClearSession(s) apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, oauth.HelpfulAdvice), m.processor.InstanceGetV1) return } @@ -219,7 +219,7 @@ func (m *Module) FinalizePOSTHandler(c *gin.Context) { } // see if the username is still available - usernameAvailable, err := m.db.IsUsernameAvailable(c.Request.Context(), form.Username) + usernameAvailable, err := m.state.DB.IsUsernameAvailable(c.Request.Context(), form.Username) if err != nil { apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, oauth.HelpfulAdvice), m.processor.InstanceGetV1) return @@ -248,7 +248,7 @@ func (m *Module) FinalizePOSTHandler(c *gin.Context) { // we're now ready to actually create the user user, errWithCode := m.createUserFromOIDC(c.Request.Context(), claims, form, net.IP(c.ClientIP()), appID) if errWithCode != nil { - m.clearSession(s) + m.mustClearSession(s) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) return } @@ -256,7 +256,7 @@ func (m *Module) FinalizePOSTHandler(c *gin.Context) { s.Delete(sessionAppID) s.Set(sessionUserID, user.ID) if err := s.Save(); err != nil { - m.clearSession(s) + m.mustClearSession(s) apiutil.ErrorHandler(c, gtserror.NewErrorInternalError(err), m.processor.InstanceGetV1) return } @@ -268,7 +268,7 @@ func (m *Module) fetchUserForClaims(ctx context.Context, claims *oidc.Claims, ip err := errors.New("no sub claim found - is your provider OIDC compliant?") return nil, gtserror.NewErrorBadRequest(err, err.Error()) } - user, err := m.db.GetUserByExternalID(ctx, claims.Sub) + user, err := m.state.DB.GetUserByExternalID(ctx, claims.Sub) if err == nil { return user, nil } @@ -280,7 +280,7 @@ func (m *Module) fetchUserForClaims(ctx context.Context, claims *oidc.Claims, ip return nil, nil } // fallback to email if we want to link existing users - user, err = m.db.GetUserByEmailAddress(ctx, claims.Email) + user, err = m.state.DB.GetUserByEmailAddress(ctx, claims.Email) if err == db.ErrNoEntries { return nil, nil } else if err != nil { @@ -290,7 +290,7 @@ func (m *Module) fetchUserForClaims(ctx context.Context, claims *oidc.Claims, ip // at this point we have found a matching user but still need to link the newly received external ID user.ExternalID = claims.Sub - err = m.db.UpdateUser(ctx, user, "external_id") + err = m.state.DB.UpdateUser(ctx, user, "external_id") if err != nil { err := fmt.Errorf("error linking existing user %s: %s", claims.Email, err) return nil, gtserror.NewErrorInternalError(err) @@ -300,7 +300,7 @@ func (m *Module) fetchUserForClaims(ctx context.Context, claims *oidc.Claims, ip func (m *Module) createUserFromOIDC(ctx context.Context, claims *oidc.Claims, extraInfo *extraInfo, ip net.IP, appID string) (*gtsmodel.User, gtserror.WithCode) { // Check if the claimed email address is available for use. - emailAvailable, err := m.db.IsEmailAvailable(ctx, claims.Email) + emailAvailable, err := m.state.DB.IsEmailAvailable(ctx, claims.Email) if err != nil { err := gtserror.Newf("db error checking email availability: %w", err) return nil, gtserror.NewErrorInternalError(err) @@ -354,7 +354,7 @@ func (m *Module) createUserFromOIDC(ctx context.Context, claims *oidc.Claims, ex // Create the user! This will also create an account and // store it in the database, so we don't need to do that. - user, err := m.db.NewSignup(ctx, gtsmodel.NewSignup{ + user, err := m.state.DB.NewSignup(ctx, gtsmodel.NewSignup{ Username: extraInfo.Username, Email: claims.Email, Password: password, |
