diff options
Diffstat (limited to 'internal/api')
61 files changed, 108 insertions, 143 deletions
diff --git a/internal/api/client/account/accountcreate.go b/internal/api/client/account/accountcreate.go index 50e72655e..a9d672f80 100644 --- a/internal/api/client/account/accountcreate.go +++ b/internal/api/client/account/accountcreate.go @@ -101,7 +101,7 @@ func (m *Module) AccountCreatePOSTHandler(c *gin.Context) { form.IP = signUpIP - ti, err := m.processor.AccountCreate(authed, form) + ti, err := m.processor.AccountCreate(c.Request.Context(), authed, form) if err != nil { l.Errorf("internal server error while creating new account: %s", err) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) diff --git a/internal/api/client/account/accountget.go b/internal/api/client/account/accountget.go index a7f9d8c70..8bac1360b 100644 --- a/internal/api/client/account/accountget.go +++ b/internal/api/client/account/accountget.go @@ -70,7 +70,7 @@ func (m *Module) AccountGETHandler(c *gin.Context) { return } - acctInfo, err := m.processor.AccountGet(authed, targetAcctID) + acctInfo, err := m.processor.AccountGet(c.Request.Context(), authed, targetAcctID) if err != nil { c.JSON(http.StatusNotFound, gin.H{"error": "not found"}) return diff --git a/internal/api/client/account/accountupdate.go b/internal/api/client/account/accountupdate.go index f55f45f59..282d172ed 100644 --- a/internal/api/client/account/accountupdate.go +++ b/internal/api/client/account/accountupdate.go @@ -122,7 +122,7 @@ func (m *Module) AccountUpdateCredentialsPATCHHandler(c *gin.Context) { return } - acctSensitive, err := m.processor.AccountUpdate(authed, form) + acctSensitive, err := m.processor.AccountUpdate(c.Request.Context(), authed, form) if err != nil { l.Debugf("could not update account: %s", err) c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) diff --git a/internal/api/client/account/accountupdate_test.go b/internal/api/client/account/accountupdate_test.go index 349429625..8fc31171b 100644 --- a/internal/api/client/account/accountupdate_test.go +++ b/internal/api/client/account/accountupdate_test.go @@ -79,7 +79,7 @@ func (suite *AccountUpdateTestSuite) TestAccountUpdateCredentialsPATCHHandler() recorder := httptest.NewRecorder() ctx, _ := gin.CreateTestContext(recorder) ctx.Set(oauth.SessionAuthorizedAccount, suite.testAccounts["local_account_1"]) - ctx.Set(oauth.SessionAuthorizedToken, oauth.TokenToOauthToken(suite.testTokens["local_account_1"])) + ctx.Set(oauth.SessionAuthorizedToken, oauth.DBTokenToToken(suite.testTokens["local_account_1"])) ctx.Set(oauth.SessionAuthorizedApplication, suite.testApplications["application_1"]) ctx.Set(oauth.SessionAuthorizedUser, suite.testUsers["local_account_1"]) ctx.Request = httptest.NewRequest(http.MethodPatch, fmt.Sprintf("http://localhost:8080/%s", account.UpdateCredentialsPath), bytes.NewReader(requestBody.Bytes())) // the endpoint we're hitting diff --git a/internal/api/client/account/accountverify.go b/internal/api/client/account/accountverify.go index 4c77f3fa6..c5c40a03d 100644 --- a/internal/api/client/account/accountverify.go +++ b/internal/api/client/account/accountverify.go @@ -59,7 +59,7 @@ func (m *Module) AccountVerifyGETHandler(c *gin.Context) { return } - acctSensitive, err := m.processor.AccountGet(authed, authed.Account.ID) + acctSensitive, err := m.processor.AccountGet(c.Request.Context(), authed, authed.Account.ID) if err != nil { l.Debugf("error getting account from processor: %s", err) c.JSON(http.StatusInternalServerError, gin.H{"error": "internal server error"}) diff --git a/internal/api/client/account/block.go b/internal/api/client/account/block.go index 0d9d6c51b..243f90c5e 100644 --- a/internal/api/client/account/block.go +++ b/internal/api/client/account/block.go @@ -72,7 +72,7 @@ func (m *Module) AccountBlockPOSTHandler(c *gin.Context) { return } - relationship, errWithCode := m.processor.AccountBlockCreate(authed, targetAcctID) + relationship, errWithCode := m.processor.AccountBlockCreate(c.Request.Context(), authed, targetAcctID) if errWithCode != nil { c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()}) return diff --git a/internal/api/client/account/follow.go b/internal/api/client/account/follow.go index 985a5f821..8f3f7ad0a 100644 --- a/internal/api/client/account/follow.go +++ b/internal/api/client/account/follow.go @@ -99,7 +99,7 @@ func (m *Module) AccountFollowPOSTHandler(c *gin.Context) { } form.ID = targetAcctID - relationship, errWithCode := m.processor.AccountFollowCreate(authed, form) + relationship, errWithCode := m.processor.AccountFollowCreate(c.Request.Context(), authed, form) if errWithCode != nil { c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()}) return diff --git a/internal/api/client/account/followers.go b/internal/api/client/account/followers.go index 7e93544b8..4f30e9939 100644 --- a/internal/api/client/account/followers.go +++ b/internal/api/client/account/followers.go @@ -74,7 +74,7 @@ func (m *Module) AccountFollowersGETHandler(c *gin.Context) { return } - followers, errWithCode := m.processor.AccountFollowersGet(authed, targetAcctID) + followers, errWithCode := m.processor.AccountFollowersGet(c.Request.Context(), authed, targetAcctID) if errWithCode != nil { c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()}) return diff --git a/internal/api/client/account/following.go b/internal/api/client/account/following.go index e70265eb5..baac2c9d3 100644 --- a/internal/api/client/account/following.go +++ b/internal/api/client/account/following.go @@ -74,7 +74,7 @@ func (m *Module) AccountFollowingGETHandler(c *gin.Context) { return } - following, errWithCode := m.processor.AccountFollowingGet(authed, targetAcctID) + following, errWithCode := m.processor.AccountFollowingGet(c.Request.Context(), authed, targetAcctID) if errWithCode != nil { c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()}) return diff --git a/internal/api/client/account/relationships.go b/internal/api/client/account/relationships.go index 9dbc8c4bb..d350209af 100644 --- a/internal/api/client/account/relationships.go +++ b/internal/api/client/account/relationships.go @@ -71,7 +71,7 @@ func (m *Module) AccountRelationshipsGETHandler(c *gin.Context) { relationships := []model.Relationship{} for _, targetAccountID := range targetAccountIDs { - r, errWithCode := m.processor.AccountRelationshipGet(authed, targetAccountID) + r, errWithCode := m.processor.AccountRelationshipGet(c.Request.Context(), authed, targetAccountID) if err != nil { c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()}) return diff --git a/internal/api/client/account/statuses.go b/internal/api/client/account/statuses.go index 097ccc3cc..4841d86df 100644 --- a/internal/api/client/account/statuses.go +++ b/internal/api/client/account/statuses.go @@ -166,7 +166,7 @@ func (m *Module) AccountStatusesGETHandler(c *gin.Context) { mediaOnly = i } - statuses, errWithCode := m.processor.AccountStatusesGet(authed, targetAcctID, limit, excludeReplies, maxID, pinnedOnly, mediaOnly) + statuses, errWithCode := m.processor.AccountStatusesGet(c.Request.Context(), authed, targetAcctID, limit, excludeReplies, maxID, pinnedOnly, mediaOnly) if errWithCode != nil { l.Debugf("error from processor account statuses get: %s", errWithCode) c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()}) diff --git a/internal/api/client/account/unblock.go b/internal/api/client/account/unblock.go index d9a2f2881..7b16ac887 100644 --- a/internal/api/client/account/unblock.go +++ b/internal/api/client/account/unblock.go @@ -72,7 +72,7 @@ func (m *Module) AccountUnblockPOSTHandler(c *gin.Context) { return } - relationship, errWithCode := m.processor.AccountBlockRemove(authed, targetAcctID) + relationship, errWithCode := m.processor.AccountBlockRemove(c.Request.Context(), authed, targetAcctID) if errWithCode != nil { c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()}) return diff --git a/internal/api/client/account/unfollow.go b/internal/api/client/account/unfollow.go index 84a558c65..7ac9697d5 100644 --- a/internal/api/client/account/unfollow.go +++ b/internal/api/client/account/unfollow.go @@ -75,7 +75,7 @@ func (m *Module) AccountUnfollowPOSTHandler(c *gin.Context) { return } - relationship, errWithCode := m.processor.AccountFollowRemove(authed, targetAcctID) + relationship, errWithCode := m.processor.AccountFollowRemove(c.Request.Context(), authed, targetAcctID) if errWithCode != nil { l.Debug(errWithCode.Error()) c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()}) diff --git a/internal/api/client/admin/domainblockcreate.go b/internal/api/client/admin/domainblockcreate.go index d48c70408..9ef4c6f92 100644 --- a/internal/api/client/admin/domainblockcreate.go +++ b/internal/api/client/admin/domainblockcreate.go @@ -141,7 +141,7 @@ func (m *Module) DomainBlocksPOSTHandler(c *gin.Context) { if imp { // we're importing multiple blocks - domainBlocks, err := m.processor.AdminDomainBlocksImport(authed, form) + domainBlocks, err := m.processor.AdminDomainBlocksImport(c.Request.Context(), authed, form) if err != nil { l.Debugf("error importing domain blocks: %s", err) c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) @@ -150,7 +150,7 @@ func (m *Module) DomainBlocksPOSTHandler(c *gin.Context) { c.JSON(http.StatusOK, domainBlocks) } else { // we're just creating one block - domainBlock, err := m.processor.AdminDomainBlockCreate(authed, form) + domainBlock, err := m.processor.AdminDomainBlockCreate(c.Request.Context(), authed, form) if err != nil { l.Debugf("error creating domain block: %s", err) c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) diff --git a/internal/api/client/admin/domainblockdelete.go b/internal/api/client/admin/domainblockdelete.go index 9cd2fd711..64e4ef6de 100644 --- a/internal/api/client/admin/domainblockdelete.go +++ b/internal/api/client/admin/domainblockdelete.go @@ -68,7 +68,7 @@ func (m *Module) DomainBlockDELETEHandler(c *gin.Context) { return } - domainBlock, errWithCode := m.processor.AdminDomainBlockDelete(authed, domainBlockID) + domainBlock, errWithCode := m.processor.AdminDomainBlockDelete(c.Request.Context(), authed, domainBlockID) if errWithCode != nil { l.Debugf("error deleting domain block: %s", errWithCode.Error()) c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()}) diff --git a/internal/api/client/admin/domainblockget.go b/internal/api/client/admin/domainblockget.go index 86923d705..d23f99a8c 100644 --- a/internal/api/client/admin/domainblockget.go +++ b/internal/api/client/admin/domainblockget.go @@ -81,7 +81,7 @@ func (m *Module) DomainBlockGETHandler(c *gin.Context) { export = i } - domainBlock, err := m.processor.AdminDomainBlockGet(authed, domainBlockID, export) + domainBlock, err := m.processor.AdminDomainBlockGet(c.Request.Context(), authed, domainBlockID, export) if err != nil { l.Debugf("error getting domain block: %s", err) c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) diff --git a/internal/api/client/admin/domainblocksget.go b/internal/api/client/admin/domainblocksget.go index 77a287387..dad8250e0 100644 --- a/internal/api/client/admin/domainblocksget.go +++ b/internal/api/client/admin/domainblocksget.go @@ -81,7 +81,7 @@ func (m *Module) DomainBlocksGETHandler(c *gin.Context) { export = i } - domainBlocks, err := m.processor.AdminDomainBlocksGet(authed, export) + domainBlocks, err := m.processor.AdminDomainBlocksGet(c.Request.Context(), authed, export) if err != nil { l.Debugf("error getting domain blocks: %s", err) c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) diff --git a/internal/api/client/admin/emojicreate.go b/internal/api/client/admin/emojicreate.go index bfdf28249..859933b16 100644 --- a/internal/api/client/admin/emojicreate.go +++ b/internal/api/client/admin/emojicreate.go @@ -111,7 +111,7 @@ func (m *Module) emojiCreatePOSTHandler(c *gin.Context) { return } - mastoEmoji, err := m.processor.AdminEmojiCreate(authed, form) + mastoEmoji, err := m.processor.AdminEmojiCreate(c.Request.Context(), authed, form) if err != nil { l.Debugf("error creating emoji: %s", err) c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) diff --git a/internal/api/client/app/appcreate.go b/internal/api/client/app/appcreate.go index 31072f9c8..d233841b0 100644 --- a/internal/api/client/app/appcreate.go +++ b/internal/api/client/app/appcreate.go @@ -101,7 +101,7 @@ func (m *Module) AppsPOSTHandler(c *gin.Context) { return } - mastoApp, err := m.processor.AppCreate(authed, form) + mastoApp, err := m.processor.AppCreate(c.Request.Context(), authed, form) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return diff --git a/internal/api/client/auth/auth_test.go b/internal/api/client/auth/auth_test.go index 48d2a2508..3d5170f31 100644 --- a/internal/api/client/auth/auth_test.go +++ b/internal/api/client/auth/auth_test.go @@ -28,7 +28,7 @@ import ( "github.com/stretchr/testify/suite" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" - "github.com/superseriousbusiness/gotosocial/internal/db/pg" + "github.com/superseriousbusiness/gotosocial/internal/db/bundb" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/oauth" "golang.org/x/crypto/bcrypt" @@ -104,7 +104,7 @@ func (suite *AuthTestSuite) SetupTest() { log := logrus.New() log.SetLevel(logrus.TraceLevel) - db, err := pg.NewPostgresService(context.Background(), suite.config, log) + db, err := bundb.NewBunDBService(context.Background(), suite.config, log) if err != nil { logrus.Panicf("error creating database connection: %s", err) } @@ -120,23 +120,23 @@ func (suite *AuthTestSuite) SetupTest() { } for _, m := range models { - if err := suite.db.CreateTable(m); err != nil { + if err := suite.db.CreateTable(context.Background(), m); err != nil { logrus.Panicf("db connection error: %s", err) } } suite.oauthServer = oauth.New(suite.db, log) - if err := suite.db.Put(suite.testAccount); err != nil { + if err := suite.db.Put(context.Background(), suite.testAccount); err != nil { logrus.Panicf("could not insert test account into db: %s", err) } - if err := suite.db.Put(suite.testUser); err != nil { + if err := suite.db.Put(context.Background(), suite.testUser); err != nil { logrus.Panicf("could not insert test user into db: %s", err) } - if err := suite.db.Put(suite.testClient); err != nil { + if err := suite.db.Put(context.Background(), suite.testClient); err != nil { logrus.Panicf("could not insert test client into db: %s", err) } - if err := suite.db.Put(suite.testApplication); err != nil { + if err := suite.db.Put(context.Background(), suite.testApplication); err != nil { logrus.Panicf("could not insert test application into db: %s", err) } @@ -152,7 +152,7 @@ func (suite *AuthTestSuite) TearDownTest() { >smodel.Application{}, } for _, m := range models { - if err := suite.db.DropTable(m); err != nil { + if err := suite.db.DropTable(context.Background(), m); err != nil { logrus.Panicf("error dropping table: %s", err) } } diff --git a/internal/api/client/auth/authorize.go b/internal/api/client/auth/authorize.go index a10408723..0328f3b21 100644 --- a/internal/api/client/auth/authorize.go +++ b/internal/api/client/auth/authorize.go @@ -70,30 +70,23 @@ func (m *Module) AuthorizeGETHandler(c *gin.Context) { c.JSON(http.StatusInternalServerError, gin.H{"error": "no client_id found in session"}) return } - app := >smodel.Application{ - ClientID: clientID, - } - if err := m.db.GetWhere([]db.Where{{Key: sessionClientID, Value: app.ClientID}}, app); err != nil { + app := >smodel.Application{} + if err := m.db.GetWhere(c.Request.Context(), []db.Where{{Key: sessionClientID, Value: app.ClientID}}, app); err != nil { m.clearSession(s) c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("no application found for client id %s", clientID)}) return } // we can also use the userid of the user to fetch their username from the db to greet them nicely <3 - user := >smodel.User{ - ID: userID, - } - if err := m.db.GetByID(user.ID, user); err != nil { + user := >smodel.User{} + if err := m.db.GetByID(c.Request.Context(), user.ID, user); err != nil { m.clearSession(s) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } - acct := >smodel.Account{ - ID: user.AccountID, - } - - if err := m.db.GetByID(acct.ID, acct); err != nil { + acct, err := m.db.GetAccountByID(c.Request.Context(), user.AccountID) + if err != nil { m.clearSession(s) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return diff --git a/internal/api/client/auth/callback.go b/internal/api/client/auth/callback.go index a26838aa3..cbb429352 100644 --- a/internal/api/client/auth/callback.go +++ b/internal/api/client/auth/callback.go @@ -19,6 +19,7 @@ package auth import ( + "context" "errors" "fmt" "net" @@ -80,13 +81,13 @@ func (m *Module) CallbackGETHandler(c *gin.Context) { app := >smodel.Application{ ClientID: clientID, } - if err := m.db.GetWhere([]db.Where{{Key: sessionClientID, Value: app.ClientID}}, app); err != nil { + if err := m.db.GetWhere(c.Request.Context(), []db.Where{{Key: sessionClientID, Value: app.ClientID}}, app); err != nil { m.clearSession(s) c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("no application found for client id %s", clientID)}) return } - user, err := m.parseUserFromClaims(claims, net.IP(c.ClientIP()), app.ID) + user, err := m.parseUserFromClaims(c.Request.Context(), claims, net.IP(c.ClientIP()), app.ID) if err != nil { m.clearSession(s) c.JSON(http.StatusForbidden, gin.H{"error": err.Error()}) @@ -103,14 +104,14 @@ func (m *Module) CallbackGETHandler(c *gin.Context) { c.Redirect(http.StatusFound, OauthAuthorizePath) } -func (m *Module) parseUserFromClaims(claims *oidc.Claims, ip net.IP, appID string) (*gtsmodel.User, error) { +func (m *Module) parseUserFromClaims(ctx context.Context, claims *oidc.Claims, ip net.IP, appID string) (*gtsmodel.User, error) { if claims.Email == "" { return nil, errors.New("no email returned in claims") } // see if we already have a user for this email address user := >smodel.User{} - err := m.db.GetWhere([]db.Where{{Key: "email", Value: claims.Email}}, user) + err := m.db.GetWhere(ctx, []db.Where{{Key: "email", Value: claims.Email}}, user) if err == nil { // we do! so we can just return it return user, nil @@ -122,7 +123,7 @@ func (m *Module) parseUserFromClaims(claims *oidc.Claims, ip net.IP, appID strin } // maybe we have an unconfirmed user - err = m.db.GetWhere([]db.Where{{Key: "unconfirmed_email", Value: claims.Email}}, user) + err = m.db.GetWhere(ctx, []db.Where{{Key: "unconfirmed_email", Value: claims.Email}}, user) if err == nil { // user is unconfirmed so return an error return nil, fmt.Errorf("user with email address %s is unconfirmed", claims.Email) @@ -137,9 +138,13 @@ func (m *Module) parseUserFromClaims(claims *oidc.Claims, ip net.IP, appID strin // however, because we trust the OIDC provider, we should now create a user + account with the provided claims // check if the email address is available for use; if it's not there's nothing we can so - if err := m.db.IsEmailAvailable(claims.Email); err != nil { + emailAvailable, err := m.db.IsEmailAvailable(ctx, claims.Email) + if err != nil { return nil, fmt.Errorf("email %s not available: %s", claims.Email, err) } + if !emailAvailable { + return nil, fmt.Errorf("email %s in use", claims.Email) + } // now we need a username var username string @@ -180,12 +185,11 @@ func (m *Module) parseUserFromClaims(claims *oidc.Claims, ip net.IP, appID strin // note that for the first iteration, iString is still "" when the check is made, so our first choice // is still the raw username with no integer stuck on the end for i := 1; !found; i = i + 1 { - if err := m.db.IsUsernameAvailable(username + iString); err != nil { - if strings.Contains(err.Error(), "db error") { - // if there's an actual db error we should return - return nil, fmt.Errorf("error checking username availability: %s", err) - } - } else { + usernameAvailable, err := m.db.IsUsernameAvailable(ctx, username+iString) + if err != nil { + return nil, err + } + if usernameAvailable { // no error so we've found a username that works found = true username = username + iString @@ -209,7 +213,7 @@ func (m *Module) parseUserFromClaims(claims *oidc.Claims, ip net.IP, appID strin password := uuid.NewString() + uuid.NewString() // create the user! this will also create an account and store it in the database so we don't need to do that here - user, err = m.db.NewSignup(username, "", m.config.AccountsConfig.RequireApproval, claims.Email, password, ip, "", appID, claims.EmailVerified, admin) + user, err = m.db.NewSignup(ctx, username, "", m.config.AccountsConfig.RequireApproval, claims.Email, password, ip, "", appID, claims.EmailVerified, admin) if err != nil { return nil, fmt.Errorf("error creating user: %s", err) } diff --git a/internal/api/client/auth/middleware.go b/internal/api/client/auth/middleware.go index a734b2ceb..3599c7048 100644 --- a/internal/api/client/auth/middleware.go +++ b/internal/api/client/auth/middleware.go @@ -49,15 +49,15 @@ func (m *Module) OauthTokenMiddleware(c *gin.Context) { // fetch user's and account for this user id user := >smodel.User{} - if err := m.db.GetByID(uid, user); err != nil || user == nil { + if err := m.db.GetByID(c.Request.Context(), uid, user); err != nil || user == nil { l.Warnf("no user found for validated uid %s", uid) return } c.Set(oauth.SessionAuthorizedUser, user) l.Tracef("set gin context %s to %+v", oauth.SessionAuthorizedUser, user) - acct := >smodel.Account{} - if err := m.db.GetByID(user.AccountID, acct); err != nil || acct == nil { + acct, err := m.db.GetAccountByID(c.Request.Context(), user.AccountID) + if err != nil || acct == nil { l.Warnf("no account found for validated user %s", uid) return } @@ -69,7 +69,7 @@ func (m *Module) OauthTokenMiddleware(c *gin.Context) { if cid := ti.GetClientID(); cid != "" { l.Tracef("authenticated client %s with bearer token, scope is %s", cid, ti.GetScope()) app := >smodel.Application{} - if err := m.db.GetWhere([]db.Where{{Key: "client_id", Value: cid}}, app); err != nil { + if err := m.db.GetWhere(c.Request.Context(), []db.Where{{Key: "client_id", Value: cid}}, app); err != nil { l.Tracef("no app found for client %s", cid) } c.Set(oauth.SessionAuthorizedApplication, app) diff --git a/internal/api/client/auth/signin.go b/internal/api/client/auth/signin.go index 543505cbd..6b8bb93db 100644 --- a/internal/api/client/auth/signin.go +++ b/internal/api/client/auth/signin.go @@ -19,6 +19,7 @@ package auth import ( + "context" "errors" "net/http" @@ -74,7 +75,7 @@ func (m *Module) SignInPOSTHandler(c *gin.Context) { } l.Tracef("parsed form: %+v", form) - userid, err := m.ValidatePassword(form.Email, form.Password) + userid, err := m.ValidatePassword(c.Request.Context(), form.Email, form.Password) if err != nil { c.String(http.StatusForbidden, err.Error()) m.clearSession(s) @@ -96,7 +97,7 @@ func (m *Module) SignInPOSTHandler(c *gin.Context) { // The goal is to authenticate the password against the one for that email // address stored in the database. If OK, we return the userid (a ulid) for that user, // so that it can be used in further Oauth flows to generate a token/retreieve an oauth client from the db. -func (m *Module) ValidatePassword(email string, password string) (userid string, err error) { +func (m *Module) ValidatePassword(ctx context.Context, email string, password string) (userid string, err error) { l := m.log.WithField("func", "ValidatePassword") // make sure an email/password was provided and bail if not @@ -108,7 +109,7 @@ func (m *Module) ValidatePassword(email string, password string) (userid string, // first we select the user from the database based on email address, bail if no user found for that email gtsUser := >smodel.User{} - if err := m.db.GetWhere([]db.Where{{Key: "email", Value: email}}, gtsUser); err != nil { + if err := m.db.GetWhere(ctx, []db.Where{{Key: "email", Value: email}}, gtsUser); err != nil { l.Debugf("user %s was not retrievable from db during oauth authorization attempt: %s", email, err) return incorrectPassword() } diff --git a/internal/api/client/blocks/blocksget.go b/internal/api/client/blocks/blocksget.go index 65c11ea1a..b6c9c39e1 100644 --- a/internal/api/client/blocks/blocksget.go +++ b/internal/api/client/blocks/blocksget.go @@ -117,7 +117,7 @@ func (m *Module) BlocksGETHandler(c *gin.Context) { limit = int(i) } - resp, errWithCode := m.processor.BlocksGet(authed, maxID, sinceID, limit) + resp, errWithCode := m.processor.BlocksGet(c.Request.Context(), authed, maxID, sinceID, limit) if errWithCode != nil { l.Debugf("error from processor BlocksGet: %s", errWithCode) c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()}) diff --git a/internal/api/client/favourites/favouritesget.go b/internal/api/client/favourites/favouritesget.go index 76eb921e0..6984ea754 100644 --- a/internal/api/client/favourites/favouritesget.go +++ b/internal/api/client/favourites/favouritesget.go @@ -43,7 +43,7 @@ func (m *Module) FavouritesGETHandler(c *gin.Context) { limit = int(i) } - resp, errWithCode := m.processor.FavedTimelineGet(authed, maxID, minID, limit) + resp, errWithCode := m.processor.FavedTimelineGet(c.Request.Context(), authed, maxID, minID, limit) if errWithCode != nil { l.Debugf("error from processor FavedTimelineGet: %s", errWithCode) c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()}) diff --git a/internal/api/client/fileserver/fileserver.go b/internal/api/client/fileserver/fileserver.go index 08e6abb62..61286c17a 100644 --- a/internal/api/client/fileserver/fileserver.go +++ b/internal/api/client/fileserver/fileserver.go @@ -25,8 +25,6 @@ import ( "github.com/sirupsen/logrus" "github.com/superseriousbusiness/gotosocial/internal/api" "github.com/superseriousbusiness/gotosocial/internal/config" - "github.com/superseriousbusiness/gotosocial/internal/db" - "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/processing" "github.com/superseriousbusiness/gotosocial/internal/router" ) @@ -66,17 +64,3 @@ func (m *FileServer) Route(s router.Router) error { s.AttachHandler(http.MethodGet, fmt.Sprintf("%s/:%s/:%s/:%s/:%s", m.storageBase, AccountIDKey, MediaTypeKey, MediaSizeKey, FileNameKey), m.ServeFile) return nil } - -// CreateTables populates necessary tables in the given DB -func (m *FileServer) CreateTables(db db.DB) error { - models := []interface{}{ - >smodel.MediaAttachment{}, - } - - for _, m := range models { - if err := db.CreateTable(m); err != nil { - return fmt.Errorf("error creating table: %s", err) - } - } - return nil -} diff --git a/internal/api/client/fileserver/servefile.go b/internal/api/client/fileserver/servefile.go index 1339fbac3..130a16c4f 100644 --- a/internal/api/client/fileserver/servefile.go +++ b/internal/api/client/fileserver/servefile.go @@ -78,7 +78,7 @@ func (m *FileServer) ServeFile(c *gin.Context) { return } - content, err := m.processor.FileGet(authed, &model.GetContentRequestForm{ + content, err := m.processor.FileGet(c.Request.Context(), authed, &model.GetContentRequestForm{ AccountID: accountID, MediaType: mediaType, MediaSize: mediaSize, diff --git a/internal/api/client/followrequest/accept.go b/internal/api/client/followrequest/accept.go index bb2910c8f..3dba7673f 100644 --- a/internal/api/client/followrequest/accept.go +++ b/internal/api/client/followrequest/accept.go @@ -48,7 +48,7 @@ func (m *Module) FollowRequestAcceptPOSTHandler(c *gin.Context) { return } - r, errWithCode := m.processor.FollowRequestAccept(authed, originAccountID) + r, errWithCode := m.processor.FollowRequestAccept(c.Request.Context(), authed, originAccountID) if errWithCode != nil { l.Debug(errWithCode.Error()) c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()}) diff --git a/internal/api/client/followrequest/get.go b/internal/api/client/followrequest/get.go index 3f02ee02a..47e1d74ba 100644 --- a/internal/api/client/followrequest/get.go +++ b/internal/api/client/followrequest/get.go @@ -41,7 +41,7 @@ func (m *Module) FollowRequestGETHandler(c *gin.Context) { return } - accts, errWithCode := m.processor.FollowRequestsGet(authed) + accts, errWithCode := m.processor.FollowRequestsGet(c.Request.Context(), authed) if errWithCode != nil { c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()}) return diff --git a/internal/api/client/instance/instanceget.go b/internal/api/client/instance/instanceget.go index 0d53edadb..0a6d17153 100644 --- a/internal/api/client/instance/instanceget.go +++ b/internal/api/client/instance/instanceget.go @@ -31,7 +31,7 @@ import ( func (m *Module) InstanceInformationGETHandler(c *gin.Context) { l := m.log.WithField("func", "InstanceInformationGETHandler") - instance, err := m.processor.InstanceGet(m.config.Host) + instance, err := m.processor.InstanceGet(c.Request.Context(), m.config.Host) if err != nil { l.Debugf("error getting instance from processor: %s", err) c.JSON(http.StatusInternalServerError, gin.H{"error": "internal server error"}) diff --git a/internal/api/client/instance/instancepatch.go b/internal/api/client/instance/instancepatch.go index 3620f6044..fa37ccd8e 100644 --- a/internal/api/client/instance/instancepatch.go +++ b/internal/api/client/instance/instancepatch.go @@ -116,7 +116,7 @@ func (m *Module) InstanceUpdatePATCHHandler(c *gin.Context) { return } - i, errWithCode := m.processor.InstancePatch(form) + i, errWithCode := m.processor.InstancePatch(c.Request.Context(), form) if errWithCode != nil { l.Debugf("error with instance patch request: %s", errWithCode.Error()) c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()}) diff --git a/internal/api/client/media/media.go b/internal/api/client/media/media.go index 05058e2e2..1e9e8fdaa 100644 --- a/internal/api/client/media/media.go +++ b/internal/api/client/media/media.go @@ -19,14 +19,11 @@ package media import ( - "fmt" "net/http" "github.com/sirupsen/logrus" "github.com/superseriousbusiness/gotosocial/internal/api" "github.com/superseriousbusiness/gotosocial/internal/config" - "github.com/superseriousbusiness/gotosocial/internal/db" - "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/processing" "github.com/superseriousbusiness/gotosocial/internal/router" ) @@ -63,17 +60,3 @@ func (m *Module) Route(s router.Router) error { s.AttachHandler(http.MethodPut, BasePathWithID, m.MediaPUTHandler) return nil } - -// CreateTables populates necessary tables in the given DB -func (m *Module) CreateTables(db db.DB) error { - models := []interface{}{ - >smodel.MediaAttachment{}, - } - - for _, m := range models { - if err := db.CreateTable(m); err != nil { - return fmt.Errorf("error creating table: %s", err) - } - } - return nil -} diff --git a/internal/api/client/media/mediacreate.go b/internal/api/client/media/mediacreate.go index f41d4568f..58d076ea6 100644 --- a/internal/api/client/media/mediacreate.go +++ b/internal/api/client/media/mediacreate.go @@ -108,7 +108,7 @@ func (m *Module) MediaCreatePOSTHandler(c *gin.Context) { } l.Debug("calling processor media create func") - mastoAttachment, err := m.processor.MediaCreate(authed, form) + mastoAttachment, err := m.processor.MediaCreate(c.Request.Context(), authed, form) if err != nil { l.Debugf("error creating attachment: %s", err) c.JSON(http.StatusUnprocessableEntity, gin.H{"error": err.Error()}) diff --git a/internal/api/client/media/mediacreate_test.go b/internal/api/client/media/mediacreate_test.go index 5c48a4381..8433786e4 100644 --- a/internal/api/client/media/mediacreate_test.go +++ b/internal/api/client/media/mediacreate_test.go @@ -121,7 +121,7 @@ func (suite *MediaCreateTestSuite) TestStatusCreatePOSTImageHandlerSuccessful() // set up the context for the request t := suite.testTokens["local_account_1"] - oauthToken := oauth.TokenToOauthToken(t) + oauthToken := oauth.DBTokenToToken(t) recorder := httptest.NewRecorder() ctx, _ := gin.CreateTestContext(recorder) ctx.Set(oauth.SessionAuthorizedApplication, suite.testApplications["application_1"]) diff --git a/internal/api/client/media/mediaget.go b/internal/api/client/media/mediaget.go index 17c5a090b..5fd7856e9 100644 --- a/internal/api/client/media/mediaget.go +++ b/internal/api/client/media/mediaget.go @@ -75,7 +75,7 @@ func (m *Module) MediaGETHandler(c *gin.Context) { return } - attachment, errWithCode := m.processor.MediaGet(authed, attachmentID) + attachment, errWithCode := m.processor.MediaGet(c.Request.Context(), authed, attachmentID) if errWithCode != nil { c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()}) return diff --git a/internal/api/client/media/mediaupdate.go b/internal/api/client/media/mediaupdate.go index 0ceb01f82..3af19297f 100644 --- a/internal/api/client/media/mediaupdate.go +++ b/internal/api/client/media/mediaupdate.go @@ -122,7 +122,7 @@ func (m *Module) MediaPUTHandler(c *gin.Context) { return } - attachment, errWithCode := m.processor.MediaUpdate(authed, attachmentID, &form) + attachment, errWithCode := m.processor.MediaUpdate(c.Request.Context(), authed, attachmentID, &form) if errWithCode != nil { c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()}) return diff --git a/internal/api/client/notification/notificationsget.go b/internal/api/client/notification/notificationsget.go index a30674750..81e8a6890 100644 --- a/internal/api/client/notification/notificationsget.go +++ b/internal/api/client/notification/notificationsget.go @@ -68,7 +68,7 @@ func (m *Module) NotificationsGETHandler(c *gin.Context) { sinceID = sinceIDString } - notifs, errWithCode := m.processor.NotificationsGet(authed, limit, maxID, sinceID) + notifs, errWithCode := m.processor.NotificationsGet(c.Request.Context(), authed, limit, maxID, sinceID) if errWithCode != nil { l.Debugf("error processing notifications get: %s", errWithCode.Error()) c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()}) diff --git a/internal/api/client/search/searchget.go b/internal/api/client/search/searchget.go index faa227719..848915274 100644 --- a/internal/api/client/search/searchget.go +++ b/internal/api/client/search/searchget.go @@ -164,7 +164,7 @@ func (m *Module) SearchGETHandler(c *gin.Context) { Following: following, } - results, errWithCode := m.processor.SearchGet(authed, searchQuery) + results, errWithCode := m.processor.SearchGet(c.Request.Context(), authed, searchQuery) if errWithCode != nil { l.Debugf("error searching: %s", errWithCode.Error()) c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()}) diff --git a/internal/api/client/status/statusboost.go b/internal/api/client/status/statusboost.go index 5aa7989bc..094e56ac0 100644 --- a/internal/api/client/status/statusboost.go +++ b/internal/api/client/status/statusboost.go @@ -87,7 +87,7 @@ func (m *Module) StatusBoostPOSTHandler(c *gin.Context) { return } - mastoStatus, errWithCode := m.processor.StatusBoost(authed, targetStatusID) + mastoStatus, errWithCode := m.processor.StatusBoost(c.Request.Context(), authed, targetStatusID) if errWithCode != nil { l.Debugf("error processing status boost: %s", errWithCode.Error()) c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()}) diff --git a/internal/api/client/status/statusboost_test.go b/internal/api/client/status/statusboost_test.go index fbe267fac..4157bde38 100644 --- a/internal/api/client/status/statusboost_test.go +++ b/internal/api/client/status/statusboost_test.go @@ -67,7 +67,7 @@ func (suite *StatusBoostTestSuite) TearDownTest() { func (suite *StatusBoostTestSuite) TestPostBoost() { t := suite.testTokens["local_account_1"] - oauthToken := oauth.TokenToOauthToken(t) + oauthToken := oauth.DBTokenToToken(t) targetStatus := suite.testStatuses["admin_account_status_1"] @@ -133,7 +133,7 @@ func (suite *StatusBoostTestSuite) TestPostBoost() { func (suite *StatusBoostTestSuite) TestPostUnboostable() { t := suite.testTokens["local_account_1"] - oauthToken := oauth.TokenToOauthToken(t) + oauthToken := oauth.DBTokenToToken(t) targetStatus := suite.testStatuses["local_account_2_status_4"] @@ -171,7 +171,7 @@ func (suite *StatusBoostTestSuite) TestPostUnboostable() { func (suite *StatusBoostTestSuite) TestPostNotVisible() { t := suite.testTokens["local_account_2"] - oauthToken := oauth.TokenToOauthToken(t) + oauthToken := oauth.DBTokenToToken(t) targetStatus := suite.testStatuses["local_account_1_status_3"] // this is a mutual only status and these accounts aren't mutuals diff --git a/internal/api/client/status/statusboostedby.go b/internal/api/client/status/statusboostedby.go index 260e21642..908c3ff10 100644 --- a/internal/api/client/status/statusboostedby.go +++ b/internal/api/client/status/statusboostedby.go @@ -84,7 +84,7 @@ func (m *Module) StatusBoostedByGETHandler(c *gin.Context) { return } - mastoAccounts, err := m.processor.StatusBoostedBy(authed, targetStatusID) + mastoAccounts, err := m.processor.StatusBoostedBy(c.Request.Context(), authed, targetStatusID) if err != nil { l.Debugf("error processing status boosted by request: %s", err) c.JSON(http.StatusBadRequest, gin.H{"error": "bad request"}) diff --git a/internal/api/client/status/statuscontext.go b/internal/api/client/status/statuscontext.go index 6e28b004e..90fcb9608 100644 --- a/internal/api/client/status/statuscontext.go +++ b/internal/api/client/status/statuscontext.go @@ -86,7 +86,7 @@ func (m *Module) StatusContextGETHandler(c *gin.Context) { return } - statusContext, errWithCode := m.processor.StatusGetContext(authed, targetStatusID) + statusContext, errWithCode := m.processor.StatusGetContext(c.Request.Context(), authed, targetStatusID) if errWithCode != nil { l.Debugf("error getting status context: %s", errWithCode.Error()) c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()}) diff --git a/internal/api/client/status/statuscreate.go b/internal/api/client/status/statuscreate.go index 2007ba80f..09fc47b5b 100644 --- a/internal/api/client/status/statuscreate.go +++ b/internal/api/client/status/statuscreate.go @@ -101,7 +101,7 @@ func (m *Module) StatusCreatePOSTHandler(c *gin.Context) { return } - mastoStatus, err := m.processor.StatusCreate(authed, form) + mastoStatus, err := m.processor.StatusCreate(c.Request.Context(), authed, form) if err != nil { l.Debugf("error processing status create: %s", err) c.JSON(http.StatusBadRequest, gin.H{"error": "bad request"}) diff --git a/internal/api/client/status/statuscreate_test.go b/internal/api/client/status/statuscreate_test.go index 33912397e..060d96dad 100644 --- a/internal/api/client/status/statuscreate_test.go +++ b/internal/api/client/status/statuscreate_test.go @@ -19,6 +19,7 @@ package status_test import ( + "context" "encoding/json" "fmt" "io/ioutil" @@ -82,7 +83,7 @@ https://docs.gotosocial.org/en/latest/user_guide/posts/#links func (suite *StatusCreateTestSuite) TestPostNewStatus() { t := suite.testTokens["local_account_1"] - oauthToken := oauth.TokenToOauthToken(t) + oauthToken := oauth.DBTokenToToken(t) // setup recorder := httptest.NewRecorder() @@ -128,7 +129,7 @@ func (suite *StatusCreateTestSuite) TestPostNewStatus() { }, statusReply.Tags[0]) gtsTag := >smodel.Tag{} - err = suite.db.GetWhere([]db.Where{{Key: "name", Value: "helloworld"}}, gtsTag) + err = suite.db.GetWhere(context.Background(), []db.Where{{Key: "name", Value: "helloworld"}}, gtsTag) assert.NoError(suite.T(), err) assert.Equal(suite.T(), statusReply.Account.ID, gtsTag.FirstSeenFromAccountID) } @@ -136,7 +137,7 @@ func (suite *StatusCreateTestSuite) TestPostNewStatus() { func (suite *StatusCreateTestSuite) TestPostAnotherNewStatus() { t := suite.testTokens["local_account_1"] - oauthToken := oauth.TokenToOauthToken(t) + oauthToken := oauth.DBTokenToToken(t) // setup recorder := httptest.NewRecorder() @@ -171,7 +172,7 @@ func (suite *StatusCreateTestSuite) TestPostAnotherNewStatus() { func (suite *StatusCreateTestSuite) TestPostNewStatusWithEmoji() { t := suite.testTokens["local_account_1"] - oauthToken := oauth.TokenToOauthToken(t) + oauthToken := oauth.DBTokenToToken(t) // setup recorder := httptest.NewRecorder() @@ -212,7 +213,7 @@ func (suite *StatusCreateTestSuite) TestPostNewStatusWithEmoji() { // Try to reply to a status that doesn't exist func (suite *StatusCreateTestSuite) TestReplyToNonexistentStatus() { t := suite.testTokens["local_account_1"] - oauthToken := oauth.TokenToOauthToken(t) + oauthToken := oauth.DBTokenToToken(t) // setup recorder := httptest.NewRecorder() @@ -243,7 +244,7 @@ func (suite *StatusCreateTestSuite) TestReplyToNonexistentStatus() { // Post a reply to the status of a local user that allows replies. func (suite *StatusCreateTestSuite) TestReplyToLocalStatus() { t := suite.testTokens["local_account_1"] - oauthToken := oauth.TokenToOauthToken(t) + oauthToken := oauth.DBTokenToToken(t) // setup recorder := httptest.NewRecorder() @@ -283,7 +284,7 @@ func (suite *StatusCreateTestSuite) TestReplyToLocalStatus() { // Take a media file which is currently not associated with a status, and attach it to a new status. func (suite *StatusCreateTestSuite) TestAttachNewMediaSuccess() { t := suite.testTokens["local_account_1"] - oauthToken := oauth.TokenToOauthToken(t) + oauthToken := oauth.DBTokenToToken(t) attachment := suite.testAttachments["local_account_1_unattached_1"] @@ -322,12 +323,11 @@ func (suite *StatusCreateTestSuite) TestAttachNewMediaSuccess() { assert.Len(suite.T(), statusResponse.MediaAttachments, 1) // get the updated media attachment from the database - gtsAttachment := >smodel.MediaAttachment{} - err = suite.db.GetByID(statusResponse.MediaAttachments[0].ID, gtsAttachment) + gtsAttachment, err := suite.db.GetAttachmentByID(context.Background(), statusResponse.MediaAttachments[0].ID) assert.NoError(suite.T(), err) // convert it to a masto attachment - gtsAttachmentAsMasto, err := suite.tc.AttachmentToMasto(gtsAttachment) + gtsAttachmentAsMasto, err := suite.tc.AttachmentToMasto(context.Background(), gtsAttachment) assert.NoError(suite.T(), err) // compare it with what we have now diff --git a/internal/api/client/status/statusdelete.go b/internal/api/client/status/statusdelete.go index 257280ce0..9a67c45ba 100644 --- a/internal/api/client/status/statusdelete.go +++ b/internal/api/client/status/statusdelete.go @@ -86,7 +86,7 @@ func (m *Module) StatusDELETEHandler(c *gin.Context) { return } - mastoStatus, err := m.processor.StatusDelete(authed, targetStatusID) + mastoStatus, err := m.processor.StatusDelete(c.Request.Context(), authed, targetStatusID) if err != nil { l.Debugf("error processing status delete: %s", err) c.JSON(http.StatusBadRequest, gin.H{"error": "bad request"}) diff --git a/internal/api/client/status/statusfave.go b/internal/api/client/status/statusfave.go index a76acf3d9..94a8f9380 100644 --- a/internal/api/client/status/statusfave.go +++ b/internal/api/client/status/statusfave.go @@ -83,7 +83,7 @@ func (m *Module) StatusFavePOSTHandler(c *gin.Context) { return } - mastoStatus, err := m.processor.StatusFave(authed, targetStatusID) + mastoStatus, err := m.processor.StatusFave(c.Request.Context(), authed, targetStatusID) if err != nil { l.Debugf("error processing status fave: %s", err) c.JSON(http.StatusBadRequest, gin.H{"error": "bad request"}) diff --git a/internal/api/client/status/statusfave_test.go b/internal/api/client/status/statusfave_test.go index 0f44b5e90..2f7a2c596 100644 --- a/internal/api/client/status/statusfave_test.go +++ b/internal/api/client/status/statusfave_test.go @@ -71,7 +71,7 @@ func (suite *StatusFaveTestSuite) TearDownTest() { func (suite *StatusFaveTestSuite) TestPostFave() { t := suite.testTokens["local_account_1"] - oauthToken := oauth.TokenToOauthToken(t) + oauthToken := oauth.DBTokenToToken(t) targetStatus := suite.testStatuses["admin_account_status_2"] @@ -119,7 +119,7 @@ func (suite *StatusFaveTestSuite) TestPostFave() { func (suite *StatusFaveTestSuite) TestPostUnfaveable() { t := suite.testTokens["local_account_1"] - oauthToken := oauth.TokenToOauthToken(t) + oauthToken := oauth.DBTokenToToken(t) targetStatus := suite.testStatuses["local_account_2_status_3"] // this one is unlikeable and unreplyable diff --git a/internal/api/client/status/statusfavedby.go b/internal/api/client/status/statusfavedby.go index a5d6e7c58..7b8e19e20 100644 --- a/internal/api/client/status/statusfavedby.go +++ b/internal/api/client/status/statusfavedby.go @@ -84,7 +84,7 @@ func (m *Module) StatusFavedByGETHandler(c *gin.Context) { return } - mastoAccounts, err := m.processor.StatusFavedBy(authed, targetStatusID) + mastoAccounts, err := m.processor.StatusFavedBy(c.Request.Context(), authed, targetStatusID) if err != nil { l.Debugf("error processing status faved by request: %s", err) c.JSON(http.StatusBadRequest, gin.H{"error": "bad request"}) diff --git a/internal/api/client/status/statusfavedby_test.go b/internal/api/client/status/statusfavedby_test.go index 22a549b30..7475f1e69 100644 --- a/internal/api/client/status/statusfavedby_test.go +++ b/internal/api/client/status/statusfavedby_test.go @@ -69,7 +69,7 @@ func (suite *StatusFavedByTestSuite) TearDownTest() { func (suite *StatusFavedByTestSuite) TestGetFavedBy() { t := suite.testTokens["local_account_2"] - oauthToken := oauth.TokenToOauthToken(t) + oauthToken := oauth.DBTokenToToken(t) targetStatus := suite.testStatuses["admin_account_status_1"] // this status is faved by local_account_1 diff --git a/internal/api/client/status/statusget.go b/internal/api/client/status/statusget.go index bcca010f5..39668288f 100644 --- a/internal/api/client/status/statusget.go +++ b/internal/api/client/status/statusget.go @@ -83,7 +83,7 @@ func (m *Module) StatusGETHandler(c *gin.Context) { return } - mastoStatus, err := m.processor.StatusGet(authed, targetStatusID) + mastoStatus, err := m.processor.StatusGet(c.Request.Context(), authed, targetStatusID) if err != nil { l.Debugf("error processing status get: %s", err) c.JSON(http.StatusBadRequest, gin.H{"error": "bad request"}) diff --git a/internal/api/client/status/statusunboost.go b/internal/api/client/status/statusunboost.go index dc42e3b62..8bb28fa50 100644 --- a/internal/api/client/status/statusunboost.go +++ b/internal/api/client/status/statusunboost.go @@ -84,7 +84,7 @@ func (m *Module) StatusUnboostPOSTHandler(c *gin.Context) { return } - mastoStatus, errWithCode := m.processor.StatusUnboost(authed, targetStatusID) + mastoStatus, errWithCode := m.processor.StatusUnboost(c.Request.Context(), authed, targetStatusID) if errWithCode != nil { l.Debugf("error processing status unboost: %s", errWithCode.Error()) c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()}) diff --git a/internal/api/client/status/statusunfave.go b/internal/api/client/status/statusunfave.go index 80eb87acf..5a1074ca2 100644 --- a/internal/api/client/status/statusunfave.go +++ b/internal/api/client/status/statusunfave.go @@ -83,7 +83,7 @@ func (m *Module) StatusUnfavePOSTHandler(c *gin.Context) { return } - mastoStatus, err := m.processor.StatusUnfave(authed, targetStatusID) + mastoStatus, err := m.processor.StatusUnfave(c.Request.Context(), authed, targetStatusID) if err != nil { l.Debugf("error processing status unfave: %s", err) c.JSON(http.StatusBadRequest, gin.H{"error": "bad request"}) diff --git a/internal/api/client/status/statusunfave_test.go b/internal/api/client/status/statusunfave_test.go index a5f267f4c..9e7ea8f82 100644 --- a/internal/api/client/status/statusunfave_test.go +++ b/internal/api/client/status/statusunfave_test.go @@ -71,7 +71,7 @@ func (suite *StatusUnfaveTestSuite) TearDownTest() { func (suite *StatusUnfaveTestSuite) TestPostUnfave() { t := suite.testTokens["local_account_1"] - oauthToken := oauth.TokenToOauthToken(t) + oauthToken := oauth.DBTokenToToken(t) // this is the status we wanna unfave: in the testrig it's already faved by this account targetStatus := suite.testStatuses["admin_account_status_1"] @@ -120,7 +120,7 @@ func (suite *StatusUnfaveTestSuite) TestPostUnfave() { func (suite *StatusUnfaveTestSuite) TestPostAlreadyNotFaved() { t := suite.testTokens["local_account_1"] - oauthToken := oauth.TokenToOauthToken(t) + oauthToken := oauth.DBTokenToToken(t) // this is the status we wanna unfave: in the testrig it's not faved by this account targetStatus := suite.testStatuses["admin_account_status_2"] diff --git a/internal/api/client/streaming/stream.go b/internal/api/client/streaming/stream.go index 626f1ff41..fa210e8d8 100644 --- a/internal/api/client/streaming/stream.go +++ b/internal/api/client/streaming/stream.go @@ -122,7 +122,7 @@ func (m *Module) StreamGETHandler(c *gin.Context) { } // make sure a valid token has been provided and obtain the associated account - account, err := m.processor.AuthorizeStreamingRequest(accessToken) + account, err := m.processor.AuthorizeStreamingRequest(c.Request.Context(), accessToken) if err != nil { c.JSON(http.StatusUnauthorized, gin.H{"error": "could not authorize with given token"}) return @@ -147,7 +147,7 @@ func (m *Module) StreamGETHandler(c *gin.Context) { defer conn.Close() // whatever happens, when we leave this function we want to close the websocket connection // inform the processor that we have a new connection and want a stream for it - stream, errWithCode := m.processor.OpenStreamForAccount(account, streamType) + stream, errWithCode := m.processor.OpenStreamForAccount(c.Request.Context(), account, streamType) if errWithCode != nil { c.JSON(errWithCode.Code(), errWithCode.Safe()) return diff --git a/internal/api/client/timeline/home.go b/internal/api/client/timeline/home.go index a6e64f384..6df4b29d0 100644 --- a/internal/api/client/timeline/home.go +++ b/internal/api/client/timeline/home.go @@ -153,7 +153,7 @@ func (m *Module) HomeTimelineGETHandler(c *gin.Context) { local = i } - resp, errWithCode := m.processor.HomeTimelineGet(authed, maxID, sinceID, minID, limit, local) + resp, errWithCode := m.processor.HomeTimelineGet(c.Request.Context(), authed, maxID, sinceID, minID, limit, local) if errWithCode != nil { l.Debugf("error from processor HomeTimelineGet: %s", errWithCode) c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()}) diff --git a/internal/api/client/timeline/public.go b/internal/api/client/timeline/public.go index 178fcd7f1..8c8c9f120 100644 --- a/internal/api/client/timeline/public.go +++ b/internal/api/client/timeline/public.go @@ -153,7 +153,7 @@ func (m *Module) PublicTimelineGETHandler(c *gin.Context) { local = i } - resp, errWithCode := m.processor.PublicTimelineGet(authed, maxID, sinceID, minID, limit, local) + resp, errWithCode := m.processor.PublicTimelineGet(c.Request.Context(), authed, maxID, sinceID, minID, limit, local) if errWithCode != nil { l.Debugf("error from processor PublicTimelineGet: %s", errWithCode) c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()}) diff --git a/internal/api/s2s/nodeinfo/nodeinfoget.go b/internal/api/s2s/nodeinfo/nodeinfoget.go index a54c8b190..c362e1d2e 100644 --- a/internal/api/s2s/nodeinfo/nodeinfoget.go +++ b/internal/api/s2s/nodeinfo/nodeinfoget.go @@ -33,7 +33,7 @@ func (m *Module) NodeInfoGETHandler(c *gin.Context) { "user-agent": c.Request.UserAgent(), }) - ni, err := m.processor.GetNodeInfo(c.Request) + ni, err := m.processor.GetNodeInfo(c.Request.Context(), c.Request) if err != nil { l.Debugf("error with get node info request: %s", err) c.JSON(err.Code(), err.Safe()) diff --git a/internal/api/s2s/nodeinfo/wellknownget.go b/internal/api/s2s/nodeinfo/wellknownget.go index 614d2a9c6..fd2c84408 100644 --- a/internal/api/s2s/nodeinfo/wellknownget.go +++ b/internal/api/s2s/nodeinfo/wellknownget.go @@ -33,7 +33,7 @@ func (m *Module) NodeInfoWellKnownGETHandler(c *gin.Context) { "user-agent": c.Request.UserAgent(), }) - niRel, err := m.processor.GetNodeInfoRel(c.Request) + niRel, err := m.processor.GetNodeInfoRel(c.Request.Context(), c.Request) if err != nil { l.Debugf("error with get node info rel request: %s", err) c.JSON(err.Code(), err.Safe()) diff --git a/internal/api/s2s/user/userget_test.go b/internal/api/s2s/user/userget_test.go index ab0015c57..29cc0e0d8 100644 --- a/internal/api/s2s/user/userget_test.go +++ b/internal/api/s2s/user/userget_test.go @@ -105,7 +105,7 @@ func (suite *UserGetTestSuite) TestGetUser() { // convert person to account // since this account is already known, we should get a pretty full model of it from the conversion - a, err := suite.tc.ASRepresentationToAccount(person, false) + a, err := suite.tc.ASRepresentationToAccount(context.Background(), person, false) assert.NoError(suite.T(), err) assert.EqualValues(suite.T(), targetAccount.Username, a.Username) } diff --git a/internal/api/security/signaturecheck.go b/internal/api/security/signaturecheck.go index 88b0b4dff..71e539e96 100644 --- a/internal/api/security/signaturecheck.go +++ b/internal/api/security/signaturecheck.go @@ -31,7 +31,7 @@ func (m *Module) SignatureCheck(c *gin.Context) { // we managed to parse the url! // if the domain is blocked we want to bail as early as possible - blocked, err := m.db.IsURIBlocked(requestingPublicKeyID) + blocked, err := m.db.IsURIBlocked(c.Request.Context(), requestingPublicKeyID) if err != nil { l.Errorf("could not tell if domain %s was blocked or not: %s", requestingPublicKeyID.Host, err) c.AbortWithStatus(http.StatusInternalServerError) |