summaryrefslogtreecommitdiff
path: root/internal/api
diff options
context:
space:
mode:
Diffstat (limited to 'internal/api')
-rw-r--r--internal/api/client/auth/authorize.go8
-rw-r--r--internal/api/client/auth/authorize_test.go10
-rw-r--r--internal/api/client/auth/callback.go3
-rw-r--r--internal/api/client/auth/signin.go6
-rw-r--r--internal/api/security/tokencheck.go23
5 files changed, 26 insertions, 24 deletions
diff --git a/internal/api/client/auth/authorize.go b/internal/api/client/auth/authorize.go
index 83cddd9b5..b345f9b01 100644
--- a/internal/api/client/auth/authorize.go
+++ b/internal/api/client/auth/authorize.go
@@ -94,8 +94,8 @@ func (m *Module) AuthorizeGETHandler(c *gin.Context) {
return
}
- user := &gtsmodel.User{}
- if err := m.db.GetByID(c.Request.Context(), userID, user); err != nil {
+ user, err := m.db.GetUserByID(c.Request.Context(), userID)
+ if err != nil {
m.clearSession(s)
safe := fmt.Sprintf("user with id %s could not be retrieved", userID)
var errWithCode gtserror.WithCode
@@ -213,8 +213,8 @@ func (m *Module) AuthorizePOSTHandler(c *gin.Context) {
return
}
- user := &gtsmodel.User{}
- if err := m.db.GetByID(c.Request.Context(), userID, user); err != nil {
+ user, err := m.db.GetUserByID(c.Request.Context(), userID)
+ if err != nil {
m.clearSession(s)
safe := fmt.Sprintf("user with id %s could not be retrieved", userID)
var errWithCode gtserror.WithCode
diff --git a/internal/api/client/auth/authorize_test.go b/internal/api/client/auth/authorize_test.go
index eab893416..fcc4b8caa 100644
--- a/internal/api/client/auth/authorize_test.go
+++ b/internal/api/client/auth/authorize_test.go
@@ -76,8 +76,11 @@ func (suite *AuthAuthorizeTestSuite) TestAccountAuthorizeHandler() {
doTest := func(testCase authorizeHandlerTestCase) {
ctx, recorder := suite.newContext(http.MethodGet, auth.OauthAuthorizePath, nil, "")
- user := suite.testUsers["unconfirmed_account"]
- account := suite.testAccounts["unconfirmed_account"]
+ user := &gtsmodel.User{}
+ account := &gtsmodel.Account{}
+
+ *user = *suite.testUsers["unconfirmed_account"]
+ *account = *suite.testAccounts["unconfirmed_account"]
testSession := sessions.Default(ctx)
testSession.Set(sessionUserID, user.ID)
@@ -91,8 +94,7 @@ func (suite *AuthAuthorizeTestSuite) TestAccountAuthorizeHandler() {
testCase.description = fmt.Sprintf("%s, %t, %s", user.Email, *user.Disabled, account.SuspendedAt)
updatingColumns = append(updatingColumns, "updated_at")
- user.UpdatedAt = time.Now()
- err := suite.db.UpdateByPrimaryKey(context.Background(), user, updatingColumns...)
+ _, err := suite.db.UpdateUser(context.Background(), user, updatingColumns...)
suite.NoError(err)
_, err = suite.db.UpdateAccount(context.Background(), account)
suite.NoError(err)
diff --git a/internal/api/client/auth/callback.go b/internal/api/client/auth/callback.go
index 96a73a52f..daee2ae31 100644
--- a/internal/api/client/auth/callback.go
+++ b/internal/api/client/auth/callback.go
@@ -134,8 +134,7 @@ func (m *Module) parseUserFromClaims(ctx context.Context, claims *oidc.Claims, i
// see if we already have a user for this email address
// if so, we don't need to continue + create one
- user := &gtsmodel.User{}
- err := m.db.GetWhere(ctx, []db.Where{{Key: "email", Value: claims.Email}}, user)
+ user, err := m.db.GetUserByEmailAddress(ctx, claims.Email)
if err == nil {
return user, nil
}
diff --git a/internal/api/client/auth/signin.go b/internal/api/client/auth/signin.go
index 58f3fad7e..06b601b10 100644
--- a/internal/api/client/auth/signin.go
+++ b/internal/api/client/auth/signin.go
@@ -28,9 +28,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/api"
"github.com/superseriousbusiness/gotosocial/internal/config"
- "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
- "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"golang.org/x/crypto/bcrypt"
)
@@ -119,8 +117,8 @@ func (m *Module) ValidatePassword(ctx context.Context, email string, password st
return incorrectPassword(err)
}
- user := &gtsmodel.User{}
- if err := m.db.GetWhere(ctx, []db.Where{{Key: "email", Value: email}}, user); err != nil {
+ user, err := m.db.GetUserByEmailAddress(ctx, email)
+ if err != nil {
err := fmt.Errorf("user %s was not retrievable from db during oauth authorization attempt: %s", email, err)
return incorrectPassword(err)
}
diff --git a/internal/api/security/tokencheck.go b/internal/api/security/tokencheck.go
index 3df7ee943..9f2b7f36e 100644
--- a/internal/api/security/tokencheck.go
+++ b/internal/api/security/tokencheck.go
@@ -52,8 +52,8 @@ func (m *Module) TokenCheck(c *gin.Context) {
log.Tracef("authenticated user %s with bearer token, scope is %s", userID, ti.GetScope())
// fetch user for this token
- user := &gtsmodel.User{}
- if err := m.db.GetByID(ctx, userID, user); err != nil {
+ user, err := m.db.GetUserByID(ctx, userID)
+ if err != nil {
if err != db.ErrNoEntries {
log.Errorf("database error looking for user with id %s: %s", userID, err)
return
@@ -80,22 +80,25 @@ func (m *Module) TokenCheck(c *gin.Context) {
c.Set(oauth.SessionAuthorizedUser, user)
// fetch account for this token
- acct, err := m.db.GetAccountByID(ctx, user.AccountID)
- if err != nil {
- if err != db.ErrNoEntries {
- log.Errorf("database error looking for account with id %s: %s", user.AccountID, err)
+ if user.Account == nil {
+ acct, err := m.db.GetAccountByID(ctx, user.AccountID)
+ if err != nil {
+ if err != db.ErrNoEntries {
+ log.Errorf("database error looking for account with id %s: %s", user.AccountID, err)
+ return
+ }
+ log.Warnf("no account found for userID %s", userID)
return
}
- log.Warnf("no account found for userID %s", userID)
- return
+ user.Account = acct
}
- if !acct.SuspendedAt.IsZero() {
+ if !user.Account.SuspendedAt.IsZero() {
log.Warnf("authenticated user %s's account (accountId=%s) has been suspended", userID, user.AccountID)
return
}
- c.Set(oauth.SessionAuthorizedAccount, acct)
+ c.Set(oauth.SessionAuthorizedAccount, user.Account)
}
// check for application token