summaryrefslogtreecommitdiff
path: root/internal/api/client/auth
diff options
context:
space:
mode:
authorLibravatar tobi <31960611+tsmethurst@users.noreply.github.com>2021-08-25 15:34:33 +0200
committerLibravatar GitHub <noreply@github.com>2021-08-25 15:34:33 +0200
commit2dc9fc1626507bb54417fc4a1920b847cafb27a2 (patch)
tree4ddeac479b923db38090aac8bd9209f3646851c1 /internal/api/client/auth
parentManually approves followers (#146) (diff)
downloadgotosocial-2dc9fc1626507bb54417fc4a1920b847cafb27a2.tar.xz
Pg to bun (#148)
* start moving to bun * changing more stuff * more * and yet more * tests passing * seems stable now * more big changes * small fix * little fixes
Diffstat (limited to 'internal/api/client/auth')
-rw-r--r--internal/api/client/auth/auth_test.go16
-rw-r--r--internal/api/client/auth/authorize.go19
-rw-r--r--internal/api/client/auth/callback.go30
-rw-r--r--internal/api/client/auth/middleware.go8
-rw-r--r--internal/api/client/auth/signin.go7
5 files changed, 39 insertions, 41 deletions
diff --git a/internal/api/client/auth/auth_test.go b/internal/api/client/auth/auth_test.go
index 48d2a2508..3d5170f31 100644
--- a/internal/api/client/auth/auth_test.go
+++ b/internal/api/client/auth/auth_test.go
@@ -28,7 +28,7 @@ import (
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
- "github.com/superseriousbusiness/gotosocial/internal/db/pg"
+ "github.com/superseriousbusiness/gotosocial/internal/db/bundb"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"golang.org/x/crypto/bcrypt"
@@ -104,7 +104,7 @@ func (suite *AuthTestSuite) SetupTest() {
log := logrus.New()
log.SetLevel(logrus.TraceLevel)
- db, err := pg.NewPostgresService(context.Background(), suite.config, log)
+ db, err := bundb.NewBunDBService(context.Background(), suite.config, log)
if err != nil {
logrus.Panicf("error creating database connection: %s", err)
}
@@ -120,23 +120,23 @@ func (suite *AuthTestSuite) SetupTest() {
}
for _, m := range models {
- if err := suite.db.CreateTable(m); err != nil {
+ if err := suite.db.CreateTable(context.Background(), m); err != nil {
logrus.Panicf("db connection error: %s", err)
}
}
suite.oauthServer = oauth.New(suite.db, log)
- if err := suite.db.Put(suite.testAccount); err != nil {
+ if err := suite.db.Put(context.Background(), suite.testAccount); err != nil {
logrus.Panicf("could not insert test account into db: %s", err)
}
- if err := suite.db.Put(suite.testUser); err != nil {
+ if err := suite.db.Put(context.Background(), suite.testUser); err != nil {
logrus.Panicf("could not insert test user into db: %s", err)
}
- if err := suite.db.Put(suite.testClient); err != nil {
+ if err := suite.db.Put(context.Background(), suite.testClient); err != nil {
logrus.Panicf("could not insert test client into db: %s", err)
}
- if err := suite.db.Put(suite.testApplication); err != nil {
+ if err := suite.db.Put(context.Background(), suite.testApplication); err != nil {
logrus.Panicf("could not insert test application into db: %s", err)
}
@@ -152,7 +152,7 @@ func (suite *AuthTestSuite) TearDownTest() {
&gtsmodel.Application{},
}
for _, m := range models {
- if err := suite.db.DropTable(m); err != nil {
+ if err := suite.db.DropTable(context.Background(), m); err != nil {
logrus.Panicf("error dropping table: %s", err)
}
}
diff --git a/internal/api/client/auth/authorize.go b/internal/api/client/auth/authorize.go
index a10408723..0328f3b21 100644
--- a/internal/api/client/auth/authorize.go
+++ b/internal/api/client/auth/authorize.go
@@ -70,30 +70,23 @@ func (m *Module) AuthorizeGETHandler(c *gin.Context) {
c.JSON(http.StatusInternalServerError, gin.H{"error": "no client_id found in session"})
return
}
- app := &gtsmodel.Application{
- ClientID: clientID,
- }
- if err := m.db.GetWhere([]db.Where{{Key: sessionClientID, Value: app.ClientID}}, app); err != nil {
+ app := &gtsmodel.Application{}
+ if err := m.db.GetWhere(c.Request.Context(), []db.Where{{Key: sessionClientID, Value: app.ClientID}}, app); err != nil {
m.clearSession(s)
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("no application found for client id %s", clientID)})
return
}
// we can also use the userid of the user to fetch their username from the db to greet them nicely <3
- user := &gtsmodel.User{
- ID: userID,
- }
- if err := m.db.GetByID(user.ID, user); err != nil {
+ user := &gtsmodel.User{}
+ if err := m.db.GetByID(c.Request.Context(), user.ID, user); err != nil {
m.clearSession(s)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
- acct := &gtsmodel.Account{
- ID: user.AccountID,
- }
-
- if err := m.db.GetByID(acct.ID, acct); err != nil {
+ acct, err := m.db.GetAccountByID(c.Request.Context(), user.AccountID)
+ if err != nil {
m.clearSession(s)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
diff --git a/internal/api/client/auth/callback.go b/internal/api/client/auth/callback.go
index a26838aa3..cbb429352 100644
--- a/internal/api/client/auth/callback.go
+++ b/internal/api/client/auth/callback.go
@@ -19,6 +19,7 @@
package auth
import (
+ "context"
"errors"
"fmt"
"net"
@@ -80,13 +81,13 @@ func (m *Module) CallbackGETHandler(c *gin.Context) {
app := &gtsmodel.Application{
ClientID: clientID,
}
- if err := m.db.GetWhere([]db.Where{{Key: sessionClientID, Value: app.ClientID}}, app); err != nil {
+ if err := m.db.GetWhere(c.Request.Context(), []db.Where{{Key: sessionClientID, Value: app.ClientID}}, app); err != nil {
m.clearSession(s)
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("no application found for client id %s", clientID)})
return
}
- user, err := m.parseUserFromClaims(claims, net.IP(c.ClientIP()), app.ID)
+ user, err := m.parseUserFromClaims(c.Request.Context(), claims, net.IP(c.ClientIP()), app.ID)
if err != nil {
m.clearSession(s)
c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
@@ -103,14 +104,14 @@ func (m *Module) CallbackGETHandler(c *gin.Context) {
c.Redirect(http.StatusFound, OauthAuthorizePath)
}
-func (m *Module) parseUserFromClaims(claims *oidc.Claims, ip net.IP, appID string) (*gtsmodel.User, error) {
+func (m *Module) parseUserFromClaims(ctx context.Context, claims *oidc.Claims, ip net.IP, appID string) (*gtsmodel.User, error) {
if claims.Email == "" {
return nil, errors.New("no email returned in claims")
}
// see if we already have a user for this email address
user := &gtsmodel.User{}
- err := m.db.GetWhere([]db.Where{{Key: "email", Value: claims.Email}}, user)
+ err := m.db.GetWhere(ctx, []db.Where{{Key: "email", Value: claims.Email}}, user)
if err == nil {
// we do! so we can just return it
return user, nil
@@ -122,7 +123,7 @@ func (m *Module) parseUserFromClaims(claims *oidc.Claims, ip net.IP, appID strin
}
// maybe we have an unconfirmed user
- err = m.db.GetWhere([]db.Where{{Key: "unconfirmed_email", Value: claims.Email}}, user)
+ err = m.db.GetWhere(ctx, []db.Where{{Key: "unconfirmed_email", Value: claims.Email}}, user)
if err == nil {
// user is unconfirmed so return an error
return nil, fmt.Errorf("user with email address %s is unconfirmed", claims.Email)
@@ -137,9 +138,13 @@ func (m *Module) parseUserFromClaims(claims *oidc.Claims, ip net.IP, appID strin
// however, because we trust the OIDC provider, we should now create a user + account with the provided claims
// check if the email address is available for use; if it's not there's nothing we can so
- if err := m.db.IsEmailAvailable(claims.Email); err != nil {
+ emailAvailable, err := m.db.IsEmailAvailable(ctx, claims.Email)
+ if err != nil {
return nil, fmt.Errorf("email %s not available: %s", claims.Email, err)
}
+ if !emailAvailable {
+ return nil, fmt.Errorf("email %s in use", claims.Email)
+ }
// now we need a username
var username string
@@ -180,12 +185,11 @@ func (m *Module) parseUserFromClaims(claims *oidc.Claims, ip net.IP, appID strin
// note that for the first iteration, iString is still "" when the check is made, so our first choice
// is still the raw username with no integer stuck on the end
for i := 1; !found; i = i + 1 {
- if err := m.db.IsUsernameAvailable(username + iString); err != nil {
- if strings.Contains(err.Error(), "db error") {
- // if there's an actual db error we should return
- return nil, fmt.Errorf("error checking username availability: %s", err)
- }
- } else {
+ usernameAvailable, err := m.db.IsUsernameAvailable(ctx, username+iString)
+ if err != nil {
+ return nil, err
+ }
+ if usernameAvailable {
// no error so we've found a username that works
found = true
username = username + iString
@@ -209,7 +213,7 @@ func (m *Module) parseUserFromClaims(claims *oidc.Claims, ip net.IP, appID strin
password := uuid.NewString() + uuid.NewString()
// create the user! this will also create an account and store it in the database so we don't need to do that here
- user, err = m.db.NewSignup(username, "", m.config.AccountsConfig.RequireApproval, claims.Email, password, ip, "", appID, claims.EmailVerified, admin)
+ user, err = m.db.NewSignup(ctx, username, "", m.config.AccountsConfig.RequireApproval, claims.Email, password, ip, "", appID, claims.EmailVerified, admin)
if err != nil {
return nil, fmt.Errorf("error creating user: %s", err)
}
diff --git a/internal/api/client/auth/middleware.go b/internal/api/client/auth/middleware.go
index a734b2ceb..3599c7048 100644
--- a/internal/api/client/auth/middleware.go
+++ b/internal/api/client/auth/middleware.go
@@ -49,15 +49,15 @@ func (m *Module) OauthTokenMiddleware(c *gin.Context) {
// fetch user's and account for this user id
user := &gtsmodel.User{}
- if err := m.db.GetByID(uid, user); err != nil || user == nil {
+ if err := m.db.GetByID(c.Request.Context(), uid, user); err != nil || user == nil {
l.Warnf("no user found for validated uid %s", uid)
return
}
c.Set(oauth.SessionAuthorizedUser, user)
l.Tracef("set gin context %s to %+v", oauth.SessionAuthorizedUser, user)
- acct := &gtsmodel.Account{}
- if err := m.db.GetByID(user.AccountID, acct); err != nil || acct == nil {
+ acct, err := m.db.GetAccountByID(c.Request.Context(), user.AccountID)
+ if err != nil || acct == nil {
l.Warnf("no account found for validated user %s", uid)
return
}
@@ -69,7 +69,7 @@ func (m *Module) OauthTokenMiddleware(c *gin.Context) {
if cid := ti.GetClientID(); cid != "" {
l.Tracef("authenticated client %s with bearer token, scope is %s", cid, ti.GetScope())
app := &gtsmodel.Application{}
- if err := m.db.GetWhere([]db.Where{{Key: "client_id", Value: cid}}, app); err != nil {
+ if err := m.db.GetWhere(c.Request.Context(), []db.Where{{Key: "client_id", Value: cid}}, app); err != nil {
l.Tracef("no app found for client %s", cid)
}
c.Set(oauth.SessionAuthorizedApplication, app)
diff --git a/internal/api/client/auth/signin.go b/internal/api/client/auth/signin.go
index 543505cbd..6b8bb93db 100644
--- a/internal/api/client/auth/signin.go
+++ b/internal/api/client/auth/signin.go
@@ -19,6 +19,7 @@
package auth
import (
+ "context"
"errors"
"net/http"
@@ -74,7 +75,7 @@ func (m *Module) SignInPOSTHandler(c *gin.Context) {
}
l.Tracef("parsed form: %+v", form)
- userid, err := m.ValidatePassword(form.Email, form.Password)
+ userid, err := m.ValidatePassword(c.Request.Context(), form.Email, form.Password)
if err != nil {
c.String(http.StatusForbidden, err.Error())
m.clearSession(s)
@@ -96,7 +97,7 @@ func (m *Module) SignInPOSTHandler(c *gin.Context) {
// The goal is to authenticate the password against the one for that email
// address stored in the database. If OK, we return the userid (a ulid) for that user,
// so that it can be used in further Oauth flows to generate a token/retreieve an oauth client from the db.
-func (m *Module) ValidatePassword(email string, password string) (userid string, err error) {
+func (m *Module) ValidatePassword(ctx context.Context, email string, password string) (userid string, err error) {
l := m.log.WithField("func", "ValidatePassword")
// make sure an email/password was provided and bail if not
@@ -108,7 +109,7 @@ func (m *Module) ValidatePassword(email string, password string) (userid string,
// first we select the user from the database based on email address, bail if no user found for that email
gtsUser := &gtsmodel.User{}
- if err := m.db.GetWhere([]db.Where{{Key: "email", Value: email}}, gtsUser); err != nil {
+ if err := m.db.GetWhere(ctx, []db.Where{{Key: "email", Value: email}}, gtsUser); err != nil {
l.Debugf("user %s was not retrievable from db during oauth authorization attempt: %s", email, err)
return incorrectPassword()
}