diff options
Diffstat (limited to 'internal/api/client/auth/callback.go')
-rw-r--r-- | internal/api/client/auth/callback.go | 206 |
1 files changed, 131 insertions, 75 deletions
diff --git a/internal/api/client/auth/callback.go b/internal/api/client/auth/callback.go index cf2c906a5..c97abf7aa 100644 --- a/internal/api/client/auth/callback.go +++ b/internal/api/client/auth/callback.go @@ -24,13 +24,13 @@ import ( "fmt" "net" "net/http" - "strconv" "strings" "github.com/gin-contrib/sessions" "github.com/gin-gonic/gin" "github.com/google/uuid" "github.com/superseriousbusiness/gotosocial/internal/api" + "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" @@ -39,6 +39,12 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/validate" ) +// extraInfo wraps a form-submitted username and transmitted name +type extraInfo struct { + Username string `form:"username"` + Name string `form:"name"` // note that this is only used for re-rendering the page in case of an error +} + // CallbackGETHandler parses a token from an external auth provider. func (m *Module) CallbackGETHandler(c *gin.Context) { s := sessions.Default(c) @@ -110,115 +116,165 @@ func (m *Module) CallbackGETHandler(c *gin.Context) { return } - user, errWithCode := m.parseUserFromClaims(c.Request.Context(), claims, net.IP(c.ClientIP()), app.ID) + user, errWithCode := m.fetchUserForClaims(c.Request.Context(), claims, net.IP(c.ClientIP()), app.ID) if errWithCode != nil { m.clearSession(s) api.ErrorHandler(c, errWithCode, m.processor.InstanceGet) return } + if user == nil { + // no user exists yet - let's ask them for their preferred username + instance, errWithCode := m.processor.InstanceGet(c.Request.Context(), config.GetHost()) + if errWithCode != nil { + api.ErrorHandler(c, errWithCode, m.processor.InstanceGet) + return + } + // store the claims in the session - that way we know the user is authenticated when processing the form later + s.Set(sessionClaims, claims) + s.Set(sessionAppID, app.ID) + if err := s.Save(); err != nil { + m.clearSession(s) + api.ErrorHandler(c, gtserror.NewErrorInternalError(err), m.processor.InstanceGet) + return + } + c.HTML(http.StatusOK, "finalize.tmpl", gin.H{ + "instance": instance, + "name": claims.Name, + "preferredUsername": claims.PreferredUsername, + }) + return + } s.Set(sessionUserID, user.ID) if err := s.Save(); err != nil { m.clearSession(s) api.ErrorHandler(c, gtserror.NewErrorInternalError(err), m.processor.InstanceGet) return } + c.Redirect(http.StatusFound, OauthAuthorizePath) +} + +// FinalizePOSTHandler registers the user after additional data has been provided +func (m *Module) FinalizePOSTHandler(c *gin.Context) { + s := sessions.Default(c) + form := &extraInfo{} + if err := c.ShouldBind(form); err != nil { + m.clearSession(s) + api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, oauth.HelpfulAdvice), m.processor.InstanceGet) + return + } + + // since we have multiple possible validation error, `validationError` is a shorthand for rendering them + validationError := func(err error) { + instance, errWithCode := m.processor.InstanceGet(c.Request.Context(), config.GetHost()) + if errWithCode != nil { + api.ErrorHandler(c, errWithCode, m.processor.InstanceGet) + return + } + c.HTML(http.StatusOK, "finalize.tmpl", gin.H{ + "instance": instance, + "name": form.Name, + "preferredUsername": form.Username, + "error": err, + }) + } + + // check if the username conforms to the spec + if err := validate.Username(form.Username); err != nil { + validationError(err) + return + } + + // see if the username is still available + usernameAvailable, err := m.db.IsUsernameAvailable(c.Request.Context(), form.Username) + if err != nil { + api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, oauth.HelpfulAdvice), m.processor.InstanceGet) + return + } + if !usernameAvailable { + validationError(fmt.Errorf("Username %s is already taken", form.Username)) + return + } + + // retrieve the information previously set by the oidc logic + appID, ok := s.Get(sessionAppID).(string) + if !ok { + err := fmt.Errorf("key %s was not found in session", sessionAppID) + api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, oauth.HelpfulAdvice), m.processor.InstanceGet) + return + } + + // retrieve the claims returned by the IDP. Having this present means that we previously already verified these claims + claims, ok := s.Get(sessionClaims).(*oidc.Claims) + if !ok { + err := fmt.Errorf("key %s was not found in session", sessionClaims) + api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, oauth.HelpfulAdvice), m.processor.InstanceGet) + return + } + + // we're now ready to actually create the user + user, errWithCode := m.createUserFromOIDC(c.Request.Context(), claims, form, net.IP(c.ClientIP()), appID) + if errWithCode != nil { + m.clearSession(s) + api.ErrorHandler(c, errWithCode, m.processor.InstanceGet) + return + } + s.Delete(sessionClaims) + s.Delete(sessionAppID) + s.Set(sessionUserID, user.ID) + if err := s.Save(); err != nil { + m.clearSession(s) + api.ErrorHandler(c, gtserror.NewErrorInternalError(err), m.processor.InstanceGet) + return + } c.Redirect(http.StatusFound, OauthAuthorizePath) } -func (m *Module) parseUserFromClaims(ctx context.Context, claims *oidc.Claims, ip net.IP, appID string) (*gtsmodel.User, gtserror.WithCode) { - if claims.Email == "" { - err := errors.New("no email returned in claims") +func (m *Module) fetchUserForClaims(ctx context.Context, claims *oidc.Claims, ip net.IP, appID string) (*gtsmodel.User, gtserror.WithCode) { + if claims.Sub == "" { + err := errors.New("no sub claim found - is your provider OIDC compliant?") return nil, gtserror.NewErrorBadRequest(err, err.Error()) } - - // see if we already have a user for this email address - // if so, we don't need to continue + create one - user, err := m.db.GetUserByEmailAddress(ctx, claims.Email) + user, err := m.db.GetUserByExternalID(ctx, claims.Sub) if err == nil { return user, nil } - if err != db.ErrNoEntries { - err := fmt.Errorf("error checking database for email %s: %s", claims.Email, err) + err := fmt.Errorf("error checking database for externalID %s: %s", claims.Sub, err) return nil, gtserror.NewErrorInternalError(err) } - - // maybe we have an unconfirmed user - err = m.db.GetWhere(ctx, []db.Where{{Key: "unconfirmed_email", Value: claims.Email}}, user) - if err == nil { - err := fmt.Errorf("user with email address %s is unconfirmed", claims.Email) - return nil, gtserror.NewErrorForbidden(err, err.Error()) + if !config.GetOIDCLinkExisting() { + return nil, nil } - - if err != db.ErrNoEntries { + // fallback to email if we want to link existing users + user, err = m.db.GetUserByEmailAddress(ctx, claims.Email) + if err == db.ErrNoEntries { + return nil, nil + } else if err != nil { err := fmt.Errorf("error checking database for email %s: %s", claims.Email, err) return nil, gtserror.NewErrorInternalError(err) } + // at this point we have found a matching user but still need to link the newly received external ID - // we don't have a confirmed or unconfirmed user with the claimed email address - // however, because we trust the OIDC provider, we should now create a user + account with the provided claims + user.ExternalID = claims.Sub + err = m.db.UpdateUser(ctx, user, "external_id") + if err != nil { + err := fmt.Errorf("error linking existing user %s: %s", claims.Email, err) + return nil, gtserror.NewErrorInternalError(err) + } + return user, nil +} +func (m *Module) createUserFromOIDC(ctx context.Context, claims *oidc.Claims, extraInfo *extraInfo, ip net.IP, appID string) (*gtsmodel.User, gtserror.WithCode) { // check if the email address is available for use; if it's not there's nothing we can so emailAvailable, err := m.db.IsEmailAvailable(ctx, claims.Email) if err != nil { return nil, gtserror.NewErrorBadRequest(err) } if !emailAvailable { - return nil, gtserror.NewErrorConflict(fmt.Errorf("email address %s is not available", claims.Email)) - } - - // now we need a username - var username string - - // make sure claims.Name is defined since we'll be using that for the username - if claims.Name == "" { - err := errors.New("no name returned in claims") - return nil, gtserror.NewErrorBadRequest(err, err.Error()) - } - - // check if we can just use claims.Name as-is - if err = validate.Username(claims.Name); err == nil { - // the name we have on the claims is already a valid username - username = claims.Name - } else { - // not a valid username so we have to fiddle with it to try to make it valid - // first trim leading and trailing whitespace - trimmed := strings.TrimSpace(claims.Name) - // underscore any spaces in the middle of the name - underscored := strings.ReplaceAll(trimmed, " ", "_") - // lowercase the whole thing - lower := strings.ToLower(underscored) - // see if this is valid.... - if err := validate.Username(lower); err != nil { - err := fmt.Errorf("couldn't parse a valid username from claims.Name value of %s: %s", claims.Name, err) - return nil, gtserror.NewErrorBadRequest(err, err.Error()) - } - // we managed to get a valid username - username = lower - } - - var iString string - var found bool - // if the username isn't available we need to iterate on it until we find one that is - // we should try to do this in a predictable way so we just keep iterating i by one and trying - // the username with that number on the end - // - // note that for the first iteration, iString is still "" when the check is made, so our first choice - // is still the raw username with no integer stuck on the end - for i := 1; !found; i++ { - usernameAvailable, err := m.db.IsUsernameAvailable(ctx, username+iString) - if err != nil { - return nil, gtserror.NewErrorInternalError(err) - } - if usernameAvailable { - // no error so we've found a username that works - found = true - username += iString - continue - } - iString = strconv.Itoa(i) + help := "The email address given to us by your authentication provider already exists in our records and the server administrator has not enabled account migration" + return nil, gtserror.NewErrorConflict(fmt.Errorf("email address %s is not available", claims.Email), help) } // check if the user is in any recognised admin groups @@ -246,7 +302,7 @@ func (m *Module) parseUserFromClaims(ctx context.Context, claims *oidc.Claims, i emailVerified := true // create the user! this will also create an account and store it in the database so we don't need to do that here - user, err = m.db.NewSignup(ctx, username, "", requireApproval, claims.Email, password, ip, "", appID, emailVerified, admin) + user, err := m.db.NewSignup(ctx, extraInfo.Username, "", requireApproval, claims.Email, password, ip, "", appID, emailVerified, claims.Sub, admin) if err != nil { return nil, gtserror.NewErrorInternalError(err) } |