summaryrefslogtreecommitdiff
path: root/internal/api/client/auth
diff options
context:
space:
mode:
Diffstat (limited to 'internal/api/client/auth')
-rw-r--r--internal/api/client/auth/authorize.go8
-rw-r--r--internal/api/client/auth/authorize_test.go10
-rw-r--r--internal/api/client/auth/callback.go3
-rw-r--r--internal/api/client/auth/signin.go6
4 files changed, 13 insertions, 14 deletions
diff --git a/internal/api/client/auth/authorize.go b/internal/api/client/auth/authorize.go
index 83cddd9b5..b345f9b01 100644
--- a/internal/api/client/auth/authorize.go
+++ b/internal/api/client/auth/authorize.go
@@ -94,8 +94,8 @@ func (m *Module) AuthorizeGETHandler(c *gin.Context) {
return
}
- user := &gtsmodel.User{}
- if err := m.db.GetByID(c.Request.Context(), userID, user); err != nil {
+ user, err := m.db.GetUserByID(c.Request.Context(), userID)
+ if err != nil {
m.clearSession(s)
safe := fmt.Sprintf("user with id %s could not be retrieved", userID)
var errWithCode gtserror.WithCode
@@ -213,8 +213,8 @@ func (m *Module) AuthorizePOSTHandler(c *gin.Context) {
return
}
- user := &gtsmodel.User{}
- if err := m.db.GetByID(c.Request.Context(), userID, user); err != nil {
+ user, err := m.db.GetUserByID(c.Request.Context(), userID)
+ if err != nil {
m.clearSession(s)
safe := fmt.Sprintf("user with id %s could not be retrieved", userID)
var errWithCode gtserror.WithCode
diff --git a/internal/api/client/auth/authorize_test.go b/internal/api/client/auth/authorize_test.go
index eab893416..fcc4b8caa 100644
--- a/internal/api/client/auth/authorize_test.go
+++ b/internal/api/client/auth/authorize_test.go
@@ -76,8 +76,11 @@ func (suite *AuthAuthorizeTestSuite) TestAccountAuthorizeHandler() {
doTest := func(testCase authorizeHandlerTestCase) {
ctx, recorder := suite.newContext(http.MethodGet, auth.OauthAuthorizePath, nil, "")
- user := suite.testUsers["unconfirmed_account"]
- account := suite.testAccounts["unconfirmed_account"]
+ user := &gtsmodel.User{}
+ account := &gtsmodel.Account{}
+
+ *user = *suite.testUsers["unconfirmed_account"]
+ *account = *suite.testAccounts["unconfirmed_account"]
testSession := sessions.Default(ctx)
testSession.Set(sessionUserID, user.ID)
@@ -91,8 +94,7 @@ func (suite *AuthAuthorizeTestSuite) TestAccountAuthorizeHandler() {
testCase.description = fmt.Sprintf("%s, %t, %s", user.Email, *user.Disabled, account.SuspendedAt)
updatingColumns = append(updatingColumns, "updated_at")
- user.UpdatedAt = time.Now()
- err := suite.db.UpdateByPrimaryKey(context.Background(), user, updatingColumns...)
+ _, err := suite.db.UpdateUser(context.Background(), user, updatingColumns...)
suite.NoError(err)
_, err = suite.db.UpdateAccount(context.Background(), account)
suite.NoError(err)
diff --git a/internal/api/client/auth/callback.go b/internal/api/client/auth/callback.go
index 96a73a52f..daee2ae31 100644
--- a/internal/api/client/auth/callback.go
+++ b/internal/api/client/auth/callback.go
@@ -134,8 +134,7 @@ func (m *Module) parseUserFromClaims(ctx context.Context, claims *oidc.Claims, i
// see if we already have a user for this email address
// if so, we don't need to continue + create one
- user := &gtsmodel.User{}
- err := m.db.GetWhere(ctx, []db.Where{{Key: "email", Value: claims.Email}}, user)
+ user, err := m.db.GetUserByEmailAddress(ctx, claims.Email)
if err == nil {
return user, nil
}
diff --git a/internal/api/client/auth/signin.go b/internal/api/client/auth/signin.go
index 58f3fad7e..06b601b10 100644
--- a/internal/api/client/auth/signin.go
+++ b/internal/api/client/auth/signin.go
@@ -28,9 +28,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/api"
"github.com/superseriousbusiness/gotosocial/internal/config"
- "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
- "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"golang.org/x/crypto/bcrypt"
)
@@ -119,8 +117,8 @@ func (m *Module) ValidatePassword(ctx context.Context, email string, password st
return incorrectPassword(err)
}
- user := &gtsmodel.User{}
- if err := m.db.GetWhere(ctx, []db.Where{{Key: "email", Value: email}}, user); err != nil {
+ user, err := m.db.GetUserByEmailAddress(ctx, email)
+ if err != nil {
err := fmt.Errorf("user %s was not retrievable from db during oauth authorization attempt: %s", email, err)
return incorrectPassword(err)
}