diff options
Diffstat (limited to 'internal/api')
-rw-r--r-- | internal/api/client/auth/authorize.go | 8 | ||||
-rw-r--r-- | internal/api/client/auth/authorize_test.go | 10 | ||||
-rw-r--r-- | internal/api/client/auth/callback.go | 3 | ||||
-rw-r--r-- | internal/api/client/auth/signin.go | 6 | ||||
-rw-r--r-- | internal/api/security/tokencheck.go | 23 |
5 files changed, 26 insertions, 24 deletions
diff --git a/internal/api/client/auth/authorize.go b/internal/api/client/auth/authorize.go index 83cddd9b5..b345f9b01 100644 --- a/internal/api/client/auth/authorize.go +++ b/internal/api/client/auth/authorize.go @@ -94,8 +94,8 @@ func (m *Module) AuthorizeGETHandler(c *gin.Context) { return } - user := >smodel.User{} - if err := m.db.GetByID(c.Request.Context(), userID, user); err != nil { + user, err := m.db.GetUserByID(c.Request.Context(), userID) + if err != nil { m.clearSession(s) safe := fmt.Sprintf("user with id %s could not be retrieved", userID) var errWithCode gtserror.WithCode @@ -213,8 +213,8 @@ func (m *Module) AuthorizePOSTHandler(c *gin.Context) { return } - user := >smodel.User{} - if err := m.db.GetByID(c.Request.Context(), userID, user); err != nil { + user, err := m.db.GetUserByID(c.Request.Context(), userID) + if err != nil { m.clearSession(s) safe := fmt.Sprintf("user with id %s could not be retrieved", userID) var errWithCode gtserror.WithCode diff --git a/internal/api/client/auth/authorize_test.go b/internal/api/client/auth/authorize_test.go index eab893416..fcc4b8caa 100644 --- a/internal/api/client/auth/authorize_test.go +++ b/internal/api/client/auth/authorize_test.go @@ -76,8 +76,11 @@ func (suite *AuthAuthorizeTestSuite) TestAccountAuthorizeHandler() { doTest := func(testCase authorizeHandlerTestCase) { ctx, recorder := suite.newContext(http.MethodGet, auth.OauthAuthorizePath, nil, "") - user := suite.testUsers["unconfirmed_account"] - account := suite.testAccounts["unconfirmed_account"] + user := >smodel.User{} + account := >smodel.Account{} + + *user = *suite.testUsers["unconfirmed_account"] + *account = *suite.testAccounts["unconfirmed_account"] testSession := sessions.Default(ctx) testSession.Set(sessionUserID, user.ID) @@ -91,8 +94,7 @@ func (suite *AuthAuthorizeTestSuite) TestAccountAuthorizeHandler() { testCase.description = fmt.Sprintf("%s, %t, %s", user.Email, *user.Disabled, account.SuspendedAt) updatingColumns = append(updatingColumns, "updated_at") - user.UpdatedAt = time.Now() - err := suite.db.UpdateByPrimaryKey(context.Background(), user, updatingColumns...) + _, err := suite.db.UpdateUser(context.Background(), user, updatingColumns...) suite.NoError(err) _, err = suite.db.UpdateAccount(context.Background(), account) suite.NoError(err) diff --git a/internal/api/client/auth/callback.go b/internal/api/client/auth/callback.go index 96a73a52f..daee2ae31 100644 --- a/internal/api/client/auth/callback.go +++ b/internal/api/client/auth/callback.go @@ -134,8 +134,7 @@ func (m *Module) parseUserFromClaims(ctx context.Context, claims *oidc.Claims, i // see if we already have a user for this email address // if so, we don't need to continue + create one - user := >smodel.User{} - err := m.db.GetWhere(ctx, []db.Where{{Key: "email", Value: claims.Email}}, user) + user, err := m.db.GetUserByEmailAddress(ctx, claims.Email) if err == nil { return user, nil } diff --git a/internal/api/client/auth/signin.go b/internal/api/client/auth/signin.go index 58f3fad7e..06b601b10 100644 --- a/internal/api/client/auth/signin.go +++ b/internal/api/client/auth/signin.go @@ -28,9 +28,7 @@ import ( "github.com/gin-gonic/gin" "github.com/superseriousbusiness/gotosocial/internal/api" "github.com/superseriousbusiness/gotosocial/internal/config" - "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtserror" - "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "golang.org/x/crypto/bcrypt" ) @@ -119,8 +117,8 @@ func (m *Module) ValidatePassword(ctx context.Context, email string, password st return incorrectPassword(err) } - user := >smodel.User{} - if err := m.db.GetWhere(ctx, []db.Where{{Key: "email", Value: email}}, user); err != nil { + user, err := m.db.GetUserByEmailAddress(ctx, email) + if err != nil { err := fmt.Errorf("user %s was not retrievable from db during oauth authorization attempt: %s", email, err) return incorrectPassword(err) } diff --git a/internal/api/security/tokencheck.go b/internal/api/security/tokencheck.go index 3df7ee943..9f2b7f36e 100644 --- a/internal/api/security/tokencheck.go +++ b/internal/api/security/tokencheck.go @@ -52,8 +52,8 @@ func (m *Module) TokenCheck(c *gin.Context) { log.Tracef("authenticated user %s with bearer token, scope is %s", userID, ti.GetScope()) // fetch user for this token - user := >smodel.User{} - if err := m.db.GetByID(ctx, userID, user); err != nil { + user, err := m.db.GetUserByID(ctx, userID) + if err != nil { if err != db.ErrNoEntries { log.Errorf("database error looking for user with id %s: %s", userID, err) return @@ -80,22 +80,25 @@ func (m *Module) TokenCheck(c *gin.Context) { c.Set(oauth.SessionAuthorizedUser, user) // fetch account for this token - acct, err := m.db.GetAccountByID(ctx, user.AccountID) - if err != nil { - if err != db.ErrNoEntries { - log.Errorf("database error looking for account with id %s: %s", user.AccountID, err) + if user.Account == nil { + acct, err := m.db.GetAccountByID(ctx, user.AccountID) + if err != nil { + if err != db.ErrNoEntries { + log.Errorf("database error looking for account with id %s: %s", user.AccountID, err) + return + } + log.Warnf("no account found for userID %s", userID) return } - log.Warnf("no account found for userID %s", userID) - return + user.Account = acct } - if !acct.SuspendedAt.IsZero() { + if !user.Account.SuspendedAt.IsZero() { log.Warnf("authenticated user %s's account (accountId=%s) has been suspended", userID, user.AccountID) return } - c.Set(oauth.SessionAuthorizedAccount, acct) + c.Set(oauth.SessionAuthorizedAccount, user.Account) } // check for application token |