summaryrefslogtreecommitdiff
path: root/internal/api/client/auth/callback.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/api/client/auth/callback.go')
-rw-r--r--internal/api/client/auth/callback.go30
1 files changed, 17 insertions, 13 deletions
diff --git a/internal/api/client/auth/callback.go b/internal/api/client/auth/callback.go
index a26838aa3..cbb429352 100644
--- a/internal/api/client/auth/callback.go
+++ b/internal/api/client/auth/callback.go
@@ -19,6 +19,7 @@
package auth
import (
+ "context"
"errors"
"fmt"
"net"
@@ -80,13 +81,13 @@ func (m *Module) CallbackGETHandler(c *gin.Context) {
app := &gtsmodel.Application{
ClientID: clientID,
}
- if err := m.db.GetWhere([]db.Where{{Key: sessionClientID, Value: app.ClientID}}, app); err != nil {
+ if err := m.db.GetWhere(c.Request.Context(), []db.Where{{Key: sessionClientID, Value: app.ClientID}}, app); err != nil {
m.clearSession(s)
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("no application found for client id %s", clientID)})
return
}
- user, err := m.parseUserFromClaims(claims, net.IP(c.ClientIP()), app.ID)
+ user, err := m.parseUserFromClaims(c.Request.Context(), claims, net.IP(c.ClientIP()), app.ID)
if err != nil {
m.clearSession(s)
c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
@@ -103,14 +104,14 @@ func (m *Module) CallbackGETHandler(c *gin.Context) {
c.Redirect(http.StatusFound, OauthAuthorizePath)
}
-func (m *Module) parseUserFromClaims(claims *oidc.Claims, ip net.IP, appID string) (*gtsmodel.User, error) {
+func (m *Module) parseUserFromClaims(ctx context.Context, claims *oidc.Claims, ip net.IP, appID string) (*gtsmodel.User, error) {
if claims.Email == "" {
return nil, errors.New("no email returned in claims")
}
// see if we already have a user for this email address
user := &gtsmodel.User{}
- err := m.db.GetWhere([]db.Where{{Key: "email", Value: claims.Email}}, user)
+ err := m.db.GetWhere(ctx, []db.Where{{Key: "email", Value: claims.Email}}, user)
if err == nil {
// we do! so we can just return it
return user, nil
@@ -122,7 +123,7 @@ func (m *Module) parseUserFromClaims(claims *oidc.Claims, ip net.IP, appID strin
}
// maybe we have an unconfirmed user
- err = m.db.GetWhere([]db.Where{{Key: "unconfirmed_email", Value: claims.Email}}, user)
+ err = m.db.GetWhere(ctx, []db.Where{{Key: "unconfirmed_email", Value: claims.Email}}, user)
if err == nil {
// user is unconfirmed so return an error
return nil, fmt.Errorf("user with email address %s is unconfirmed", claims.Email)
@@ -137,9 +138,13 @@ func (m *Module) parseUserFromClaims(claims *oidc.Claims, ip net.IP, appID strin
// however, because we trust the OIDC provider, we should now create a user + account with the provided claims
// check if the email address is available for use; if it's not there's nothing we can so
- if err := m.db.IsEmailAvailable(claims.Email); err != nil {
+ emailAvailable, err := m.db.IsEmailAvailable(ctx, claims.Email)
+ if err != nil {
return nil, fmt.Errorf("email %s not available: %s", claims.Email, err)
}
+ if !emailAvailable {
+ return nil, fmt.Errorf("email %s in use", claims.Email)
+ }
// now we need a username
var username string
@@ -180,12 +185,11 @@ func (m *Module) parseUserFromClaims(claims *oidc.Claims, ip net.IP, appID strin
// note that for the first iteration, iString is still "" when the check is made, so our first choice
// is still the raw username with no integer stuck on the end
for i := 1; !found; i = i + 1 {
- if err := m.db.IsUsernameAvailable(username + iString); err != nil {
- if strings.Contains(err.Error(), "db error") {
- // if there's an actual db error we should return
- return nil, fmt.Errorf("error checking username availability: %s", err)
- }
- } else {
+ usernameAvailable, err := m.db.IsUsernameAvailable(ctx, username+iString)
+ if err != nil {
+ return nil, err
+ }
+ if usernameAvailable {
// no error so we've found a username that works
found = true
username = username + iString
@@ -209,7 +213,7 @@ func (m *Module) parseUserFromClaims(claims *oidc.Claims, ip net.IP, appID strin
password := uuid.NewString() + uuid.NewString()
// create the user! this will also create an account and store it in the database so we don't need to do that here
- user, err = m.db.NewSignup(username, "", m.config.AccountsConfig.RequireApproval, claims.Email, password, ip, "", appID, claims.EmailVerified, admin)
+ user, err = m.db.NewSignup(ctx, username, "", m.config.AccountsConfig.RequireApproval, claims.Email, password, ip, "", appID, claims.EmailVerified, admin)
if err != nil {
return nil, fmt.Errorf("error creating user: %s", err)
}