diff options
author | 2021-08-25 15:34:33 +0200 | |
---|---|---|
committer | 2021-08-25 15:34:33 +0200 | |
commit | 2dc9fc1626507bb54417fc4a1920b847cafb27a2 (patch) | |
tree | 4ddeac479b923db38090aac8bd9209f3646851c1 /internal/api/client/auth | |
parent | Manually approves followers (#146) (diff) | |
download | gotosocial-2dc9fc1626507bb54417fc4a1920b847cafb27a2.tar.xz |
Pg to bun (#148)
* start moving to bun
* changing more stuff
* more
* and yet more
* tests passing
* seems stable now
* more big changes
* small fix
* little fixes
Diffstat (limited to 'internal/api/client/auth')
-rw-r--r-- | internal/api/client/auth/auth_test.go | 16 | ||||
-rw-r--r-- | internal/api/client/auth/authorize.go | 19 | ||||
-rw-r--r-- | internal/api/client/auth/callback.go | 30 | ||||
-rw-r--r-- | internal/api/client/auth/middleware.go | 8 | ||||
-rw-r--r-- | internal/api/client/auth/signin.go | 7 |
5 files changed, 39 insertions, 41 deletions
diff --git a/internal/api/client/auth/auth_test.go b/internal/api/client/auth/auth_test.go index 48d2a2508..3d5170f31 100644 --- a/internal/api/client/auth/auth_test.go +++ b/internal/api/client/auth/auth_test.go @@ -28,7 +28,7 @@ import ( "github.com/stretchr/testify/suite" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" - "github.com/superseriousbusiness/gotosocial/internal/db/pg" + "github.com/superseriousbusiness/gotosocial/internal/db/bundb" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/oauth" "golang.org/x/crypto/bcrypt" @@ -104,7 +104,7 @@ func (suite *AuthTestSuite) SetupTest() { log := logrus.New() log.SetLevel(logrus.TraceLevel) - db, err := pg.NewPostgresService(context.Background(), suite.config, log) + db, err := bundb.NewBunDBService(context.Background(), suite.config, log) if err != nil { logrus.Panicf("error creating database connection: %s", err) } @@ -120,23 +120,23 @@ func (suite *AuthTestSuite) SetupTest() { } for _, m := range models { - if err := suite.db.CreateTable(m); err != nil { + if err := suite.db.CreateTable(context.Background(), m); err != nil { logrus.Panicf("db connection error: %s", err) } } suite.oauthServer = oauth.New(suite.db, log) - if err := suite.db.Put(suite.testAccount); err != nil { + if err := suite.db.Put(context.Background(), suite.testAccount); err != nil { logrus.Panicf("could not insert test account into db: %s", err) } - if err := suite.db.Put(suite.testUser); err != nil { + if err := suite.db.Put(context.Background(), suite.testUser); err != nil { logrus.Panicf("could not insert test user into db: %s", err) } - if err := suite.db.Put(suite.testClient); err != nil { + if err := suite.db.Put(context.Background(), suite.testClient); err != nil { logrus.Panicf("could not insert test client into db: %s", err) } - if err := suite.db.Put(suite.testApplication); err != nil { + if err := suite.db.Put(context.Background(), suite.testApplication); err != nil { logrus.Panicf("could not insert test application into db: %s", err) } @@ -152,7 +152,7 @@ func (suite *AuthTestSuite) TearDownTest() { >smodel.Application{}, } for _, m := range models { - if err := suite.db.DropTable(m); err != nil { + if err := suite.db.DropTable(context.Background(), m); err != nil { logrus.Panicf("error dropping table: %s", err) } } diff --git a/internal/api/client/auth/authorize.go b/internal/api/client/auth/authorize.go index a10408723..0328f3b21 100644 --- a/internal/api/client/auth/authorize.go +++ b/internal/api/client/auth/authorize.go @@ -70,30 +70,23 @@ func (m *Module) AuthorizeGETHandler(c *gin.Context) { c.JSON(http.StatusInternalServerError, gin.H{"error": "no client_id found in session"}) return } - app := >smodel.Application{ - ClientID: clientID, - } - if err := m.db.GetWhere([]db.Where{{Key: sessionClientID, Value: app.ClientID}}, app); err != nil { + app := >smodel.Application{} + if err := m.db.GetWhere(c.Request.Context(), []db.Where{{Key: sessionClientID, Value: app.ClientID}}, app); err != nil { m.clearSession(s) c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("no application found for client id %s", clientID)}) return } // we can also use the userid of the user to fetch their username from the db to greet them nicely <3 - user := >smodel.User{ - ID: userID, - } - if err := m.db.GetByID(user.ID, user); err != nil { + user := >smodel.User{} + if err := m.db.GetByID(c.Request.Context(), user.ID, user); err != nil { m.clearSession(s) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } - acct := >smodel.Account{ - ID: user.AccountID, - } - - if err := m.db.GetByID(acct.ID, acct); err != nil { + acct, err := m.db.GetAccountByID(c.Request.Context(), user.AccountID) + if err != nil { m.clearSession(s) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return diff --git a/internal/api/client/auth/callback.go b/internal/api/client/auth/callback.go index a26838aa3..cbb429352 100644 --- a/internal/api/client/auth/callback.go +++ b/internal/api/client/auth/callback.go @@ -19,6 +19,7 @@ package auth import ( + "context" "errors" "fmt" "net" @@ -80,13 +81,13 @@ func (m *Module) CallbackGETHandler(c *gin.Context) { app := >smodel.Application{ ClientID: clientID, } - if err := m.db.GetWhere([]db.Where{{Key: sessionClientID, Value: app.ClientID}}, app); err != nil { + if err := m.db.GetWhere(c.Request.Context(), []db.Where{{Key: sessionClientID, Value: app.ClientID}}, app); err != nil { m.clearSession(s) c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("no application found for client id %s", clientID)}) return } - user, err := m.parseUserFromClaims(claims, net.IP(c.ClientIP()), app.ID) + user, err := m.parseUserFromClaims(c.Request.Context(), claims, net.IP(c.ClientIP()), app.ID) if err != nil { m.clearSession(s) c.JSON(http.StatusForbidden, gin.H{"error": err.Error()}) @@ -103,14 +104,14 @@ func (m *Module) CallbackGETHandler(c *gin.Context) { c.Redirect(http.StatusFound, OauthAuthorizePath) } -func (m *Module) parseUserFromClaims(claims *oidc.Claims, ip net.IP, appID string) (*gtsmodel.User, error) { +func (m *Module) parseUserFromClaims(ctx context.Context, claims *oidc.Claims, ip net.IP, appID string) (*gtsmodel.User, error) { if claims.Email == "" { return nil, errors.New("no email returned in claims") } // see if we already have a user for this email address user := >smodel.User{} - err := m.db.GetWhere([]db.Where{{Key: "email", Value: claims.Email}}, user) + err := m.db.GetWhere(ctx, []db.Where{{Key: "email", Value: claims.Email}}, user) if err == nil { // we do! so we can just return it return user, nil @@ -122,7 +123,7 @@ func (m *Module) parseUserFromClaims(claims *oidc.Claims, ip net.IP, appID strin } // maybe we have an unconfirmed user - err = m.db.GetWhere([]db.Where{{Key: "unconfirmed_email", Value: claims.Email}}, user) + err = m.db.GetWhere(ctx, []db.Where{{Key: "unconfirmed_email", Value: claims.Email}}, user) if err == nil { // user is unconfirmed so return an error return nil, fmt.Errorf("user with email address %s is unconfirmed", claims.Email) @@ -137,9 +138,13 @@ func (m *Module) parseUserFromClaims(claims *oidc.Claims, ip net.IP, appID strin // however, because we trust the OIDC provider, we should now create a user + account with the provided claims // check if the email address is available for use; if it's not there's nothing we can so - if err := m.db.IsEmailAvailable(claims.Email); err != nil { + emailAvailable, err := m.db.IsEmailAvailable(ctx, claims.Email) + if err != nil { return nil, fmt.Errorf("email %s not available: %s", claims.Email, err) } + if !emailAvailable { + return nil, fmt.Errorf("email %s in use", claims.Email) + } // now we need a username var username string @@ -180,12 +185,11 @@ func (m *Module) parseUserFromClaims(claims *oidc.Claims, ip net.IP, appID strin // note that for the first iteration, iString is still "" when the check is made, so our first choice // is still the raw username with no integer stuck on the end for i := 1; !found; i = i + 1 { - if err := m.db.IsUsernameAvailable(username + iString); err != nil { - if strings.Contains(err.Error(), "db error") { - // if there's an actual db error we should return - return nil, fmt.Errorf("error checking username availability: %s", err) - } - } else { + usernameAvailable, err := m.db.IsUsernameAvailable(ctx, username+iString) + if err != nil { + return nil, err + } + if usernameAvailable { // no error so we've found a username that works found = true username = username + iString @@ -209,7 +213,7 @@ func (m *Module) parseUserFromClaims(claims *oidc.Claims, ip net.IP, appID strin password := uuid.NewString() + uuid.NewString() // create the user! this will also create an account and store it in the database so we don't need to do that here - user, err = m.db.NewSignup(username, "", m.config.AccountsConfig.RequireApproval, claims.Email, password, ip, "", appID, claims.EmailVerified, admin) + user, err = m.db.NewSignup(ctx, username, "", m.config.AccountsConfig.RequireApproval, claims.Email, password, ip, "", appID, claims.EmailVerified, admin) if err != nil { return nil, fmt.Errorf("error creating user: %s", err) } diff --git a/internal/api/client/auth/middleware.go b/internal/api/client/auth/middleware.go index a734b2ceb..3599c7048 100644 --- a/internal/api/client/auth/middleware.go +++ b/internal/api/client/auth/middleware.go @@ -49,15 +49,15 @@ func (m *Module) OauthTokenMiddleware(c *gin.Context) { // fetch user's and account for this user id user := >smodel.User{} - if err := m.db.GetByID(uid, user); err != nil || user == nil { + if err := m.db.GetByID(c.Request.Context(), uid, user); err != nil || user == nil { l.Warnf("no user found for validated uid %s", uid) return } c.Set(oauth.SessionAuthorizedUser, user) l.Tracef("set gin context %s to %+v", oauth.SessionAuthorizedUser, user) - acct := >smodel.Account{} - if err := m.db.GetByID(user.AccountID, acct); err != nil || acct == nil { + acct, err := m.db.GetAccountByID(c.Request.Context(), user.AccountID) + if err != nil || acct == nil { l.Warnf("no account found for validated user %s", uid) return } @@ -69,7 +69,7 @@ func (m *Module) OauthTokenMiddleware(c *gin.Context) { if cid := ti.GetClientID(); cid != "" { l.Tracef("authenticated client %s with bearer token, scope is %s", cid, ti.GetScope()) app := >smodel.Application{} - if err := m.db.GetWhere([]db.Where{{Key: "client_id", Value: cid}}, app); err != nil { + if err := m.db.GetWhere(c.Request.Context(), []db.Where{{Key: "client_id", Value: cid}}, app); err != nil { l.Tracef("no app found for client %s", cid) } c.Set(oauth.SessionAuthorizedApplication, app) diff --git a/internal/api/client/auth/signin.go b/internal/api/client/auth/signin.go index 543505cbd..6b8bb93db 100644 --- a/internal/api/client/auth/signin.go +++ b/internal/api/client/auth/signin.go @@ -19,6 +19,7 @@ package auth import ( + "context" "errors" "net/http" @@ -74,7 +75,7 @@ func (m *Module) SignInPOSTHandler(c *gin.Context) { } l.Tracef("parsed form: %+v", form) - userid, err := m.ValidatePassword(form.Email, form.Password) + userid, err := m.ValidatePassword(c.Request.Context(), form.Email, form.Password) if err != nil { c.String(http.StatusForbidden, err.Error()) m.clearSession(s) @@ -96,7 +97,7 @@ func (m *Module) SignInPOSTHandler(c *gin.Context) { // The goal is to authenticate the password against the one for that email // address stored in the database. If OK, we return the userid (a ulid) for that user, // so that it can be used in further Oauth flows to generate a token/retreieve an oauth client from the db. -func (m *Module) ValidatePassword(email string, password string) (userid string, err error) { +func (m *Module) ValidatePassword(ctx context.Context, email string, password string) (userid string, err error) { l := m.log.WithField("func", "ValidatePassword") // make sure an email/password was provided and bail if not @@ -108,7 +109,7 @@ func (m *Module) ValidatePassword(email string, password string) (userid string, // first we select the user from the database based on email address, bail if no user found for that email gtsUser := >smodel.User{} - if err := m.db.GetWhere([]db.Where{{Key: "email", Value: email}}, gtsUser); err != nil { + if err := m.db.GetWhere(ctx, []db.Where{{Key: "email", Value: email}}, gtsUser); err != nil { l.Debugf("user %s was not retrievable from db during oauth authorization attempt: %s", email, err) return incorrectPassword() } |