diff options
author | 2021-08-25 15:34:33 +0200 | |
---|---|---|
committer | 2021-08-25 15:34:33 +0200 | |
commit | 2dc9fc1626507bb54417fc4a1920b847cafb27a2 (patch) | |
tree | 4ddeac479b923db38090aac8bd9209f3646851c1 /internal/api/client/auth/callback.go | |
parent | Manually approves followers (#146) (diff) | |
download | gotosocial-2dc9fc1626507bb54417fc4a1920b847cafb27a2.tar.xz |
Pg to bun (#148)
* start moving to bun
* changing more stuff
* more
* and yet more
* tests passing
* seems stable now
* more big changes
* small fix
* little fixes
Diffstat (limited to 'internal/api/client/auth/callback.go')
-rw-r--r-- | internal/api/client/auth/callback.go | 30 |
1 files changed, 17 insertions, 13 deletions
diff --git a/internal/api/client/auth/callback.go b/internal/api/client/auth/callback.go index a26838aa3..cbb429352 100644 --- a/internal/api/client/auth/callback.go +++ b/internal/api/client/auth/callback.go @@ -19,6 +19,7 @@ package auth import ( + "context" "errors" "fmt" "net" @@ -80,13 +81,13 @@ func (m *Module) CallbackGETHandler(c *gin.Context) { app := >smodel.Application{ ClientID: clientID, } - if err := m.db.GetWhere([]db.Where{{Key: sessionClientID, Value: app.ClientID}}, app); err != nil { + if err := m.db.GetWhere(c.Request.Context(), []db.Where{{Key: sessionClientID, Value: app.ClientID}}, app); err != nil { m.clearSession(s) c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("no application found for client id %s", clientID)}) return } - user, err := m.parseUserFromClaims(claims, net.IP(c.ClientIP()), app.ID) + user, err := m.parseUserFromClaims(c.Request.Context(), claims, net.IP(c.ClientIP()), app.ID) if err != nil { m.clearSession(s) c.JSON(http.StatusForbidden, gin.H{"error": err.Error()}) @@ -103,14 +104,14 @@ func (m *Module) CallbackGETHandler(c *gin.Context) { c.Redirect(http.StatusFound, OauthAuthorizePath) } -func (m *Module) parseUserFromClaims(claims *oidc.Claims, ip net.IP, appID string) (*gtsmodel.User, error) { +func (m *Module) parseUserFromClaims(ctx context.Context, claims *oidc.Claims, ip net.IP, appID string) (*gtsmodel.User, error) { if claims.Email == "" { return nil, errors.New("no email returned in claims") } // see if we already have a user for this email address user := >smodel.User{} - err := m.db.GetWhere([]db.Where{{Key: "email", Value: claims.Email}}, user) + err := m.db.GetWhere(ctx, []db.Where{{Key: "email", Value: claims.Email}}, user) if err == nil { // we do! so we can just return it return user, nil @@ -122,7 +123,7 @@ func (m *Module) parseUserFromClaims(claims *oidc.Claims, ip net.IP, appID strin } // maybe we have an unconfirmed user - err = m.db.GetWhere([]db.Where{{Key: "unconfirmed_email", Value: claims.Email}}, user) + err = m.db.GetWhere(ctx, []db.Where{{Key: "unconfirmed_email", Value: claims.Email}}, user) if err == nil { // user is unconfirmed so return an error return nil, fmt.Errorf("user with email address %s is unconfirmed", claims.Email) @@ -137,9 +138,13 @@ func (m *Module) parseUserFromClaims(claims *oidc.Claims, ip net.IP, appID strin // however, because we trust the OIDC provider, we should now create a user + account with the provided claims // check if the email address is available for use; if it's not there's nothing we can so - if err := m.db.IsEmailAvailable(claims.Email); err != nil { + emailAvailable, err := m.db.IsEmailAvailable(ctx, claims.Email) + if err != nil { return nil, fmt.Errorf("email %s not available: %s", claims.Email, err) } + if !emailAvailable { + return nil, fmt.Errorf("email %s in use", claims.Email) + } // now we need a username var username string @@ -180,12 +185,11 @@ func (m *Module) parseUserFromClaims(claims *oidc.Claims, ip net.IP, appID strin // note that for the first iteration, iString is still "" when the check is made, so our first choice // is still the raw username with no integer stuck on the end for i := 1; !found; i = i + 1 { - if err := m.db.IsUsernameAvailable(username + iString); err != nil { - if strings.Contains(err.Error(), "db error") { - // if there's an actual db error we should return - return nil, fmt.Errorf("error checking username availability: %s", err) - } - } else { + usernameAvailable, err := m.db.IsUsernameAvailable(ctx, username+iString) + if err != nil { + return nil, err + } + if usernameAvailable { // no error so we've found a username that works found = true username = username + iString @@ -209,7 +213,7 @@ func (m *Module) parseUserFromClaims(claims *oidc.Claims, ip net.IP, appID strin password := uuid.NewString() + uuid.NewString() // create the user! this will also create an account and store it in the database so we don't need to do that here - user, err = m.db.NewSignup(username, "", m.config.AccountsConfig.RequireApproval, claims.Email, password, ip, "", appID, claims.EmailVerified, admin) + user, err = m.db.NewSignup(ctx, username, "", m.config.AccountsConfig.RequireApproval, claims.Email, password, ip, "", appID, claims.EmailVerified, admin) if err != nil { return nil, fmt.Errorf("error creating user: %s", err) } |