diff options
author | 2021-08-25 15:34:33 +0200 | |
---|---|---|
committer | 2021-08-25 15:34:33 +0200 | |
commit | 2dc9fc1626507bb54417fc4a1920b847cafb27a2 (patch) | |
tree | 4ddeac479b923db38090aac8bd9209f3646851c1 /internal | |
parent | Manually approves followers (#146) (diff) | |
download | gotosocial-2dc9fc1626507bb54417fc4a1920b847cafb27a2.tar.xz |
Pg to bun (#148)
* start moving to bun
* changing more stuff
* more
* and yet more
* tests passing
* seems stable now
* more big changes
* small fix
* little fixes
Diffstat (limited to 'internal')
249 files changed, 3778 insertions, 2944 deletions
diff --git a/internal/api/client/account/accountcreate.go b/internal/api/client/account/accountcreate.go index 50e72655e..a9d672f80 100644 --- a/internal/api/client/account/accountcreate.go +++ b/internal/api/client/account/accountcreate.go @@ -101,7 +101,7 @@ func (m *Module) AccountCreatePOSTHandler(c *gin.Context) { form.IP = signUpIP - ti, err := m.processor.AccountCreate(authed, form) + ti, err := m.processor.AccountCreate(c.Request.Context(), authed, form) if err != nil { l.Errorf("internal server error while creating new account: %s", err) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) diff --git a/internal/api/client/account/accountget.go b/internal/api/client/account/accountget.go index a7f9d8c70..8bac1360b 100644 --- a/internal/api/client/account/accountget.go +++ b/internal/api/client/account/accountget.go @@ -70,7 +70,7 @@ func (m *Module) AccountGETHandler(c *gin.Context) { return } - acctInfo, err := m.processor.AccountGet(authed, targetAcctID) + acctInfo, err := m.processor.AccountGet(c.Request.Context(), authed, targetAcctID) if err != nil { c.JSON(http.StatusNotFound, gin.H{"error": "not found"}) return diff --git a/internal/api/client/account/accountupdate.go b/internal/api/client/account/accountupdate.go index f55f45f59..282d172ed 100644 --- a/internal/api/client/account/accountupdate.go +++ b/internal/api/client/account/accountupdate.go @@ -122,7 +122,7 @@ func (m *Module) AccountUpdateCredentialsPATCHHandler(c *gin.Context) { return } - acctSensitive, err := m.processor.AccountUpdate(authed, form) + acctSensitive, err := m.processor.AccountUpdate(c.Request.Context(), authed, form) if err != nil { l.Debugf("could not update account: %s", err) c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) diff --git a/internal/api/client/account/accountupdate_test.go b/internal/api/client/account/accountupdate_test.go index 349429625..8fc31171b 100644 --- a/internal/api/client/account/accountupdate_test.go +++ b/internal/api/client/account/accountupdate_test.go @@ -79,7 +79,7 @@ func (suite *AccountUpdateTestSuite) TestAccountUpdateCredentialsPATCHHandler() recorder := httptest.NewRecorder() ctx, _ := gin.CreateTestContext(recorder) ctx.Set(oauth.SessionAuthorizedAccount, suite.testAccounts["local_account_1"]) - ctx.Set(oauth.SessionAuthorizedToken, oauth.TokenToOauthToken(suite.testTokens["local_account_1"])) + ctx.Set(oauth.SessionAuthorizedToken, oauth.DBTokenToToken(suite.testTokens["local_account_1"])) ctx.Set(oauth.SessionAuthorizedApplication, suite.testApplications["application_1"]) ctx.Set(oauth.SessionAuthorizedUser, suite.testUsers["local_account_1"]) ctx.Request = httptest.NewRequest(http.MethodPatch, fmt.Sprintf("http://localhost:8080/%s", account.UpdateCredentialsPath), bytes.NewReader(requestBody.Bytes())) // the endpoint we're hitting diff --git a/internal/api/client/account/accountverify.go b/internal/api/client/account/accountverify.go index 4c77f3fa6..c5c40a03d 100644 --- a/internal/api/client/account/accountverify.go +++ b/internal/api/client/account/accountverify.go @@ -59,7 +59,7 @@ func (m *Module) AccountVerifyGETHandler(c *gin.Context) { return } - acctSensitive, err := m.processor.AccountGet(authed, authed.Account.ID) + acctSensitive, err := m.processor.AccountGet(c.Request.Context(), authed, authed.Account.ID) if err != nil { l.Debugf("error getting account from processor: %s", err) c.JSON(http.StatusInternalServerError, gin.H{"error": "internal server error"}) diff --git a/internal/api/client/account/block.go b/internal/api/client/account/block.go index 0d9d6c51b..243f90c5e 100644 --- a/internal/api/client/account/block.go +++ b/internal/api/client/account/block.go @@ -72,7 +72,7 @@ func (m *Module) AccountBlockPOSTHandler(c *gin.Context) { return } - relationship, errWithCode := m.processor.AccountBlockCreate(authed, targetAcctID) + relationship, errWithCode := m.processor.AccountBlockCreate(c.Request.Context(), authed, targetAcctID) if errWithCode != nil { c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()}) return diff --git a/internal/api/client/account/follow.go b/internal/api/client/account/follow.go index 985a5f821..8f3f7ad0a 100644 --- a/internal/api/client/account/follow.go +++ b/internal/api/client/account/follow.go @@ -99,7 +99,7 @@ func (m *Module) AccountFollowPOSTHandler(c *gin.Context) { } form.ID = targetAcctID - relationship, errWithCode := m.processor.AccountFollowCreate(authed, form) + relationship, errWithCode := m.processor.AccountFollowCreate(c.Request.Context(), authed, form) if errWithCode != nil { c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()}) return diff --git a/internal/api/client/account/followers.go b/internal/api/client/account/followers.go index 7e93544b8..4f30e9939 100644 --- a/internal/api/client/account/followers.go +++ b/internal/api/client/account/followers.go @@ -74,7 +74,7 @@ func (m *Module) AccountFollowersGETHandler(c *gin.Context) { return } - followers, errWithCode := m.processor.AccountFollowersGet(authed, targetAcctID) + followers, errWithCode := m.processor.AccountFollowersGet(c.Request.Context(), authed, targetAcctID) if errWithCode != nil { c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()}) return diff --git a/internal/api/client/account/following.go b/internal/api/client/account/following.go index e70265eb5..baac2c9d3 100644 --- a/internal/api/client/account/following.go +++ b/internal/api/client/account/following.go @@ -74,7 +74,7 @@ func (m *Module) AccountFollowingGETHandler(c *gin.Context) { return } - following, errWithCode := m.processor.AccountFollowingGet(authed, targetAcctID) + following, errWithCode := m.processor.AccountFollowingGet(c.Request.Context(), authed, targetAcctID) if errWithCode != nil { c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()}) return diff --git a/internal/api/client/account/relationships.go b/internal/api/client/account/relationships.go index 9dbc8c4bb..d350209af 100644 --- a/internal/api/client/account/relationships.go +++ b/internal/api/client/account/relationships.go @@ -71,7 +71,7 @@ func (m *Module) AccountRelationshipsGETHandler(c *gin.Context) { relationships := []model.Relationship{} for _, targetAccountID := range targetAccountIDs { - r, errWithCode := m.processor.AccountRelationshipGet(authed, targetAccountID) + r, errWithCode := m.processor.AccountRelationshipGet(c.Request.Context(), authed, targetAccountID) if err != nil { c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()}) return diff --git a/internal/api/client/account/statuses.go b/internal/api/client/account/statuses.go index 097ccc3cc..4841d86df 100644 --- a/internal/api/client/account/statuses.go +++ b/internal/api/client/account/statuses.go @@ -166,7 +166,7 @@ func (m *Module) AccountStatusesGETHandler(c *gin.Context) { mediaOnly = i } - statuses, errWithCode := m.processor.AccountStatusesGet(authed, targetAcctID, limit, excludeReplies, maxID, pinnedOnly, mediaOnly) + statuses, errWithCode := m.processor.AccountStatusesGet(c.Request.Context(), authed, targetAcctID, limit, excludeReplies, maxID, pinnedOnly, mediaOnly) if errWithCode != nil { l.Debugf("error from processor account statuses get: %s", errWithCode) c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()}) diff --git a/internal/api/client/account/unblock.go b/internal/api/client/account/unblock.go index d9a2f2881..7b16ac887 100644 --- a/internal/api/client/account/unblock.go +++ b/internal/api/client/account/unblock.go @@ -72,7 +72,7 @@ func (m *Module) AccountUnblockPOSTHandler(c *gin.Context) { return } - relationship, errWithCode := m.processor.AccountBlockRemove(authed, targetAcctID) + relationship, errWithCode := m.processor.AccountBlockRemove(c.Request.Context(), authed, targetAcctID) if errWithCode != nil { c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()}) return diff --git a/internal/api/client/account/unfollow.go b/internal/api/client/account/unfollow.go index 84a558c65..7ac9697d5 100644 --- a/internal/api/client/account/unfollow.go +++ b/internal/api/client/account/unfollow.go @@ -75,7 +75,7 @@ func (m *Module) AccountUnfollowPOSTHandler(c *gin.Context) { return } - relationship, errWithCode := m.processor.AccountFollowRemove(authed, targetAcctID) + relationship, errWithCode := m.processor.AccountFollowRemove(c.Request.Context(), authed, targetAcctID) if errWithCode != nil { l.Debug(errWithCode.Error()) c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()}) diff --git a/internal/api/client/admin/domainblockcreate.go b/internal/api/client/admin/domainblockcreate.go index d48c70408..9ef4c6f92 100644 --- a/internal/api/client/admin/domainblockcreate.go +++ b/internal/api/client/admin/domainblockcreate.go @@ -141,7 +141,7 @@ func (m *Module) DomainBlocksPOSTHandler(c *gin.Context) { if imp { // we're importing multiple blocks - domainBlocks, err := m.processor.AdminDomainBlocksImport(authed, form) + domainBlocks, err := m.processor.AdminDomainBlocksImport(c.Request.Context(), authed, form) if err != nil { l.Debugf("error importing domain blocks: %s", err) c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) @@ -150,7 +150,7 @@ func (m *Module) DomainBlocksPOSTHandler(c *gin.Context) { c.JSON(http.StatusOK, domainBlocks) } else { // we're just creating one block - domainBlock, err := m.processor.AdminDomainBlockCreate(authed, form) + domainBlock, err := m.processor.AdminDomainBlockCreate(c.Request.Context(), authed, form) if err != nil { l.Debugf("error creating domain block: %s", err) c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) diff --git a/internal/api/client/admin/domainblockdelete.go b/internal/api/client/admin/domainblockdelete.go index 9cd2fd711..64e4ef6de 100644 --- a/internal/api/client/admin/domainblockdelete.go +++ b/internal/api/client/admin/domainblockdelete.go @@ -68,7 +68,7 @@ func (m *Module) DomainBlockDELETEHandler(c *gin.Context) { return } - domainBlock, errWithCode := m.processor.AdminDomainBlockDelete(authed, domainBlockID) + domainBlock, errWithCode := m.processor.AdminDomainBlockDelete(c.Request.Context(), authed, domainBlockID) if errWithCode != nil { l.Debugf("error deleting domain block: %s", errWithCode.Error()) c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()}) diff --git a/internal/api/client/admin/domainblockget.go b/internal/api/client/admin/domainblockget.go index 86923d705..d23f99a8c 100644 --- a/internal/api/client/admin/domainblockget.go +++ b/internal/api/client/admin/domainblockget.go @@ -81,7 +81,7 @@ func (m *Module) DomainBlockGETHandler(c *gin.Context) { export = i } - domainBlock, err := m.processor.AdminDomainBlockGet(authed, domainBlockID, export) + domainBlock, err := m.processor.AdminDomainBlockGet(c.Request.Context(), authed, domainBlockID, export) if err != nil { l.Debugf("error getting domain block: %s", err) c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) diff --git a/internal/api/client/admin/domainblocksget.go b/internal/api/client/admin/domainblocksget.go index 77a287387..dad8250e0 100644 --- a/internal/api/client/admin/domainblocksget.go +++ b/internal/api/client/admin/domainblocksget.go @@ -81,7 +81,7 @@ func (m *Module) DomainBlocksGETHandler(c *gin.Context) { export = i } - domainBlocks, err := m.processor.AdminDomainBlocksGet(authed, export) + domainBlocks, err := m.processor.AdminDomainBlocksGet(c.Request.Context(), authed, export) if err != nil { l.Debugf("error getting domain blocks: %s", err) c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) diff --git a/internal/api/client/admin/emojicreate.go b/internal/api/client/admin/emojicreate.go index bfdf28249..859933b16 100644 --- a/internal/api/client/admin/emojicreate.go +++ b/internal/api/client/admin/emojicreate.go @@ -111,7 +111,7 @@ func (m *Module) emojiCreatePOSTHandler(c *gin.Context) { return } - mastoEmoji, err := m.processor.AdminEmojiCreate(authed, form) + mastoEmoji, err := m.processor.AdminEmojiCreate(c.Request.Context(), authed, form) if err != nil { l.Debugf("error creating emoji: %s", err) c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) diff --git a/internal/api/client/app/appcreate.go b/internal/api/client/app/appcreate.go index 31072f9c8..d233841b0 100644 --- a/internal/api/client/app/appcreate.go +++ b/internal/api/client/app/appcreate.go @@ -101,7 +101,7 @@ func (m *Module) AppsPOSTHandler(c *gin.Context) { return } - mastoApp, err := m.processor.AppCreate(authed, form) + mastoApp, err := m.processor.AppCreate(c.Request.Context(), authed, form) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return diff --git a/internal/api/client/auth/auth_test.go b/internal/api/client/auth/auth_test.go index 48d2a2508..3d5170f31 100644 --- a/internal/api/client/auth/auth_test.go +++ b/internal/api/client/auth/auth_test.go @@ -28,7 +28,7 @@ import ( "github.com/stretchr/testify/suite" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" - "github.com/superseriousbusiness/gotosocial/internal/db/pg" + "github.com/superseriousbusiness/gotosocial/internal/db/bundb" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/oauth" "golang.org/x/crypto/bcrypt" @@ -104,7 +104,7 @@ func (suite *AuthTestSuite) SetupTest() { log := logrus.New() log.SetLevel(logrus.TraceLevel) - db, err := pg.NewPostgresService(context.Background(), suite.config, log) + db, err := bundb.NewBunDBService(context.Background(), suite.config, log) if err != nil { logrus.Panicf("error creating database connection: %s", err) } @@ -120,23 +120,23 @@ func (suite *AuthTestSuite) SetupTest() { } for _, m := range models { - if err := suite.db.CreateTable(m); err != nil { + if err := suite.db.CreateTable(context.Background(), m); err != nil { logrus.Panicf("db connection error: %s", err) } } suite.oauthServer = oauth.New(suite.db, log) - if err := suite.db.Put(suite.testAccount); err != nil { + if err := suite.db.Put(context.Background(), suite.testAccount); err != nil { logrus.Panicf("could not insert test account into db: %s", err) } - if err := suite.db.Put(suite.testUser); err != nil { + if err := suite.db.Put(context.Background(), suite.testUser); err != nil { logrus.Panicf("could not insert test user into db: %s", err) } - if err := suite.db.Put(suite.testClient); err != nil { + if err := suite.db.Put(context.Background(), suite.testClient); err != nil { logrus.Panicf("could not insert test client into db: %s", err) } - if err := suite.db.Put(suite.testApplication); err != nil { + if err := suite.db.Put(context.Background(), suite.testApplication); err != nil { logrus.Panicf("could not insert test application into db: %s", err) } @@ -152,7 +152,7 @@ func (suite *AuthTestSuite) TearDownTest() { >smodel.Application{}, } for _, m := range models { - if err := suite.db.DropTable(m); err != nil { + if err := suite.db.DropTable(context.Background(), m); err != nil { logrus.Panicf("error dropping table: %s", err) } } diff --git a/internal/api/client/auth/authorize.go b/internal/api/client/auth/authorize.go index a10408723..0328f3b21 100644 --- a/internal/api/client/auth/authorize.go +++ b/internal/api/client/auth/authorize.go @@ -70,30 +70,23 @@ func (m *Module) AuthorizeGETHandler(c *gin.Context) { c.JSON(http.StatusInternalServerError, gin.H{"error": "no client_id found in session"}) return } - app := >smodel.Application{ - ClientID: clientID, - } - if err := m.db.GetWhere([]db.Where{{Key: sessionClientID, Value: app.ClientID}}, app); err != nil { + app := >smodel.Application{} + if err := m.db.GetWhere(c.Request.Context(), []db.Where{{Key: sessionClientID, Value: app.ClientID}}, app); err != nil { m.clearSession(s) c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("no application found for client id %s", clientID)}) return } // we can also use the userid of the user to fetch their username from the db to greet them nicely <3 - user := >smodel.User{ - ID: userID, - } - if err := m.db.GetByID(user.ID, user); err != nil { + user := >smodel.User{} + if err := m.db.GetByID(c.Request.Context(), user.ID, user); err != nil { m.clearSession(s) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } - acct := >smodel.Account{ - ID: user.AccountID, - } - - if err := m.db.GetByID(acct.ID, acct); err != nil { + acct, err := m.db.GetAccountByID(c.Request.Context(), user.AccountID) + if err != nil { m.clearSession(s) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return diff --git a/internal/api/client/auth/callback.go b/internal/api/client/auth/callback.go index a26838aa3..cbb429352 100644 --- a/internal/api/client/auth/callback.go +++ b/internal/api/client/auth/callback.go @@ -19,6 +19,7 @@ package auth import ( + "context" "errors" "fmt" "net" @@ -80,13 +81,13 @@ func (m *Module) CallbackGETHandler(c *gin.Context) { app := >smodel.Application{ ClientID: clientID, } - if err := m.db.GetWhere([]db.Where{{Key: sessionClientID, Value: app.ClientID}}, app); err != nil { + if err := m.db.GetWhere(c.Request.Context(), []db.Where{{Key: sessionClientID, Value: app.ClientID}}, app); err != nil { m.clearSession(s) c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("no application found for client id %s", clientID)}) return } - user, err := m.parseUserFromClaims(claims, net.IP(c.ClientIP()), app.ID) + user, err := m.parseUserFromClaims(c.Request.Context(), claims, net.IP(c.ClientIP()), app.ID) if err != nil { m.clearSession(s) c.JSON(http.StatusForbidden, gin.H{"error": err.Error()}) @@ -103,14 +104,14 @@ func (m *Module) CallbackGETHandler(c *gin.Context) { c.Redirect(http.StatusFound, OauthAuthorizePath) } -func (m *Module) parseUserFromClaims(claims *oidc.Claims, ip net.IP, appID string) (*gtsmodel.User, error) { +func (m *Module) parseUserFromClaims(ctx context.Context, claims *oidc.Claims, ip net.IP, appID string) (*gtsmodel.User, error) { if claims.Email == "" { return nil, errors.New("no email returned in claims") } // see if we already have a user for this email address user := >smodel.User{} - err := m.db.GetWhere([]db.Where{{Key: "email", Value: claims.Email}}, user) + err := m.db.GetWhere(ctx, []db.Where{{Key: "email", Value: claims.Email}}, user) if err == nil { // we do! so we can just return it return user, nil @@ -122,7 +123,7 @@ func (m *Module) parseUserFromClaims(claims *oidc.Claims, ip net.IP, appID strin } // maybe we have an unconfirmed user - err = m.db.GetWhere([]db.Where{{Key: "unconfirmed_email", Value: claims.Email}}, user) + err = m.db.GetWhere(ctx, []db.Where{{Key: "unconfirmed_email", Value: claims.Email}}, user) if err == nil { // user is unconfirmed so return an error return nil, fmt.Errorf("user with email address %s is unconfirmed", claims.Email) @@ -137,9 +138,13 @@ func (m *Module) parseUserFromClaims(claims *oidc.Claims, ip net.IP, appID strin // however, because we trust the OIDC provider, we should now create a user + account with the provided claims // check if the email address is available for use; if it's not there's nothing we can so - if err := m.db.IsEmailAvailable(claims.Email); err != nil { + emailAvailable, err := m.db.IsEmailAvailable(ctx, claims.Email) + if err != nil { return nil, fmt.Errorf("email %s not available: %s", claims.Email, err) } + if !emailAvailable { + return nil, fmt.Errorf("email %s in use", claims.Email) + } // now we need a username var username string @@ -180,12 +185,11 @@ func (m *Module) parseUserFromClaims(claims *oidc.Claims, ip net.IP, appID strin // note that for the first iteration, iString is still "" when the check is made, so our first choice // is still the raw username with no integer stuck on the end for i := 1; !found; i = i + 1 { - if err := m.db.IsUsernameAvailable(username + iString); err != nil { - if strings.Contains(err.Error(), "db error") { - // if there's an actual db error we should return - return nil, fmt.Errorf("error checking username availability: %s", err) - } - } else { + usernameAvailable, err := m.db.IsUsernameAvailable(ctx, username+iString) + if err != nil { + return nil, err + } + if usernameAvailable { // no error so we've found a username that works found = true username = username + iString @@ -209,7 +213,7 @@ func (m *Module) parseUserFromClaims(claims *oidc.Claims, ip net.IP, appID strin password := uuid.NewString() + uuid.NewString() // create the user! this will also create an account and store it in the database so we don't need to do that here - user, err = m.db.NewSignup(username, "", m.config.AccountsConfig.RequireApproval, claims.Email, password, ip, "", appID, claims.EmailVerified, admin) + user, err = m.db.NewSignup(ctx, username, "", m.config.AccountsConfig.RequireApproval, claims.Email, password, ip, "", appID, claims.EmailVerified, admin) if err != nil { return nil, fmt.Errorf("error creating user: %s", err) } diff --git a/internal/api/client/auth/middleware.go b/internal/api/client/auth/middleware.go index a734b2ceb..3599c7048 100644 --- a/internal/api/client/auth/middleware.go +++ b/internal/api/client/auth/middleware.go @@ -49,15 +49,15 @@ func (m *Module) OauthTokenMiddleware(c *gin.Context) { // fetch user's and account for this user id user := >smodel.User{} - if err := m.db.GetByID(uid, user); err != nil || user == nil { + if err := m.db.GetByID(c.Request.Context(), uid, user); err != nil || user == nil { l.Warnf("no user found for validated uid %s", uid) return } c.Set(oauth.SessionAuthorizedUser, user) l.Tracef("set gin context %s to %+v", oauth.SessionAuthorizedUser, user) - acct := >smodel.Account{} - if err := m.db.GetByID(user.AccountID, acct); err != nil || acct == nil { + acct, err := m.db.GetAccountByID(c.Request.Context(), user.AccountID) + if err != nil || acct == nil { l.Warnf("no account found for validated user %s", uid) return } @@ -69,7 +69,7 @@ func (m *Module) OauthTokenMiddleware(c *gin.Context) { if cid := ti.GetClientID(); cid != "" { l.Tracef("authenticated client %s with bearer token, scope is %s", cid, ti.GetScope()) app := >smodel.Application{} - if err := m.db.GetWhere([]db.Where{{Key: "client_id", Value: cid}}, app); err != nil { + if err := m.db.GetWhere(c.Request.Context(), []db.Where{{Key: "client_id", Value: cid}}, app); err != nil { l.Tracef("no app found for client %s", cid) } c.Set(oauth.SessionAuthorizedApplication, app) diff --git a/internal/api/client/auth/signin.go b/internal/api/client/auth/signin.go index 543505cbd..6b8bb93db 100644 --- a/internal/api/client/auth/signin.go +++ b/internal/api/client/auth/signin.go @@ -19,6 +19,7 @@ package auth import ( + "context" "errors" "net/http" @@ -74,7 +75,7 @@ func (m *Module) SignInPOSTHandler(c *gin.Context) { } l.Tracef("parsed form: %+v", form) - userid, err := m.ValidatePassword(form.Email, form.Password) + userid, err := m.ValidatePassword(c.Request.Context(), form.Email, form.Password) if err != nil { c.String(http.StatusForbidden, err.Error()) m.clearSession(s) @@ -96,7 +97,7 @@ func (m *Module) SignInPOSTHandler(c *gin.Context) { // The goal is to authenticate the password against the one for that email // address stored in the database. If OK, we return the userid (a ulid) for that user, // so that it can be used in further Oauth flows to generate a token/retreieve an oauth client from the db. -func (m *Module) ValidatePassword(email string, password string) (userid string, err error) { +func (m *Module) ValidatePassword(ctx context.Context, email string, password string) (userid string, err error) { l := m.log.WithField("func", "ValidatePassword") // make sure an email/password was provided and bail if not @@ -108,7 +109,7 @@ func (m *Module) ValidatePassword(email string, password string) (userid string, // first we select the user from the database based on email address, bail if no user found for that email gtsUser := >smodel.User{} - if err := m.db.GetWhere([]db.Where{{Key: "email", Value: email}}, gtsUser); err != nil { + if err := m.db.GetWhere(ctx, []db.Where{{Key: "email", Value: email}}, gtsUser); err != nil { l.Debugf("user %s was not retrievable from db during oauth authorization attempt: %s", email, err) return incorrectPassword() } diff --git a/internal/api/client/blocks/blocksget.go b/internal/api/client/blocks/blocksget.go index 65c11ea1a..b6c9c39e1 100644 --- a/internal/api/client/blocks/blocksget.go +++ b/internal/api/client/blocks/blocksget.go @@ -117,7 +117,7 @@ func (m *Module) BlocksGETHandler(c *gin.Context) { limit = int(i) } - resp, errWithCode := m.processor.BlocksGet(authed, maxID, sinceID, limit) + resp, errWithCode := m.processor.BlocksGet(c.Request.Context(), authed, maxID, sinceID, limit) if errWithCode != nil { l.Debugf("error from processor BlocksGet: %s", errWithCode) c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()}) diff --git a/internal/api/client/favourites/favouritesget.go b/internal/api/client/favourites/favouritesget.go index 76eb921e0..6984ea754 100644 --- a/internal/api/client/favourites/favouritesget.go +++ b/internal/api/client/favourites/favouritesget.go @@ -43,7 +43,7 @@ func (m *Module) FavouritesGETHandler(c *gin.Context) { limit = int(i) } - resp, errWithCode := m.processor.FavedTimelineGet(authed, maxID, minID, limit) + resp, errWithCode := m.processor.FavedTimelineGet(c.Request.Context(), authed, maxID, minID, limit) if errWithCode != nil { l.Debugf("error from processor FavedTimelineGet: %s", errWithCode) c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()}) diff --git a/internal/api/client/fileserver/fileserver.go b/internal/api/client/fileserver/fileserver.go index 08e6abb62..61286c17a 100644 --- a/internal/api/client/fileserver/fileserver.go +++ b/internal/api/client/fileserver/fileserver.go @@ -25,8 +25,6 @@ import ( "github.com/sirupsen/logrus" "github.com/superseriousbusiness/gotosocial/internal/api" "github.com/superseriousbusiness/gotosocial/internal/config" - "github.com/superseriousbusiness/gotosocial/internal/db" - "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/processing" "github.com/superseriousbusiness/gotosocial/internal/router" ) @@ -66,17 +64,3 @@ func (m *FileServer) Route(s router.Router) error { s.AttachHandler(http.MethodGet, fmt.Sprintf("%s/:%s/:%s/:%s/:%s", m.storageBase, AccountIDKey, MediaTypeKey, MediaSizeKey, FileNameKey), m.ServeFile) return nil } - -// CreateTables populates necessary tables in the given DB -func (m *FileServer) CreateTables(db db.DB) error { - models := []interface{}{ - >smodel.MediaAttachment{}, - } - - for _, m := range models { - if err := db.CreateTable(m); err != nil { - return fmt.Errorf("error creating table: %s", err) - } - } - return nil -} diff --git a/internal/api/client/fileserver/servefile.go b/internal/api/client/fileserver/servefile.go index 1339fbac3..130a16c4f 100644 --- a/internal/api/client/fileserver/servefile.go +++ b/internal/api/client/fileserver/servefile.go @@ -78,7 +78,7 @@ func (m *FileServer) ServeFile(c *gin.Context) { return } - content, err := m.processor.FileGet(authed, &model.GetContentRequestForm{ + content, err := m.processor.FileGet(c.Request.Context(), authed, &model.GetContentRequestForm{ AccountID: accountID, MediaType: mediaType, MediaSize: mediaSize, diff --git a/internal/api/client/followrequest/accept.go b/internal/api/client/followrequest/accept.go index bb2910c8f..3dba7673f 100644 --- a/internal/api/client/followrequest/accept.go +++ b/internal/api/client/followrequest/accept.go @@ -48,7 +48,7 @@ func (m *Module) FollowRequestAcceptPOSTHandler(c *gin.Context) { return } - r, errWithCode := m.processor.FollowRequestAccept(authed, originAccountID) + r, errWithCode := m.processor.FollowRequestAccept(c.Request.Context(), authed, originAccountID) if errWithCode != nil { l.Debug(errWithCode.Error()) c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()}) diff --git a/internal/api/client/followrequest/get.go b/internal/api/client/followrequest/get.go index 3f02ee02a..47e1d74ba 100644 --- a/internal/api/client/followrequest/get.go +++ b/internal/api/client/followrequest/get.go @@ -41,7 +41,7 @@ func (m *Module) FollowRequestGETHandler(c *gin.Context) { return } - accts, errWithCode := m.processor.FollowRequestsGet(authed) + accts, errWithCode := m.processor.FollowRequestsGet(c.Request.Context(), authed) if errWithCode != nil { c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()}) return diff --git a/internal/api/client/instance/instanceget.go b/internal/api/client/instance/instanceget.go index 0d53edadb..0a6d17153 100644 --- a/internal/api/client/instance/instanceget.go +++ b/internal/api/client/instance/instanceget.go @@ -31,7 +31,7 @@ import ( func (m *Module) InstanceInformationGETHandler(c *gin.Context) { l := m.log.WithField("func", "InstanceInformationGETHandler") - instance, err := m.processor.InstanceGet(m.config.Host) + instance, err := m.processor.InstanceGet(c.Request.Context(), m.config.Host) if err != nil { l.Debugf("error getting instance from processor: %s", err) c.JSON(http.StatusInternalServerError, gin.H{"error": "internal server error"}) diff --git a/internal/api/client/instance/instancepatch.go b/internal/api/client/instance/instancepatch.go index 3620f6044..fa37ccd8e 100644 --- a/internal/api/client/instance/instancepatch.go +++ b/internal/api/client/instance/instancepatch.go @@ -116,7 +116,7 @@ func (m *Module) InstanceUpdatePATCHHandler(c *gin.Context) { return } - i, errWithCode := m.processor.InstancePatch(form) + i, errWithCode := m.processor.InstancePatch(c.Request.Context(), form) if errWithCode != nil { l.Debugf("error with instance patch request: %s", errWithCode.Error()) c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()}) diff --git a/internal/api/client/media/media.go b/internal/api/client/media/media.go index 05058e2e2..1e9e8fdaa 100644 --- a/internal/api/client/media/media.go +++ b/internal/api/client/media/media.go @@ -19,14 +19,11 @@ package media import ( - "fmt" "net/http" "github.com/sirupsen/logrus" "github.com/superseriousbusiness/gotosocial/internal/api" "github.com/superseriousbusiness/gotosocial/internal/config" - "github.com/superseriousbusiness/gotosocial/internal/db" - "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/processing" "github.com/superseriousbusiness/gotosocial/internal/router" ) @@ -63,17 +60,3 @@ func (m *Module) Route(s router.Router) error { s.AttachHandler(http.MethodPut, BasePathWithID, m.MediaPUTHandler) return nil } - -// CreateTables populates necessary tables in the given DB -func (m *Module) CreateTables(db db.DB) error { - models := []interface{}{ - >smodel.MediaAttachment{}, - } - - for _, m := range models { - if err := db.CreateTable(m); err != nil { - return fmt.Errorf("error creating table: %s", err) - } - } - return nil -} diff --git a/internal/api/client/media/mediacreate.go b/internal/api/client/media/mediacreate.go index f41d4568f..58d076ea6 100644 --- a/internal/api/client/media/mediacreate.go +++ b/internal/api/client/media/mediacreate.go @@ -108,7 +108,7 @@ func (m *Module) MediaCreatePOSTHandler(c *gin.Context) { } l.Debug("calling processor media create func") - mastoAttachment, err := m.processor.MediaCreate(authed, form) + mastoAttachment, err := m.processor.MediaCreate(c.Request.Context(), authed, form) if err != nil { l.Debugf("error creating attachment: %s", err) c.JSON(http.StatusUnprocessableEntity, gin.H{"error": err.Error()}) diff --git a/internal/api/client/media/mediacreate_test.go b/internal/api/client/media/mediacreate_test.go index 5c48a4381..8433786e4 100644 --- a/internal/api/client/media/mediacreate_test.go +++ b/internal/api/client/media/mediacreate_test.go @@ -121,7 +121,7 @@ func (suite *MediaCreateTestSuite) TestStatusCreatePOSTImageHandlerSuccessful() // set up the context for the request t := suite.testTokens["local_account_1"] - oauthToken := oauth.TokenToOauthToken(t) + oauthToken := oauth.DBTokenToToken(t) recorder := httptest.NewRecorder() ctx, _ := gin.CreateTestContext(recorder) ctx.Set(oauth.SessionAuthorizedApplication, suite.testApplications["application_1"]) diff --git a/internal/api/client/media/mediaget.go b/internal/api/client/media/mediaget.go index 17c5a090b..5fd7856e9 100644 --- a/internal/api/client/media/mediaget.go +++ b/internal/api/client/media/mediaget.go @@ -75,7 +75,7 @@ func (m *Module) MediaGETHandler(c *gin.Context) { return } - attachment, errWithCode := m.processor.MediaGet(authed, attachmentID) + attachment, errWithCode := m.processor.MediaGet(c.Request.Context(), authed, attachmentID) if errWithCode != nil { c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()}) return diff --git a/internal/api/client/media/mediaupdate.go b/internal/api/client/media/mediaupdate.go index 0ceb01f82..3af19297f 100644 --- a/internal/api/client/media/mediaupdate.go +++ b/internal/api/client/media/mediaupdate.go @@ -122,7 +122,7 @@ func (m *Module) MediaPUTHandler(c *gin.Context) { return } - attachment, errWithCode := m.processor.MediaUpdate(authed, attachmentID, &form) + attachment, errWithCode := m.processor.MediaUpdate(c.Request.Context(), authed, attachmentID, &form) if errWithCode != nil { c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()}) return diff --git a/internal/api/client/notification/notificationsget.go b/internal/api/client/notification/notificationsget.go index a30674750..81e8a6890 100644 --- a/internal/api/client/notification/notificationsget.go +++ b/internal/api/client/notification/notificationsget.go @@ -68,7 +68,7 @@ func (m *Module) NotificationsGETHandler(c *gin.Context) { sinceID = sinceIDString } - notifs, errWithCode := m.processor.NotificationsGet(authed, limit, maxID, sinceID) + notifs, errWithCode := m.processor.NotificationsGet(c.Request.Context(), authed, limit, maxID, sinceID) if errWithCode != nil { l.Debugf("error processing notifications get: %s", errWithCode.Error()) c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()}) diff --git a/internal/api/client/search/searchget.go b/internal/api/client/search/searchget.go index faa227719..848915274 100644 --- a/internal/api/client/search/searchget.go +++ b/internal/api/client/search/searchget.go @@ -164,7 +164,7 @@ func (m *Module) SearchGETHandler(c *gin.Context) { Following: following, } - results, errWithCode := m.processor.SearchGet(authed, searchQuery) + results, errWithCode := m.processor.SearchGet(c.Request.Context(), authed, searchQuery) if errWithCode != nil { l.Debugf("error searching: %s", errWithCode.Error()) c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()}) diff --git a/internal/api/client/status/statusboost.go b/internal/api/client/status/statusboost.go index 5aa7989bc..094e56ac0 100644 --- a/internal/api/client/status/statusboost.go +++ b/internal/api/client/status/statusboost.go @@ -87,7 +87,7 @@ func (m *Module) StatusBoostPOSTHandler(c *gin.Context) { return } - mastoStatus, errWithCode := m.processor.StatusBoost(authed, targetStatusID) + mastoStatus, errWithCode := m.processor.StatusBoost(c.Request.Context(), authed, targetStatusID) if errWithCode != nil { l.Debugf("error processing status boost: %s", errWithCode.Error()) c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()}) diff --git a/internal/api/client/status/statusboost_test.go b/internal/api/client/status/statusboost_test.go index fbe267fac..4157bde38 100644 --- a/internal/api/client/status/statusboost_test.go +++ b/internal/api/client/status/statusboost_test.go @@ -67,7 +67,7 @@ func (suite *StatusBoostTestSuite) TearDownTest() { func (suite *StatusBoostTestSuite) TestPostBoost() { t := suite.testTokens["local_account_1"] - oauthToken := oauth.TokenToOauthToken(t) + oauthToken := oauth.DBTokenToToken(t) targetStatus := suite.testStatuses["admin_account_status_1"] @@ -133,7 +133,7 @@ func (suite *StatusBoostTestSuite) TestPostBoost() { func (suite *StatusBoostTestSuite) TestPostUnboostable() { t := suite.testTokens["local_account_1"] - oauthToken := oauth.TokenToOauthToken(t) + oauthToken := oauth.DBTokenToToken(t) targetStatus := suite.testStatuses["local_account_2_status_4"] @@ -171,7 +171,7 @@ func (suite *StatusBoostTestSuite) TestPostUnboostable() { func (suite *StatusBoostTestSuite) TestPostNotVisible() { t := suite.testTokens["local_account_2"] - oauthToken := oauth.TokenToOauthToken(t) + oauthToken := oauth.DBTokenToToken(t) targetStatus := suite.testStatuses["local_account_1_status_3"] // this is a mutual only status and these accounts aren't mutuals diff --git a/internal/api/client/status/statusboostedby.go b/internal/api/client/status/statusboostedby.go index 260e21642..908c3ff10 100644 --- a/internal/api/client/status/statusboostedby.go +++ b/internal/api/client/status/statusboostedby.go @@ -84,7 +84,7 @@ func (m *Module) StatusBoostedByGETHandler(c *gin.Context) { return } - mastoAccounts, err := m.processor.StatusBoostedBy(authed, targetStatusID) + mastoAccounts, err := m.processor.StatusBoostedBy(c.Request.Context(), authed, targetStatusID) if err != nil { l.Debugf("error processing status boosted by request: %s", err) c.JSON(http.StatusBadRequest, gin.H{"error": "bad request"}) diff --git a/internal/api/client/status/statuscontext.go b/internal/api/client/status/statuscontext.go index 6e28b004e..90fcb9608 100644 --- a/internal/api/client/status/statuscontext.go +++ b/internal/api/client/status/statuscontext.go @@ -86,7 +86,7 @@ func (m *Module) StatusContextGETHandler(c *gin.Context) { return } - statusContext, errWithCode := m.processor.StatusGetContext(authed, targetStatusID) + statusContext, errWithCode := m.processor.StatusGetContext(c.Request.Context(), authed, targetStatusID) if errWithCode != nil { l.Debugf("error getting status context: %s", errWithCode.Error()) c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()}) diff --git a/internal/api/client/status/statuscreate.go b/internal/api/client/status/statuscreate.go index 2007ba80f..09fc47b5b 100644 --- a/internal/api/client/status/statuscreate.go +++ b/internal/api/client/status/statuscreate.go @@ -101,7 +101,7 @@ func (m *Module) StatusCreatePOSTHandler(c *gin.Context) { return } - mastoStatus, err := m.processor.StatusCreate(authed, form) + mastoStatus, err := m.processor.StatusCreate(c.Request.Context(), authed, form) if err != nil { l.Debugf("error processing status create: %s", err) c.JSON(http.StatusBadRequest, gin.H{"error": "bad request"}) diff --git a/internal/api/client/status/statuscreate_test.go b/internal/api/client/status/statuscreate_test.go index 33912397e..060d96dad 100644 --- a/internal/api/client/status/statuscreate_test.go +++ b/internal/api/client/status/statuscreate_test.go @@ -19,6 +19,7 @@ package status_test import ( + "context" "encoding/json" "fmt" "io/ioutil" @@ -82,7 +83,7 @@ https://docs.gotosocial.org/en/latest/user_guide/posts/#links func (suite *StatusCreateTestSuite) TestPostNewStatus() { t := suite.testTokens["local_account_1"] - oauthToken := oauth.TokenToOauthToken(t) + oauthToken := oauth.DBTokenToToken(t) // setup recorder := httptest.NewRecorder() @@ -128,7 +129,7 @@ func (suite *StatusCreateTestSuite) TestPostNewStatus() { }, statusReply.Tags[0]) gtsTag := >smodel.Tag{} - err = suite.db.GetWhere([]db.Where{{Key: "name", Value: "helloworld"}}, gtsTag) + err = suite.db.GetWhere(context.Background(), []db.Where{{Key: "name", Value: "helloworld"}}, gtsTag) assert.NoError(suite.T(), err) assert.Equal(suite.T(), statusReply.Account.ID, gtsTag.FirstSeenFromAccountID) } @@ -136,7 +137,7 @@ func (suite *StatusCreateTestSuite) TestPostNewStatus() { func (suite *StatusCreateTestSuite) TestPostAnotherNewStatus() { t := suite.testTokens["local_account_1"] - oauthToken := oauth.TokenToOauthToken(t) + oauthToken := oauth.DBTokenToToken(t) // setup recorder := httptest.NewRecorder() @@ -171,7 +172,7 @@ func (suite *StatusCreateTestSuite) TestPostAnotherNewStatus() { func (suite *StatusCreateTestSuite) TestPostNewStatusWithEmoji() { t := suite.testTokens["local_account_1"] - oauthToken := oauth.TokenToOauthToken(t) + oauthToken := oauth.DBTokenToToken(t) // setup recorder := httptest.NewRecorder() @@ -212,7 +213,7 @@ func (suite *StatusCreateTestSuite) TestPostNewStatusWithEmoji() { // Try to reply to a status that doesn't exist func (suite *StatusCreateTestSuite) TestReplyToNonexistentStatus() { t := suite.testTokens["local_account_1"] - oauthToken := oauth.TokenToOauthToken(t) + oauthToken := oauth.DBTokenToToken(t) // setup recorder := httptest.NewRecorder() @@ -243,7 +244,7 @@ func (suite *StatusCreateTestSuite) TestReplyToNonexistentStatus() { // Post a reply to the status of a local user that allows replies. func (suite *StatusCreateTestSuite) TestReplyToLocalStatus() { t := suite.testTokens["local_account_1"] - oauthToken := oauth.TokenToOauthToken(t) + oauthToken := oauth.DBTokenToToken(t) // setup recorder := httptest.NewRecorder() @@ -283,7 +284,7 @@ func (suite *StatusCreateTestSuite) TestReplyToLocalStatus() { // Take a media file which is currently not associated with a status, and attach it to a new status. func (suite *StatusCreateTestSuite) TestAttachNewMediaSuccess() { t := suite.testTokens["local_account_1"] - oauthToken := oauth.TokenToOauthToken(t) + oauthToken := oauth.DBTokenToToken(t) attachment := suite.testAttachments["local_account_1_unattached_1"] @@ -322,12 +323,11 @@ func (suite *StatusCreateTestSuite) TestAttachNewMediaSuccess() { assert.Len(suite.T(), statusResponse.MediaAttachments, 1) // get the updated media attachment from the database - gtsAttachment := >smodel.MediaAttachment{} - err = suite.db.GetByID(statusResponse.MediaAttachments[0].ID, gtsAttachment) + gtsAttachment, err := suite.db.GetAttachmentByID(context.Background(), statusResponse.MediaAttachments[0].ID) assert.NoError(suite.T(), err) // convert it to a masto attachment - gtsAttachmentAsMasto, err := suite.tc.AttachmentToMasto(gtsAttachment) + gtsAttachmentAsMasto, err := suite.tc.AttachmentToMasto(context.Background(), gtsAttachment) assert.NoError(suite.T(), err) // compare it with what we have now diff --git a/internal/api/client/status/statusdelete.go b/internal/api/client/status/statusdelete.go index 257280ce0..9a67c45ba 100644 --- a/internal/api/client/status/statusdelete.go +++ b/internal/api/client/status/statusdelete.go @@ -86,7 +86,7 @@ func (m *Module) StatusDELETEHandler(c *gin.Context) { return } - mastoStatus, err := m.processor.StatusDelete(authed, targetStatusID) + mastoStatus, err := m.processor.StatusDelete(c.Request.Context(), authed, targetStatusID) if err != nil { l.Debugf("error processing status delete: %s", err) c.JSON(http.StatusBadRequest, gin.H{"error": "bad request"}) diff --git a/internal/api/client/status/statusfave.go b/internal/api/client/status/statusfave.go index a76acf3d9..94a8f9380 100644 --- a/internal/api/client/status/statusfave.go +++ b/internal/api/client/status/statusfave.go @@ -83,7 +83,7 @@ func (m *Module) StatusFavePOSTHandler(c *gin.Context) { return } - mastoStatus, err := m.processor.StatusFave(authed, targetStatusID) + mastoStatus, err := m.processor.StatusFave(c.Request.Context(), authed, targetStatusID) if err != nil { l.Debugf("error processing status fave: %s", err) c.JSON(http.StatusBadRequest, gin.H{"error": "bad request"}) diff --git a/internal/api/client/status/statusfave_test.go b/internal/api/client/status/statusfave_test.go index 0f44b5e90..2f7a2c596 100644 --- a/internal/api/client/status/statusfave_test.go +++ b/internal/api/client/status/statusfave_test.go @@ -71,7 +71,7 @@ func (suite *StatusFaveTestSuite) TearDownTest() { func (suite *StatusFaveTestSuite) TestPostFave() { t := suite.testTokens["local_account_1"] - oauthToken := oauth.TokenToOauthToken(t) + oauthToken := oauth.DBTokenToToken(t) targetStatus := suite.testStatuses["admin_account_status_2"] @@ -119,7 +119,7 @@ func (suite *StatusFaveTestSuite) TestPostFave() { func (suite *StatusFaveTestSuite) TestPostUnfaveable() { t := suite.testTokens["local_account_1"] - oauthToken := oauth.TokenToOauthToken(t) + oauthToken := oauth.DBTokenToToken(t) targetStatus := suite.testStatuses["local_account_2_status_3"] // this one is unlikeable and unreplyable diff --git a/internal/api/client/status/statusfavedby.go b/internal/api/client/status/statusfavedby.go index a5d6e7c58..7b8e19e20 100644 --- a/internal/api/client/status/statusfavedby.go +++ b/internal/api/client/status/statusfavedby.go @@ -84,7 +84,7 @@ func (m *Module) StatusFavedByGETHandler(c *gin.Context) { return } - mastoAccounts, err := m.processor.StatusFavedBy(authed, targetStatusID) + mastoAccounts, err := m.processor.StatusFavedBy(c.Request.Context(), authed, targetStatusID) if err != nil { l.Debugf("error processing status faved by request: %s", err) c.JSON(http.StatusBadRequest, gin.H{"error": "bad request"}) diff --git a/internal/api/client/status/statusfavedby_test.go b/internal/api/client/status/statusfavedby_test.go index 22a549b30..7475f1e69 100644 --- a/internal/api/client/status/statusfavedby_test.go +++ b/internal/api/client/status/statusfavedby_test.go @@ -69,7 +69,7 @@ func (suite *StatusFavedByTestSuite) TearDownTest() { func (suite *StatusFavedByTestSuite) TestGetFavedBy() { t := suite.testTokens["local_account_2"] - oauthToken := oauth.TokenToOauthToken(t) + oauthToken := oauth.DBTokenToToken(t) targetStatus := suite.testStatuses["admin_account_status_1"] // this status is faved by local_account_1 diff --git a/internal/api/client/status/statusget.go b/internal/api/client/status/statusget.go index bcca010f5..39668288f 100644 --- a/internal/api/client/status/statusget.go +++ b/internal/api/client/status/statusget.go @@ -83,7 +83,7 @@ func (m *Module) StatusGETHandler(c *gin.Context) { return } - mastoStatus, err := m.processor.StatusGet(authed, targetStatusID) + mastoStatus, err := m.processor.StatusGet(c.Request.Context(), authed, targetStatusID) if err != nil { l.Debugf("error processing status get: %s", err) c.JSON(http.StatusBadRequest, gin.H{"error": "bad request"}) diff --git a/internal/api/client/status/statusunboost.go b/internal/api/client/status/statusunboost.go index dc42e3b62..8bb28fa50 100644 --- a/internal/api/client/status/statusunboost.go +++ b/internal/api/client/status/statusunboost.go @@ -84,7 +84,7 @@ func (m *Module) StatusUnboostPOSTHandler(c *gin.Context) { return } - mastoStatus, errWithCode := m.processor.StatusUnboost(authed, targetStatusID) + mastoStatus, errWithCode := m.processor.StatusUnboost(c.Request.Context(), authed, targetStatusID) if errWithCode != nil { l.Debugf("error processing status unboost: %s", errWithCode.Error()) c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()}) diff --git a/internal/api/client/status/statusunfave.go b/internal/api/client/status/statusunfave.go index 80eb87acf..5a1074ca2 100644 --- a/internal/api/client/status/statusunfave.go +++ b/internal/api/client/status/statusunfave.go @@ -83,7 +83,7 @@ func (m *Module) StatusUnfavePOSTHandler(c *gin.Context) { return } - mastoStatus, err := m.processor.StatusUnfave(authed, targetStatusID) + mastoStatus, err := m.processor.StatusUnfave(c.Request.Context(), authed, targetStatusID) if err != nil { l.Debugf("error processing status unfave: %s", err) c.JSON(http.StatusBadRequest, gin.H{"error": "bad request"}) diff --git a/internal/api/client/status/statusunfave_test.go b/internal/api/client/status/statusunfave_test.go index a5f267f4c..9e7ea8f82 100644 --- a/internal/api/client/status/statusunfave_test.go +++ b/internal/api/client/status/statusunfave_test.go @@ -71,7 +71,7 @@ func (suite *StatusUnfaveTestSuite) TearDownTest() { func (suite *StatusUnfaveTestSuite) TestPostUnfave() { t := suite.testTokens["local_account_1"] - oauthToken := oauth.TokenToOauthToken(t) + oauthToken := oauth.DBTokenToToken(t) // this is the status we wanna unfave: in the testrig it's already faved by this account targetStatus := suite.testStatuses["admin_account_status_1"] @@ -120,7 +120,7 @@ func (suite *StatusUnfaveTestSuite) TestPostUnfave() { func (suite *StatusUnfaveTestSuite) TestPostAlreadyNotFaved() { t := suite.testTokens["local_account_1"] - oauthToken := oauth.TokenToOauthToken(t) + oauthToken := oauth.DBTokenToToken(t) // this is the status we wanna unfave: in the testrig it's not faved by this account targetStatus := suite.testStatuses["admin_account_status_2"] diff --git a/internal/api/client/streaming/stream.go b/internal/api/client/streaming/stream.go index 626f1ff41..fa210e8d8 100644 --- a/internal/api/client/streaming/stream.go +++ b/internal/api/client/streaming/stream.go @@ -122,7 +122,7 @@ func (m *Module) StreamGETHandler(c *gin.Context) { } // make sure a valid token has been provided and obtain the associated account - account, err := m.processor.AuthorizeStreamingRequest(accessToken) + account, err := m.processor.AuthorizeStreamingRequest(c.Request.Context(), accessToken) if err != nil { c.JSON(http.StatusUnauthorized, gin.H{"error": "could not authorize with given token"}) return @@ -147,7 +147,7 @@ func (m *Module) StreamGETHandler(c *gin.Context) { defer conn.Close() // whatever happens, when we leave this function we want to close the websocket connection // inform the processor that we have a new connection and want a stream for it - stream, errWithCode := m.processor.OpenStreamForAccount(account, streamType) + stream, errWithCode := m.processor.OpenStreamForAccount(c.Request.Context(), account, streamType) if errWithCode != nil { c.JSON(errWithCode.Code(), errWithCode.Safe()) return diff --git a/internal/api/client/timeline/home.go b/internal/api/client/timeline/home.go index a6e64f384..6df4b29d0 100644 --- a/internal/api/client/timeline/home.go +++ b/internal/api/client/timeline/home.go @@ -153,7 +153,7 @@ func (m *Module) HomeTimelineGETHandler(c *gin.Context) { local = i } - resp, errWithCode := m.processor.HomeTimelineGet(authed, maxID, sinceID, minID, limit, local) + resp, errWithCode := m.processor.HomeTimelineGet(c.Request.Context(), authed, maxID, sinceID, minID, limit, local) if errWithCode != nil { l.Debugf("error from processor HomeTimelineGet: %s", errWithCode) c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()}) diff --git a/internal/api/client/timeline/public.go b/internal/api/client/timeline/public.go index 178fcd7f1..8c8c9f120 100644 --- a/internal/api/client/timeline/public.go +++ b/internal/api/client/timeline/public.go @@ -153,7 +153,7 @@ func (m *Module) PublicTimelineGETHandler(c *gin.Context) { local = i } - resp, errWithCode := m.processor.PublicTimelineGet(authed, maxID, sinceID, minID, limit, local) + resp, errWithCode := m.processor.PublicTimelineGet(c.Request.Context(), authed, maxID, sinceID, minID, limit, local) if errWithCode != nil { l.Debugf("error from processor PublicTimelineGet: %s", errWithCode) c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()}) diff --git a/internal/api/s2s/nodeinfo/nodeinfoget.go b/internal/api/s2s/nodeinfo/nodeinfoget.go index a54c8b190..c362e1d2e 100644 --- a/internal/api/s2s/nodeinfo/nodeinfoget.go +++ b/internal/api/s2s/nodeinfo/nodeinfoget.go @@ -33,7 +33,7 @@ func (m *Module) NodeInfoGETHandler(c *gin.Context) { "user-agent": c.Request.UserAgent(), }) - ni, err := m.processor.GetNodeInfo(c.Request) + ni, err := m.processor.GetNodeInfo(c.Request.Context(), c.Request) if err != nil { l.Debugf("error with get node info request: %s", err) c.JSON(err.Code(), err.Safe()) diff --git a/internal/api/s2s/nodeinfo/wellknownget.go b/internal/api/s2s/nodeinfo/wellknownget.go index 614d2a9c6..fd2c84408 100644 --- a/internal/api/s2s/nodeinfo/wellknownget.go +++ b/internal/api/s2s/nodeinfo/wellknownget.go @@ -33,7 +33,7 @@ func (m *Module) NodeInfoWellKnownGETHandler(c *gin.Context) { "user-agent": c.Request.UserAgent(), }) - niRel, err := m.processor.GetNodeInfoRel(c.Request) + niRel, err := m.processor.GetNodeInfoRel(c.Request.Context(), c.Request) if err != nil { l.Debugf("error with get node info rel request: %s", err) c.JSON(err.Code(), err.Safe()) diff --git a/internal/api/s2s/user/userget_test.go b/internal/api/s2s/user/userget_test.go index ab0015c57..29cc0e0d8 100644 --- a/internal/api/s2s/user/userget_test.go +++ b/internal/api/s2s/user/userget_test.go @@ -105,7 +105,7 @@ func (suite *UserGetTestSuite) TestGetUser() { // convert person to account // since this account is already known, we should get a pretty full model of it from the conversion - a, err := suite.tc.ASRepresentationToAccount(person, false) + a, err := suite.tc.ASRepresentationToAccount(context.Background(), person, false) assert.NoError(suite.T(), err) assert.EqualValues(suite.T(), targetAccount.Username, a.Username) } diff --git a/internal/api/security/signaturecheck.go b/internal/api/security/signaturecheck.go index 88b0b4dff..71e539e96 100644 --- a/internal/api/security/signaturecheck.go +++ b/internal/api/security/signaturecheck.go @@ -31,7 +31,7 @@ func (m *Module) SignatureCheck(c *gin.Context) { // we managed to parse the url! // if the domain is blocked we want to bail as early as possible - blocked, err := m.db.IsURIBlocked(requestingPublicKeyID) + blocked, err := m.db.IsURIBlocked(c.Request.Context(), requestingPublicKeyID) if err != nil { l.Errorf("could not tell if domain %s was blocked or not: %s", requestingPublicKeyID.Host, err) c.AbortWithStatus(http.StatusInternalServerError) diff --git a/internal/cache/cache.go b/internal/cache/cache.go index eb3744cfe..ce4aad04d 100644 --- a/internal/cache/cache.go +++ b/internal/cache/cache.go @@ -37,7 +37,7 @@ type cache struct { // New returns a new in-memory cache. func New() Cache { c := ttlcache.NewCache() - c.SetTTL(30 * time.Second) + c.SetTTL(5 * time.Minute) cache := &cache{ c: c, } diff --git a/internal/cliactions/admin/account/account.go b/internal/cliactions/admin/account/account.go index 0ae7f32de..46998ec6a 100644 --- a/internal/cliactions/admin/account/account.go +++ b/internal/cliactions/admin/account/account.go @@ -28,7 +28,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/cliactions" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" - "github.com/superseriousbusiness/gotosocial/internal/db/pg" + "github.com/superseriousbusiness/gotosocial/internal/db/bundb" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/util" "golang.org/x/crypto/bcrypt" @@ -36,7 +36,7 @@ import ( // Create creates a new account in the database using the provided flags. var Create cliactions.GTSAction = func(ctx context.Context, c *config.Config, log *logrus.Logger) error { - dbConn, err := pg.NewPostgresService(ctx, c, log) + dbConn, err := bundb.NewBunDBService(ctx, c, log) if err != nil { return fmt.Errorf("error creating dbservice: %s", err) } @@ -65,7 +65,7 @@ var Create cliactions.GTSAction = func(ctx context.Context, c *config.Config, lo return err } - _, err = dbConn.NewSignup(username, "", false, email, password, nil, "", "", false, false) + _, err = dbConn.NewSignup(ctx, username, "", false, email, password, nil, "", "", false, false) if err != nil { return err } @@ -75,7 +75,7 @@ var Create cliactions.GTSAction = func(ctx context.Context, c *config.Config, lo // Confirm sets a user to Approved, sets Email to the current UnconfirmedEmail value, and sets ConfirmedAt to now. var Confirm cliactions.GTSAction = func(ctx context.Context, c *config.Config, log *logrus.Logger) error { - dbConn, err := pg.NewPostgresService(ctx, c, log) + dbConn, err := bundb.NewBunDBService(ctx, c, log) if err != nil { return fmt.Errorf("error creating dbservice: %s", err) } @@ -88,20 +88,20 @@ var Confirm cliactions.GTSAction = func(ctx context.Context, c *config.Config, l return err } - a, err := dbConn.GetLocalAccountByUsername(username) + a, err := dbConn.GetLocalAccountByUsername(ctx, username) if err != nil { return err } u := >smodel.User{} - if err := dbConn.GetWhere([]db.Where{{Key: "account_id", Value: a.ID}}, u); err != nil { + if err := dbConn.GetWhere(ctx, []db.Where{{Key: "account_id", Value: a.ID}}, u); err != nil { return err } u.Approved = true u.Email = u.UnconfirmedEmail u.ConfirmedAt = time.Now() - if err := dbConn.UpdateByID(u.ID, u); err != nil { + if err := dbConn.UpdateByID(ctx, u.ID, u); err != nil { return err } @@ -110,7 +110,7 @@ var Confirm cliactions.GTSAction = func(ctx context.Context, c *config.Config, l // Promote sets a user to admin. var Promote cliactions.GTSAction = func(ctx context.Context, c *config.Config, log *logrus.Logger) error { - dbConn, err := pg.NewPostgresService(ctx, c, log) + dbConn, err := bundb.NewBunDBService(ctx, c, log) if err != nil { return fmt.Errorf("error creating dbservice: %s", err) } @@ -123,17 +123,17 @@ var Promote cliactions.GTSAction = func(ctx context.Context, c *config.Config, l return err } - a, err := dbConn.GetLocalAccountByUsername(username) + a, err := dbConn.GetLocalAccountByUsername(ctx, username) if err != nil { return err } u := >smodel.User{} - if err := dbConn.GetWhere([]db.Where{{Key: "account_id", Value: a.ID}}, u); err != nil { + if err := dbConn.GetWhere(ctx, []db.Where{{Key: "account_id", Value: a.ID}}, u); err != nil { return err } u.Admin = true - if err := dbConn.UpdateByID(u.ID, u); err != nil { + if err := dbConn.UpdateByID(ctx, u.ID, u); err != nil { return err } @@ -142,7 +142,7 @@ var Promote cliactions.GTSAction = func(ctx context.Context, c *config.Config, l // Demote sets admin on a user to false. var Demote cliactions.GTSAction = func(ctx context.Context, c *config.Config, log *logrus.Logger) error { - dbConn, err := pg.NewPostgresService(ctx, c, log) + dbConn, err := bundb.NewBunDBService(ctx, c, log) if err != nil { return fmt.Errorf("error creating dbservice: %s", err) } @@ -155,17 +155,17 @@ var Demote cliactions.GTSAction = func(ctx context.Context, c *config.Config, lo return err } - a, err := dbConn.GetLocalAccountByUsername(username) + a, err := dbConn.GetLocalAccountByUsername(ctx, username) if err != nil { return err } u := >smodel.User{} - if err := dbConn.GetWhere([]db.Where{{Key: "account_id", Value: a.ID}}, u); err != nil { + if err := dbConn.GetWhere(ctx, []db.Where{{Key: "account_id", Value: a.ID}}, u); err != nil { return err } u.Admin = false - if err := dbConn.UpdateByID(u.ID, u); err != nil { + if err := dbConn.UpdateByID(ctx, u.ID, u); err != nil { return err } @@ -174,7 +174,7 @@ var Demote cliactions.GTSAction = func(ctx context.Context, c *config.Config, lo // Disable sets Disabled to true on a user. var Disable cliactions.GTSAction = func(ctx context.Context, c *config.Config, log *logrus.Logger) error { - dbConn, err := pg.NewPostgresService(ctx, c, log) + dbConn, err := bundb.NewBunDBService(ctx, c, log) if err != nil { return fmt.Errorf("error creating dbservice: %s", err) } @@ -187,17 +187,17 @@ var Disable cliactions.GTSAction = func(ctx context.Context, c *config.Config, l return err } - a, err := dbConn.GetLocalAccountByUsername(username) + a, err := dbConn.GetLocalAccountByUsername(ctx, username) if err != nil { return err } u := >smodel.User{} - if err := dbConn.GetWhere([]db.Where{{Key: "account_id", Value: a.ID}}, u); err != nil { + if err := dbConn.GetWhere(ctx, []db.Where{{Key: "account_id", Value: a.ID}}, u); err != nil { return err } u.Disabled = true - if err := dbConn.UpdateByID(u.ID, u); err != nil { + if err := dbConn.UpdateByID(ctx, u.ID, u); err != nil { return err } @@ -212,7 +212,7 @@ var Suspend cliactions.GTSAction = func(ctx context.Context, c *config.Config, l // Password sets the password of target account. var Password cliactions.GTSAction = func(ctx context.Context, c *config.Config, log *logrus.Logger) error { - dbConn, err := pg.NewPostgresService(ctx, c, log) + dbConn, err := bundb.NewBunDBService(ctx, c, log) if err != nil { return fmt.Errorf("error creating dbservice: %s", err) } @@ -233,13 +233,13 @@ var Password cliactions.GTSAction = func(ctx context.Context, c *config.Config, return err } - a, err := dbConn.GetLocalAccountByUsername(username) + a, err := dbConn.GetLocalAccountByUsername(ctx, username) if err != nil { return err } u := >smodel.User{} - if err := dbConn.GetWhere([]db.Where{{Key: "account_id", Value: a.ID}}, u); err != nil { + if err := dbConn.GetWhere(ctx, []db.Where{{Key: "account_id", Value: a.ID}}, u); err != nil { return err } @@ -250,7 +250,7 @@ var Password cliactions.GTSAction = func(ctx context.Context, c *config.Config, u.EncryptedPassword = string(pw) - if err := dbConn.UpdateByID(u.ID, u); err != nil { + if err := dbConn.UpdateByID(ctx, u.ID, u); err != nil { return err } diff --git a/internal/cliactions/server/server.go b/internal/cliactions/server/server.go index 72c6cfadf..877a9d397 100644 --- a/internal/cliactions/server/server.go +++ b/internal/cliactions/server/server.go @@ -35,7 +35,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/blob" "github.com/superseriousbusiness/gotosocial/internal/cliactions" "github.com/superseriousbusiness/gotosocial/internal/config" - "github.com/superseriousbusiness/gotosocial/internal/db/pg" + "github.com/superseriousbusiness/gotosocial/internal/db/bundb" "github.com/superseriousbusiness/gotosocial/internal/federation" "github.com/superseriousbusiness/gotosocial/internal/federation/federatingdb" "github.com/superseriousbusiness/gotosocial/internal/gotosocial" @@ -79,28 +79,28 @@ var models []interface{} = []interface{}{ // Start creates and starts a gotosocial server var Start cliactions.GTSAction = func(ctx context.Context, c *config.Config, log *logrus.Logger) error { - dbService, err := pg.NewPostgresService(ctx, c, log) + dbService, err := bundb.NewBunDBService(ctx, c, log) if err != nil { return fmt.Errorf("error creating dbservice: %s", err) } for _, m := range models { - if err := dbService.CreateTable(m); err != nil { + if err := dbService.CreateTable(ctx, m); err != nil { return fmt.Errorf("table creation error: %s", err) } } - if err := dbService.CreateInstanceAccount(); err != nil { + if err := dbService.CreateInstanceAccount(ctx); err != nil { return fmt.Errorf("error creating instance account: %s", err) } - if err := dbService.CreateInstanceInstance(); err != nil { + if err := dbService.CreateInstanceInstance(ctx); err != nil { return fmt.Errorf("error creating instance instance: %s", err) } federatingDB := federatingdb.New(dbService, c, log) - router, err := router.New(c, dbService, log) + router, err := router.New(ctx, c, dbService, log) if err != nil { return fmt.Errorf("error creating router: %s", err) } @@ -120,7 +120,7 @@ var Start cliactions.GTSAction = func(ctx context.Context, c *config.Config, log transportController := transport.NewController(c, dbService, &federation.Clock{}, http.DefaultClient, log) federator := federation.NewFederator(dbService, federatingDB, transportController, c, log, typeConverter, mediaHandler) processor := processing.NewProcessor(c, typeConverter, federator, oauthServer, mediaHandler, storageBackend, timelineManager, dbService, log) - if err := processor.Start(); err != nil { + if err := processor.Start(ctx); err != nil { return fmt.Errorf("error starting processor: %s", err) } diff --git a/internal/cliactions/testrig/testrig.go b/internal/cliactions/testrig/testrig.go index a7032825c..7badca556 100644 --- a/internal/cliactions/testrig/testrig.go +++ b/internal/cliactions/testrig/testrig.go @@ -63,7 +63,7 @@ var Start cliactions.GTSAction = func(ctx context.Context, _ *config.Config, log federator := testrig.NewTestFederator(dbService, transportController, storageBackend) processor := testrig.NewTestProcessor(dbService, storageBackend, federator) - if err := processor.Start(); err != nil { + if err := processor.Start(ctx); err != nil { return fmt.Errorf("error starting processor: %s", err) } diff --git a/internal/db/account.go b/internal/db/account.go index 0e1575f9b..058a89859 100644 --- a/internal/db/account.go +++ b/internal/db/account.go @@ -19,6 +19,7 @@ package db import ( + "context" "time" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" @@ -27,40 +28,43 @@ import ( // Account contains functions related to account getting/setting/creation. type Account interface { // GetAccountByID returns one account with the given ID, or an error if something goes wrong. - GetAccountByID(id string) (*gtsmodel.Account, Error) + GetAccountByID(ctx context.Context, id string) (*gtsmodel.Account, Error) // GetAccountByURI returns one account with the given URI, or an error if something goes wrong. - GetAccountByURI(uri string) (*gtsmodel.Account, Error) + GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.Account, Error) // GetAccountByURL returns one account with the given URL, or an error if something goes wrong. - GetAccountByURL(uri string) (*gtsmodel.Account, Error) + GetAccountByURL(ctx context.Context, uri string) (*gtsmodel.Account, Error) + + // UpdateAccount updates one account by ID. + UpdateAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, Error) // GetLocalAccountByUsername returns an account on this instance by its username. - GetLocalAccountByUsername(username string) (*gtsmodel.Account, Error) + GetLocalAccountByUsername(ctx context.Context, username string) (*gtsmodel.Account, Error) // GetAccountFaves fetches faves/likes created by the target accountID. - GetAccountFaves(accountID string) ([]*gtsmodel.StatusFave, Error) + GetAccountFaves(ctx context.Context, accountID string) ([]*gtsmodel.StatusFave, Error) // GetAccountStatusesCount is a shortcut for the common action of counting statuses produced by accountID. - CountAccountStatuses(accountID string) (int, Error) + CountAccountStatuses(ctx context.Context, accountID string) (int, Error) // GetAccountStatuses is a shortcut for getting the most recent statuses. accountID is optional, if not provided // then all statuses will be returned. If limit is set to 0, the size of the returned slice will not be limited. This can // be very memory intensive so you probably shouldn't do this! // In case of no entries, a 'no entries' error will be returned - GetAccountStatuses(accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, Error) + GetAccountStatuses(ctx context.Context, accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, Error) - GetAccountBlocks(accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, Error) + GetAccountBlocks(ctx context.Context, accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, Error) // GetAccountLastPosted simply gets the timestamp of the most recent post by the account. // // The returned time will be zero if account has never posted anything. - GetAccountLastPosted(accountID string) (time.Time, Error) + GetAccountLastPosted(ctx context.Context, accountID string) (time.Time, Error) // SetAccountHeaderOrAvatar sets the header or avatar for the given accountID to the given media attachment. - SetAccountHeaderOrAvatar(mediaAttachment *gtsmodel.MediaAttachment, accountID string) Error + SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachment *gtsmodel.MediaAttachment, accountID string) Error // GetInstanceAccount returns the instance account for the given domain. // If domain is empty, this instance account will be returned. - GetInstanceAccount(domain string) (*gtsmodel.Account, Error) + GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, Error) } diff --git a/internal/db/admin.go b/internal/db/admin.go index aa2b22f47..24d628e84 100644 --- a/internal/db/admin.go +++ b/internal/db/admin.go @@ -19,6 +19,7 @@ package db import ( + "context" "net" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" @@ -28,26 +29,26 @@ import ( type Admin interface { // IsUsernameAvailable checks whether a given username is available on our domain. // Returns an error if the username is already taken, or something went wrong in the db. - IsUsernameAvailable(username string) Error + IsUsernameAvailable(ctx context.Context, username string) (bool, Error) // IsEmailAvailable checks whether a given email address for a new account is available to be used on our domain. // Return an error if: // A) the email is already associated with an account // B) we block signups from this email domain // C) something went wrong in the db - IsEmailAvailable(email string) Error + IsEmailAvailable(ctx context.Context, email string) (bool, Error) // NewSignup creates a new user in the database with the given parameters. // By the time this function is called, it should be assumed that all the parameters have passed validation! - NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, Error) + NewSignup(ctx context.Context, username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, Error) // CreateInstanceAccount creates an account in the database with the same username as the instance host value. // Ie., if the instance is hosted at 'example.org' the instance user will have a username of 'example.org'. // This is needed for things like serving files that belong to the instance and not an individual user/account. - CreateInstanceAccount() Error + CreateInstanceAccount(ctx context.Context) Error // CreateInstanceInstance creates an instance in the database with the same domain as the instance host value. // Ie., if the instance is hosted at 'example.org' the instance will have a domain of 'example.org'. // This is needed for things like serving instance information through /api/v1/instance - CreateInstanceInstance() Error + CreateInstanceInstance(ctx context.Context) Error } diff --git a/internal/db/basic.go b/internal/db/basic.go index 729920bba..cf65ddc09 100644 --- a/internal/db/basic.go +++ b/internal/db/basic.go @@ -24,15 +24,11 @@ import "context" type Basic interface { // CreateTable creates a table for the given interface. // For implementations that don't use tables, this can just return nil. - CreateTable(i interface{}) Error + CreateTable(ctx context.Context, i interface{}) Error // DropTable drops the table for the given interface. // For implementations that don't use tables, this can just return nil. - DropTable(i interface{}) Error - - // RegisterTable registers a table for use in many2many relations. - // For implementations that don't use tables, or many2many relations, this can just return nil. - RegisterTable(i interface{}) Error + DropTable(ctx context.Context, i interface{}) Error // Stop should stop and close the database connection cleanly, returning an error if this is not possible. // If the database implementation doesn't need to be stopped, this can just return nil. @@ -45,43 +41,38 @@ type Basic interface { // for other implementations (for example, in-memory) it might just be the key of a map. // The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice. // In case of no entries, a 'no entries' error will be returned - GetByID(id string, i interface{}) Error + GetByID(ctx context.Context, id string, i interface{}) Error // GetWhere gets one entry where key = value. This is similar to GetByID but allows the caller to specify the // name of the key to select from. // The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice. // In case of no entries, a 'no entries' error will be returned - GetWhere(where []Where, i interface{}) Error + GetWhere(ctx context.Context, where []Where, i interface{}) Error // GetAll will try to get all entries of type i. // The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice. // In case of no entries, a 'no entries' error will be returned - GetAll(i interface{}) Error + GetAll(ctx context.Context, i interface{}) Error // Put simply stores i. It is up to the implementation to figure out how to store it, and using what key. // The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice. - Put(i interface{}) Error - - // Upsert stores or updates i based on the given conflict column, as in https://www.postgresqltutorial.com/postgresql-upsert/ - // It is up to the implementation to figure out how to store it, and using what key. - // The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice. - Upsert(i interface{}, conflictColumn string) Error + Put(ctx context.Context, i interface{}) Error // UpdateByID updates i with id id. // The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice. - UpdateByID(id string, i interface{}) Error + UpdateByID(ctx context.Context, id string, i interface{}) Error // UpdateOneByID updates interface i with database the given database id. It will update one field of key key and value value. - UpdateOneByID(id string, key string, value interface{}, i interface{}) Error + UpdateOneByID(ctx context.Context, id string, key string, value interface{}, i interface{}) Error // UpdateWhere updates column key of interface i with the given value, where the given parameters apply. - UpdateWhere(where []Where, key string, value interface{}, i interface{}) Error + UpdateWhere(ctx context.Context, where []Where, key string, value interface{}, i interface{}) Error // DeleteByID removes i with id id. // If i didn't exist anyway, then no error should be returned. - DeleteByID(id string, i interface{}) Error + DeleteByID(ctx context.Context, id string, i interface{}) Error // DeleteWhere deletes i where key = value // If i didn't exist anyway, then no error should be returned. - DeleteWhere(where []Where, i interface{}) Error + DeleteWhere(ctx context.Context, where []Where, i interface{}) Error } diff --git a/internal/db/pg/account.go b/internal/db/bundb/account.go index 3889c6601..7ebb79a15 100644 --- a/internal/db/pg/account.go +++ b/internal/db/bundb/account.go @@ -16,70 +16,90 @@ along with this program. If not, see <http://www.gnu.org/licenses/>. */ -package pg +package bundb import ( "context" "errors" "fmt" + "strings" "time" - "github.com/go-pg/pg/v10" - "github.com/go-pg/pg/v10/orm" "github.com/sirupsen/logrus" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/uptrace/bun" ) type accountDB struct { config *config.Config - conn *pg.DB + conn *bun.DB log *logrus.Logger - cancel context.CancelFunc } -func (a *accountDB) newAccountQ(account *gtsmodel.Account) *orm.Query { - return a.conn.Model(account). +func (a *accountDB) newAccountQ(account *gtsmodel.Account) *bun.SelectQuery { + return a.conn. + NewSelect(). + Model(account). Relation("AvatarMediaAttachment"). Relation("HeaderMediaAttachment") } -func (a *accountDB) GetAccountByID(id string) (*gtsmodel.Account, db.Error) { - account := >smodel.Account{} +func (a *accountDB) GetAccountByID(ctx context.Context, id string) (*gtsmodel.Account, db.Error) { + account := new(gtsmodel.Account) q := a.newAccountQ(account). Where("account.id = ?", id) - err := processErrorResponse(q.Select()) + err := processErrorResponse(q.Scan(ctx)) return account, err } -func (a *accountDB) GetAccountByURI(uri string) (*gtsmodel.Account, db.Error) { - account := >smodel.Account{} +func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) { + account := new(gtsmodel.Account) q := a.newAccountQ(account). Where("account.uri = ?", uri) - err := processErrorResponse(q.Select()) + err := processErrorResponse(q.Scan(ctx)) return account, err } -func (a *accountDB) GetAccountByURL(uri string) (*gtsmodel.Account, db.Error) { - account := >smodel.Account{} +func (a *accountDB) GetAccountByURL(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) { + account := new(gtsmodel.Account) q := a.newAccountQ(account). Where("account.url = ?", uri) - err := processErrorResponse(q.Select()) + err := processErrorResponse(q.Scan(ctx)) return account, err } -func (a *accountDB) GetInstanceAccount(domain string) (*gtsmodel.Account, db.Error) { - account := >smodel.Account{} +func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, db.Error) { + if strings.TrimSpace(account.ID) == "" { + return nil, errors.New("account had no ID") + } + + account.UpdatedAt = time.Now() + + q := a.conn. + NewUpdate(). + Model(account). + WherePK() + + _, err := q.Exec(ctx) + + err = processErrorResponse(err) + + return account, err +} + +func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, db.Error) { + account := new(gtsmodel.Account) q := a.newAccountQ(account) @@ -90,29 +110,31 @@ func (a *accountDB) GetInstanceAccount(domain string) (*gtsmodel.Account, db.Err } else { q = q. Where("account.username = ?", domain). - Where("? IS NULL", pg.Ident("domain")) + Where("? IS NULL", bun.Ident("domain")) } - err := processErrorResponse(q.Select()) + err := processErrorResponse(q.Scan(ctx)) return account, err } -func (a *accountDB) GetAccountLastPosted(accountID string) (time.Time, db.Error) { - status := >smodel.Status{} +func (a *accountDB) GetAccountLastPosted(ctx context.Context, accountID string) (time.Time, db.Error) { + status := new(gtsmodel.Status) - q := a.conn.Model(status). + q := a.conn. + NewSelect(). + Model(status). Order("id DESC"). Limit(1). Where("account_id = ?", accountID). Column("created_at") - err := processErrorResponse(q.Select()) + err := processErrorResponse(q.Scan(ctx)) return status.CreatedAt, err } -func (a *accountDB) SetAccountHeaderOrAvatar(mediaAttachment *gtsmodel.MediaAttachment, accountID string) db.Error { +func (a *accountDB) SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachment *gtsmodel.MediaAttachment, accountID string) db.Error { if mediaAttachment.Avatar && mediaAttachment.Header { return errors.New("one media attachment cannot be both header and avatar") } @@ -127,51 +149,66 @@ func (a *accountDB) SetAccountHeaderOrAvatar(mediaAttachment *gtsmodel.MediaAtta } // TODO: there are probably more side effects here that need to be handled - if _, err := a.conn.Model(mediaAttachment).OnConflict("(id) DO UPDATE").Insert(); err != nil { + if _, err := a.conn. + NewInsert(). + Model(mediaAttachment). + Exec(ctx); err != nil { return err } - if _, err := a.conn.Model(>smodel.Account{}).Set(fmt.Sprintf("%s_media_attachment_id = ?", headerOrAVI), mediaAttachment.ID).Where("id = ?", accountID).Update(); err != nil { + if _, err := a.conn. + NewUpdate(). + Model(>smodel.Account{}). + Set(fmt.Sprintf("%s_media_attachment_id = ?", headerOrAVI), mediaAttachment.ID). + Where("id = ?", accountID). + Exec(ctx); err != nil { return err } return nil } -func (a *accountDB) GetLocalAccountByUsername(username string) (*gtsmodel.Account, db.Error) { - account := >smodel.Account{} +func (a *accountDB) GetLocalAccountByUsername(ctx context.Context, username string) (*gtsmodel.Account, db.Error) { + account := new(gtsmodel.Account) q := a.newAccountQ(account). Where("username = ?", username). - Where("? IS NULL", pg.Ident("domain")) + Where("? IS NULL", bun.Ident("domain")) - err := processErrorResponse(q.Select()) + err := processErrorResponse(q.Scan(ctx)) return account, err } -func (a *accountDB) GetAccountFaves(accountID string) ([]*gtsmodel.StatusFave, db.Error) { - faves := []*gtsmodel.StatusFave{} +func (a *accountDB) GetAccountFaves(ctx context.Context, accountID string) ([]*gtsmodel.StatusFave, db.Error) { + faves := new([]*gtsmodel.StatusFave) - if err := a.conn.Model(&faves). + if err := a.conn. + NewSelect(). + Model(faves). Where("account_id = ?", accountID). - Select(); err != nil { - if err == pg.ErrNoRows { - return faves, nil - } + Scan(ctx); err != nil { return nil, err } - return faves, nil + return *faves, nil } -func (a *accountDB) CountAccountStatuses(accountID string) (int, db.Error) { - return a.conn.Model(>smodel.Status{}).Where("account_id = ?", accountID).Count() +func (a *accountDB) CountAccountStatuses(ctx context.Context, accountID string) (int, db.Error) { + return a.conn. + NewSelect(). + Model(>smodel.Status{}). + Where("account_id = ?", accountID). + Count(ctx) } -func (a *accountDB) GetAccountStatuses(accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, db.Error) { +func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, db.Error) { a.log.Debugf("getting statuses for account %s", accountID) statuses := []*gtsmodel.Status{} - q := a.conn.Model(&statuses).Order("id DESC") + q := a.conn. + NewSelect(). + Model(&statuses). + Order("id DESC") + if accountID != "" { q = q.Where("account_id = ?", accountID) } @@ -181,27 +218,26 @@ func (a *accountDB) GetAccountStatuses(accountID string, limit int, excludeRepli } if excludeReplies { - q = q.Where("? IS NULL", pg.Ident("in_reply_to_id")) + q = q.Where("? IS NULL", bun.Ident("in_reply_to_id")) } if pinnedOnly { q = q.Where("pinned = ?", true) } - if mediaOnly { - q = q.WhereGroup(func(q *pg.Query) (*pg.Query, error) { - return q.Where("? IS NOT NULL", pg.Ident("attachments")).Where("attachments != '{}'"), nil - }) - } - if maxID != "" { q = q.Where("id < ?", maxID) } - if err := q.Select(); err != nil { - if err == pg.ErrNoRows { - return nil, db.ErrNoEntries - } + if mediaOnly { + q = q.WhereGroup(" AND ", func(q *bun.SelectQuery) *bun.SelectQuery { + return q. + WhereOr("? IS NOT NULL", bun.Ident("attachments")). + WhereOr("attachments != '{}'") + }) + } + + if err := q.Scan(ctx); err != nil { return nil, err } @@ -213,10 +249,12 @@ func (a *accountDB) GetAccountStatuses(accountID string, limit int, excludeRepli return statuses, nil } -func (a *accountDB) GetAccountBlocks(accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, db.Error) { +func (a *accountDB) GetAccountBlocks(ctx context.Context, accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, db.Error) { blocks := []*gtsmodel.Block{} - fq := a.conn.Model(&blocks). + fq := a.conn. + NewSelect(). + Model(&blocks). Where("block.account_id = ?", accountID). Relation("TargetAccount"). Order("block.id DESC") @@ -233,11 +271,8 @@ func (a *accountDB) GetAccountBlocks(accountID string, maxID string, sinceID str fq = fq.Limit(limit) } - err := fq.Select() + err := fq.Scan(ctx) if err != nil { - if err == pg.ErrNoRows { - return nil, "", "", db.ErrNoEntries - } return nil, "", "", err } diff --git a/internal/db/pg/account_test.go b/internal/db/bundb/account_test.go index 7ea5ff39a..7174b781d 100644 --- a/internal/db/pg/account_test.go +++ b/internal/db/bundb/account_test.go @@ -16,17 +16,19 @@ along with this program. If not, see <http://www.gnu.org/licenses/>. */ -package pg_test +package bundb_test import ( + "context" "testing" + "time" "github.com/stretchr/testify/suite" "github.com/superseriousbusiness/gotosocial/testrig" ) type AccountTestSuite struct { - PGStandardTestSuite + BunDBStandardTestSuite } func (suite *AccountTestSuite) SetupSuite() { @@ -54,7 +56,7 @@ func (suite *AccountTestSuite) TearDownTest() { } func (suite *AccountTestSuite) TestGetAccountByIDWithExtras() { - account, err := suite.db.GetAccountByID(suite.testAccounts["local_account_1"].ID) + account, err := suite.db.GetAccountByID(context.Background(), suite.testAccounts["local_account_1"].ID) if err != nil { suite.FailNow(err.Error()) } @@ -65,6 +67,20 @@ func (suite *AccountTestSuite) TestGetAccountByIDWithExtras() { suite.NotEmpty(account.HeaderMediaAttachment.URL) } +func (suite *AccountTestSuite) TestUpdateAccount() { + testAccount := suite.testAccounts["local_account_1"] + + testAccount.DisplayName = "new display name!" + + _, err := suite.db.UpdateAccount(context.Background(), testAccount) + suite.NoError(err) + + updated, err := suite.db.GetAccountByID(context.Background(), testAccount.ID) + suite.NoError(err) + suite.Equal("new display name!", updated.DisplayName) + suite.WithinDuration(time.Now(), updated.UpdatedAt, 5*time.Second) +} + func TestAccountTestSuite(t *testing.T) { suite.Run(t, new(AccountTestSuite)) } diff --git a/internal/db/pg/admin.go b/internal/db/bundb/admin.go index 854f56ef0..67a1e8a0d 100644 --- a/internal/db/pg/admin.go +++ b/internal/db/bundb/admin.go @@ -16,76 +16,76 @@ along with this program. If not, see <http://www.gnu.org/licenses/>. */ -package pg +package bundb import ( "context" "crypto/rand" "crypto/rsa" + "database/sql" "fmt" "net" "net/mail" "strings" "time" - "github.com/go-pg/pg/v10" "github.com/sirupsen/logrus" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/id" "github.com/superseriousbusiness/gotosocial/internal/util" + "github.com/uptrace/bun" "golang.org/x/crypto/bcrypt" ) type adminDB struct { config *config.Config - conn *pg.DB + conn *bun.DB log *logrus.Logger - cancel context.CancelFunc } -func (a *adminDB) IsUsernameAvailable(username string) db.Error { - // if no error we fail because it means we found something - // if error but it's not pg.ErrNoRows then we fail - // if err is pg.ErrNoRows we're good, we found nothing so continue - if err := a.conn.Model(>smodel.Account{}).Where("username = ?", username).Where("domain = ?", nil).Select(); err == nil { - return fmt.Errorf("username %s already in use", username) - } else if err != pg.ErrNoRows { - return fmt.Errorf("db error: %s", err) - } - return nil +func (a *adminDB) IsUsernameAvailable(ctx context.Context, username string) (bool, db.Error) { + q := a.conn. + NewSelect(). + Model(>smodel.Account{}). + Where("username = ?", username). + Where("domain = ?", nil) + + return notExists(ctx, q) } -func (a *adminDB) IsEmailAvailable(email string) db.Error { +func (a *adminDB) IsEmailAvailable(ctx context.Context, email string) (bool, db.Error) { // parse the domain from the email m, err := mail.ParseAddress(email) if err != nil { - return fmt.Errorf("error parsing email address %s: %s", email, err) + return false, fmt.Errorf("error parsing email address %s: %s", email, err) } domain := strings.Split(m.Address, "@")[1] // domain will always be the second part after @ // check if the email domain is blocked - if err := a.conn.Model(>smodel.EmailDomainBlock{}).Where("domain = ?", domain).Select(); err == nil { + if err := a.conn. + NewSelect(). + Model(>smodel.EmailDomainBlock{}). + Where("domain = ?", domain). + Scan(ctx); err == nil { // fail because we found something - return fmt.Errorf("email domain %s is blocked", domain) - } else if err != pg.ErrNoRows { - // fail because we got an unexpected error - return fmt.Errorf("db error: %s", err) + return false, fmt.Errorf("email domain %s is blocked", domain) + } else if err != sql.ErrNoRows { + return false, processErrorResponse(err) } // check if this email is associated with a user already - if err := a.conn.Model(>smodel.User{}).Where("email = ?", email).WhereOr("unconfirmed_email = ?", email).Select(); err == nil { - // fail because we found something - return fmt.Errorf("email %s already in use", email) - } else if err != pg.ErrNoRows { - // fail because we got an unexpected error - return fmt.Errorf("db error: %s", err) - } - return nil + q := a.conn. + NewSelect(). + Model(>smodel.User{}). + Where("email = ?", email). + WhereOr("unconfirmed_email = ?", email) + + return notExists(ctx, q) } -func (a *adminDB) NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, db.Error) { +func (a *adminDB) NewSignup(ctx context.Context, username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, db.Error) { key, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { a.log.Errorf("error creating new rsa key: %s", err) @@ -94,13 +94,12 @@ func (a *adminDB) NewSignup(username string, reason string, requireApproval bool // if something went wrong while creating a user, we might already have an account, so check here first... acct := >smodel.Account{} - err = a.conn.Model(acct).Where("username = ?", username).Where("? IS NULL", pg.Ident("domain")).Select() + err = a.conn.NewSelect(). + Model(acct). + Where("username = ?", username). + Where("? IS NULL", bun.Ident("domain")). + Scan(ctx) if err != nil { - // there's been an actual error - if err != pg.ErrNoRows { - return nil, fmt.Errorf("db error checking existence of account: %s", err) - } - // we just don't have an account yet create one newAccountURIs := util.GenerateURIsForAccount(username, a.config.Protocol, a.config.Host) newAccountID, err := id.NewRandomULID() @@ -125,7 +124,10 @@ func (a *adminDB) NewSignup(username string, reason string, requireApproval bool FollowingURI: newAccountURIs.FollowingURI, FeaturedCollectionURI: newAccountURIs.CollectionURI, } - if _, err = a.conn.Model(acct).Insert(); err != nil { + if _, err = a.conn. + NewInsert(). + Model(acct). + Exec(ctx); err != nil { return nil, err } } @@ -161,15 +163,33 @@ func (a *adminDB) NewSignup(username string, reason string, requireApproval bool u.Moderator = true } - if _, err = a.conn.Model(u).Insert(); err != nil { + if _, err = a.conn. + NewInsert(). + Model(u). + Exec(ctx); err != nil { return nil, err } return u, nil } -func (a *adminDB) CreateInstanceAccount() db.Error { +func (a *adminDB) CreateInstanceAccount(ctx context.Context) db.Error { username := a.config.Host + + // check if instance account already exists + existsQ := a.conn. + NewSelect(). + Model(>smodel.Account{}). + Where("username = ?", username). + Where("? IS NULL", bun.Ident("domain")) + count, err := existsQ.Count(ctx) + if err != nil && count == 1 { + a.log.Infof("instance account %s already exists", username) + return nil + } else if err != sql.ErrNoRows { + return processErrorResponse(err) + } + key, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { a.log.Errorf("error creating new rsa key: %s", err) @@ -198,19 +218,36 @@ func (a *adminDB) CreateInstanceAccount() db.Error { FollowingURI: newAccountURIs.FollowingURI, FeaturedCollectionURI: newAccountURIs.CollectionURI, } - inserted, err := a.conn.Model(acct).Where("username = ?", username).SelectOrInsert() - if err != nil { + + insertQ := a.conn. + NewInsert(). + Model(acct) + + if _, err := insertQ.Exec(ctx); err != nil { return err } - if inserted { - a.log.Infof("created instance account %s with id %s", username, acct.ID) - } else { - a.log.Infof("instance account %s already exists with id %s", username, acct.ID) - } + + a.log.Infof("instance account %s CREATED with id %s", username, acct.ID) return nil } -func (a *adminDB) CreateInstanceInstance() db.Error { +func (a *adminDB) CreateInstanceInstance(ctx context.Context) db.Error { + domain := a.config.Host + + // check if instance entry already exists + existsQ := a.conn. + NewSelect(). + Model(>smodel.Instance{}). + Where("domain = ?", domain) + + count, err := existsQ.Count(ctx) + if err != nil && count == 1 { + a.log.Infof("instance instance %s already exists", domain) + return nil + } else if err != sql.ErrNoRows { + return processErrorResponse(err) + } + iID, err := id.NewRandomULID() if err != nil { return err @@ -218,18 +255,18 @@ func (a *adminDB) CreateInstanceInstance() db.Error { i := >smodel.Instance{ ID: iID, - Domain: a.config.Host, - Title: a.config.Host, + Domain: domain, + Title: domain, URI: fmt.Sprintf("%s://%s", a.config.Protocol, a.config.Host), } - inserted, err := a.conn.Model(i).Where("domain = ?", a.config.Host).SelectOrInsert() - if err != nil { + + insertQ := a.conn. + NewInsert(). + Model(i) + + if _, err := insertQ.Exec(ctx); err != nil { return err } - if inserted { - a.log.Infof("created instance instance %s with id %s", a.config.Host, i.ID) - } else { - a.log.Infof("instance instance %s already exists with id %s", a.config.Host, i.ID) - } + a.log.Infof("created instance instance %s with id %s", domain, i.ID) return nil } diff --git a/internal/db/bundb/basic.go b/internal/db/bundb/basic.go new file mode 100644 index 000000000..983b6b810 --- /dev/null +++ b/internal/db/bundb/basic.go @@ -0,0 +1,179 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package bundb + +import ( + "context" + "errors" + "strings" + + "github.com/sirupsen/logrus" + "github.com/superseriousbusiness/gotosocial/internal/config" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/uptrace/bun" +) + +type basicDB struct { + config *config.Config + conn *bun.DB + log *logrus.Logger +} + +func (b *basicDB) Put(ctx context.Context, i interface{}) db.Error { + _, err := b.conn.NewInsert().Model(i).Exec(ctx) + if err != nil && strings.Contains(err.Error(), "duplicate key value violates unique constraint") { + return db.ErrAlreadyExists + } + return err +} + +func (b *basicDB) GetByID(ctx context.Context, id string, i interface{}) db.Error { + q := b.conn. + NewSelect(). + Model(i). + Where("id = ?", id) + + return processErrorResponse(q.Scan(ctx)) +} + +func (b *basicDB) GetWhere(ctx context.Context, where []db.Where, i interface{}) db.Error { + if len(where) == 0 { + return errors.New("no queries provided") + } + + q := b.conn.NewSelect().Model(i) + for _, w := range where { + + if w.Value == nil { + q = q.Where("? IS NULL", bun.Ident(w.Key)) + } else { + if w.CaseInsensitive { + q = q.Where("LOWER(?) = LOWER(?)", bun.Safe(w.Key), w.Value) + } else { + q = q.Where("? = ?", bun.Safe(w.Key), w.Value) + } + } + } + + return processErrorResponse(q.Scan(ctx)) +} + +func (b *basicDB) GetAll(ctx context.Context, i interface{}) db.Error { + q := b.conn. + NewSelect(). + Model(i) + + return processErrorResponse(q.Scan(ctx)) +} + +func (b *basicDB) DeleteByID(ctx context.Context, id string, i interface{}) db.Error { + q := b.conn. + NewDelete(). + Model(i). + Where("id = ?", id) + + _, err := q.Exec(ctx) + + return processErrorResponse(err) +} + +func (b *basicDB) DeleteWhere(ctx context.Context, where []db.Where, i interface{}) db.Error { + if len(where) == 0 { + return errors.New("no queries provided") + } + + q := b.conn. + NewDelete(). + Model(i) + + for _, w := range where { + q = q.Where("? = ?", bun.Safe(w.Key), w.Value) + } + + _, err := q.Exec(ctx) + + return processErrorResponse(err) +} + +func (b *basicDB) UpdateByID(ctx context.Context, id string, i interface{}) db.Error { + q := b.conn. + NewUpdate(). + Model(i). + WherePK() + + _, err := q.Exec(ctx) + + return processErrorResponse(err) +} + +func (b *basicDB) UpdateOneByID(ctx context.Context, id string, key string, value interface{}, i interface{}) db.Error { + q := b.conn.NewUpdate(). + Model(i). + Set("? = ?", bun.Safe(key), value). + WherePK() + + _, err := q.Exec(ctx) + + return processErrorResponse(err) +} + +func (b *basicDB) UpdateWhere(ctx context.Context, where []db.Where, key string, value interface{}, i interface{}) db.Error { + q := b.conn.NewUpdate().Model(i) + + for _, w := range where { + if w.Value == nil { + q = q.Where("? IS NULL", bun.Ident(w.Key)) + } else { + if w.CaseInsensitive { + q = q.Where("LOWER(?) = LOWER(?)", bun.Safe(w.Key), w.Value) + } else { + q = q.Where("? = ?", bun.Safe(w.Key), w.Value) + } + } + } + + q = q.Set("? = ?", bun.Safe(key), value) + + _, err := q.Exec(ctx) + + return processErrorResponse(err) +} + +func (b *basicDB) CreateTable(ctx context.Context, i interface{}) db.Error { + _, err := b.conn.NewCreateTable().Model(i).IfNotExists().Exec(ctx) + return err +} + +func (b *basicDB) DropTable(ctx context.Context, i interface{}) db.Error { + _, err := b.conn.NewDropTable().Model(i).IfExists().Exec(ctx) + return processErrorResponse(err) +} + +func (b *basicDB) IsHealthy(ctx context.Context) db.Error { + return b.conn.Ping() +} + +func (b *basicDB) Stop(ctx context.Context) db.Error { + b.log.Info("closing db connection") + if err := b.conn.Close(); err != nil { + // only cancel if there's a problem closing the db + return err + } + return nil +} diff --git a/internal/db/bundb/basic_test.go b/internal/db/bundb/basic_test.go new file mode 100644 index 000000000..9189618c9 --- /dev/null +++ b/internal/db/bundb/basic_test.go @@ -0,0 +1,68 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package bundb_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/suite" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/testrig" +) + +type BasicTestSuite struct { + BunDBStandardTestSuite +} + +func (suite *BasicTestSuite) SetupSuite() { + suite.testTokens = testrig.NewTestTokens() + suite.testClients = testrig.NewTestClients() + suite.testApplications = testrig.NewTestApplications() + suite.testUsers = testrig.NewTestUsers() + suite.testAccounts = testrig.NewTestAccounts() + suite.testAttachments = testrig.NewTestAttachments() + suite.testStatuses = testrig.NewTestStatuses() + suite.testTags = testrig.NewTestTags() + suite.testMentions = testrig.NewTestMentions() +} + +func (suite *BasicTestSuite) SetupTest() { + suite.config = testrig.NewTestConfig() + suite.db = testrig.NewTestDB() + suite.log = testrig.NewTestLog() + + testrig.StandardDBSetup(suite.db, suite.testAccounts) +} + +func (suite *BasicTestSuite) TearDownTest() { + testrig.StandardDBTeardown(suite.db) +} + +func (suite *BasicTestSuite) TestGetAccountByID() { + testAccount := suite.testAccounts["local_account_1"] + + a := >smodel.Account{} + err := suite.db.GetByID(context.Background(), testAccount.ID, a) + suite.NoError(err) +} + +func TestBasicTestSuite(t *testing.T) { + suite.Run(t, new(BasicTestSuite)) +} diff --git a/internal/db/pg/pg.go b/internal/db/bundb/bundb.go index 0437baf02..49ed09cbd 100644 --- a/internal/db/pg/pg.go +++ b/internal/db/bundb/bundb.go @@ -16,12 +16,13 @@ along with this program. If not, see <http://www.gnu.org/licenses/>. */ -package pg +package bundb import ( "context" "crypto/tls" "crypto/x509" + "database/sql" "encoding/pem" "errors" "fmt" @@ -29,14 +30,20 @@ import ( "strings" "time" - "github.com/go-pg/pg/extra/pgdebug" - "github.com/go-pg/pg/v10" - "github.com/go-pg/pg/v10/orm" + "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v4/stdlib" "github.com/sirupsen/logrus" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/id" + "github.com/uptrace/bun" + "github.com/uptrace/bun/dialect/pgdialect" +) + +const ( + dbTypePostgres = "postgres" + dbTypeSqlite = "sqlite" ) var registerTables []interface{} = []interface{}{ @@ -44,8 +51,8 @@ var registerTables []interface{} = []interface{}{ >smodel.StatusToTag{}, } -// postgresService satisfies the DB interface -type postgresService struct { +// bunDBService satisfies the DB interface +type bunDBService struct { db.Account db.Admin db.Basic @@ -55,130 +62,115 @@ type postgresService struct { db.Mention db.Notification db.Relationship + db.Session db.Status db.Timeline config *config.Config - conn *pg.DB + conn *bun.DB log *logrus.Logger - cancel context.CancelFunc } -// NewPostgresService returns a postgresService derived from the provided config, which implements the go-fed DB interface. -// Under the hood, it uses https://github.com/go-pg/pg to create and maintain a database connection. -func NewPostgresService(ctx context.Context, c *config.Config, log *logrus.Logger) (db.DB, error) { - for _, t := range registerTables { - // https://pg.uptrace.dev/orm/many-to-many-relation/ - orm.RegisterTable(t) - } - - opts, err := derivePGOptions(c) - if err != nil { - return nil, fmt.Errorf("could not create postgres service: %s", err) - } - log.Debugf("using pg options: %+v", opts) - - // create a connection - pgCtx, cancel := context.WithCancel(ctx) - conn := pg.Connect(opts).WithContext(pgCtx) - - // this will break the logfmt format we normally log in, - // since we can't choose where pg outputs to and it defaults to - // stdout. So use this option with care! - if log.GetLevel() >= logrus.TraceLevel { - conn.AddQueryHook(pgdebug.DebugHook{ - // Print all queries. - Verbose: true, - }) +// NewBunDBService returns a bunDB derived from the provided config, which implements the go-fed DB interface. +// Under the hood, it uses https://github.com/uptrace/bun to create and maintain a database connection. +func NewBunDBService(ctx context.Context, c *config.Config, log *logrus.Logger) (db.DB, error) { + var sqldb *sql.DB + var conn *bun.DB + + // depending on the database type we're trying to create, we need to use a different driver... + switch strings.ToLower(c.DBConfig.Type) { + case dbTypePostgres: + // POSTGRES + opts, err := deriveBunDBPGOptions(c) + if err != nil { + return nil, fmt.Errorf("could not create bundb postgres options: %s", err) + } + sqldb = stdlib.OpenDB(*opts) + conn = bun.NewDB(sqldb, pgdialect.New()) + case dbTypeSqlite: + // SQLITE + // TODO: https://bun.uptrace.dev/guide/drivers.html#sqlite + default: + return nil, fmt.Errorf("database type %s not supported for bundb", strings.ToLower(c.DBConfig.Type)) } // actually *begin* the connection so that we can tell if the db is there and listening - if err := conn.Ping(ctx); err != nil { - cancel() + if err := conn.Ping(); err != nil { return nil, fmt.Errorf("db connection error: %s", err) } + log.Info("connected to database") - // print out discovered postgres version - var version string - if _, err = conn.QueryOneContext(ctx, pg.Scan(&version), "SELECT version()"); err != nil { - cancel() - return nil, fmt.Errorf("db connection error: %s", err) + for _, t := range registerTables { + // https://bun.uptrace.dev/orm/many-to-many-relation/ + conn.RegisterModel(t) } - log.Infof("connected to postgres version: %s", version) - ps := &postgresService{ + ps := &bunDBService{ Account: &accountDB{ config: c, conn: conn, log: log, - cancel: cancel, }, Admin: &adminDB{ config: c, conn: conn, log: log, - cancel: cancel, }, Basic: &basicDB{ config: c, conn: conn, log: log, - cancel: cancel, }, Domain: &domainDB{ config: c, conn: conn, log: log, - cancel: cancel, }, Instance: &instanceDB{ config: c, conn: conn, log: log, - cancel: cancel, }, Media: &mediaDB{ config: c, conn: conn, log: log, - cancel: cancel, }, Mention: &mentionDB{ config: c, conn: conn, log: log, - cancel: cancel, }, Notification: ¬ificationDB{ config: c, conn: conn, log: log, - cancel: cancel, }, Relationship: &relationshipDB{ config: c, conn: conn, log: log, - cancel: cancel, + }, + Session: &sessionDB{ + config: c, + conn: conn, + log: log, }, Status: &statusDB{ config: c, conn: conn, log: log, - cancel: cancel, }, Timeline: &timelineDB{ config: c, conn: conn, log: log, - cancel: cancel, }, config: c, conn: conn, log: log, - cancel: cancel, } - // we can confidently return this useable postgres service now + // we can confidently return this useable service now return ps, nil } @@ -186,9 +178,9 @@ func NewPostgresService(ctx context.Context, c *config.Config, log *logrus.Logge HANDY STUFF */ -// derivePGOptions takes an application config and returns either a ready-to-use *pg.Options +// deriveBunDBPGOptions takes an application config and returns either a ready-to-use set of options // with sensible defaults, or an error if it's not satisfied by the provided config. -func derivePGOptions(c *config.Config) (*pg.Options, error) { +func deriveBunDBPGOptions(c *config.Config) (*pgx.ConnConfig, error) { if strings.ToUpper(c.DBConfig.Type) != db.DBTypePostgres { return nil, fmt.Errorf("expected db type of %s but got %s", db.DBTypePostgres, c.DBConfig.Type) } @@ -266,18 +258,16 @@ func derivePGOptions(c *config.Config) (*pg.Options, error) { tlsConfig.RootCAs = certPool } - // We can rely on the pg library we're using to set - // sensible defaults for everything we don't set here. - options := &pg.Options{ - Addr: fmt.Sprintf("%s:%d", c.DBConfig.Address, c.DBConfig.Port), - User: c.DBConfig.User, - Password: c.DBConfig.Password, - Database: c.DBConfig.Database, - ApplicationName: c.ApplicationName, - TLSConfig: tlsConfig, - } + cfg, _ := pgx.ParseConfig("") + cfg.Host = c.DBConfig.Address + cfg.Port = uint16(c.DBConfig.Port) + cfg.User = c.DBConfig.User + cfg.Password = c.DBConfig.Password + cfg.TLSConfig = tlsConfig + cfg.Database = c.DBConfig.Database + cfg.PreferSimpleProtocol = true - return options, nil + return cfg, nil } /* @@ -286,9 +276,9 @@ func derivePGOptions(c *config.Config) (*pg.Options, error) { // TODO: move these to the type converter, it's bananas that they're here and not there -func (ps *postgresService) MentionStringsToMentions(targetAccounts []string, originAccountID string, statusID string) ([]*gtsmodel.Mention, error) { +func (ps *bunDBService) MentionStringsToMentions(ctx context.Context, targetAccounts []string, originAccountID string, statusID string) ([]*gtsmodel.Mention, error) { ogAccount := >smodel.Account{} - if err := ps.conn.Model(ogAccount).Where("id = ?", originAccountID).Select(); err != nil { + if err := ps.conn.NewSelect().Model(ogAccount).Where("id = ?", originAccountID).Scan(ctx); err != nil { return nil, err } @@ -333,14 +323,14 @@ func (ps *postgresService) MentionStringsToMentions(targetAccounts []string, ori // match username + account, case insensitive if local { // local user -- should have a null domain - err = ps.conn.Model(mentionedAccount).Where("LOWER(?) = LOWER(?)", pg.Ident("username"), username).Where("? IS NULL", pg.Ident("domain")).Select() + err = ps.conn.NewSelect().Model(mentionedAccount).Where("LOWER(?) = LOWER(?)", bun.Ident("username"), username).Where("? IS NULL", bun.Ident("domain")).Scan(ctx) } else { // remote user -- should have domain defined - err = ps.conn.Model(mentionedAccount).Where("LOWER(?) = LOWER(?)", pg.Ident("username"), username).Where("LOWER(?) = LOWER(?)", pg.Ident("domain"), domain).Select() + err = ps.conn.NewSelect().Model(mentionedAccount).Where("LOWER(?) = LOWER(?)", bun.Ident("username"), username).Where("LOWER(?) = LOWER(?)", bun.Ident("domain"), domain).Scan(ctx) } if err != nil { - if err == pg.ErrNoRows { + if err == sql.ErrNoRows { // no result found for this username/domain so just don't include it as a mencho and carry on about our business ps.log.Debugf("no account found with username '%s' and domain '%s', skipping it", username, domain) continue @@ -364,14 +354,14 @@ func (ps *postgresService) MentionStringsToMentions(targetAccounts []string, ori return menchies, nil } -func (ps *postgresService) TagStringsToTags(tags []string, originAccountID string, statusID string) ([]*gtsmodel.Tag, error) { +func (ps *bunDBService) TagStringsToTags(ctx context.Context, tags []string, originAccountID string, statusID string) ([]*gtsmodel.Tag, error) { newTags := []*gtsmodel.Tag{} for _, t := range tags { tag := >smodel.Tag{} // we can use selectorinsert here to create the new tag if it doesn't exist already // inserted will be true if this is a new tag we just created - if err := ps.conn.Model(tag).Where("LOWER(?) = LOWER(?)", pg.Ident("name"), t).Select(); err != nil { - if err == pg.ErrNoRows { + if err := ps.conn.NewSelect().Model(tag).Where("LOWER(?) = LOWER(?)", bun.Ident("name"), t).Scan(ctx); err != nil { + if err == sql.ErrNoRows { // tag doesn't exist yet so populate it newID, err := id.NewRandomULID() if err != nil { @@ -400,13 +390,13 @@ func (ps *postgresService) TagStringsToTags(tags []string, originAccountID strin return newTags, nil } -func (ps *postgresService) EmojiStringsToEmojis(emojis []string, originAccountID string, statusID string) ([]*gtsmodel.Emoji, error) { +func (ps *bunDBService) EmojiStringsToEmojis(ctx context.Context, emojis []string, originAccountID string, statusID string) ([]*gtsmodel.Emoji, error) { newEmojis := []*gtsmodel.Emoji{} for _, e := range emojis { emoji := >smodel.Emoji{} - err := ps.conn.Model(emoji).Where("shortcode = ?", e).Where("visible_in_picker = true").Where("disabled = false").Select() + err := ps.conn.NewSelect().Model(emoji).Where("shortcode = ?", e).Where("visible_in_picker = true").Where("disabled = false").Scan(ctx) if err != nil { - if err == pg.ErrNoRows { + if err == sql.ErrNoRows { // no result found for this username/domain so just don't include it as an emoji and carry on about our business ps.log.Debugf("no emoji found with shortcode %s, skipping it", e) continue diff --git a/internal/db/pg/pg_test.go b/internal/db/bundb/bundb_test.go index c1e10abdf..b789375af 100644 --- a/internal/db/pg/pg_test.go +++ b/internal/db/bundb/bundb_test.go @@ -16,7 +16,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>. */ -package pg_test +package bundb_test import ( "github.com/sirupsen/logrus" @@ -27,7 +27,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/oauth" ) -type PGStandardTestSuite struct { +type BunDBStandardTestSuite struct { // standard suite interfaces suite.Suite config *config.Config diff --git a/internal/db/pg/domain.go b/internal/db/bundb/domain.go index 4e9b2ab48..6aa2b8ffe 100644 --- a/internal/db/pg/domain.go +++ b/internal/db/bundb/domain.go @@ -16,48 +16,46 @@ along with this program. If not, see <http://www.gnu.org/licenses/>. */ -package pg +package bundb import ( "context" "net/url" - "github.com/go-pg/pg/v10" "github.com/sirupsen/logrus" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/util" + "github.com/uptrace/bun" ) type domainDB struct { config *config.Config - conn *pg.DB + conn *bun.DB log *logrus.Logger - cancel context.CancelFunc } -func (d *domainDB) IsDomainBlocked(domain string) (bool, db.Error) { +func (d *domainDB) IsDomainBlocked(ctx context.Context, domain string) (bool, db.Error) { if domain == "" { return false, nil } - blocked, err := d.conn. + q := d.conn. + NewSelect(). Model(>smodel.DomainBlock{}). Where("LOWER(domain) = LOWER(?)", domain). - Exists() + Limit(1) - err = processErrorResponse(err) - - return blocked, err + return exists(ctx, q) } -func (d *domainDB) AreDomainsBlocked(domains []string) (bool, db.Error) { +func (d *domainDB) AreDomainsBlocked(ctx context.Context, domains []string) (bool, db.Error) { // filter out any doubles uniqueDomains := util.UniqueStrings(domains) for _, domain := range uniqueDomains { - if blocked, err := d.IsDomainBlocked(domain); err != nil { + if blocked, err := d.IsDomainBlocked(ctx, domain); err != nil { return false, err } else if blocked { return blocked, nil @@ -68,16 +66,16 @@ func (d *domainDB) AreDomainsBlocked(domains []string) (bool, db.Error) { return false, nil } -func (d *domainDB) IsURIBlocked(uri *url.URL) (bool, db.Error) { +func (d *domainDB) IsURIBlocked(ctx context.Context, uri *url.URL) (bool, db.Error) { domain := uri.Hostname() - return d.IsDomainBlocked(domain) + return d.IsDomainBlocked(ctx, domain) } -func (d *domainDB) AreURIsBlocked(uris []*url.URL) (bool, db.Error) { +func (d *domainDB) AreURIsBlocked(ctx context.Context, uris []*url.URL) (bool, db.Error) { domains := []string{} for _, uri := range uris { domains = append(domains, uri.Hostname()) } - return d.AreDomainsBlocked(domains) + return d.AreDomainsBlocked(ctx, domains) } diff --git a/internal/db/pg/instance.go b/internal/db/bundb/instance.go index 968832ca5..f9364346e 100644 --- a/internal/db/pg/instance.go +++ b/internal/db/bundb/instance.go @@ -16,43 +16,50 @@ along with this program. If not, see <http://www.gnu.org/licenses/>. */ -package pg +package bundb import ( "context" - "github.com/go-pg/pg/v10" "github.com/sirupsen/logrus" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/uptrace/bun" ) type instanceDB struct { config *config.Config - conn *pg.DB + conn *bun.DB log *logrus.Logger - cancel context.CancelFunc } -func (i *instanceDB) CountInstanceUsers(domain string) (int, db.Error) { - q := i.conn.Model(&[]*gtsmodel.Account{}) +func (i *instanceDB) CountInstanceUsers(ctx context.Context, domain string) (int, db.Error) { + q := i.conn. + NewSelect(). + Model(&[]*gtsmodel.Account{}) if domain == i.config.Host { // if the domain is *this* domain, just count where the domain field is null - q = q.Where("? IS NULL", pg.Ident("domain")) + q = q.Where("? IS NULL", bun.Ident("domain")) } else { q = q.Where("domain = ?", domain) } // don't count the instance account or suspended users - q = q.Where("username != ?", domain).Where("? IS NULL", pg.Ident("suspended_at")) + q = q. + Where("username != ?", domain). + Where("? IS NULL", bun.Ident("suspended_at")) - return q.Count() + count, err := q.Count(ctx) + + return count, processErrorResponse(err) } -func (i *instanceDB) CountInstanceStatuses(domain string) (int, db.Error) { - q := i.conn.Model(&[]*gtsmodel.Status{}) +func (i *instanceDB) CountInstanceStatuses(ctx context.Context, domain string) (int, db.Error) { + q := i.conn. + NewSelect(). + Model(&[]*gtsmodel.Status{}) if domain == i.config.Host { // if the domain is *this* domain, just count where local is true @@ -63,30 +70,39 @@ func (i *instanceDB) CountInstanceStatuses(domain string) (int, db.Error) { Where("account.domain = ?", domain) } - return q.Count() + count, err := q.Count(ctx) + + return count, processErrorResponse(err) } -func (i *instanceDB) CountInstanceDomains(domain string) (int, db.Error) { - q := i.conn.Model(&[]*gtsmodel.Instance{}) +func (i *instanceDB) CountInstanceDomains(ctx context.Context, domain string) (int, db.Error) { + q := i.conn. + NewSelect(). + Model(&[]*gtsmodel.Instance{}) if domain == i.config.Host { // if the domain is *this* domain, just count other instances it knows about // exclude domains that are blocked - q = q.Where("domain != ?", domain).Where("? IS NULL", pg.Ident("suspended_at")) + q = q.Where("domain != ?", domain).Where("? IS NULL", bun.Ident("suspended_at")) } else { // TODO: implement federated domain counting properly for remote domains return 0, nil } - return q.Count() + count, err := q.Count(ctx) + + return count, processErrorResponse(err) } -func (i *instanceDB) GetInstanceAccounts(domain string, maxID string, limit int) ([]*gtsmodel.Account, db.Error) { +func (i *instanceDB) GetInstanceAccounts(ctx context.Context, domain string, maxID string, limit int) ([]*gtsmodel.Account, db.Error) { i.log.Debug("GetAccountsForInstance") accounts := []*gtsmodel.Account{} - q := i.conn.Model(&accounts).Where("domain = ?", domain).Order("id DESC") + q := i.conn.NewSelect(). + Model(&accounts). + Where("domain = ?", domain). + Order("id DESC") if maxID != "" { q = q.Where("id < ?", maxID) @@ -96,17 +112,7 @@ func (i *instanceDB) GetInstanceAccounts(domain string, maxID string, limit int) q = q.Limit(limit) } - err := q.Select() - if err != nil { - if err == pg.ErrNoRows { - return nil, db.ErrNoEntries - } - return nil, err - } - - if len(accounts) == 0 { - return nil, db.ErrNoEntries - } + err := processErrorResponse(q.Scan(ctx)) - return accounts, nil + return accounts, err } diff --git a/internal/db/pg/media.go b/internal/db/bundb/media.go index 618030af3..04e55ca62 100644 --- a/internal/db/pg/media.go +++ b/internal/db/bundb/media.go @@ -16,38 +16,38 @@ along with this program. If not, see <http://www.gnu.org/licenses/>. */ -package pg +package bundb import ( "context" - "github.com/go-pg/pg/v10" - "github.com/go-pg/pg/v10/orm" "github.com/sirupsen/logrus" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/uptrace/bun" ) type mediaDB struct { config *config.Config - conn *pg.DB + conn *bun.DB log *logrus.Logger - cancel context.CancelFunc } -func (m *mediaDB) newMediaQ(i interface{}) *orm.Query { - return m.conn.Model(i). +func (m *mediaDB) newMediaQ(i interface{}) *bun.SelectQuery { + return m.conn. + NewSelect(). + Model(i). Relation("Account") } -func (m *mediaDB) GetAttachmentByID(id string) (*gtsmodel.MediaAttachment, db.Error) { +func (m *mediaDB) GetAttachmentByID(ctx context.Context, id string) (*gtsmodel.MediaAttachment, db.Error) { attachment := >smodel.MediaAttachment{} q := m.newMediaQ(attachment). Where("media_attachment.id = ?", id) - err := processErrorResponse(q.Select()) + err := processErrorResponse(q.Scan(ctx)) return attachment, err } diff --git a/internal/db/pg/mention.go b/internal/db/bundb/mention.go index b31f07b67..a444f9b5f 100644 --- a/internal/db/pg/mention.go +++ b/internal/db/bundb/mention.go @@ -16,25 +16,23 @@ along with this program. If not, see <http://www.gnu.org/licenses/>. */ -package pg +package bundb import ( "context" - "github.com/go-pg/pg/v10" - "github.com/go-pg/pg/v10/orm" "github.com/sirupsen/logrus" "github.com/superseriousbusiness/gotosocial/internal/cache" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/uptrace/bun" ) type mentionDB struct { config *config.Config - conn *pg.DB + conn *bun.DB log *logrus.Logger - cancel context.CancelFunc cache cache.Cache } @@ -67,14 +65,16 @@ func (m *mentionDB) mentionCached(id string) (*gtsmodel.Mention, bool) { return mention, true } -func (m *mentionDB) newMentionQ(i interface{}) *orm.Query { - return m.conn.Model(i). +func (m *mentionDB) newMentionQ(i interface{}) *bun.SelectQuery { + return m.conn. + NewSelect(). + Model(i). Relation("Status"). Relation("OriginAccount"). Relation("TargetAccount") } -func (m *mentionDB) GetMention(id string) (*gtsmodel.Mention, db.Error) { +func (m *mentionDB) GetMention(ctx context.Context, id string) (*gtsmodel.Mention, db.Error) { if mention, cached := m.mentionCached(id); cached { return mention, nil } @@ -84,7 +84,7 @@ func (m *mentionDB) GetMention(id string) (*gtsmodel.Mention, db.Error) { q := m.newMentionQ(mention). Where("mention.id = ?", id) - err := processErrorResponse(q.Select()) + err := processErrorResponse(q.Scan(ctx)) if err == nil && mention != nil { m.cacheMention(id, mention) @@ -93,11 +93,11 @@ func (m *mentionDB) GetMention(id string) (*gtsmodel.Mention, db.Error) { return mention, err } -func (m *mentionDB) GetMentions(ids []string) ([]*gtsmodel.Mention, db.Error) { +func (m *mentionDB) GetMentions(ctx context.Context, ids []string) ([]*gtsmodel.Mention, db.Error) { mentions := []*gtsmodel.Mention{} for _, i := range ids { - mention, err := m.GetMention(i) + mention, err := m.GetMention(ctx, i) if err != nil { return nil, processErrorResponse(err) } diff --git a/internal/db/pg/notification.go b/internal/db/bundb/notification.go index 281a76d85..1c30837ec 100644 --- a/internal/db/pg/notification.go +++ b/internal/db/bundb/notification.go @@ -16,25 +16,23 @@ along with this program. If not, see <http://www.gnu.org/licenses/>. */ -package pg +package bundb import ( "context" - "github.com/go-pg/pg/v10" - "github.com/go-pg/pg/v10/orm" "github.com/sirupsen/logrus" "github.com/superseriousbusiness/gotosocial/internal/cache" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/uptrace/bun" ) type notificationDB struct { config *config.Config - conn *pg.DB + conn *bun.DB log *logrus.Logger - cancel context.CancelFunc cache cache.Cache } @@ -67,14 +65,16 @@ func (n *notificationDB) notificationCached(id string) (*gtsmodel.Notification, return notification, true } -func (n *notificationDB) newNotificationQ(i interface{}) *orm.Query { - return n.conn.Model(i). +func (n *notificationDB) newNotificationQ(i interface{}) *bun.SelectQuery { + return n.conn. + NewSelect(). + Model(i). Relation("OriginAccount"). Relation("TargetAccount"). Relation("Status") } -func (n *notificationDB) GetNotification(id string) (*gtsmodel.Notification, db.Error) { +func (n *notificationDB) GetNotification(ctx context.Context, id string) (*gtsmodel.Notification, db.Error) { if notification, cached := n.notificationCached(id); cached { return notification, nil } @@ -84,7 +84,7 @@ func (n *notificationDB) GetNotification(id string) (*gtsmodel.Notification, db. q := n.newNotificationQ(notification). Where("notification.id = ?", id) - err := processErrorResponse(q.Select()) + err := processErrorResponse(q.Scan(ctx)) if err == nil && notification != nil { n.cacheNotification(id, notification) @@ -93,10 +93,11 @@ func (n *notificationDB) GetNotification(id string) (*gtsmodel.Notification, db. return notification, err } -func (n *notificationDB) GetNotifications(accountID string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, db.Error) { +func (n *notificationDB) GetNotifications(ctx context.Context, accountID string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, db.Error) { // begin by selecting just the IDs notifIDs := []*gtsmodel.Notification{} q := n.conn. + NewSelect(). Model(¬ifIDs). Column("id"). Where("target_account_id = ?", accountID). @@ -114,7 +115,7 @@ func (n *notificationDB) GetNotifications(accountID string, limit int, maxID str q = q.Limit(limit) } - err := processErrorResponse(q.Select()) + err := processErrorResponse(q.Scan(ctx)) if err != nil { return nil, err } @@ -123,7 +124,7 @@ func (n *notificationDB) GetNotifications(accountID string, limit int, maxID str // reason for this is that for each notif, we can instead get it from our cache if it's cached notifications := []*gtsmodel.Notification{} for _, notifID := range notifIDs { - notif, err := n.GetNotification(notifID.ID) + notif, err := n.GetNotification(ctx, notifID.ID) errP := processErrorResponse(err) if errP != nil { return nil, errP diff --git a/internal/db/pg/relationship.go b/internal/db/bundb/relationship.go index 76bd50c76..ccc604baf 100644 --- a/internal/db/pg/relationship.go +++ b/internal/db/bundb/relationship.go @@ -16,44 +16,49 @@ along with this program. If not, see <http://www.gnu.org/licenses/>. */ -package pg +package bundb import ( "context" + "database/sql" "fmt" - "github.com/go-pg/pg/v10" - "github.com/go-pg/pg/v10/orm" "github.com/sirupsen/logrus" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/uptrace/bun" ) type relationshipDB struct { config *config.Config - conn *pg.DB + conn *bun.DB log *logrus.Logger - cancel context.CancelFunc } -func (r *relationshipDB) newBlockQ(block *gtsmodel.Block) *orm.Query { - return r.conn.Model(block). +func (r *relationshipDB) newBlockQ(block *gtsmodel.Block) *bun.SelectQuery { + return r.conn. + NewSelect(). + Model(block). Relation("Account"). Relation("TargetAccount") } -func (r *relationshipDB) newFollowQ(follow interface{}) *orm.Query { - return r.conn.Model(follow). +func (r *relationshipDB) newFollowQ(follow interface{}) *bun.SelectQuery { + return r.conn. + NewSelect(). + Model(follow). Relation("Account"). Relation("TargetAccount") } -func (r *relationshipDB) IsBlocked(account1 string, account2 string, eitherDirection bool) (bool, db.Error) { +func (r *relationshipDB) IsBlocked(ctx context.Context, account1 string, account2 string, eitherDirection bool) (bool, db.Error) { q := r.conn. + NewSelect(). Model(>smodel.Block{}). Where("account_id = ?", account1). - Where("target_account_id = ?", account2) + Where("target_account_id = ?", account2). + Limit(1) if eitherDirection { q = q. @@ -61,30 +66,36 @@ func (r *relationshipDB) IsBlocked(account1 string, account2 string, eitherDirec Where("account_id = ?", account2) } - return q.Exists() + return exists(ctx, q) } -func (r *relationshipDB) GetBlock(account1 string, account2 string) (*gtsmodel.Block, db.Error) { +func (r *relationshipDB) GetBlock(ctx context.Context, account1 string, account2 string) (*gtsmodel.Block, db.Error) { block := >smodel.Block{} q := r.newBlockQ(block). Where("block.account_id = ?", account1). Where("block.target_account_id = ?", account2) - err := processErrorResponse(q.Select()) + err := processErrorResponse(q.Scan(ctx)) return block, err } -func (r *relationshipDB) GetRelationship(requestingAccount string, targetAccount string) (*gtsmodel.Relationship, db.Error) { +func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount string, targetAccount string) (*gtsmodel.Relationship, db.Error) { rel := >smodel.Relationship{ ID: targetAccount, } // check if the requesting account follows the target account follow := >smodel.Follow{} - if err := r.conn.Model(follow).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Select(); err != nil { - if err != pg.ErrNoRows { + if err := r.conn. + NewSelect(). + Model(follow). + Where("account_id = ?", requestingAccount). + Where("target_account_id = ?", targetAccount). + Limit(1). + Scan(ctx); err != nil { + if err != sql.ErrNoRows { // a proper error return nil, fmt.Errorf("getrelationship: error checking follow existence: %s", err) } @@ -100,75 +111,101 @@ func (r *relationshipDB) GetRelationship(requestingAccount string, targetAccount } // check if the target account follows the requesting account - followedBy, err := r.conn.Model(>smodel.Follow{}).Where("account_id = ?", targetAccount).Where("target_account_id = ?", requestingAccount).Exists() + count, err := r.conn. + NewSelect(). + Model(>smodel.Follow{}). + Where("account_id = ?", targetAccount). + Where("target_account_id = ?", requestingAccount). + Limit(1). + Count(ctx) if err != nil { return nil, fmt.Errorf("getrelationship: error checking followed_by existence: %s", err) } - rel.FollowedBy = followedBy + rel.FollowedBy = count > 0 // check if the requesting account blocks the target account - blocking, err := r.conn.Model(>smodel.Block{}).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Exists() + count, err = r.conn.NewSelect(). + Model(>smodel.Block{}). + Where("account_id = ?", requestingAccount). + Where("target_account_id = ?", targetAccount). + Limit(1). + Count(ctx) if err != nil { return nil, fmt.Errorf("getrelationship: error checking blocking existence: %s", err) } - rel.Blocking = blocking + rel.Blocking = count > 0 // check if the target account blocks the requesting account - blockedBy, err := r.conn.Model(>smodel.Block{}).Where("account_id = ?", targetAccount).Where("target_account_id = ?", requestingAccount).Exists() + count, err = r.conn. + NewSelect(). + Model(>smodel.Block{}). + Where("account_id = ?", targetAccount). + Where("target_account_id = ?", requestingAccount). + Limit(1). + Count(ctx) if err != nil { return nil, fmt.Errorf("getrelationship: error checking blocked existence: %s", err) } - rel.BlockedBy = blockedBy + rel.BlockedBy = count > 0 // check if there's a pending following request from requesting account to target account - requested, err := r.conn.Model(>smodel.FollowRequest{}).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Exists() + count, err = r.conn. + NewSelect(). + Model(>smodel.FollowRequest{}). + Where("account_id = ?", requestingAccount). + Where("target_account_id = ?", targetAccount). + Limit(1). + Count(ctx) if err != nil { return nil, fmt.Errorf("getrelationship: error checking blocked existence: %s", err) } - rel.Requested = requested + rel.Requested = count > 0 return rel, nil } -func (r *relationshipDB) IsFollowing(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.Error) { +func (r *relationshipDB) IsFollowing(ctx context.Context, sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.Error) { if sourceAccount == nil || targetAccount == nil { return false, nil } q := r.conn. + NewSelect(). Model(>smodel.Follow{}). Where("account_id = ?", sourceAccount.ID). - Where("target_account_id = ?", targetAccount.ID) + Where("target_account_id = ?", targetAccount.ID). + Limit(1) - return q.Exists() + return exists(ctx, q) } -func (r *relationshipDB) IsFollowRequested(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.Error) { +func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.Error) { if sourceAccount == nil || targetAccount == nil { return false, nil } q := r.conn. + NewSelect(). Model(>smodel.FollowRequest{}). Where("account_id = ?", sourceAccount.ID). Where("target_account_id = ?", targetAccount.ID) - return q.Exists() + return exists(ctx, q) } -func (r *relationshipDB) IsMutualFollowing(account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, db.Error) { +func (r *relationshipDB) IsMutualFollowing(ctx context.Context, account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, db.Error) { if account1 == nil || account2 == nil { return false, nil } // make sure account 1 follows account 2 - f1, err := r.IsFollowing(account1, account2) + f1, err := r.IsFollowing(ctx, account1, account2) if err != nil { return false, processErrorResponse(err) } // make sure account 2 follows account 1 - f2, err := r.IsFollowing(account2, account1) + f2, err := r.IsFollowing(ctx, account2, account1) if err != nil { return false, processErrorResponse(err) } @@ -176,14 +213,16 @@ func (r *relationshipDB) IsMutualFollowing(account1 *gtsmodel.Account, account2 return f1 && f2, nil } -func (r *relationshipDB) AcceptFollowRequest(originAccountID string, targetAccountID string) (*gtsmodel.Follow, db.Error) { +func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.Follow, db.Error) { // make sure the original follow request exists fr := >smodel.FollowRequest{} - if err := r.conn.Model(fr).Where("account_id = ?", originAccountID).Where("target_account_id = ?", targetAccountID).Select(); err != nil { - if err == pg.ErrMultiRows { - return nil, db.ErrNoEntries - } - return nil, err + if err := r.conn. + NewSelect(). + Model(fr). + Where("account_id = ?", originAccountID). + Where("target_account_id = ?", targetAccountID). + Scan(ctx); err != nil { + return nil, processErrorResponse(err) } // create a new follow to 'replace' the request with @@ -195,82 +234,95 @@ func (r *relationshipDB) AcceptFollowRequest(originAccountID string, targetAccou } // if the follow already exists, just update the URI -- we don't need to do anything else - if _, err := r.conn.Model(follow).OnConflict("ON CONSTRAINT follows_account_id_target_account_id_key DO UPDATE set uri = ?", follow.URI).Insert(); err != nil { - return nil, err + if _, err := r.conn. + NewInsert(). + Model(follow). + On("CONFLICT CONSTRAINT follows_account_id_target_account_id_key DO UPDATE set uri = ?", follow.URI). + Exec(ctx); err != nil { + return nil, processErrorResponse(err) } // now remove the follow request - if _, err := r.conn.Model(>smodel.FollowRequest{}).Where("account_id = ?", originAccountID).Where("target_account_id = ?", targetAccountID).Delete(); err != nil { - return nil, err + if _, err := r.conn. + NewDelete(). + Model(>smodel.FollowRequest{}). + Where("account_id = ?", originAccountID). + Where("target_account_id = ?", targetAccountID). + Exec(ctx); err != nil { + return nil, processErrorResponse(err) } return follow, nil } -func (r *relationshipDB) GetAccountFollowRequests(accountID string) ([]*gtsmodel.FollowRequest, db.Error) { +func (r *relationshipDB) GetAccountFollowRequests(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, db.Error) { followRequests := []*gtsmodel.FollowRequest{} q := r.newFollowQ(&followRequests). Where("target_account_id = ?", accountID) - err := processErrorResponse(q.Select()) + err := processErrorResponse(q.Scan(ctx)) return followRequests, err } -func (r *relationshipDB) GetAccountFollows(accountID string) ([]*gtsmodel.Follow, db.Error) { +func (r *relationshipDB) GetAccountFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, db.Error) { follows := []*gtsmodel.Follow{} q := r.newFollowQ(&follows). Where("account_id = ?", accountID) - err := processErrorResponse(q.Select()) + err := processErrorResponse(q.Scan(ctx)) return follows, err } -func (r *relationshipDB) CountAccountFollows(accountID string, localOnly bool) (int, db.Error) { +func (r *relationshipDB) CountAccountFollows(ctx context.Context, accountID string, localOnly bool) (int, db.Error) { return r.conn. + NewSelect(). Model(&[]*gtsmodel.Follow{}). Where("account_id = ?", accountID). - Count() + Count(ctx) } -func (r *relationshipDB) GetAccountFollowedBy(accountID string, localOnly bool) ([]*gtsmodel.Follow, db.Error) { +func (r *relationshipDB) GetAccountFollowedBy(ctx context.Context, accountID string, localOnly bool) ([]*gtsmodel.Follow, db.Error) { follows := []*gtsmodel.Follow{} - q := r.conn.Model(&follows) + q := r.conn. + NewSelect(). + Model(&follows) if localOnly { // for local accounts let's get where domain is null OR where domain is an empty string, just to be safe - whereGroup := func(q *pg.Query) (*pg.Query, error) { + whereGroup := func(q *bun.SelectQuery) *bun.SelectQuery { q = q. - WhereOr("? IS NULL", pg.Ident("a.domain")). + WhereOr("? IS NULL", bun.Ident("a.domain")). WhereOr("a.domain = ?", "") - return q, nil + return q } q = q.ColumnExpr("follow.*"). Join("JOIN accounts AS a ON follow.account_id = TEXT(a.id)"). Where("follow.target_account_id = ?", accountID). - WhereGroup(whereGroup) + WhereGroup(" AND ", whereGroup) } else { q = q.Where("target_account_id = ?", accountID) } - if err := q.Select(); err != nil { - if err == pg.ErrNoRows { + if err := q.Scan(ctx); err != nil { + if err == sql.ErrNoRows { return follows, nil } - return nil, err + return nil, processErrorResponse(err) } return follows, nil } -func (r *relationshipDB) CountAccountFollowedBy(accountID string, localOnly bool) (int, db.Error) { +func (r *relationshipDB) CountAccountFollowedBy(ctx context.Context, accountID string, localOnly bool) (int, db.Error) { return r.conn. + NewSelect(). Model(&[]*gtsmodel.Follow{}). Where("target_account_id = ?", accountID). - Count() + Count(ctx) } diff --git a/internal/db/bundb/session.go b/internal/db/bundb/session.go new file mode 100644 index 000000000..87e20673d --- /dev/null +++ b/internal/db/bundb/session.go @@ -0,0 +1,85 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package bundb + +import ( + "context" + "crypto/rand" + + "github.com/sirupsen/logrus" + "github.com/superseriousbusiness/gotosocial/internal/config" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/id" + "github.com/uptrace/bun" +) + +type sessionDB struct { + config *config.Config + conn *bun.DB + log *logrus.Logger +} + +func (s *sessionDB) GetSession(ctx context.Context) (*gtsmodel.RouterSession, db.Error) { + rs := new(gtsmodel.RouterSession) + + q := s.conn. + NewSelect(). + Model(rs). + Limit(1) + + _, err := q.Exec(ctx) + + err = processErrorResponse(err) + + return rs, err +} + +func (s *sessionDB) CreateSession(ctx context.Context) (*gtsmodel.RouterSession, db.Error) { + auth := make([]byte, 32) + crypt := make([]byte, 32) + + if _, err := rand.Read(auth); err != nil { + return nil, err + } + if _, err := rand.Read(crypt); err != nil { + return nil, err + } + + rid, err := id.NewULID() + if err != nil { + return nil, err + } + + rs := >smodel.RouterSession{ + ID: rid, + Auth: auth, + Crypt: crypt, + } + + q := s.conn. + NewInsert(). + Model(rs) + + _, err = q.Exec(ctx) + + err = processErrorResponse(err) + + return rs, err +} diff --git a/internal/db/bundb/status.go b/internal/db/bundb/status.go new file mode 100644 index 000000000..da8d8ca41 --- /dev/null +++ b/internal/db/bundb/status.go @@ -0,0 +1,375 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package bundb + +import ( + "container/list" + "context" + "errors" + "time" + + "github.com/sirupsen/logrus" + "github.com/superseriousbusiness/gotosocial/internal/cache" + "github.com/superseriousbusiness/gotosocial/internal/config" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/uptrace/bun" +) + +type statusDB struct { + config *config.Config + conn *bun.DB + log *logrus.Logger + cache cache.Cache +} + +func (s *statusDB) cacheStatus(id string, status *gtsmodel.Status) { + if s.cache == nil { + s.cache = cache.New() + } + + if err := s.cache.Store(id, status); err != nil { + s.log.Panicf("statusDB: error storing in cache: %s", err) + } +} + +func (s *statusDB) statusCached(id string) (*gtsmodel.Status, bool) { + if s.cache == nil { + s.cache = cache.New() + return nil, false + } + + sI, err := s.cache.Fetch(id) + if err != nil || sI == nil { + return nil, false + } + + status, ok := sI.(*gtsmodel.Status) + if !ok { + s.log.Panicf("statusDB: cached interface with key %s was not a status", id) + } + + return status, true +} + +func (s *statusDB) newStatusQ(status interface{}) *bun.SelectQuery { + return s.conn. + NewSelect(). + Model(status). + Relation("Attachments"). + Relation("Tags"). + Relation("Mentions"). + Relation("Emojis"). + Relation("Account"). + Relation("InReplyToAccount"). + Relation("BoostOfAccount"). + Relation("CreatedWithApplication") +} + +func (s *statusDB) getAttachedStatuses(ctx context.Context, status *gtsmodel.Status) *gtsmodel.Status { + if status.InReplyToID != "" && status.InReplyTo == nil { + if inReplyTo, cached := s.statusCached(status.InReplyToID); cached { + status.InReplyTo = inReplyTo + } else if inReplyTo, err := s.GetStatusByID(ctx, status.InReplyToID); err == nil { + status.InReplyTo = inReplyTo + } + } + + if status.BoostOfID != "" && status.BoostOf == nil { + if boostOf, cached := s.statusCached(status.BoostOfID); cached { + status.BoostOf = boostOf + } else if boostOf, err := s.GetStatusByID(ctx, status.BoostOfID); err == nil { + status.BoostOf = boostOf + } + } + + return status +} + +func (s *statusDB) newFaveQ(faves interface{}) *bun.SelectQuery { + return s.conn. + NewSelect(). + Model(faves). + Relation("Account"). + Relation("TargetAccount"). + Relation("Status") +} + +func (s *statusDB) GetStatusByID(ctx context.Context, id string) (*gtsmodel.Status, db.Error) { + if status, cached := s.statusCached(id); cached { + return status, nil + } + + status := new(gtsmodel.Status) + + q := s.newStatusQ(status). + Where("status.id = ?", id) + + err := processErrorResponse(q.Scan(ctx)) + + if err != nil { + return nil, err + } + + if status != nil { + s.cacheStatus(id, status) + } + + return s.getAttachedStatuses(ctx, status), err +} + +func (s *statusDB) GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.Status, db.Error) { + if status, cached := s.statusCached(uri); cached { + return status, nil + } + + status := >smodel.Status{} + + q := s.newStatusQ(status). + Where("LOWER(status.uri) = LOWER(?)", uri) + + err := processErrorResponse(q.Scan(ctx)) + + if err != nil { + return nil, err + } + + if status != nil { + s.cacheStatus(uri, status) + } + + return s.getAttachedStatuses(ctx, status), err +} + +func (s *statusDB) GetStatusByURL(ctx context.Context, uri string) (*gtsmodel.Status, db.Error) { + if status, cached := s.statusCached(uri); cached { + return status, nil + } + + status := >smodel.Status{} + + q := s.newStatusQ(status). + Where("LOWER(status.url) = LOWER(?)", uri) + + err := processErrorResponse(q.Scan(ctx)) + + if err != nil { + return nil, err + } + + if status != nil { + s.cacheStatus(uri, status) + } + + return s.getAttachedStatuses(ctx, status), err +} + +func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Error { + transaction := func(ctx context.Context, tx bun.Tx) error { + // create links between this status and any emojis it uses + for _, i := range status.EmojiIDs { + if _, err := tx.NewInsert().Model(>smodel.StatusToEmoji{ + StatusID: status.ID, + EmojiID: i, + }).Exec(ctx); err != nil { + return err + } + } + + // create links between this status and any tags it uses + for _, i := range status.TagIDs { + if _, err := tx.NewInsert().Model(>smodel.StatusToTag{ + StatusID: status.ID, + TagID: i, + }).Exec(ctx); err != nil { + return err + } + } + + // change the status ID of the media attachments to the new status + for _, a := range status.Attachments { + a.StatusID = status.ID + a.UpdatedAt = time.Now() + if _, err := s.conn.NewUpdate().Model(a). + Where("id = ?", a.ID). + Exec(ctx); err != nil { + return err + } + } + + _, err := tx.NewInsert().Model(status).Exec(ctx) + return err + } + + return processErrorResponse(s.conn.RunInTx(ctx, nil, transaction)) +} + +func (s *statusDB) GetStatusParents(ctx context.Context, status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, db.Error) { + parents := []*gtsmodel.Status{} + s.statusParent(ctx, status, &parents, onlyDirect) + + return parents, nil +} + +func (s *statusDB) statusParent(ctx context.Context, status *gtsmodel.Status, foundStatuses *[]*gtsmodel.Status, onlyDirect bool) { + if status.InReplyToID == "" { + return + } + + parentStatus, err := s.GetStatusByID(ctx, status.InReplyToID) + if err == nil { + *foundStatuses = append(*foundStatuses, parentStatus) + } + + if onlyDirect { + return + } + + s.statusParent(ctx, parentStatus, foundStatuses, false) +} + +func (s *statusDB) GetStatusChildren(ctx context.Context, status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, db.Error) { + foundStatuses := &list.List{} + foundStatuses.PushFront(status) + s.statusChildren(ctx, status, foundStatuses, onlyDirect, minID) + + children := []*gtsmodel.Status{} + for e := foundStatuses.Front(); e != nil; e = e.Next() { + entry, ok := e.Value.(*gtsmodel.Status) + if !ok { + panic(errors.New("entry in foundStatuses was not a *gtsmodel.Status")) + } + + // only append children, not the overall parent status + if entry.ID != status.ID { + children = append(children, entry) + } + } + + return children, nil +} + +func (s *statusDB) statusChildren(ctx context.Context, status *gtsmodel.Status, foundStatuses *list.List, onlyDirect bool, minID string) { + immediateChildren := []*gtsmodel.Status{} + + q := s.conn. + NewSelect(). + Model(&immediateChildren). + Where("in_reply_to_id = ?", status.ID) + if minID != "" { + q = q.Where("status.id > ?", minID) + } + + if err := q.Scan(ctx); err != nil { + return + } + + for _, child := range immediateChildren { + insertLoop: + for e := foundStatuses.Front(); e != nil; e = e.Next() { + entry, ok := e.Value.(*gtsmodel.Status) + if !ok { + panic(errors.New("entry in foundStatuses was not a *gtsmodel.Status")) + } + + if child.InReplyToAccountID != "" && entry.ID == child.InReplyToID { + foundStatuses.InsertAfter(child, e) + break insertLoop + } + } + + // only do one loop if we only want direct children + if onlyDirect { + return + } + s.statusChildren(ctx, child, foundStatuses, false, minID) + } +} + +func (s *statusDB) CountStatusReplies(ctx context.Context, status *gtsmodel.Status) (int, db.Error) { + return s.conn.NewSelect().Model(>smodel.Status{}).Where("in_reply_to_id = ?", status.ID).Count(ctx) +} + +func (s *statusDB) CountStatusReblogs(ctx context.Context, status *gtsmodel.Status) (int, db.Error) { + return s.conn.NewSelect().Model(>smodel.Status{}).Where("boost_of_id = ?", status.ID).Count(ctx) +} + +func (s *statusDB) CountStatusFaves(ctx context.Context, status *gtsmodel.Status) (int, db.Error) { + return s.conn.NewSelect().Model(>smodel.StatusFave{}).Where("status_id = ?", status.ID).Count(ctx) +} + +func (s *statusDB) IsStatusFavedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) { + q := s.conn. + NewSelect(). + Model(>smodel.StatusFave{}). + Where("status_id = ?", status.ID). + Where("account_id = ?", accountID) + + return exists(ctx, q) +} + +func (s *statusDB) IsStatusRebloggedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) { + q := s.conn. + NewSelect(). + Model(>smodel.Status{}). + Where("boost_of_id = ?", status.ID). + Where("account_id = ?", accountID) + + return exists(ctx, q) +} + +func (s *statusDB) IsStatusMutedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) { + q := s.conn. + NewSelect(). + Model(>smodel.StatusMute{}). + Where("status_id = ?", status.ID). + Where("account_id = ?", accountID) + + return exists(ctx, q) +} + +func (s *statusDB) IsStatusBookmarkedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) { + q := s.conn. + NewSelect(). + Model(>smodel.StatusBookmark{}). + Where("status_id = ?", status.ID). + Where("account_id = ?", accountID) + + return exists(ctx, q) +} + +func (s *statusDB) GetStatusFaves(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.StatusFave, db.Error) { + faves := []*gtsmodel.StatusFave{} + + q := s.newFaveQ(&faves). + Where("status_id = ?", status.ID) + + err := processErrorResponse(q.Scan(ctx)) + return faves, err +} + +func (s *statusDB) GetStatusReblogs(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.Status, db.Error) { + reblogs := []*gtsmodel.Status{} + + q := s.newStatusQ(&reblogs). + Where("boost_of_id = ?", status.ID) + + err := processErrorResponse(q.Scan(ctx)) + return reblogs, err +} diff --git a/internal/db/pg/status_test.go b/internal/db/bundb/status_test.go index 8a185757c..513000577 100644 --- a/internal/db/pg/status_test.go +++ b/internal/db/bundb/status_test.go @@ -16,9 +16,10 @@ along with this program. If not, see <http://www.gnu.org/licenses/>. */ -package pg_test +package bundb_test import ( + "context" "fmt" "testing" "time" @@ -28,7 +29,7 @@ import ( ) type StatusTestSuite struct { - PGStandardTestSuite + BunDBStandardTestSuite } func (suite *StatusTestSuite) SetupSuite() { @@ -56,8 +57,9 @@ func (suite *StatusTestSuite) TearDownTest() { } func (suite *StatusTestSuite) TestGetStatusByID() { - status, err := suite.db.GetStatusByID(suite.testStatuses["local_account_1_status_1"].ID) + status, err := suite.db.GetStatusByID(context.Background(), suite.testStatuses["local_account_1_status_1"].ID) if err != nil { + fmt.Println(err.Error()) suite.FailNow(err.Error()) } suite.NotNil(status) @@ -70,7 +72,7 @@ func (suite *StatusTestSuite) TestGetStatusByID() { } func (suite *StatusTestSuite) TestGetStatusByURI() { - status, err := suite.db.GetStatusByURI(suite.testStatuses["local_account_1_status_1"].URI) + status, err := suite.db.GetStatusByURI(context.Background(), suite.testStatuses["local_account_1_status_1"].URI) if err != nil { suite.FailNow(err.Error()) } @@ -84,7 +86,7 @@ func (suite *StatusTestSuite) TestGetStatusByURI() { } func (suite *StatusTestSuite) TestGetStatusWithExtras() { - status, err := suite.db.GetStatusByID(suite.testStatuses["admin_account_status_1"].ID) + status, err := suite.db.GetStatusByID(context.Background(), suite.testStatuses["admin_account_status_1"].ID) if err != nil { suite.FailNow(err.Error()) } @@ -97,7 +99,7 @@ func (suite *StatusTestSuite) TestGetStatusWithExtras() { } func (suite *StatusTestSuite) TestGetStatusWithMention() { - status, err := suite.db.GetStatusByID(suite.testStatuses["local_account_2_status_5"].ID) + status, err := suite.db.GetStatusByID(context.Background(), suite.testStatuses["local_account_2_status_5"].ID) if err != nil { suite.FailNow(err.Error()) } @@ -112,18 +114,18 @@ func (suite *StatusTestSuite) TestGetStatusWithMention() { func (suite *StatusTestSuite) TestGetStatusTwice() { before1 := time.Now() - _, err := suite.db.GetStatusByURI(suite.testStatuses["local_account_1_status_1"].URI) + _, err := suite.db.GetStatusByURI(context.Background(), suite.testStatuses["local_account_1_status_1"].URI) suite.NoError(err) after1 := time.Now() duration1 := after1.Sub(before1) - fmt.Println(duration1.Nanoseconds()) + fmt.Println(duration1.Milliseconds()) before2 := time.Now() - _, err = suite.db.GetStatusByURI(suite.testStatuses["local_account_1_status_1"].URI) + _, err = suite.db.GetStatusByURI(context.Background(), suite.testStatuses["local_account_1_status_1"].URI) suite.NoError(err) after2 := time.Now() duration2 := after2.Sub(before2) - fmt.Println(duration2.Nanoseconds()) + fmt.Println(duration2.Milliseconds()) // second retrieval should be several orders faster since it will be cached now suite.Less(duration2, duration1) diff --git a/internal/db/pg/timeline.go b/internal/db/bundb/timeline.go index fa8b07aab..b62ad4c50 100644 --- a/internal/db/pg/timeline.go +++ b/internal/db/bundb/timeline.go @@ -16,43 +16,35 @@ along with this program. If not, see <http://www.gnu.org/licenses/>. */ -package pg +package bundb import ( "context" + "database/sql" "sort" - "github.com/go-pg/pg/v10" "github.com/sirupsen/logrus" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/uptrace/bun" ) type timelineDB struct { config *config.Config - conn *pg.DB + conn *bun.DB log *logrus.Logger - cancel context.CancelFunc } -func (t *timelineDB) GetHomeTimeline(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) { +func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) { statuses := []*gtsmodel.Status{} - q := t.conn.Model(&statuses) + q := t.conn. + NewSelect(). + Model(&statuses) q = q.ColumnExpr("status.*"). // Find out who accountID follows. Join("LEFT JOIN follows AS f ON f.target_account_id = status.account_id"). - // Use a WhereGroup here to specify that we want EITHER statuses posted by accounts that accountID follows, - // OR statuses posted by accountID itself (since a user should be able to see their own statuses). - // - // This is equivalent to something like WHERE ... AND (... OR ...) - // See: https://pg.uptrace.dev/queries/#select - WhereGroup(func(q *pg.Query) (*pg.Query, error) { - q = q.WhereOr("f.account_id = ?", accountID). - WhereOr("status.account_id = ?", accountID) - return q, nil - }). // Sort by highest ID (newest) to lowest ID (oldest) Order("status.id DESC") @@ -81,29 +73,32 @@ func (t *timelineDB) GetHomeTimeline(accountID string, maxID string, sinceID str q = q.Limit(limit) } - err := q.Select() - if err != nil { - if err == pg.ErrNoRows { - return nil, db.ErrNoEntries - } - return nil, err + // Use a WhereGroup here to specify that we want EITHER statuses posted by accounts that accountID follows, + // OR statuses posted by accountID itself (since a user should be able to see their own statuses). + // + // This is equivalent to something like WHERE ... AND (... OR ...) + // See: https://bun.uptrace.dev/guide/queries.html#select + whereGroup := func(*bun.SelectQuery) *bun.SelectQuery { + return q. + WhereOr("f.account_id = ?", accountID). + WhereOr("status.account_id = ?", accountID) } - if len(statuses) == 0 { - return nil, db.ErrNoEntries - } + q = q.WhereGroup(" AND ", whereGroup) - return statuses, nil + return statuses, processErrorResponse(q.Scan(ctx)) } -func (t *timelineDB) GetPublicTimeline(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) { +func (t *timelineDB) GetPublicTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) { statuses := []*gtsmodel.Status{} - q := t.conn.Model(&statuses). + q := t.conn. + NewSelect(). + Model(&statuses). Where("visibility = ?", gtsmodel.VisibilityPublic). - Where("? IS NULL", pg.Ident("in_reply_to_id")). - Where("? IS NULL", pg.Ident("in_reply_to_uri")). - Where("? IS NULL", pg.Ident("boost_of_id")). + Where("? IS NULL", bun.Ident("in_reply_to_id")). + Where("? IS NULL", bun.Ident("in_reply_to_uri")). + Where("? IS NULL", bun.Ident("boost_of_id")). Order("status.id DESC") if maxID != "" { @@ -126,28 +121,18 @@ func (t *timelineDB) GetPublicTimeline(accountID string, maxID string, sinceID s q = q.Limit(limit) } - err := q.Select() - if err != nil { - if err == pg.ErrNoRows { - return nil, db.ErrNoEntries - } - return nil, err - } - - if len(statuses) == 0 { - return nil, db.ErrNoEntries - } - - return statuses, nil + return statuses, processErrorResponse(q.Scan(ctx)) } // TODO optimize this query and the logic here, because it's slow as balls -- it takes like a literal second to return with a limit of 20! // It might be worth serving it through a timeline instead of raw DB queries, like we do for Home feeds. -func (t *timelineDB) GetFavedTimeline(accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, db.Error) { +func (t *timelineDB) GetFavedTimeline(ctx context.Context, accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, db.Error) { faves := []*gtsmodel.StatusFave{} - fq := t.conn.Model(&faves). + fq := t.conn. + NewSelect(). + Model(&faves). Where("account_id = ?", accountID). Order("id DESC") @@ -163,9 +148,9 @@ func (t *timelineDB) GetFavedTimeline(accountID string, maxID string, minID stri fq = fq.Limit(limit) } - err := fq.Select() + err := fq.Scan(ctx) if err != nil { - if err == pg.ErrNoRows { + if err == sql.ErrNoRows { return nil, "", "", db.ErrNoEntries } return nil, "", "", err @@ -185,9 +170,13 @@ func (t *timelineDB) GetFavedTimeline(accountID string, maxID string, minID stri } statuses := []*gtsmodel.Status{} - err = t.conn.Model(&statuses).Where("id IN (?)", pg.In(in)).Select() + err = t.conn. + NewSelect(). + Model(&statuses). + Where("id IN (?)", bun.In(in)). + Scan(ctx) if err != nil { - if err == pg.ErrNoRows { + if err == sql.ErrNoRows { return nil, "", "", db.ErrNoEntries } return nil, "", "", err diff --git a/internal/db/bundb/util.go b/internal/db/bundb/util.go new file mode 100644 index 000000000..115d18de2 --- /dev/null +++ b/internal/db/bundb/util.go @@ -0,0 +1,78 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package bundb + +import ( + "context" + "strings" + + "database/sql" + + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/uptrace/bun" +) + +// processErrorResponse parses the given error and returns an appropriate DBError. +func processErrorResponse(err error) db.Error { + switch err { + case nil: + return nil + case sql.ErrNoRows: + return db.ErrNoEntries + default: + if strings.Contains(err.Error(), "duplicate key value violates unique constraint") { + return db.ErrAlreadyExists + } + return err + } +} + +func exists(ctx context.Context, q *bun.SelectQuery) (bool, db.Error) { + count, err := q.Count(ctx) + + exists := count != 0 + + err = processErrorResponse(err) + + if err != nil { + if err == db.ErrNoEntries { + return false, nil + } + return false, err + } + + return exists, nil +} + +func notExists(ctx context.Context, q *bun.SelectQuery) (bool, db.Error) { + count, err := q.Count(ctx) + + notExists := count == 0 + + err = processErrorResponse(err) + + if err != nil { + if err == db.ErrNoEntries { + return true, nil + } + return false, err + } + + return notExists, nil +} diff --git a/internal/db/db.go b/internal/db/db.go index d6ac883e4..ec94fcfe7 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -19,6 +19,8 @@ package db import ( + "context" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) @@ -38,6 +40,7 @@ type DB interface { Mention Notification Relationship + Session Status Timeline @@ -52,7 +55,7 @@ type DB interface { // // Note: this func doesn't/shouldn't do any manipulation of the accounts in the DB, it's just for checking // if they exist in the db and conveniently returning them if they do. - MentionStringsToMentions(targetAccounts []string, originAccountID string, statusID string) ([]*gtsmodel.Mention, error) + MentionStringsToMentions(ctx context.Context, targetAccounts []string, originAccountID string, statusID string) ([]*gtsmodel.Mention, error) // TagStringsToTags takes a slice of deduplicated, lowercase tags in the form "somehashtag", which have been // used in a status. It takes the id of the account that wrote the status, and the id of the status itself, and then @@ -61,7 +64,7 @@ type DB interface { // // Note: this func doesn't/shouldn't do any manipulation of the tags in the DB, it's just for checking // if they exist in the db already, and conveniently returning them, or creating new tag structs. - TagStringsToTags(tags []string, originAccountID string, statusID string) ([]*gtsmodel.Tag, error) + TagStringsToTags(ctx context.Context, tags []string, originAccountID string, statusID string) ([]*gtsmodel.Tag, error) // EmojiStringsToEmojis takes a slice of deduplicated, lowercase emojis in the form ":emojiname:", which have been // used in a status. It takes the id of the account that wrote the status, and the id of the status itself, and then @@ -69,5 +72,5 @@ type DB interface { // // Note: this func doesn't/shouldn't do any manipulation of the emoji in the DB, it's just for checking // if they exist in the db and conveniently returning them if they do. - EmojiStringsToEmojis(emojis []string, originAccountID string, statusID string) ([]*gtsmodel.Emoji, error) + EmojiStringsToEmojis(ctx context.Context, emojis []string, originAccountID string, statusID string) ([]*gtsmodel.Emoji, error) } diff --git a/internal/db/domain.go b/internal/db/domain.go index a6583c80c..df50a6770 100644 --- a/internal/db/domain.go +++ b/internal/db/domain.go @@ -18,19 +18,22 @@ package db -import "net/url" +import ( + "context" + "net/url" +) // Domain contains DB functions related to domains and domain blocks. type Domain interface { // IsDomainBlocked checks if an instance-level domain block exists for the given domain string (eg., `example.org`). - IsDomainBlocked(domain string) (bool, Error) + IsDomainBlocked(ctx context.Context, domain string) (bool, Error) // AreDomainsBlocked checks if an instance-level domain block exists for any of the given domains strings, and returns true if even one is found. - AreDomainsBlocked(domains []string) (bool, Error) + AreDomainsBlocked(ctx context.Context, domains []string) (bool, Error) // IsURIBlocked checks if an instance-level domain block exists for the `host` in the given URI (eg., `https://example.org/users/whatever`). - IsURIBlocked(uri *url.URL) (bool, Error) + IsURIBlocked(ctx context.Context, uri *url.URL) (bool, Error) // AreURIsBlocked checks if an instance-level domain block exists for any `host` in the given URI slice, and returns true if even one is found. - AreURIsBlocked(uris []*url.URL) (bool, Error) + AreURIsBlocked(ctx context.Context, uris []*url.URL) (bool, Error) } diff --git a/internal/db/instance.go b/internal/db/instance.go index 1f7c83e4f..dcd978a81 100644 --- a/internal/db/instance.go +++ b/internal/db/instance.go @@ -18,19 +18,23 @@ package db -import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +import ( + "context" + + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) // Instance contains functions for instance-level actions (counting instance users etc.). type Instance interface { // CountInstanceUsers returns the number of known accounts registered with the given domain. - CountInstanceUsers(domain string) (int, Error) + CountInstanceUsers(ctx context.Context, domain string) (int, Error) // CountInstanceStatuses returns the number of known statuses posted from the given domain. - CountInstanceStatuses(domain string) (int, Error) + CountInstanceStatuses(ctx context.Context, domain string) (int, Error) // CountInstanceDomains returns the number of known instances known that the given domain federates with. - CountInstanceDomains(domain string) (int, Error) + CountInstanceDomains(ctx context.Context, domain string) (int, Error) // GetInstanceAccounts returns a slice of accounts from the given instance, arranged by ID. - GetInstanceAccounts(domain string, maxID string, limit int) ([]*gtsmodel.Account, Error) + GetInstanceAccounts(ctx context.Context, domain string, maxID string, limit int) ([]*gtsmodel.Account, Error) } diff --git a/internal/db/media.go b/internal/db/media.go index db4db3411..b779dd276 100644 --- a/internal/db/media.go +++ b/internal/db/media.go @@ -18,10 +18,14 @@ package db -import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +import ( + "context" + + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) // Media contains functions related to creating/getting/removing media attachments. type Media interface { // GetAttachmentByID gets a single attachment by its ID - GetAttachmentByID(id string) (*gtsmodel.MediaAttachment, Error) + GetAttachmentByID(ctx context.Context, id string) (*gtsmodel.MediaAttachment, Error) } diff --git a/internal/db/mention.go b/internal/db/mention.go index cb1c56dc1..b9b45546a 100644 --- a/internal/db/mention.go +++ b/internal/db/mention.go @@ -18,13 +18,17 @@ package db -import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +import ( + "context" + + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) // Mention contains functions for getting/creating mentions in the database. type Mention interface { // GetMention gets a single mention by ID - GetMention(id string) (*gtsmodel.Mention, Error) + GetMention(ctx context.Context, id string) (*gtsmodel.Mention, Error) // GetMentions gets multiple mentions. - GetMentions(ids []string) ([]*gtsmodel.Mention, Error) + GetMentions(ctx context.Context, ids []string) ([]*gtsmodel.Mention, Error) } diff --git a/internal/db/notification.go b/internal/db/notification.go index 326f0f149..09c17f031 100644 --- a/internal/db/notification.go +++ b/internal/db/notification.go @@ -18,14 +18,18 @@ package db -import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +import ( + "context" + + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) // Notification contains functions for creating and getting notifications. type Notification interface { // GetNotifications returns a slice of notifications that pertain to the given accountID. // // Returned notifications will be ordered ID descending (ie., highest/newest to lowest/oldest). - GetNotifications(accountID string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, Error) + GetNotifications(ctx context.Context, accountID string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, Error) // GetNotification returns one notification according to its id. - GetNotification(id string) (*gtsmodel.Notification, Error) + GetNotification(ctx context.Context, id string) (*gtsmodel.Notification, Error) } diff --git a/internal/db/pg/basic.go b/internal/db/pg/basic.go deleted file mode 100644 index 6e76b4450..000000000 --- a/internal/db/pg/basic.go +++ /dev/null @@ -1,205 +0,0 @@ -/* - GoToSocial - Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org - - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU Affero General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU Affero General Public License for more details. - - You should have received a copy of the GNU Affero General Public License - along with this program. If not, see <http://www.gnu.org/licenses/>. -*/ - -package pg - -import ( - "context" - "errors" - "fmt" - "strings" - - "github.com/go-pg/pg/v10" - "github.com/go-pg/pg/v10/orm" - "github.com/sirupsen/logrus" - "github.com/superseriousbusiness/gotosocial/internal/config" - "github.com/superseriousbusiness/gotosocial/internal/db" -) - -type basicDB struct { - config *config.Config - conn *pg.DB - log *logrus.Logger - cancel context.CancelFunc -} - -func (b *basicDB) Put(i interface{}) db.Error { - _, err := b.conn.Model(i).Insert(i) - if err != nil && strings.Contains(err.Error(), "duplicate key value violates unique constraint") { - return db.ErrAlreadyExists - } - return err -} - -func (b *basicDB) GetByID(id string, i interface{}) db.Error { - if err := b.conn.Model(i).Where("id = ?", id).Select(); err != nil { - if err == pg.ErrNoRows { - return db.ErrNoEntries - } - return err - - } - return nil -} - -func (b *basicDB) GetWhere(where []db.Where, i interface{}) db.Error { - if len(where) == 0 { - return errors.New("no queries provided") - } - - q := b.conn.Model(i) - for _, w := range where { - - if w.Value == nil { - q = q.Where("? IS NULL", pg.Ident(w.Key)) - } else { - if w.CaseInsensitive { - q = q.Where("LOWER(?) = LOWER(?)", pg.Safe(w.Key), w.Value) - } else { - q = q.Where("? = ?", pg.Safe(w.Key), w.Value) - } - } - } - - if err := q.Select(); err != nil { - if err == pg.ErrNoRows { - return db.ErrNoEntries - } - return err - } - return nil -} - -func (b *basicDB) GetAll(i interface{}) db.Error { - if err := b.conn.Model(i).Select(); err != nil { - if err == pg.ErrNoRows { - return db.ErrNoEntries - } - return err - } - return nil -} - -func (b *basicDB) DeleteByID(id string, i interface{}) db.Error { - if _, err := b.conn.Model(i).Where("id = ?", id).Delete(); err != nil { - // if there are no rows *anyway* then that's fine - // just return err if there's an actual error - if err != pg.ErrNoRows { - return err - } - } - return nil -} - -func (b *basicDB) DeleteWhere(where []db.Where, i interface{}) db.Error { - if len(where) == 0 { - return errors.New("no queries provided") - } - - q := b.conn.Model(i) - for _, w := range where { - q = q.Where("? = ?", pg.Safe(w.Key), w.Value) - } - - if _, err := q.Delete(); err != nil { - // if there are no rows *anyway* then that's fine - // just return err if there's an actual error - if err != pg.ErrNoRows { - return err - } - } - return nil -} - -func (b *basicDB) Upsert(i interface{}, conflictColumn string) db.Error { - if _, err := b.conn.Model(i).OnConflict(fmt.Sprintf("(%s) DO UPDATE", conflictColumn)).Insert(); err != nil { - if err == pg.ErrNoRows { - return db.ErrNoEntries - } - return err - } - return nil -} - -func (b *basicDB) UpdateByID(id string, i interface{}) db.Error { - if _, err := b.conn.Model(i).Where("id = ?", id).OnConflict("(id) DO UPDATE").Insert(); err != nil { - if err == pg.ErrNoRows { - return db.ErrNoEntries - } - return err - } - return nil -} - -func (b *basicDB) UpdateOneByID(id string, key string, value interface{}, i interface{}) db.Error { - _, err := b.conn.Model(i).Set("? = ?", pg.Safe(key), value).Where("id = ?", id).Update() - return err -} - -func (b *basicDB) UpdateWhere(where []db.Where, key string, value interface{}, i interface{}) db.Error { - q := b.conn.Model(i) - - for _, w := range where { - if w.Value == nil { - q = q.Where("? IS NULL", pg.Ident(w.Key)) - } else { - if w.CaseInsensitive { - q = q.Where("LOWER(?) = LOWER(?)", pg.Safe(w.Key), w.Value) - } else { - q = q.Where("? = ?", pg.Safe(w.Key), w.Value) - } - } - } - - q = q.Set("? = ?", pg.Safe(key), value) - - _, err := q.Update() - - return err -} - -func (b *basicDB) CreateTable(i interface{}) db.Error { - return b.conn.Model(i).CreateTable(&orm.CreateTableOptions{ - IfNotExists: true, - }) -} - -func (b *basicDB) DropTable(i interface{}) db.Error { - return b.conn.Model(i).DropTable(&orm.DropTableOptions{ - IfExists: true, - }) -} - -func (b *basicDB) RegisterTable(i interface{}) db.Error { - orm.RegisterTable(i) - return nil -} - -func (b *basicDB) IsHealthy(ctx context.Context) db.Error { - return b.conn.Ping(ctx) -} - -func (b *basicDB) Stop(ctx context.Context) db.Error { - b.log.Info("closing db connection") - if err := b.conn.Close(); err != nil { - // only cancel if there's a problem closing the db - b.cancel() - return err - } - return nil -} diff --git a/internal/db/pg/status.go b/internal/db/pg/status.go deleted file mode 100644 index 99790428e..000000000 --- a/internal/db/pg/status.go +++ /dev/null @@ -1,318 +0,0 @@ -/* - GoToSocial - Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org - - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU Affero General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU Affero General Public License for more details. - - You should have received a copy of the GNU Affero General Public License - along with this program. If not, see <http://www.gnu.org/licenses/>. -*/ - -package pg - -import ( - "container/list" - "context" - "errors" - "time" - - "github.com/go-pg/pg/v10" - "github.com/go-pg/pg/v10/orm" - "github.com/sirupsen/logrus" - "github.com/superseriousbusiness/gotosocial/internal/cache" - "github.com/superseriousbusiness/gotosocial/internal/config" - "github.com/superseriousbusiness/gotosocial/internal/db" - "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" -) - -type statusDB struct { - config *config.Config - conn *pg.DB - log *logrus.Logger - cancel context.CancelFunc - cache cache.Cache -} - -func (s *statusDB) cacheStatus(id string, status *gtsmodel.Status) { - if s.cache == nil { - s.cache = cache.New() - } - - if err := s.cache.Store(id, status); err != nil { - s.log.Panicf("statusDB: error storing in cache: %s", err) - } -} - -func (s *statusDB) statusCached(id string) (*gtsmodel.Status, bool) { - if s.cache == nil { - s.cache = cache.New() - return nil, false - } - - sI, err := s.cache.Fetch(id) - if err != nil || sI == nil { - return nil, false - } - - status, ok := sI.(*gtsmodel.Status) - if !ok { - s.log.Panicf("statusDB: cached interface with key %s was not a status", id) - } - - return status, true -} - -func (s *statusDB) newStatusQ(status interface{}) *orm.Query { - return s.conn.Model(status). - Relation("Attachments"). - Relation("Tags"). - Relation("Mentions"). - Relation("Emojis"). - Relation("Account"). - Relation("InReplyTo"). - Relation("InReplyToAccount"). - Relation("BoostOf"). - Relation("BoostOfAccount"). - Relation("CreatedWithApplication") -} - -func (s *statusDB) newFaveQ(faves interface{}) *orm.Query { - return s.conn.Model(faves). - Relation("Account"). - Relation("TargetAccount"). - Relation("Status") -} - -func (s *statusDB) GetStatusByID(id string) (*gtsmodel.Status, db.Error) { - if status, cached := s.statusCached(id); cached { - return status, nil - } - - status := >smodel.Status{} - - q := s.newStatusQ(status). - Where("status.id = ?", id) - - err := processErrorResponse(q.Select()) - - if err == nil && status != nil { - s.cacheStatus(id, status) - } - - return status, err -} - -func (s *statusDB) GetStatusByURI(uri string) (*gtsmodel.Status, db.Error) { - if status, cached := s.statusCached(uri); cached { - return status, nil - } - - status := >smodel.Status{} - - q := s.newStatusQ(status). - Where("LOWER(status.uri) = LOWER(?)", uri) - - err := processErrorResponse(q.Select()) - - if err == nil && status != nil { - s.cacheStatus(uri, status) - } - - return status, err -} - -func (s *statusDB) GetStatusByURL(uri string) (*gtsmodel.Status, db.Error) { - if status, cached := s.statusCached(uri); cached { - return status, nil - } - - status := >smodel.Status{} - - q := s.newStatusQ(status). - Where("LOWER(status.url) = LOWER(?)", uri) - - err := processErrorResponse(q.Select()) - - if err == nil && status != nil { - s.cacheStatus(uri, status) - } - - return status, err -} - -func (s *statusDB) PutStatus(status *gtsmodel.Status) db.Error { - transaction := func(tx *pg.Tx) error { - // create links between this status and any emojis it uses - for _, i := range status.EmojiIDs { - if _, err := tx.Model(>smodel.StatusToEmoji{ - StatusID: status.ID, - EmojiID: i, - }).Insert(); err != nil { - return err - } - } - - // create links between this status and any tags it uses - for _, i := range status.TagIDs { - if _, err := tx.Model(>smodel.StatusToTag{ - StatusID: status.ID, - TagID: i, - }).Insert(); err != nil { - return err - } - } - - // change the status ID of the media attachments to the new status - for _, a := range status.Attachments { - a.StatusID = status.ID - a.UpdatedAt = time.Now() - if _, err := s.conn.Model(a). - Where("id = ?", a.ID). - Update(); err != nil { - return err - } - } - - _, err := tx.Model(status).Insert() - return err - } - - return processErrorResponse(s.conn.RunInTransaction(context.Background(), transaction)) -} - -func (s *statusDB) GetStatusParents(status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, db.Error) { - parents := []*gtsmodel.Status{} - s.statusParent(status, &parents, onlyDirect) - - return parents, nil -} - -func (s *statusDB) statusParent(status *gtsmodel.Status, foundStatuses *[]*gtsmodel.Status, onlyDirect bool) { - if status.InReplyToID == "" { - return - } - - parentStatus, err := s.GetStatusByID(status.InReplyToID) - if err == nil { - *foundStatuses = append(*foundStatuses, parentStatus) - } - - if onlyDirect { - return - } - - s.statusParent(parentStatus, foundStatuses, false) -} - -func (s *statusDB) GetStatusChildren(status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, db.Error) { - foundStatuses := &list.List{} - foundStatuses.PushFront(status) - s.statusChildren(status, foundStatuses, onlyDirect, minID) - - children := []*gtsmodel.Status{} - for e := foundStatuses.Front(); e != nil; e = e.Next() { - entry, ok := e.Value.(*gtsmodel.Status) - if !ok { - panic(errors.New("entry in foundStatuses was not a *gtsmodel.Status")) - } - - // only append children, not the overall parent status - if entry.ID != status.ID { - children = append(children, entry) - } - } - - return children, nil -} - -func (s *statusDB) statusChildren(status *gtsmodel.Status, foundStatuses *list.List, onlyDirect bool, minID string) { - immediateChildren := []*gtsmodel.Status{} - - q := s.conn.Model(&immediateChildren).Where("in_reply_to_id = ?", status.ID) - if minID != "" { - q = q.Where("status.id > ?", minID) - } - - if err := q.Select(); err != nil { - return - } - - for _, child := range immediateChildren { - insertLoop: - for e := foundStatuses.Front(); e != nil; e = e.Next() { - entry, ok := e.Value.(*gtsmodel.Status) - if !ok { - panic(errors.New("entry in foundStatuses was not a *gtsmodel.Status")) - } - - if child.InReplyToAccountID != "" && entry.ID == child.InReplyToID { - foundStatuses.InsertAfter(child, e) - break insertLoop - } - } - - // only do one loop if we only want direct children - if onlyDirect { - return - } - s.statusChildren(child, foundStatuses, false, minID) - } -} - -func (s *statusDB) CountStatusReplies(status *gtsmodel.Status) (int, db.Error) { - return s.conn.Model(>smodel.Status{}).Where("in_reply_to_id = ?", status.ID).Count() -} - -func (s *statusDB) CountStatusReblogs(status *gtsmodel.Status) (int, db.Error) { - return s.conn.Model(>smodel.Status{}).Where("boost_of_id = ?", status.ID).Count() -} - -func (s *statusDB) CountStatusFaves(status *gtsmodel.Status) (int, db.Error) { - return s.conn.Model(>smodel.StatusFave{}).Where("status_id = ?", status.ID).Count() -} - -func (s *statusDB) IsStatusFavedBy(status *gtsmodel.Status, accountID string) (bool, db.Error) { - return s.conn.Model(>smodel.StatusFave{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists() -} - -func (s *statusDB) IsStatusRebloggedBy(status *gtsmodel.Status, accountID string) (bool, db.Error) { - return s.conn.Model(>smodel.Status{}).Where("boost_of_id = ?", status.ID).Where("account_id = ?", accountID).Exists() -} - -func (s *statusDB) IsStatusMutedBy(status *gtsmodel.Status, accountID string) (bool, db.Error) { - return s.conn.Model(>smodel.StatusMute{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists() -} - -func (s *statusDB) IsStatusBookmarkedBy(status *gtsmodel.Status, accountID string) (bool, db.Error) { - return s.conn.Model(>smodel.StatusBookmark{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists() -} - -func (s *statusDB) GetStatusFaves(status *gtsmodel.Status) ([]*gtsmodel.StatusFave, db.Error) { - faves := []*gtsmodel.StatusFave{} - - q := s.newFaveQ(&faves). - Where("status_id = ?", status.ID) - - err := processErrorResponse(q.Select()) - - return faves, err -} - -func (s *statusDB) GetStatusReblogs(status *gtsmodel.Status) ([]*gtsmodel.Status, db.Error) { - reblogs := []*gtsmodel.Status{} - - q := s.newStatusQ(&reblogs). - Where("boost_of_id = ?", status.ID) - - err := processErrorResponse(q.Select()) - - return reblogs, err -} diff --git a/internal/db/pg/util.go b/internal/db/pg/util.go deleted file mode 100644 index 17c09b720..000000000 --- a/internal/db/pg/util.go +++ /dev/null @@ -1,25 +0,0 @@ -package pg - -import ( - "strings" - - "github.com/go-pg/pg/v10" - "github.com/superseriousbusiness/gotosocial/internal/db" -) - -// processErrorResponse parses the given error and returns an appropriate DBError. -func processErrorResponse(err error) db.Error { - switch err { - case nil: - return nil - case pg.ErrNoRows: - return db.ErrNoEntries - case pg.ErrMultiRows: - return db.ErrMultipleEntries - default: - if strings.Contains(err.Error(), "duplicate key value violates unique constraint") { - return db.ErrAlreadyExists - } - return err - } -} diff --git a/internal/db/relationship.go b/internal/db/relationship.go index 85f64d72b..804526425 100644 --- a/internal/db/relationship.go +++ b/internal/db/relationship.go @@ -18,54 +18,58 @@ package db -import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +import ( + "context" + + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) // Relationship contains functions for getting or modifying the relationship between two accounts. type Relationship interface { // IsBlocked checks whether account 1 has a block in place against block2. // If eitherDirection is true, then the function returns true if account1 blocks account2, OR if account2 blocks account1. - IsBlocked(account1 string, account2 string, eitherDirection bool) (bool, Error) + IsBlocked(ctx context.Context, account1 string, account2 string, eitherDirection bool) (bool, Error) // GetBlock returns the block from account1 targeting account2, if it exists, or an error if it doesn't. // // Because this is slower than Blocked, only use it if you need the actual Block struct for some reason, // not if you're just checking for the existence of a block. - GetBlock(account1 string, account2 string) (*gtsmodel.Block, Error) + GetBlock(ctx context.Context, account1 string, account2 string) (*gtsmodel.Block, Error) // GetRelationship retrieves the relationship of the targetAccount to the requestingAccount. - GetRelationship(requestingAccount string, targetAccount string) (*gtsmodel.Relationship, Error) + GetRelationship(ctx context.Context, requestingAccount string, targetAccount string) (*gtsmodel.Relationship, Error) // IsFollowing returns true if sourceAccount follows target account, or an error if something goes wrong while finding out. - IsFollowing(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, Error) + IsFollowing(ctx context.Context, sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, Error) // IsFollowRequested returns true if sourceAccount has requested to follow target account, or an error if something goes wrong while finding out. - IsFollowRequested(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, Error) + IsFollowRequested(ctx context.Context, sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, Error) // IsMutualFollowing returns true if account1 and account2 both follow each other, or an error if something goes wrong while finding out. - IsMutualFollowing(account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, Error) + IsMutualFollowing(ctx context.Context, account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, Error) // AcceptFollowRequest moves a follow request in the database from the follow_requests table to the follows table. // In other words, it should create the follow, and delete the existing follow request. // // It will return the newly created follow for further processing. - AcceptFollowRequest(originAccountID string, targetAccountID string) (*gtsmodel.Follow, Error) + AcceptFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.Follow, Error) // GetAccountFollowRequests returns all follow requests targeting the given account. - GetAccountFollowRequests(accountID string) ([]*gtsmodel.FollowRequest, Error) + GetAccountFollowRequests(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, Error) // GetAccountFollows returns a slice of follows owned by the given accountID. - GetAccountFollows(accountID string) ([]*gtsmodel.Follow, Error) + GetAccountFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, Error) // CountAccountFollows returns the amount of accounts that the given accountID is following. // // If localOnly is set to true, then only follows from *this instance* will be returned. - CountAccountFollows(accountID string, localOnly bool) (int, Error) + CountAccountFollows(ctx context.Context, accountID string, localOnly bool) (int, Error) // GetAccountFollowedBy fetches follows that target given accountID. // // If localOnly is set to true, then only follows from *this instance* will be returned. - GetAccountFollowedBy(accountID string, localOnly bool) ([]*gtsmodel.Follow, Error) + GetAccountFollowedBy(ctx context.Context, accountID string, localOnly bool) ([]*gtsmodel.Follow, Error) // CountAccountFollowedBy returns the amounts that the given ID is followed by. - CountAccountFollowedBy(accountID string, localOnly bool) (int, Error) + CountAccountFollowedBy(ctx context.Context, accountID string, localOnly bool) (int, Error) } diff --git a/internal/federation/dereferencing/blocked.go b/internal/db/session.go index c8a4c6ade..ae13dccce 100644 --- a/internal/federation/dereferencing/blocked.go +++ b/internal/db/session.go @@ -16,26 +16,16 @@ along with this program. If not, see <http://www.gnu.org/licenses/>. */ -package dereferencing +package db import ( - "github.com/superseriousbusiness/gotosocial/internal/db" + "context" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) -func (d *deref) blockedDomain(host string) (bool, error) { - b := >smodel.DomainBlock{} - err := d.db.GetWhere([]db.Where{{Key: "domain", Value: host, CaseInsensitive: true}}, b) - if err == nil { - // block exists - return true, nil - } - - if err == db.ErrNoEntries { - // there are no entries so there's no block - return false, nil - } - - // there's an actual error - return false, err +// Session handles getting/creation of router sessions. +type Session interface { + GetSession(ctx context.Context) (*gtsmodel.RouterSession, Error) + CreateSession(ctx context.Context) (*gtsmodel.RouterSession, Error) } diff --git a/internal/db/status.go b/internal/db/status.go index 9d206c198..7430433c4 100644 --- a/internal/db/status.go +++ b/internal/db/status.go @@ -18,58 +18,62 @@ package db -import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +import ( + "context" + + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) // Status contains functions for getting statuses, creating statuses, and checking various other fields on statuses. type Status interface { // GetStatusByID returns one status from the database, with all rel fields populated (if possible). - GetStatusByID(id string) (*gtsmodel.Status, Error) + GetStatusByID(ctx context.Context, id string) (*gtsmodel.Status, Error) // GetStatusByURI returns one status from the database, with all rel fields populated (if possible). - GetStatusByURI(uri string) (*gtsmodel.Status, Error) + GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.Status, Error) // GetStatusByURL returns one status from the database, with all rel fields populated (if possible). - GetStatusByURL(uri string) (*gtsmodel.Status, Error) + GetStatusByURL(ctx context.Context, uri string) (*gtsmodel.Status, Error) // PutStatus stores one status in the database. - PutStatus(status *gtsmodel.Status) Error + PutStatus(ctx context.Context, status *gtsmodel.Status) Error // CountStatusReplies returns the amount of replies recorded for a status, or an error if something goes wrong - CountStatusReplies(status *gtsmodel.Status) (int, Error) + CountStatusReplies(ctx context.Context, status *gtsmodel.Status) (int, Error) // CountStatusReblogs returns the amount of reblogs/boosts recorded for a status, or an error if something goes wrong - CountStatusReblogs(status *gtsmodel.Status) (int, Error) + CountStatusReblogs(ctx context.Context, status *gtsmodel.Status) (int, Error) // CountStatusFaves returns the amount of faves/likes recorded for a status, or an error if something goes wrong - CountStatusFaves(status *gtsmodel.Status) (int, Error) + CountStatusFaves(ctx context.Context, status *gtsmodel.Status) (int, Error) // GetStatusParents gets the parent statuses of a given status. // // If onlyDirect is true, only the immediate parent will be returned. - GetStatusParents(status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, Error) + GetStatusParents(ctx context.Context, status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, Error) // GetStatusChildren gets the child statuses of a given status. // // If onlyDirect is true, only the immediate children will be returned. - GetStatusChildren(status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, Error) + GetStatusChildren(ctx context.Context, status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, Error) // IsStatusFavedBy checks if a given status has been faved by a given account ID - IsStatusFavedBy(status *gtsmodel.Status, accountID string) (bool, Error) + IsStatusFavedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, Error) // IsStatusRebloggedBy checks if a given status has been reblogged/boosted by a given account ID - IsStatusRebloggedBy(status *gtsmodel.Status, accountID string) (bool, Error) + IsStatusRebloggedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, Error) // IsStatusMutedBy checks if a given status has been muted by a given account ID - IsStatusMutedBy(status *gtsmodel.Status, accountID string) (bool, Error) + IsStatusMutedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, Error) // IsStatusBookmarkedBy checks if a given status has been bookmarked by a given account ID - IsStatusBookmarkedBy(status *gtsmodel.Status, accountID string) (bool, Error) + IsStatusBookmarkedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, Error) // GetStatusFaves returns a slice of faves/likes of the given status. // This slice will be unfiltered, not taking account of blocks and whatnot, so filter it before serving it back to a user. - GetStatusFaves(status *gtsmodel.Status) ([]*gtsmodel.StatusFave, Error) + GetStatusFaves(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.StatusFave, Error) // GetStatusReblogs returns a slice of statuses that are a boost/reblog of the given status. // This slice will be unfiltered, not taking account of blocks and whatnot, so filter it before serving it back to a user. - GetStatusReblogs(status *gtsmodel.Status) ([]*gtsmodel.Status, Error) + GetStatusReblogs(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.Status, Error) } diff --git a/internal/db/timeline.go b/internal/db/timeline.go index 74aa5c781..83fb3a959 100644 --- a/internal/db/timeline.go +++ b/internal/db/timeline.go @@ -18,20 +18,24 @@ package db -import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +import ( + "context" + + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) // Timeline contains functionality for retrieving home/public/faved etc timelines for an account. type Timeline interface { // GetHomeTimeline returns a slice of statuses from accounts that are followed by the given account id. // // Statuses should be returned in descending order of when they were created (newest first). - GetHomeTimeline(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, Error) + GetHomeTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, Error) // GetPublicTimeline fetches the account's PUBLIC timeline -- ie., posts and replies that are public. // It will use the given filters and try to return as many statuses as possible up to the limit. // // Statuses should be returned in descending order of when they were created (newest first). - GetPublicTimeline(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, Error) + GetPublicTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, Error) // GetFavedTimeline fetches the account's FAVED timeline -- ie., posts and replies that the requesting account has faved. // It will use the given filters and try to return as many statuses as possible up to the limit. @@ -40,5 +44,5 @@ type Timeline interface { // In other words, they'll be returned in descending order of when they were faved by the requesting user, not when they were created. // // Also note the extra return values, which correspond to the nextMaxID and prevMinID for building Link headers. - GetFavedTimeline(accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, Error) + GetFavedTimeline(ctx context.Context, accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, Error) } diff --git a/internal/federation/authenticate.go b/internal/federation/authenticate.go index 699691ca6..81ac84544 100644 --- a/internal/federation/authenticate.go +++ b/internal/federation/authenticate.go @@ -148,7 +148,7 @@ func (f *federator) AuthenticateFederatedRequest(ctx context.Context, requestedU // LOCAL ACCOUNT REQUEST // the request is coming from INSIDE THE HOUSE so skip the remote dereferencing l.Tracef("proceeding without dereference for local public key %s", requestingPublicKeyID) - if err := f.db.GetWhere([]db.Where{{Key: "public_key_uri", Value: requestingPublicKeyID.String()}}, requestingLocalAccount); err != nil { + if err := f.db.GetWhere(ctx, []db.Where{{Key: "public_key_uri", Value: requestingPublicKeyID.String()}}, requestingLocalAccount); err != nil { return nil, false, fmt.Errorf("couldn't get local account with public key uri %s from the database: %s", requestingPublicKeyID.String(), err) } publicKey = requestingLocalAccount.PublicKey @@ -156,7 +156,7 @@ func (f *federator) AuthenticateFederatedRequest(ctx context.Context, requestedU if err != nil { return nil, false, fmt.Errorf("error parsing url %s: %s", requestingLocalAccount.URI, err) } - } else if err := f.db.GetWhere([]db.Where{{Key: "public_key_uri", Value: requestingPublicKeyID.String()}}, requestingRemoteAccount); err == nil { + } else if err := f.db.GetWhere(ctx, []db.Where{{Key: "public_key_uri", Value: requestingPublicKeyID.String()}}, requestingRemoteAccount); err == nil { // REMOTE ACCOUNT REQUEST WITH KEY CACHED LOCALLY // this is a remote account and we already have the public key for it so use that l.Tracef("proceeding without dereference for cached public key %s", requestingPublicKeyID) @@ -170,7 +170,7 @@ func (f *federator) AuthenticateFederatedRequest(ctx context.Context, requestedU // the request is remote and we don't have the public key yet, // so we need to authenticate the request properly by dereferencing the remote key l.Tracef("proceeding with dereference for uncached public key %s", requestingPublicKeyID) - transport, err := f.transportController.NewTransportForUsername(requestedUsername) + transport, err := f.transportController.NewTransportForUsername(ctx, requestedUsername) if err != nil { return nil, false, fmt.Errorf("transport err: %s", err) } diff --git a/internal/federation/dereference.go b/internal/federation/dereference.go index 96a662e32..a09f0f84b 100644 --- a/internal/federation/dereference.go +++ b/internal/federation/dereference.go @@ -19,36 +19,37 @@ package federation import ( + "context" "net/url" "github.com/superseriousbusiness/gotosocial/internal/ap" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) -func (f *federator) GetRemoteAccount(username string, remoteAccountID *url.URL, refresh bool) (*gtsmodel.Account, bool, error) { - return f.dereferencer.GetRemoteAccount(username, remoteAccountID, refresh) +func (f *federator) GetRemoteAccount(ctx context.Context, username string, remoteAccountID *url.URL, refresh bool) (*gtsmodel.Account, bool, error) { + return f.dereferencer.GetRemoteAccount(ctx, username, remoteAccountID, refresh) } -func (f *federator) EnrichRemoteAccount(username string, account *gtsmodel.Account) (*gtsmodel.Account, error) { - return f.dereferencer.EnrichRemoteAccount(username, account) +func (f *federator) EnrichRemoteAccount(ctx context.Context, username string, account *gtsmodel.Account) (*gtsmodel.Account, error) { + return f.dereferencer.EnrichRemoteAccount(ctx, username, account) } -func (f *federator) GetRemoteStatus(username string, remoteStatusID *url.URL, refresh bool) (*gtsmodel.Status, ap.Statusable, bool, error) { - return f.dereferencer.GetRemoteStatus(username, remoteStatusID, refresh) +func (f *federator) GetRemoteStatus(ctx context.Context, username string, remoteStatusID *url.URL, refresh bool) (*gtsmodel.Status, ap.Statusable, bool, error) { + return f.dereferencer.GetRemoteStatus(ctx, username, remoteStatusID, refresh) } -func (f *federator) EnrichRemoteStatus(username string, status *gtsmodel.Status) (*gtsmodel.Status, error) { - return f.dereferencer.EnrichRemoteStatus(username, status) +func (f *federator) EnrichRemoteStatus(ctx context.Context, username string, status *gtsmodel.Status) (*gtsmodel.Status, error) { + return f.dereferencer.EnrichRemoteStatus(ctx, username, status) } -func (f *federator) DereferenceRemoteThread(username string, statusIRI *url.URL) error { - return f.dereferencer.DereferenceThread(username, statusIRI) +func (f *federator) DereferenceRemoteThread(ctx context.Context, username string, statusIRI *url.URL) error { + return f.dereferencer.DereferenceThread(ctx, username, statusIRI) } -func (f *federator) GetRemoteInstance(username string, remoteInstanceURI *url.URL) (*gtsmodel.Instance, error) { - return f.dereferencer.GetRemoteInstance(username, remoteInstanceURI) +func (f *federator) GetRemoteInstance(ctx context.Context, username string, remoteInstanceURI *url.URL) (*gtsmodel.Instance, error) { + return f.dereferencer.GetRemoteInstance(ctx, username, remoteInstanceURI) } -func (f *federator) DereferenceAnnounce(announce *gtsmodel.Status, requestingUsername string) error { - return f.dereferencer.DereferenceAnnounce(announce, requestingUsername) +func (f *federator) DereferenceAnnounce(ctx context.Context, announce *gtsmodel.Status, requestingUsername string) error { + return f.dereferencer.DereferenceAnnounce(ctx, announce, requestingUsername) } diff --git a/internal/federation/dereferencing/account.go b/internal/federation/dereferencing/account.go index ba6766061..2eee0645d 100644 --- a/internal/federation/dereferencing/account.go +++ b/internal/federation/dereferencing/account.go @@ -24,6 +24,7 @@ import ( "errors" "fmt" "net/url" + "strings" "github.com/go-fed/activity/streams" "github.com/go-fed/activity/streams/vocab" @@ -34,18 +35,33 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/transport" ) +func instanceAccount(account *gtsmodel.Account) bool { + return strings.EqualFold(account.Username, account.Domain) || + account.FollowersURI == "" || + account.FollowingURI == "" || + (account.Username == "internal.fetch" && strings.Contains(account.Note, "internal service actor")) +} + // EnrichRemoteAccount takes an account that's already been inserted into the database in a minimal form, // and populates it with additional fields, media, etc. // // EnrichRemoteAccount is mostly useful for calling after an account has been initially created by // the federatingDB's Create function, or during the federated authorization flow. -func (d *deref) EnrichRemoteAccount(username string, account *gtsmodel.Account) (*gtsmodel.Account, error) { - if err := d.PopulateAccountFields(account, username, false); err != nil { +func (d *deref) EnrichRemoteAccount(ctx context.Context, username string, account *gtsmodel.Account) (*gtsmodel.Account, error) { + + // if we're dealing with an instance account, we don't need to update anything + if instanceAccount(account) { + return account, nil + } + + if err := d.PopulateAccountFields(ctx, account, username, false); err != nil { return nil, err } - if err := d.db.UpdateByID(account.ID, account); err != nil { - return nil, fmt.Errorf("EnrichRemoteAccount: error updating account: %s", err) + var err error + account, err = d.db.UpdateAccount(ctx, account) + if err != nil { + d.log.Errorf("EnrichRemoteAccount: error updating account: %s", err) } return account, nil @@ -60,27 +76,27 @@ func (d *deref) EnrichRemoteAccount(username string, account *gtsmodel.Account) // the remote instance again. // // SIDE EFFECTS: remote account will be stored in the database, or updated if it already exists (and refresh is true). -func (d *deref) GetRemoteAccount(username string, remoteAccountID *url.URL, refresh bool) (*gtsmodel.Account, bool, error) { +func (d *deref) GetRemoteAccount(ctx context.Context, username string, remoteAccountID *url.URL, refresh bool) (*gtsmodel.Account, bool, error) { new := true // check if we already have the account in our db - maybeAccount, err := d.db.GetAccountByURI(remoteAccountID.String()) + maybeAccount, err := d.db.GetAccountByURI(ctx, remoteAccountID.String()) if err == nil { // we've seen this account before so it's not new new = false if !refresh { // we're not being asked to refresh, but just in case we don't have the avatar/header cached yet.... - maybeAccount, err = d.EnrichRemoteAccount(username, maybeAccount) + maybeAccount, err = d.EnrichRemoteAccount(ctx, username, maybeAccount) return maybeAccount, new, err } } - accountable, err := d.dereferenceAccountable(username, remoteAccountID) + accountable, err := d.dereferenceAccountable(ctx, username, remoteAccountID) if err != nil { return nil, new, fmt.Errorf("FullyDereferenceAccount: error dereferencing accountable: %s", err) } - gtsAccount, err := d.typeConverter.ASRepresentationToAccount(accountable, refresh) + gtsAccount, err := d.typeConverter.ASRepresentationToAccount(ctx, accountable, refresh) if err != nil { return nil, new, fmt.Errorf("FullyDereferenceAccount: error converting accountable to account: %s", err) } @@ -93,23 +109,24 @@ func (d *deref) GetRemoteAccount(username string, remoteAccountID *url.URL, refr } gtsAccount.ID = ulid - if err := d.PopulateAccountFields(gtsAccount, username, refresh); err != nil { + if err := d.PopulateAccountFields(ctx, gtsAccount, username, refresh); err != nil { return nil, new, fmt.Errorf("FullyDereferenceAccount: error populating further account fields: %s", err) } - if err := d.db.Put(gtsAccount); err != nil { + if err := d.db.Put(ctx, gtsAccount); err != nil { return nil, new, fmt.Errorf("FullyDereferenceAccount: error putting new account: %s", err) } } else { // take the id we already have and do an update gtsAccount.ID = maybeAccount.ID - if err := d.PopulateAccountFields(gtsAccount, username, refresh); err != nil { + if err := d.PopulateAccountFields(ctx, gtsAccount, username, refresh); err != nil { return nil, new, fmt.Errorf("FullyDereferenceAccount: error populating further account fields: %s", err) } - if err := d.db.UpdateByID(gtsAccount.ID, gtsAccount); err != nil { - return nil, new, fmt.Errorf("FullyDereferenceAccount: error updating existing account: %s", err) + gtsAccount, err = d.db.UpdateAccount(ctx, gtsAccount) + if err != nil { + return nil, false, fmt.Errorf("EnrichRemoteAccount: error updating account: %s", err) } } @@ -120,15 +137,15 @@ func (d *deref) GetRemoteAccount(username string, remoteAccountID *url.URL, refr // it finds as something that an account model can be constructed out of. // // Will work for Person, Application, or Service models. -func (d *deref) dereferenceAccountable(username string, remoteAccountID *url.URL) (ap.Accountable, error) { +func (d *deref) dereferenceAccountable(ctx context.Context, username string, remoteAccountID *url.URL) (ap.Accountable, error) { d.startHandshake(username, remoteAccountID) defer d.stopHandshake(username, remoteAccountID) - if blocked, err := d.blockedDomain(remoteAccountID.Host); blocked || err != nil { + if blocked, err := d.db.IsDomainBlocked(ctx, remoteAccountID.Host); blocked || err != nil { return nil, fmt.Errorf("DereferenceAccountable: domain %s is blocked", remoteAccountID.Host) } - transport, err := d.transportController.NewTransportForUsername(username) + transport, err := d.transportController.NewTransportForUsername(ctx, username) if err != nil { return nil, fmt.Errorf("DereferenceAccountable: transport err: %s", err) } @@ -174,7 +191,7 @@ func (d *deref) dereferenceAccountable(username string, remoteAccountID *url.URL // PopulateAccountFields populates any fields on the given account that weren't populated by the initial // dereferencing. This includes things like header and avatar etc. -func (d *deref) PopulateAccountFields(account *gtsmodel.Account, requestingUsername string, refresh bool) error { +func (d *deref) PopulateAccountFields(ctx context.Context, account *gtsmodel.Account, requestingUsername string, refresh bool) error { l := d.log.WithFields(logrus.Fields{ "func": "PopulateAccountFields", "requestingUsername": requestingUsername, @@ -184,17 +201,17 @@ func (d *deref) PopulateAccountFields(account *gtsmodel.Account, requestingUsern if err != nil { return fmt.Errorf("PopulateAccountFields: couldn't parse account URI %s: %s", account.URI, err) } - if blocked, err := d.blockedDomain(accountURI.Host); blocked || err != nil { + if blocked, err := d.db.IsDomainBlocked(ctx, accountURI.Host); blocked || err != nil { return fmt.Errorf("PopulateAccountFields: domain %s is blocked", accountURI.Host) } - t, err := d.transportController.NewTransportForUsername(requestingUsername) + t, err := d.transportController.NewTransportForUsername(ctx, requestingUsername) if err != nil { return fmt.Errorf("PopulateAccountFields: error getting transport for user: %s", err) } // fetch the header and avatar - if err := d.fetchHeaderAndAviForAccount(account, t, refresh); err != nil { + if err := d.fetchHeaderAndAviForAccount(ctx, account, t, refresh); err != nil { // if this doesn't work, just skip it -- we can do it later l.Debugf("error fetching header/avi for account: %s", err) } @@ -208,17 +225,17 @@ func (d *deref) PopulateAccountFields(account *gtsmodel.Account, requestingUsern // targetAccount's AvatarMediaAttachmentID and HeaderMediaAttachmentID will be updated as necessary. // // SIDE EFFECTS: remote header and avatar will be stored in local storage. -func (d *deref) fetchHeaderAndAviForAccount(targetAccount *gtsmodel.Account, t transport.Transport, refresh bool) error { +func (d *deref) fetchHeaderAndAviForAccount(ctx context.Context, targetAccount *gtsmodel.Account, t transport.Transport, refresh bool) error { accountURI, err := url.Parse(targetAccount.URI) if err != nil { return fmt.Errorf("fetchHeaderAndAviForAccount: couldn't parse account URI %s: %s", targetAccount.URI, err) } - if blocked, err := d.blockedDomain(accountURI.Host); blocked || err != nil { + if blocked, err := d.db.IsDomainBlocked(ctx, accountURI.Host); blocked || err != nil { return fmt.Errorf("fetchHeaderAndAviForAccount: domain %s is blocked", accountURI.Host) } if targetAccount.AvatarRemoteURL != "" && (targetAccount.AvatarMediaAttachmentID == "" || refresh) { - a, err := d.mediaHandler.ProcessRemoteHeaderOrAvatar(t, >smodel.MediaAttachment{ + a, err := d.mediaHandler.ProcessRemoteHeaderOrAvatar(ctx, t, >smodel.MediaAttachment{ RemoteURL: targetAccount.AvatarRemoteURL, Avatar: true, }, targetAccount.ID) @@ -229,7 +246,7 @@ func (d *deref) fetchHeaderAndAviForAccount(targetAccount *gtsmodel.Account, t t } if targetAccount.HeaderRemoteURL != "" && (targetAccount.HeaderMediaAttachmentID == "" || refresh) { - a, err := d.mediaHandler.ProcessRemoteHeaderOrAvatar(t, >smodel.MediaAttachment{ + a, err := d.mediaHandler.ProcessRemoteHeaderOrAvatar(ctx, t, >smodel.MediaAttachment{ RemoteURL: targetAccount.HeaderRemoteURL, Header: true, }, targetAccount.ID) diff --git a/internal/federation/dereferencing/announce.go b/internal/federation/dereferencing/announce.go index 6773db425..33af74ebe 100644 --- a/internal/federation/dereferencing/announce.go +++ b/internal/federation/dereferencing/announce.go @@ -19,6 +19,7 @@ package dereferencing import ( + "context" "errors" "fmt" "net/url" @@ -26,7 +27,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) -func (d *deref) DereferenceAnnounce(announce *gtsmodel.Status, requestingUsername string) error { +func (d *deref) DereferenceAnnounce(ctx context.Context, announce *gtsmodel.Status, requestingUsername string) error { if announce.BoostOf == nil || announce.BoostOf.URI == "" { // we can't do anything unfortunately return errors.New("DereferenceAnnounce: no URI to dereference") @@ -36,16 +37,16 @@ func (d *deref) DereferenceAnnounce(announce *gtsmodel.Status, requestingUsernam if err != nil { return fmt.Errorf("DereferenceAnnounce: couldn't parse boosted status URI %s: %s", announce.BoostOf.URI, err) } - if blocked, err := d.blockedDomain(boostedStatusURI.Host); blocked || err != nil { + if blocked, err := d.db.IsDomainBlocked(ctx, boostedStatusURI.Host); blocked || err != nil { return fmt.Errorf("DereferenceAnnounce: domain %s is blocked", boostedStatusURI.Host) } // dereference statuses in the thread of the boosted status - if err := d.DereferenceThread(requestingUsername, boostedStatusURI); err != nil { + if err := d.DereferenceThread(ctx, requestingUsername, boostedStatusURI); err != nil { return fmt.Errorf("DereferenceAnnounce: error dereferencing thread of boosted status: %s", err) } - boostedStatus, _, _, err := d.GetRemoteStatus(requestingUsername, boostedStatusURI, false) + boostedStatus, _, _, err := d.GetRemoteStatus(ctx, requestingUsername, boostedStatusURI, false) if err != nil { return fmt.Errorf("DereferenceAnnounce: error dereferencing remote status with id %s: %s", announce.BoostOf.URI, err) } diff --git a/internal/federation/dereferencing/collectionpage.go b/internal/federation/dereferencing/collectionpage.go index 5feadc1ad..6f0beeaf6 100644 --- a/internal/federation/dereferencing/collectionpage.go +++ b/internal/federation/dereferencing/collectionpage.go @@ -32,12 +32,12 @@ import ( ) // DereferenceCollectionPage returns the activitystreams CollectionPage at the specified IRI, or an error if something goes wrong. -func (d *deref) DereferenceCollectionPage(username string, pageIRI *url.URL) (ap.CollectionPageable, error) { - if blocked, err := d.blockedDomain(pageIRI.Host); blocked || err != nil { +func (d *deref) DereferenceCollectionPage(ctx context.Context, username string, pageIRI *url.URL) (ap.CollectionPageable, error) { + if blocked, err := d.db.IsDomainBlocked(ctx, pageIRI.Host); blocked || err != nil { return nil, fmt.Errorf("DereferenceCollectionPage: domain %s is blocked", pageIRI.Host) } - transport, err := d.transportController.NewTransportForUsername(username) + transport, err := d.transportController.NewTransportForUsername(ctx, username) if err != nil { return nil, fmt.Errorf("DereferenceCollectionPage: error creating transport: %s", err) } diff --git a/internal/federation/dereferencing/dereferencer.go b/internal/federation/dereferencing/dereferencer.go index 03b90569a..71625ed88 100644 --- a/internal/federation/dereferencing/dereferencer.go +++ b/internal/federation/dereferencing/dereferencer.go @@ -19,6 +19,7 @@ package dereferencing import ( + "context" "net/url" "sync" @@ -34,18 +35,18 @@ import ( // Dereferencer wraps logic and functionality for doing dereferencing of remote accounts, statuses, etc, from federated instances. type Dereferencer interface { - GetRemoteAccount(username string, remoteAccountID *url.URL, refresh bool) (*gtsmodel.Account, bool, error) - EnrichRemoteAccount(username string, account *gtsmodel.Account) (*gtsmodel.Account, error) + GetRemoteAccount(ctx context.Context, username string, remoteAccountID *url.URL, refresh bool) (*gtsmodel.Account, bool, error) + EnrichRemoteAccount(ctx context.Context, username string, account *gtsmodel.Account) (*gtsmodel.Account, error) - GetRemoteStatus(username string, remoteStatusID *url.URL, refresh bool) (*gtsmodel.Status, ap.Statusable, bool, error) - EnrichRemoteStatus(username string, status *gtsmodel.Status) (*gtsmodel.Status, error) + GetRemoteStatus(ctx context.Context, username string, remoteStatusID *url.URL, refresh bool) (*gtsmodel.Status, ap.Statusable, bool, error) + EnrichRemoteStatus(ctx context.Context, username string, status *gtsmodel.Status) (*gtsmodel.Status, error) - GetRemoteInstance(username string, remoteInstanceURI *url.URL) (*gtsmodel.Instance, error) + GetRemoteInstance(ctx context.Context, username string, remoteInstanceURI *url.URL) (*gtsmodel.Instance, error) - DereferenceAnnounce(announce *gtsmodel.Status, requestingUsername string) error - DereferenceThread(username string, statusIRI *url.URL) error + DereferenceAnnounce(ctx context.Context, announce *gtsmodel.Status, requestingUsername string) error + DereferenceThread(ctx context.Context, username string, statusIRI *url.URL) error - Handshaking(username string, remoteAccountID *url.URL) bool + Handshaking(ctx context.Context, username string, remoteAccountID *url.URL) bool } type deref struct { diff --git a/internal/federation/dereferencing/handshake.go b/internal/federation/dereferencing/handshake.go index cda8eafd0..17003be84 100644 --- a/internal/federation/dereferencing/handshake.go +++ b/internal/federation/dereferencing/handshake.go @@ -18,9 +18,12 @@ package dereferencing -import "net/url" +import ( + "context" + "net/url" +) -func (d *deref) Handshaking(username string, remoteAccountID *url.URL) bool { +func (d *deref) Handshaking(ctx context.Context, username string, remoteAccountID *url.URL) bool { d.handshakeSync.Lock() defer d.handshakeSync.Unlock() diff --git a/internal/federation/dereferencing/instance.go b/internal/federation/dereferencing/instance.go index 80f626662..ec3c3f13d 100644 --- a/internal/federation/dereferencing/instance.go +++ b/internal/federation/dereferencing/instance.go @@ -26,12 +26,12 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) -func (d *deref) GetRemoteInstance(username string, remoteInstanceURI *url.URL) (*gtsmodel.Instance, error) { - if blocked, err := d.blockedDomain(remoteInstanceURI.Host); blocked || err != nil { +func (d *deref) GetRemoteInstance(ctx context.Context, username string, remoteInstanceURI *url.URL) (*gtsmodel.Instance, error) { + if blocked, err := d.db.IsDomainBlocked(ctx, remoteInstanceURI.Host); blocked || err != nil { return nil, fmt.Errorf("GetRemoteInstance: domain %s is blocked", remoteInstanceURI.Host) } - transport, err := d.transportController.NewTransportForUsername(username) + transport, err := d.transportController.NewTransportForUsername(ctx, username) if err != nil { return nil, fmt.Errorf("transport err: %s", err) } diff --git a/internal/federation/dereferencing/status.go b/internal/federation/dereferencing/status.go index 68693c021..93ead6523 100644 --- a/internal/federation/dereferencing/status.go +++ b/internal/federation/dereferencing/status.go @@ -39,12 +39,12 @@ import ( // // EnrichRemoteStatus is mostly useful for calling after a status has been initially created by // the federatingDB's Create function, but additional dereferencing is needed on it. -func (d *deref) EnrichRemoteStatus(username string, status *gtsmodel.Status) (*gtsmodel.Status, error) { - if err := d.populateStatusFields(status, username); err != nil { +func (d *deref) EnrichRemoteStatus(ctx context.Context, username string, status *gtsmodel.Status) (*gtsmodel.Status, error) { + if err := d.populateStatusFields(ctx, status, username); err != nil { return nil, err } - if err := d.db.UpdateByID(status.ID, status); err != nil { + if err := d.db.UpdateByID(ctx, status.ID, status); err != nil { return nil, fmt.Errorf("EnrichRemoteStatus: error updating status: %s", err) } @@ -62,11 +62,11 @@ func (d *deref) EnrichRemoteStatus(username string, status *gtsmodel.Status) (*g // If a dereference was performed, then the function also returns the ap.Statusable representation for further processing. // // SIDE EFFECTS: remote status will be stored in the database, and the remote status owner will also be stored. -func (d *deref) GetRemoteStatus(username string, remoteStatusID *url.URL, refresh bool) (*gtsmodel.Status, ap.Statusable, bool, error) { +func (d *deref) GetRemoteStatus(ctx context.Context, username string, remoteStatusID *url.URL, refresh bool) (*gtsmodel.Status, ap.Statusable, bool, error) { new := true // check if we already have the status in our db - maybeStatus, err := d.db.GetStatusByURI(remoteStatusID.String()) + maybeStatus, err := d.db.GetStatusByURI(ctx, remoteStatusID.String()) if err == nil { // we've seen this status before so it's not new new = false @@ -77,7 +77,7 @@ func (d *deref) GetRemoteStatus(username string, remoteStatusID *url.URL, refres } } - statusable, err := d.dereferenceStatusable(username, remoteStatusID) + statusable, err := d.dereferenceStatusable(ctx, username, remoteStatusID) if err != nil { return nil, statusable, new, fmt.Errorf("GetRemoteStatus: error dereferencing statusable: %s", err) } @@ -88,12 +88,12 @@ func (d *deref) GetRemoteStatus(username string, remoteStatusID *url.URL, refres } // do this so we know we have the remote account of the status in the db - _, _, err = d.GetRemoteAccount(username, accountURI, false) + _, _, err = d.GetRemoteAccount(ctx, username, accountURI, false) if err != nil { return nil, statusable, new, fmt.Errorf("GetRemoteStatus: couldn't derive status author: %s", err) } - gtsStatus, err := d.typeConverter.ASStatusToStatus(statusable) + gtsStatus, err := d.typeConverter.ASStatusToStatus(ctx, statusable) if err != nil { return nil, statusable, new, fmt.Errorf("GetRemoteStatus: error converting statusable to status: %s", err) } @@ -105,21 +105,21 @@ func (d *deref) GetRemoteStatus(username string, remoteStatusID *url.URL, refres } gtsStatus.ID = ulid - if err := d.populateStatusFields(gtsStatus, username); err != nil { + if err := d.populateStatusFields(ctx, gtsStatus, username); err != nil { return nil, statusable, new, fmt.Errorf("GetRemoteStatus: error populating status fields: %s", err) } - if err := d.db.PutStatus(gtsStatus); err != nil { + if err := d.db.PutStatus(ctx, gtsStatus); err != nil { return nil, statusable, new, fmt.Errorf("GetRemoteStatus: error putting new status: %s", err) } } else { gtsStatus.ID = maybeStatus.ID - if err := d.populateStatusFields(gtsStatus, username); err != nil { + if err := d.populateStatusFields(ctx, gtsStatus, username); err != nil { return nil, statusable, new, fmt.Errorf("GetRemoteStatus: error populating status fields: %s", err) } - if err := d.db.UpdateByID(gtsStatus.ID, gtsStatus); err != nil { + if err := d.db.UpdateByID(ctx, gtsStatus.ID, gtsStatus); err != nil { return nil, statusable, new, fmt.Errorf("GetRemoteStatus: error updating status: %s", err) } } @@ -127,12 +127,12 @@ func (d *deref) GetRemoteStatus(username string, remoteStatusID *url.URL, refres return gtsStatus, statusable, new, nil } -func (d *deref) dereferenceStatusable(username string, remoteStatusID *url.URL) (ap.Statusable, error) { - if blocked, err := d.blockedDomain(remoteStatusID.Host); blocked || err != nil { +func (d *deref) dereferenceStatusable(ctx context.Context, username string, remoteStatusID *url.URL) (ap.Statusable, error) { + if blocked, err := d.db.IsDomainBlocked(ctx, remoteStatusID.Host); blocked || err != nil { return nil, fmt.Errorf("DereferenceStatusable: domain %s is blocked", remoteStatusID.Host) } - transport, err := d.transportController.NewTransportForUsername(username) + transport, err := d.transportController.NewTransportForUsername(ctx, username) if err != nil { return nil, fmt.Errorf("DereferenceStatusable: transport err: %s", err) } @@ -236,7 +236,7 @@ func (d *deref) dereferenceStatusable(username string, remoteStatusID *url.URL) // This function will deference all of the above, insert them in the database as necessary, // and attach them to the status. The status itself will not be added to the database yet, // that's up the caller to do. -func (d *deref) populateStatusFields(status *gtsmodel.Status, requestingUsername string) error { +func (d *deref) populateStatusFields(ctx context.Context, status *gtsmodel.Status, requestingUsername string) error { l := d.log.WithFields(logrus.Fields{ "func": "dereferenceStatusFields", "status": fmt.Sprintf("%+v", status), @@ -248,12 +248,12 @@ func (d *deref) populateStatusFields(status *gtsmodel.Status, requestingUsername if err != nil { return fmt.Errorf("DereferenceStatusFields: couldn't parse status URI %s: %s", status.URI, err) } - if blocked, err := d.blockedDomain(statusURI.Host); blocked || err != nil { + if blocked, err := d.db.IsDomainBlocked(ctx, statusURI.Host); blocked || err != nil { return fmt.Errorf("DereferenceStatusFields: domain %s is blocked", statusURI.Host) } // we can continue -- create a new transport here because we'll probably need it - t, err := d.transportController.NewTransportForUsername(requestingUsername) + t, err := d.transportController.NewTransportForUsername(ctx, requestingUsername) if err != nil { return fmt.Errorf("error creating transport: %s", err) } @@ -281,7 +281,7 @@ func (d *deref) populateStatusFields(status *gtsmodel.Status, requestingUsername // it might have been processed elsewhere so check first if it's already in the database or not maybeAttachment := >smodel.MediaAttachment{} - err := d.db.GetWhere([]db.Where{{Key: "remote_url", Value: a.RemoteURL}}, maybeAttachment) + err := d.db.GetWhere(ctx, []db.Where{{Key: "remote_url", Value: a.RemoteURL}}, maybeAttachment) if err == nil { // we already have it in the db, dereferenced, no need to do it again l.Tracef("attachment already exists with id %s", maybeAttachment.ID) @@ -294,7 +294,7 @@ func (d *deref) populateStatusFields(status *gtsmodel.Status, requestingUsername } // it just doesn't exist yet so carry on l.Debug("attachment doesn't exist yet, calling ProcessRemoteAttachment", a) - deferencedAttachment, err := d.mediaHandler.ProcessRemoteAttachment(t, a, status.AccountID) + deferencedAttachment, err := d.mediaHandler.ProcessRemoteAttachment(ctx, t, a, status.AccountID) if err != nil { l.Errorf("error dereferencing status attachment: %s", err) continue @@ -302,7 +302,7 @@ func (d *deref) populateStatusFields(status *gtsmodel.Status, requestingUsername l.Debugf("dereferenced attachment: %+v", deferencedAttachment) deferencedAttachment.StatusID = status.ID deferencedAttachment.Description = a.Description - if err := d.db.Put(deferencedAttachment); err != nil { + if err := d.db.Put(ctx, deferencedAttachment); err != nil { return fmt.Errorf("error inserting dereferenced attachment with remote url %s: %s", a.RemoteURL, err) } attachmentIDs = append(attachmentIDs, deferencedAttachment.ID) @@ -338,9 +338,9 @@ func (d *deref) populateStatusFields(status *gtsmodel.Status, requestingUsername } var targetAccount *gtsmodel.Account - if a, err := d.db.GetAccountByURL(targetAccountURI.String()); err == nil { + if a, err := d.db.GetAccountByURL(ctx, targetAccountURI.String()); err == nil { targetAccount = a - } else if a, _, err := d.GetRemoteAccount(requestingUsername, targetAccountURI, false); err == nil { + } else if a, _, err := d.GetRemoteAccount(ctx, requestingUsername, targetAccountURI, false); err == nil { targetAccount = a } else { // we can't find the target account so bail @@ -369,7 +369,7 @@ func (d *deref) populateStatusFields(status *gtsmodel.Status, requestingUsername TargetAccountURL: targetAccount.URL, } - if err := d.db.Put(m); err != nil { + if err := d.db.Put(ctx, m); err != nil { return fmt.Errorf("error creating mention: %s", err) } mentionIDs = append(mentionIDs, m.ID) @@ -382,13 +382,13 @@ func (d *deref) populateStatusFields(status *gtsmodel.Status, requestingUsername if err != nil { return err } - if replyToStatus, err := d.db.GetStatusByURI(status.InReplyToURI); err == nil { + if replyToStatus, err := d.db.GetStatusByURI(ctx, status.InReplyToURI); err == nil { // we have the status status.InReplyToID = replyToStatus.ID status.InReplyTo = replyToStatus status.InReplyToAccountID = replyToStatus.AccountID status.InReplyToAccount = replyToStatus.Account - } else if replyToStatus, _, _, err := d.GetRemoteStatus(requestingUsername, statusURI, false); err == nil { + } else if replyToStatus, _, _, err := d.GetRemoteStatus(ctx, requestingUsername, statusURI, false); err == nil { // we got the status status.InReplyToID = replyToStatus.ID status.InReplyTo = replyToStatus diff --git a/internal/federation/dereferencing/thread.go b/internal/federation/dereferencing/thread.go index 2a407f923..f9dd9aa09 100644 --- a/internal/federation/dereferencing/thread.go +++ b/internal/federation/dereferencing/thread.go @@ -19,12 +19,12 @@ package dereferencing import ( + "context" "fmt" "net/url" "github.com/sirupsen/logrus" "github.com/superseriousbusiness/gotosocial/internal/ap" - "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/util" ) @@ -34,7 +34,7 @@ import ( // This process involves working up and down the chain of replies, and parsing through the collections of IDs // presented by remote instances as part of their replies collections, and will likely involve making several calls to // multiple different hosts. -func (d *deref) DereferenceThread(username string, statusIRI *url.URL) error { +func (d *deref) DereferenceThread(ctx context.Context, username string, statusIRI *url.URL) error { l := d.log.WithFields(logrus.Fields{ "func": "DereferenceThread", "username": username, @@ -49,18 +49,18 @@ func (d *deref) DereferenceThread(username string, statusIRI *url.URL) error { } // first make sure we have this status in our db - _, statusable, _, err := d.GetRemoteStatus(username, statusIRI, true) + _, statusable, _, err := d.GetRemoteStatus(ctx, username, statusIRI, true) if err != nil { return fmt.Errorf("DereferenceThread: error getting status with id %s: %s", statusIRI.String(), err) } // first iterate up through ancestors, dereferencing if necessary as we go - if err := d.iterateAncestors(username, *statusIRI); err != nil { + if err := d.iterateAncestors(ctx, username, *statusIRI); err != nil { return fmt.Errorf("error iterating ancestors of status %s: %s", statusIRI.String(), err) } // now iterate down through descendants, again dereferencing as we go - if err := d.iterateDescendants(username, *statusIRI, statusable); err != nil { + if err := d.iterateDescendants(ctx, username, *statusIRI, statusable); err != nil { return fmt.Errorf("error iterating descendants of status %s: %s", statusIRI.String(), err) } @@ -68,7 +68,7 @@ func (d *deref) DereferenceThread(username string, statusIRI *url.URL) error { } // iterateAncestors has the goal of reaching the oldest ancestor of a given status, and stashing all statuses along the way. -func (d *deref) iterateAncestors(username string, statusIRI url.URL) error { +func (d *deref) iterateAncestors(ctx context.Context, username string, statusIRI url.URL) error { l := d.log.WithFields(logrus.Fields{ "func": "iterateAncestors", "username": username, @@ -86,8 +86,8 @@ func (d *deref) iterateAncestors(username string, statusIRI url.URL) error { return err } - status := >smodel.Status{} - if err := d.db.GetByID(id, status); err != nil { + status, err := d.db.GetStatusByID(ctx, id) + if err != nil { return err } @@ -99,12 +99,12 @@ func (d *deref) iterateAncestors(username string, statusIRI url.URL) error { if err != nil { return err } - return d.iterateAncestors(username, *nextIRI) + return d.iterateAncestors(ctx, username, *nextIRI) } // If we reach here, we're looking at a remote status -- make sure we have it in our db by calling GetRemoteStatus // We call it with refresh to true because we want the statusable representation to parse inReplyTo from. - status, statusable, _, err := d.GetRemoteStatus(username, &statusIRI, true) + status, statusable, _, err := d.GetRemoteStatus(ctx, username, &statusIRI, true) if err != nil { l.Debugf("error getting remote status: %s", err) return nil @@ -117,22 +117,22 @@ func (d *deref) iterateAncestors(username string, statusIRI url.URL) error { } // get the ancestor status into our database if we don't have it yet - if _, _, _, err := d.GetRemoteStatus(username, inReplyTo, false); err != nil { + if _, _, _, err := d.GetRemoteStatus(ctx, username, inReplyTo, false); err != nil { l.Debugf("error getting remote status: %s", err) return nil } // now enrich the current status, since we should have the ancestor in the db - if _, err := d.EnrichRemoteStatus(username, status); err != nil { + if _, err := d.EnrichRemoteStatus(ctx, username, status); err != nil { l.Debugf("error enriching remote status: %s", err) return nil } // now move up to the next ancestor - return d.iterateAncestors(username, *inReplyTo) + return d.iterateAncestors(ctx, username, *inReplyTo) } -func (d *deref) iterateDescendants(username string, statusIRI url.URL, statusable ap.Statusable) error { +func (d *deref) iterateDescendants(ctx context.Context, username string, statusIRI url.URL, statusable ap.Statusable) error { l := d.log.WithFields(logrus.Fields{ "func": "iterateDescendants", "username": username, @@ -182,7 +182,7 @@ func (d *deref) iterateDescendants(username string, statusIRI url.URL, statusabl pageLoop: for { l.Debugf("dereferencing page %s", currentPageIRI) - nextPage, err := d.DereferenceCollectionPage(username, currentPageIRI) + nextPage, err := d.DereferenceCollectionPage(ctx, username, currentPageIRI) if err != nil { return nil } @@ -226,10 +226,10 @@ pageLoop: foundReplies = foundReplies + 1 // get the remote statusable and put it in the db - _, statusable, new, err := d.GetRemoteStatus(username, itemURI, false) + _, statusable, new, err := d.GetRemoteStatus(ctx, username, itemURI, false) if new && err == nil && statusable != nil { // now iterate descendants of *that* status - if err := d.iterateDescendants(username, *itemURI, statusable); err != nil { + if err := d.iterateDescendants(ctx, username, *itemURI, statusable); err != nil { continue } } diff --git a/internal/federation/federatingdb/accept.go b/internal/federation/federatingdb/accept.go index 91d9df86f..0b14e8a6a 100644 --- a/internal/federation/federatingdb/accept.go +++ b/internal/federation/federatingdb/accept.go @@ -86,7 +86,7 @@ func (f *federatingDB) Accept(ctx context.Context, accept vocab.ActivityStreamsA if util.IsFollowPath(acceptedObjectIRI) { // ACCEPT FOLLOW gtsFollowRequest := >smodel.FollowRequest{} - if err := f.db.GetWhere([]db.Where{{Key: "uri", Value: acceptedObjectIRI.String()}}, gtsFollowRequest); err != nil { + if err := f.db.GetWhere(ctx, []db.Where{{Key: "uri", Value: acceptedObjectIRI.String()}}, gtsFollowRequest); err != nil { return fmt.Errorf("ACCEPT: couldn't get follow request with id %s from the database: %s", acceptedObjectIRI.String(), err) } @@ -94,7 +94,7 @@ func (f *federatingDB) Accept(ctx context.Context, accept vocab.ActivityStreamsA if gtsFollowRequest.AccountID != targetAcct.ID { return errors.New("ACCEPT: follow object account and inbox account were not the same") } - follow, err := f.db.AcceptFollowRequest(gtsFollowRequest.AccountID, gtsFollowRequest.TargetAccountID) + follow, err := f.db.AcceptFollowRequest(ctx, gtsFollowRequest.AccountID, gtsFollowRequest.TargetAccountID) if err != nil { return err } @@ -123,7 +123,7 @@ func (f *federatingDB) Accept(ctx context.Context, accept vocab.ActivityStreamsA return errors.New("ACCEPT: couldn't parse follow into vocab.ActivityStreamsFollow") } // convert the follow to something we can understand - gtsFollow, err := f.typeConverter.ASFollowToFollow(asFollow) + gtsFollow, err := f.typeConverter.ASFollowToFollow(ctx, asFollow) if err != nil { return fmt.Errorf("ACCEPT: error converting asfollow to gtsfollow: %s", err) } @@ -131,7 +131,7 @@ func (f *federatingDB) Accept(ctx context.Context, accept vocab.ActivityStreamsA if gtsFollow.AccountID != targetAcct.ID { return errors.New("ACCEPT: follow object account and inbox account were not the same") } - follow, err := f.db.AcceptFollowRequest(gtsFollow.AccountID, gtsFollow.TargetAccountID) + follow, err := f.db.AcceptFollowRequest(ctx, gtsFollow.AccountID, gtsFollow.TargetAccountID) if err != nil { return err } diff --git a/internal/federation/federatingdb/announce.go b/internal/federation/federatingdb/announce.go index 981eaf1ef..5cd34285e 100644 --- a/internal/federation/federatingdb/announce.go +++ b/internal/federation/federatingdb/announce.go @@ -71,7 +71,7 @@ func (f *federatingDB) Announce(ctx context.Context, announce vocab.ActivityStre return nil } - boost, isNew, err := f.typeConverter.ASAnnounceToStatus(announce) + boost, isNew, err := f.typeConverter.ASAnnounceToStatus(ctx, announce) if err != nil { return fmt.Errorf("ANNOUNCE: error converting announce to boost: %s", err) } diff --git a/internal/federation/federatingdb/create.go b/internal/federation/federatingdb/create.go index fb4353cd4..8ea549c5a 100644 --- a/internal/federation/federatingdb/create.go +++ b/internal/federation/federatingdb/create.go @@ -100,7 +100,7 @@ func (f *federatingDB) Create(ctx context.Context, asType vocab.Type) error { case gtsmodel.ActivityStreamsNote: // CREATE A NOTE note := objectIter.GetActivityStreamsNote() - status, err := f.typeConverter.ASStatusToStatus(note) + status, err := f.typeConverter.ASStatusToStatus(ctx, note) if err != nil { return fmt.Errorf("CREATE: error converting note to status: %s", err) } @@ -112,7 +112,7 @@ func (f *federatingDB) Create(ctx context.Context, asType vocab.Type) error { } status.ID = statusID - if err := f.db.PutStatus(status); err != nil { + if err := f.db.PutStatus(ctx, status); err != nil { if err == db.ErrAlreadyExists { // the status already exists in the database, which means we've already handled everything else, // so we can just return nil here and be done with it. @@ -137,7 +137,7 @@ func (f *federatingDB) Create(ctx context.Context, asType vocab.Type) error { return errors.New("CREATE: could not convert type to follow") } - followRequest, err := f.typeConverter.ASFollowToFollowRequest(follow) + followRequest, err := f.typeConverter.ASFollowToFollowRequest(ctx, follow) if err != nil { return fmt.Errorf("CREATE: could not convert Follow to follow request: %s", err) } @@ -148,7 +148,7 @@ func (f *federatingDB) Create(ctx context.Context, asType vocab.Type) error { } followRequest.ID = newID - if err := f.db.Put(followRequest); err != nil { + if err := f.db.Put(ctx, followRequest); err != nil { return fmt.Errorf("CREATE: database error inserting follow request: %s", err) } @@ -165,7 +165,7 @@ func (f *federatingDB) Create(ctx context.Context, asType vocab.Type) error { return errors.New("CREATE: could not convert type to like") } - fave, err := f.typeConverter.ASLikeToFave(like) + fave, err := f.typeConverter.ASLikeToFave(ctx, like) if err != nil { return fmt.Errorf("CREATE: could not convert Like to fave: %s", err) } @@ -176,7 +176,7 @@ func (f *federatingDB) Create(ctx context.Context, asType vocab.Type) error { } fave.ID = newID - if err := f.db.Put(fave); err != nil { + if err := f.db.Put(ctx, fave); err != nil { return fmt.Errorf("CREATE: database error inserting fave: %s", err) } @@ -193,7 +193,7 @@ func (f *federatingDB) Create(ctx context.Context, asType vocab.Type) error { return errors.New("CREATE: could not convert type to block") } - block, err := f.typeConverter.ASBlockToBlock(blockable) + block, err := f.typeConverter.ASBlockToBlock(ctx, blockable) if err != nil { return fmt.Errorf("CREATE: could not convert Block to gts model block") } @@ -204,7 +204,7 @@ func (f *federatingDB) Create(ctx context.Context, asType vocab.Type) error { } block.ID = newID - if err := f.db.Put(block); err != nil { + if err := f.db.Put(ctx, block); err != nil { return fmt.Errorf("CREATE: database error inserting block: %s", err) } diff --git a/internal/federation/federatingdb/delete.go b/internal/federation/federatingdb/delete.go index ee9310789..11b818168 100644 --- a/internal/federation/federatingdb/delete.go +++ b/internal/federation/federatingdb/delete.go @@ -69,11 +69,11 @@ func (f *federatingDB) Delete(ctx context.Context, id *url.URL) error { // in a delete we only get the URI, we can't know if we have a status or a profile or something else, // so we have to try a few different things... - s, err := f.db.GetStatusByURI(id.String()) + s, err := f.db.GetStatusByURI(ctx, id.String()) if err == nil { // it's a status l.Debugf("uri is for status with id: %s", s.ID) - if err := f.db.DeleteByID(s.ID, >smodel.Status{}); err != nil { + if err := f.db.DeleteByID(ctx, s.ID, >smodel.Status{}); err != nil { return fmt.Errorf("DELETE: err deleting status: %s", err) } fromFederatorChan <- gtsmodel.FromFederator{ @@ -84,11 +84,11 @@ func (f *federatingDB) Delete(ctx context.Context, id *url.URL) error { } } - a, err := f.db.GetAccountByURI(id.String()) + a, err := f.db.GetAccountByURI(ctx, id.String()) if err == nil { // it's an account - l.Debugf("uri is for an account with id: %s", s.ID) - if err := f.db.DeleteByID(a.ID, >smodel.Account{}); err != nil { + l.Debugf("uri is for an account with id: %s", a.ID) + if err := f.db.DeleteByID(ctx, a.ID, >smodel.Account{}); err != nil { return fmt.Errorf("DELETE: err deleting account: %s", err) } fromFederatorChan <- gtsmodel.FromFederator{ diff --git a/internal/federation/federatingdb/followers.go b/internal/federation/federatingdb/followers.go index 241362fc1..c7f636a12 100644 --- a/internal/federation/federatingdb/followers.go +++ b/internal/federation/federatingdb/followers.go @@ -19,7 +19,7 @@ import ( // If modified, the library will then call Update. // // The library makes this call only after acquiring a lock first. -func (f *federatingDB) Followers(c context.Context, actorIRI *url.URL) (followers vocab.ActivityStreamsCollection, err error) { +func (f *federatingDB) Followers(ctx context.Context, actorIRI *url.URL) (followers vocab.ActivityStreamsCollection, err error) { l := f.log.WithFields( logrus.Fields{ "func": "Followers", @@ -31,19 +31,19 @@ func (f *federatingDB) Followers(c context.Context, actorIRI *url.URL) (follower acct := >smodel.Account{} if util.IsUserPath(actorIRI) { - acct, err = f.db.GetAccountByURI(actorIRI.String()) + acct, err = f.db.GetAccountByURI(ctx, actorIRI.String()) if err != nil { return nil, fmt.Errorf("FOLLOWERS: db error getting account with uri %s: %s", actorIRI.String(), err) } } else if util.IsFollowersPath(actorIRI) { - if err := f.db.GetWhere([]db.Where{{Key: "followers_uri", Value: actorIRI.String()}}, acct); err != nil { + if err := f.db.GetWhere(ctx, []db.Where{{Key: "followers_uri", Value: actorIRI.String()}}, acct); err != nil { return nil, fmt.Errorf("FOLLOWERS: db error getting account with followers uri %s: %s", actorIRI.String(), err) } } else { return nil, fmt.Errorf("FOLLOWERS: could not parse actor IRI %s as users or followers path", actorIRI.String()) } - acctFollowers, err := f.db.GetAccountFollowedBy(acct.ID, false) + acctFollowers, err := f.db.GetAccountFollowedBy(ctx, acct.ID, false) if err != nil { return nil, fmt.Errorf("FOLLOWERS: db error getting followers for account id %s: %s", acct.ID, err) } @@ -51,13 +51,17 @@ func (f *federatingDB) Followers(c context.Context, actorIRI *url.URL) (follower followers = streams.NewActivityStreamsCollection() items := streams.NewActivityStreamsItemsProperty() for _, follow := range acctFollowers { - gtsFollower := >smodel.Account{} - if err := f.db.GetByID(follow.AccountID, gtsFollower); err != nil { - return nil, fmt.Errorf("FOLLOWERS: db error getting account id %s: %s", follow.AccountID, err) + if follow.Account == nil { + followAccount, err := f.db.GetAccountByID(ctx, follow.AccountID) + if err != nil { + return nil, fmt.Errorf("FOLLOWERS: db error getting account id %s: %s", follow.AccountID, err) + } + follow.Account = followAccount } - uri, err := url.Parse(gtsFollower.URI) + + uri, err := url.Parse(follow.Account.URI) if err != nil { - return nil, fmt.Errorf("FOLLOWERS: error parsing %s as url: %s", gtsFollower.URI, err) + return nil, fmt.Errorf("FOLLOWERS: error parsing %s as url: %s", follow.Account.URI, err) } items.AppendIRI(uri) } diff --git a/internal/federation/federatingdb/following.go b/internal/federation/federatingdb/following.go index 45785c671..9d5c0693c 100644 --- a/internal/federation/federatingdb/following.go +++ b/internal/federation/federatingdb/following.go @@ -18,7 +18,7 @@ import ( // If modified, the library will then call Update. // // The library makes this call only after acquiring a lock first. -func (f *federatingDB) Following(c context.Context, actorIRI *url.URL) (following vocab.ActivityStreamsCollection, err error) { +func (f *federatingDB) Following(ctx context.Context, actorIRI *url.URL) (following vocab.ActivityStreamsCollection, err error) { l := f.log.WithFields( logrus.Fields{ "func": "Following", @@ -34,7 +34,7 @@ func (f *federatingDB) Following(c context.Context, actorIRI *url.URL) (followin return nil, fmt.Errorf("FOLLOWING: error parsing user path: %s", err) } - a, err := f.db.GetLocalAccountByUsername(username) + a, err := f.db.GetLocalAccountByUsername(ctx, username) if err != nil { return nil, fmt.Errorf("FOLLOWING: db error getting account with uri %s: %s", actorIRI.String(), err) } @@ -46,7 +46,7 @@ func (f *federatingDB) Following(c context.Context, actorIRI *url.URL) (followin return nil, fmt.Errorf("FOLLOWING: error parsing following path: %s", err) } - a, err := f.db.GetLocalAccountByUsername(username) + a, err := f.db.GetLocalAccountByUsername(ctx, username) if err != nil { return nil, fmt.Errorf("FOLLOWING: db error getting account with following uri %s: %s", actorIRI.String(), err) } @@ -56,7 +56,7 @@ func (f *federatingDB) Following(c context.Context, actorIRI *url.URL) (followin return nil, fmt.Errorf("FOLLOWING: could not parse actor IRI %s as users or following path", actorIRI.String()) } - acctFollowing, err := f.db.GetAccountFollows(acct.ID) + acctFollowing, err := f.db.GetAccountFollows(ctx, acct.ID) if err != nil { return nil, fmt.Errorf("FOLLOWING: db error getting following for account id %s: %s", acct.ID, err) } @@ -64,13 +64,17 @@ func (f *federatingDB) Following(c context.Context, actorIRI *url.URL) (followin following = streams.NewActivityStreamsCollection() items := streams.NewActivityStreamsItemsProperty() for _, follow := range acctFollowing { - gtsFollowing := >smodel.Account{} - if err := f.db.GetByID(follow.AccountID, gtsFollowing); err != nil { - return nil, fmt.Errorf("FOLLOWING: db error getting account id %s: %s", follow.AccountID, err) + if follow.Account == nil { + followAccount, err := f.db.GetAccountByID(ctx, follow.AccountID) + if err != nil { + return nil, fmt.Errorf("FOLLOWING: db error getting account id %s: %s", follow.AccountID, err) + } + follow.Account = followAccount } - uri, err := url.Parse(gtsFollowing.URI) + + uri, err := url.Parse(follow.Account.URI) if err != nil { - return nil, fmt.Errorf("FOLLOWING: error parsing %s as url: %s", gtsFollowing.URI, err) + return nil, fmt.Errorf("FOLLOWING: error parsing %s as url: %s", follow.Account.URI, err) } items.AppendIRI(uri) } diff --git a/internal/federation/federatingdb/get.go b/internal/federation/federatingdb/get.go index 0265080f9..cc04dd851 100644 --- a/internal/federation/federatingdb/get.go +++ b/internal/federation/federatingdb/get.go @@ -33,7 +33,7 @@ import ( // Get returns the database entry for the specified id. // // The library makes this call only after acquiring a lock first. -func (f *federatingDB) Get(c context.Context, id *url.URL) (value vocab.Type, err error) { +func (f *federatingDB) Get(ctx context.Context, id *url.URL) (value vocab.Type, err error) { l := f.log.WithFields( logrus.Fields{ "func": "Get", @@ -43,17 +43,17 @@ func (f *federatingDB) Get(c context.Context, id *url.URL) (value vocab.Type, er l.Debug("entering GET function") if util.IsUserPath(id) { - acct, err := f.db.GetAccountByURI(id.String()) + acct, err := f.db.GetAccountByURI(ctx, id.String()) if err != nil { return nil, err } l.Debug("is user path! returning account") - return f.typeConverter.AccountToAS(acct) + return f.typeConverter.AccountToAS(ctx, acct) } if util.IsFollowersPath(id) { acct := >smodel.Account{} - if err := f.db.GetWhere([]db.Where{{Key: "followers_uri", Value: id.String()}}, acct); err != nil { + if err := f.db.GetWhere(ctx, []db.Where{{Key: "followers_uri", Value: id.String()}}, acct); err != nil { return nil, err } @@ -62,12 +62,12 @@ func (f *federatingDB) Get(c context.Context, id *url.URL) (value vocab.Type, er return nil, err } - return f.Followers(c, followersURI) + return f.Followers(ctx, followersURI) } if util.IsFollowingPath(id) { acct := >smodel.Account{} - if err := f.db.GetWhere([]db.Where{{Key: "following_uri", Value: id.String()}}, acct); err != nil { + if err := f.db.GetWhere(ctx, []db.Where{{Key: "following_uri", Value: id.String()}}, acct); err != nil { return nil, err } @@ -76,7 +76,7 @@ func (f *federatingDB) Get(c context.Context, id *url.URL) (value vocab.Type, er return nil, err } - return f.Following(c, followingURI) + return f.Following(ctx, followingURI) } return nil, errors.New("could not get") diff --git a/internal/federation/federatingdb/outbox.go b/internal/federation/federatingdb/outbox.go index 849014432..81b90aae2 100644 --- a/internal/federation/federatingdb/outbox.go +++ b/internal/federation/federatingdb/outbox.go @@ -35,7 +35,7 @@ import ( // at the specified IRI, for prepending new items. // // The library makes this call only after acquiring a lock first. -func (f *federatingDB) GetOutbox(c context.Context, outboxIRI *url.URL) (inbox vocab.ActivityStreamsOrderedCollectionPage, err error) { +func (f *federatingDB) GetOutbox(ctx context.Context, outboxIRI *url.URL) (inbox vocab.ActivityStreamsOrderedCollectionPage, err error) { l := f.log.WithFields( logrus.Fields{ "func": "GetOutbox", @@ -51,7 +51,7 @@ func (f *federatingDB) GetOutbox(c context.Context, outboxIRI *url.URL) (inbox v // database entries. Separate calls to Create will do that. // // The library makes this call only after acquiring a lock first. -func (f *federatingDB) SetOutbox(c context.Context, outbox vocab.ActivityStreamsOrderedCollectionPage) error { +func (f *federatingDB) SetOutbox(ctx context.Context, outbox vocab.ActivityStreamsOrderedCollectionPage) error { l := f.log.WithFields( logrus.Fields{ "func": "SetOutbox", @@ -66,7 +66,7 @@ func (f *federatingDB) SetOutbox(c context.Context, outbox vocab.ActivityStreams // actor's inbox IRI. // // The library makes this call only after acquiring a lock first. -func (f *federatingDB) OutboxForInbox(c context.Context, inboxIRI *url.URL) (outboxIRI *url.URL, err error) { +func (f *federatingDB) OutboxForInbox(ctx context.Context, inboxIRI *url.URL) (outboxIRI *url.URL, err error) { l := f.log.WithFields( logrus.Fields{ "func": "OutboxForInbox", @@ -79,7 +79,7 @@ func (f *federatingDB) OutboxForInbox(c context.Context, inboxIRI *url.URL) (out return nil, fmt.Errorf("%s is not an inbox URI", inboxIRI.String()) } acct := >smodel.Account{} - if err := f.db.GetWhere([]db.Where{{Key: "inbox_uri", Value: inboxIRI.String()}}, acct); err != nil { + if err := f.db.GetWhere(ctx, []db.Where{{Key: "inbox_uri", Value: inboxIRI.String()}}, acct); err != nil { if err == db.ErrNoEntries { return nil, fmt.Errorf("no actor found that corresponds to inbox %s", inboxIRI.String()) } diff --git a/internal/federation/federatingdb/owns.go b/internal/federation/federatingdb/owns.go index 0a65397ff..1c1f2512d 100644 --- a/internal/federation/federatingdb/owns.go +++ b/internal/federation/federatingdb/owns.go @@ -32,7 +32,7 @@ import ( // Owns returns true if the IRI belongs to this instance, and if // the database has an entry for the IRI. // The library makes this call only after acquiring a lock first. -func (f *federatingDB) Owns(c context.Context, id *url.URL) (bool, error) { +func (f *federatingDB) Owns(ctx context.Context, id *url.URL) (bool, error) { l := f.log.WithFields( logrus.Fields{ "func": "Owns", @@ -54,7 +54,7 @@ func (f *federatingDB) Owns(c context.Context, id *url.URL) (bool, error) { if err != nil { return false, fmt.Errorf("error parsing statuses path for url %s: %s", id.String(), err) } - status, err := f.db.GetStatusByURI(uid) + status, err := f.db.GetStatusByURI(ctx, uid) if err != nil { if err == db.ErrNoEntries { // there are no entries for this status @@ -71,7 +71,7 @@ func (f *federatingDB) Owns(c context.Context, id *url.URL) (bool, error) { if err != nil { return false, fmt.Errorf("error parsing statuses path for url %s: %s", id.String(), err) } - if _, err := f.db.GetLocalAccountByUsername(username); err != nil { + if _, err := f.db.GetLocalAccountByUsername(ctx, username); err != nil { if err == db.ErrNoEntries { // there are no entries for this username return false, nil @@ -88,7 +88,7 @@ func (f *federatingDB) Owns(c context.Context, id *url.URL) (bool, error) { if err != nil { return false, fmt.Errorf("error parsing statuses path for url %s: %s", id.String(), err) } - if _, err := f.db.GetLocalAccountByUsername(username); err != nil { + if _, err := f.db.GetLocalAccountByUsername(ctx, username); err != nil { if err == db.ErrNoEntries { // there are no entries for this username return false, nil @@ -105,7 +105,7 @@ func (f *federatingDB) Owns(c context.Context, id *url.URL) (bool, error) { if err != nil { return false, fmt.Errorf("error parsing statuses path for url %s: %s", id.String(), err) } - if _, err := f.db.GetLocalAccountByUsername(username); err != nil { + if _, err := f.db.GetLocalAccountByUsername(ctx, username); err != nil { if err == db.ErrNoEntries { // there are no entries for this username return false, nil @@ -122,7 +122,7 @@ func (f *federatingDB) Owns(c context.Context, id *url.URL) (bool, error) { if err != nil { return false, fmt.Errorf("error parsing like path for url %s: %s", id.String(), err) } - if _, err := f.db.GetLocalAccountByUsername(username); err != nil { + if _, err := f.db.GetLocalAccountByUsername(ctx, username); err != nil { if err == db.ErrNoEntries { // there are no entries for this username return false, nil @@ -130,7 +130,7 @@ func (f *federatingDB) Owns(c context.Context, id *url.URL) (bool, error) { // an actual error happened return false, fmt.Errorf("database error fetching account with username %s: %s", username, err) } - if err := f.db.GetByID(likeID, >smodel.StatusFave{}); err != nil { + if err := f.db.GetByID(ctx, likeID, >smodel.StatusFave{}); err != nil { if err == db.ErrNoEntries { // there are no entries return false, nil @@ -147,7 +147,7 @@ func (f *federatingDB) Owns(c context.Context, id *url.URL) (bool, error) { if err != nil { return false, fmt.Errorf("error parsing block path for url %s: %s", id.String(), err) } - if _, err := f.db.GetLocalAccountByUsername(username); err != nil { + if _, err := f.db.GetLocalAccountByUsername(ctx, username); err != nil { if err == db.ErrNoEntries { // there are no entries for this username return false, nil @@ -155,7 +155,7 @@ func (f *federatingDB) Owns(c context.Context, id *url.URL) (bool, error) { // an actual error happened return false, fmt.Errorf("database error fetching account with username %s: %s", username, err) } - if err := f.db.GetByID(blockID, >smodel.Block{}); err != nil { + if err := f.db.GetByID(ctx, blockID, >smodel.Block{}); err != nil { if err == db.ErrNoEntries { // there are no entries return false, nil diff --git a/internal/federation/federatingdb/undo.go b/internal/federation/federatingdb/undo.go index c527833b4..0fa38114d 100644 --- a/internal/federation/federatingdb/undo.go +++ b/internal/federation/federatingdb/undo.go @@ -83,7 +83,7 @@ func (f *federatingDB) Undo(ctx context.Context, undo vocab.ActivityStreamsUndo) return errors.New("UNDO: follow actor and activity actor not the same") } // convert the follow to something we can understand - gtsFollow, err := f.typeConverter.ASFollowToFollow(ASFollow) + gtsFollow, err := f.typeConverter.ASFollowToFollow(ctx, ASFollow) if err != nil { return fmt.Errorf("UNDO: error converting asfollow to gtsfollow: %s", err) } @@ -92,11 +92,11 @@ func (f *federatingDB) Undo(ctx context.Context, undo vocab.ActivityStreamsUndo) return errors.New("UNDO: follow object account and inbox account were not the same") } // delete any existing FOLLOW - if err := f.db.DeleteWhere([]db.Where{{Key: "uri", Value: gtsFollow.URI}}, >smodel.Follow{}); err != nil { + if err := f.db.DeleteWhere(ctx, []db.Where{{Key: "uri", Value: gtsFollow.URI}}, >smodel.Follow{}); err != nil { return fmt.Errorf("UNDO: db error removing follow: %s", err) } // delete any existing FOLLOW REQUEST - if err := f.db.DeleteWhere([]db.Where{{Key: "uri", Value: gtsFollow.URI}}, >smodel.FollowRequest{}); err != nil { + if err := f.db.DeleteWhere(ctx, []db.Where{{Key: "uri", Value: gtsFollow.URI}}, >smodel.FollowRequest{}); err != nil { return fmt.Errorf("UNDO: db error removing follow request: %s", err) } l.Debug("follow undone") @@ -116,7 +116,7 @@ func (f *federatingDB) Undo(ctx context.Context, undo vocab.ActivityStreamsUndo) return errors.New("UNDO: block actor and activity actor not the same") } // convert the block to something we can understand - gtsBlock, err := f.typeConverter.ASBlockToBlock(ASBlock) + gtsBlock, err := f.typeConverter.ASBlockToBlock(ctx, ASBlock) if err != nil { return fmt.Errorf("UNDO: error converting asblock to gtsblock: %s", err) } @@ -125,7 +125,7 @@ func (f *federatingDB) Undo(ctx context.Context, undo vocab.ActivityStreamsUndo) return errors.New("UNDO: block object account and inbox account were not the same") } // delete any existing BLOCK - if err := f.db.DeleteWhere([]db.Where{{Key: "uri", Value: gtsBlock.URI}}, >smodel.Block{}); err != nil { + if err := f.db.DeleteWhere(ctx, []db.Where{{Key: "uri", Value: gtsBlock.URI}}, >smodel.Block{}); err != nil { return fmt.Errorf("UNDO: db error removing block: %s", err) } l.Debug("block undone") diff --git a/internal/federation/federatingdb/update.go b/internal/federation/federatingdb/update.go index 88ffc23b4..e9dfe5315 100644 --- a/internal/federation/federatingdb/update.go +++ b/internal/federation/federatingdb/update.go @@ -136,7 +136,7 @@ func (f *federatingDB) Update(ctx context.Context, asType vocab.Type) error { accountable = i } - updatedAcct, err := f.typeConverter.ASRepresentationToAccount(accountable, true) + updatedAcct, err := f.typeConverter.ASRepresentationToAccount(ctx, accountable, true) if err != nil { return fmt.Errorf("UPDATE: error converting to account: %s", err) } @@ -152,7 +152,8 @@ func (f *federatingDB) Update(ctx context.Context, asType vocab.Type) error { } updatedAcct.ID = requestingAcct.ID // set this here so the db will update properly instead of trying to PUT this and getting constraint issues - if err := f.db.UpdateByID(requestingAcct.ID, updatedAcct); err != nil { + updatedAcct, err = f.db.UpdateAccount(ctx, updatedAcct) + if err != nil { return fmt.Errorf("UPDATE: database error inserting updated account: %s", err) } diff --git a/internal/federation/federatingdb/util.go b/internal/federation/federatingdb/util.go index eac70d85c..b5befc613 100644 --- a/internal/federation/federatingdb/util.go +++ b/internal/federation/federatingdb/util.go @@ -60,7 +60,7 @@ func sameActor(activityActor vocab.ActivityStreamsActorProperty, followActor voc // // The go-fed library will handle setting the 'id' property on the // activity or object provided with the value returned. -func (f *federatingDB) NewID(c context.Context, t vocab.Type) (idURL *url.URL, err error) { +func (f *federatingDB) NewID(ctx context.Context, t vocab.Type) (idURL *url.URL, err error) { l := f.log.WithFields( logrus.Fields{ "func": "NewID", @@ -98,7 +98,7 @@ func (f *federatingDB) NewID(c context.Context, t vocab.Type) (idURL *url.URL, e // take the IRI of the first actor we can find (there should only be one) if iter.IsIRI() { // if there's an error here, just use the fallback behavior -- we don't need to return an error here - if actorAccount, err := f.db.GetAccountByURI(iter.GetIRI().String()); err == nil { + if actorAccount, err := f.db.GetAccountByURI(ctx, iter.GetIRI().String()); err == nil { newID, err := id.NewRandomULID() if err != nil { return nil, err @@ -199,7 +199,7 @@ func (f *federatingDB) NewID(c context.Context, t vocab.Type) (idURL *url.URL, e // ActorForOutbox fetches the actor's IRI for the given outbox IRI. // // The library makes this call only after acquiring a lock first. -func (f *federatingDB) ActorForOutbox(c context.Context, outboxIRI *url.URL) (actorIRI *url.URL, err error) { +func (f *federatingDB) ActorForOutbox(ctx context.Context, outboxIRI *url.URL) (actorIRI *url.URL, err error) { l := f.log.WithFields( logrus.Fields{ "func": "ActorForOutbox", @@ -212,7 +212,7 @@ func (f *federatingDB) ActorForOutbox(c context.Context, outboxIRI *url.URL) (ac return nil, fmt.Errorf("%s is not an outbox URI", outboxIRI.String()) } acct := >smodel.Account{} - if err := f.db.GetWhere([]db.Where{{Key: "outbox_uri", Value: outboxIRI.String()}}, acct); err != nil { + if err := f.db.GetWhere(ctx, []db.Where{{Key: "outbox_uri", Value: outboxIRI.String()}}, acct); err != nil { if err == db.ErrNoEntries { return nil, fmt.Errorf("no actor found that corresponds to outbox %s", outboxIRI.String()) } @@ -224,7 +224,7 @@ func (f *federatingDB) ActorForOutbox(c context.Context, outboxIRI *url.URL) (ac // ActorForInbox fetches the actor's IRI for the given outbox IRI. // // The library makes this call only after acquiring a lock first. -func (f *federatingDB) ActorForInbox(c context.Context, inboxIRI *url.URL) (actorIRI *url.URL, err error) { +func (f *federatingDB) ActorForInbox(ctx context.Context, inboxIRI *url.URL) (actorIRI *url.URL, err error) { l := f.log.WithFields( logrus.Fields{ "func": "ActorForInbox", @@ -237,7 +237,7 @@ func (f *federatingDB) ActorForInbox(c context.Context, inboxIRI *url.URL) (acto return nil, fmt.Errorf("%s is not an inbox URI", inboxIRI.String()) } acct := >smodel.Account{} - if err := f.db.GetWhere([]db.Where{{Key: "inbox_uri", Value: inboxIRI.String()}}, acct); err != nil { + if err := f.db.GetWhere(ctx, []db.Where{{Key: "inbox_uri", Value: inboxIRI.String()}}, acct); err != nil { if err == db.ErrNoEntries { return nil, fmt.Errorf("no actor found that corresponds to inbox %s", inboxIRI.String()) } diff --git a/internal/federation/federatingprotocol.go b/internal/federation/federatingprotocol.go index 5da68afd3..7f8958111 100644 --- a/internal/federation/federatingprotocol.go +++ b/internal/federation/federatingprotocol.go @@ -113,7 +113,7 @@ func (f *federator) AuthenticatePostInbox(ctx context.Context, w http.ResponseWr return nil, false, errors.New("username was empty") } - requestedAccount, err := f.db.GetLocalAccountByUsername(username) + requestedAccount, err := f.db.GetLocalAccountByUsername(ctx, username) if err != nil { return nil, false, fmt.Errorf("could not fetch requested account with username %s: %s", username, err) } @@ -131,14 +131,14 @@ func (f *federator) AuthenticatePostInbox(ctx context.Context, w http.ResponseWr // authentication has passed, so add an instance entry for this instance if it hasn't been done already i := >smodel.Instance{} - if err := f.db.GetWhere([]db.Where{{Key: "domain", Value: publicKeyOwnerURI.Host, CaseInsensitive: true}}, i); err != nil { + if err := f.db.GetWhere(ctx, []db.Where{{Key: "domain", Value: publicKeyOwnerURI.Host, CaseInsensitive: true}}, i); err != nil { if err != db.ErrNoEntries { // there's been an actual error return ctx, false, fmt.Errorf("error getting requesting account with public key id %s: %s", publicKeyOwnerURI.String(), err) } // we don't have an entry for this instance yet so dereference it - i, err = f.GetRemoteInstance(username, &url.URL{ + i, err = f.GetRemoteInstance(ctx, username, &url.URL{ Scheme: publicKeyOwnerURI.Scheme, Host: publicKeyOwnerURI.Host, }) @@ -147,12 +147,12 @@ func (f *federator) AuthenticatePostInbox(ctx context.Context, w http.ResponseWr } // and put it in the db - if err := f.db.Put(i); err != nil { + if err := f.db.Put(ctx, i); err != nil { return nil, false, fmt.Errorf("error inserting newly dereferenced instance %s: %s", publicKeyOwnerURI.Host, err) } } - requestingAccount, _, err := f.GetRemoteAccount(username, publicKeyOwnerURI, false) + requestingAccount, _, err := f.GetRemoteAccount(ctx, username, publicKeyOwnerURI, false) if err != nil { return nil, false, fmt.Errorf("couldn't get remote account: %s", err) } @@ -189,7 +189,7 @@ func (f *federator) Blocked(ctx context.Context, actorIRIs []*url.URL) (bool, er return false, errors.New("requested account not set on request context, so couldn't determine blocks") } - blocked, err := f.db.AreURIsBlocked(actorIRIs) + blocked, err := f.db.AreURIsBlocked(ctx, actorIRIs) if err != nil { return false, fmt.Errorf("error checking domain blocks: %s", err) } @@ -198,7 +198,7 @@ func (f *federator) Blocked(ctx context.Context, actorIRIs []*url.URL) (bool, er } for _, uri := range actorIRIs { - requestingAccount, err := f.db.GetAccountByURI(uri.String()) + requestingAccount, err := f.db.GetAccountByURI(ctx, uri.String()) if err != nil { if err == db.ErrNoEntries { // we don't have an entry for this account so it's not blocked @@ -208,7 +208,7 @@ func (f *federator) Blocked(ctx context.Context, actorIRIs []*url.URL) (bool, er return false, fmt.Errorf("error getting account with uri %s: %s", uri.String(), err) } - blocked, err = f.db.IsBlocked(requestedAccount.ID, requestingAccount.ID, true) + blocked, err = f.db.IsBlocked(ctx, requestedAccount.ID, requestingAccount.ID, true) if err != nil { return false, fmt.Errorf("error checking account block: %s", err) } diff --git a/internal/federation/federator.go b/internal/federation/federator.go index 1b5f5441a..5eddcbb99 100644 --- a/internal/federation/federator.go +++ b/internal/federation/federator.go @@ -54,21 +54,21 @@ type Federator interface { // FingerRemoteAccount performs a webfinger lookup for a remote account, using the .well-known path. It will return the ActivityPub URI for that // account, or an error if it doesn't exist or can't be retrieved. - FingerRemoteAccount(requestingUsername string, targetUsername string, targetDomain string) (*url.URL, error) + FingerRemoteAccount(ctx context.Context, requestingUsername string, targetUsername string, targetDomain string) (*url.URL, error) - DereferenceRemoteThread(username string, statusURI *url.URL) error - DereferenceAnnounce(announce *gtsmodel.Status, requestingUsername string) error + DereferenceRemoteThread(ctx context.Context, username string, statusURI *url.URL) error + DereferenceAnnounce(ctx context.Context, announce *gtsmodel.Status, requestingUsername string) error - GetRemoteAccount(username string, remoteAccountID *url.URL, refresh bool) (*gtsmodel.Account, bool, error) - EnrichRemoteAccount(username string, account *gtsmodel.Account) (*gtsmodel.Account, error) + GetRemoteAccount(ctx context.Context, username string, remoteAccountID *url.URL, refresh bool) (*gtsmodel.Account, bool, error) + EnrichRemoteAccount(ctx context.Context, username string, account *gtsmodel.Account) (*gtsmodel.Account, error) - GetRemoteStatus(username string, remoteStatusID *url.URL, refresh bool) (*gtsmodel.Status, ap.Statusable, bool, error) - EnrichRemoteStatus(username string, status *gtsmodel.Status) (*gtsmodel.Status, error) + GetRemoteStatus(ctx context.Context, username string, remoteStatusID *url.URL, refresh bool) (*gtsmodel.Status, ap.Statusable, bool, error) + EnrichRemoteStatus(ctx context.Context, username string, status *gtsmodel.Status) (*gtsmodel.Status, error) - GetRemoteInstance(username string, remoteInstanceURI *url.URL) (*gtsmodel.Instance, error) + GetRemoteInstance(ctx context.Context, username string, remoteInstanceURI *url.URL) (*gtsmodel.Instance, error) // Handshaking returns true if the given username is currently in the process of dereferencing the remoteAccountID. - Handshaking(username string, remoteAccountID *url.URL) bool + Handshaking(ctx context.Context, username string, remoteAccountID *url.URL) bool pub.CommonBehavior pub.FederatingProtocol } diff --git a/internal/federation/finger.go b/internal/federation/finger.go index a5a4fa0e7..5cdd4c04d 100644 --- a/internal/federation/finger.go +++ b/internal/federation/finger.go @@ -29,12 +29,12 @@ import ( apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" ) -func (f *federator) FingerRemoteAccount(requestingUsername string, targetUsername string, targetDomain string) (*url.URL, error) { - if blocked, err := f.db.IsDomainBlocked(targetDomain); blocked || err != nil { +func (f *federator) FingerRemoteAccount(ctx context.Context, requestingUsername string, targetUsername string, targetDomain string) (*url.URL, error) { + if blocked, err := f.db.IsDomainBlocked(ctx, targetDomain); blocked || err != nil { return nil, fmt.Errorf("FingerRemoteAccount: domain %s is blocked", targetDomain) } - t, err := f.transportController.NewTransportForUsername(requestingUsername) + t, err := f.transportController.NewTransportForUsername(ctx, requestingUsername) if err != nil { return nil, fmt.Errorf("FingerRemoteAccount: error getting transport for username %s while dereferencing @%s@%s: %s", requestingUsername, targetUsername, targetDomain, err) } diff --git a/internal/federation/handshake.go b/internal/federation/handshake.go index 0671e78a9..b973680b3 100644 --- a/internal/federation/handshake.go +++ b/internal/federation/handshake.go @@ -18,8 +18,11 @@ package federation -import "net/url" +import ( + "context" + "net/url" +) -func (f *federator) Handshaking(username string, remoteAccountID *url.URL) bool { - return f.dereferencer.Handshaking(username, remoteAccountID) +func (f *federator) Handshaking(ctx context.Context, username string, remoteAccountID *url.URL) bool { + return f.dereferencer.Handshaking(ctx, username, remoteAccountID) } diff --git a/internal/federation/transport.go b/internal/federation/transport.go index 20aee964b..9e2e38e19 100644 --- a/internal/federation/transport.go +++ b/internal/federation/transport.go @@ -68,5 +68,5 @@ func (f *federator) NewTransport(ctx context.Context, actorBoxIRI *url.URL, gofe return nil, fmt.Errorf("id %s was neither an inbox path nor an outbox path", actorBoxIRI.String()) } - return f.transportController.NewTransportForUsername(username) + return f.transportController.NewTransportForUsername(ctx, username) } diff --git a/internal/gtsmodel/account.go b/internal/gtsmodel/account.go index 9673dcb2c..98d2dcfc9 100644 --- a/internal/gtsmodel/account.go +++ b/internal/gtsmodel/account.go @@ -18,8 +18,8 @@ // Package gtsmodel contains types used *internally* by GoToSocial and added/removed/selected from the database. // These types should never be serialized and/or sent out via public APIs, as they contain sensitive information. -// The annotation used on these structs is for handling them via the go-pg ORM (hence why they're in this db subdir). -// See here for more info on go-pg model annotations: https://pg.uptrace.dev/models/ +// The annotation used on these structs is for handling them via the bun-db ORM. +// See here for more info on bun model annotations: https://bun.uptrace.dev/guide/models.html package gtsmodel import ( @@ -34,24 +34,24 @@ type Account struct { */ // id of this account in the local database - ID string `pg:"type:CHAR(26),pk,notnull,unique"` + ID string `bun:"type:CHAR(26),pk,nullzero,notnull,unique"` // Username of the account, should just be a string of [a-z0-9_]. Can be added to domain to create the full username in the form ``[username]@[domain]`` eg., ``user_96@example.org`` - Username string `pg:",notnull,unique:userdomain"` // username and domain should be unique *with* each other + Username string `bun:",notnull,unique:userdomain,nullzero"` // username and domain should be unique *with* each other // Domain of the account, will be null if this is a local account, otherwise something like ``example.org`` or ``mastodon.social``. Should be unique with username. - Domain string `pg:",unique:userdomain"` // username and domain should be unique *with* each other + Domain string `bun:",unique:userdomain,nullzero"` // username and domain should be unique *with* each other /* ACCOUNT METADATA */ // ID of the avatar as a media attachment - AvatarMediaAttachmentID string `pg:"type:CHAR(26)"` - AvatarMediaAttachment *MediaAttachment `pg:"rel:has-one"` + AvatarMediaAttachmentID string `bun:"type:CHAR(26),nullzero"` + AvatarMediaAttachment *MediaAttachment `bun:"rel:belongs-to"` // For a non-local account, where can the header be fetched? AvatarRemoteURL string // ID of the header as a media attachment - HeaderMediaAttachmentID string `pg:"type:CHAR(26)"` - HeaderMediaAttachment *MediaAttachment `pg:"rel:has-one"` + HeaderMediaAttachmentID string `bun:"type:CHAR(26),nullzero"` + HeaderMediaAttachment *MediaAttachment `bun:"rel:belongs-to"` // For a non-local account, where can the header be fetched? HeaderRemoteURL string // DisplayName for this account. Can be empty, then just the Username will be used for display purposes. @@ -63,11 +63,11 @@ type Account struct { // Is this a memorial account, ie., has the user passed away? Memorial bool // This account has moved this account id in the database - MovedToAccountID string `pg:"type:CHAR(26)"` + MovedToAccountID string `bun:"type:CHAR(26),nullzero"` // When was this account created? - CreatedAt time.Time `pg:"type:timestamp,notnull,default:now()"` + CreatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"` // When was this account last updated? - UpdatedAt time.Time `pg:"type:timestamp,notnull,default:now()"` + UpdatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"` // Does this account identify itself as a bot? Bot bool // What reason was given for signing up when this account was created? @@ -78,36 +78,36 @@ type Account struct { */ // Does this account need an approval for new followers? - Locked bool `pg:",default:true,use_zero"` + Locked bool `bun:",default:true"` // Should this account be shown in the instance's profile directory? - Discoverable bool `pg:",default:false"` + Discoverable bool `bun:",default:false"` // Default post privacy for this account - Privacy Visibility `pg:",default:'public'"` + Privacy Visibility `bun:",default:'public'"` // Set posts from this account to sensitive by default? - Sensitive bool `pg:",default:false"` + Sensitive bool `bun:",default:false"` // What language does this account post in? - Language string `pg:",default:'en'"` + Language string `bun:",default:'en'"` /* ACTIVITYPUB THINGS */ // What is the activitypub URI for this account discovered by webfinger? - URI string `pg:",unique"` + URI string `bun:",unique,nullzero"` // At which URL can we see the user account in a web browser? - URL string `pg:",unique"` + URL string `bun:",unique,nullzero"` // Last time this account was located using the webfinger API. - LastWebfingeredAt time.Time `pg:"type:timestamp"` + LastWebfingeredAt time.Time `bun:",nullzero"` // Address of this account's activitypub inbox, for sending activity to - InboxURI string `pg:",unique"` + InboxURI string `bun:",unique,nullzero"` // Address of this account's activitypub outbox - OutboxURI string `pg:",unique"` + OutboxURI string `bun:",unique,nullzero"` // URI for getting the following list of this account - FollowingURI string `pg:",unique"` + FollowingURI string `bun:",unique,nullzero"` // URI for getting the followers list of this account - FollowersURI string `pg:",unique"` + FollowersURI string `bun:",unique,nullzero"` // URL for getting the featured collection list of this account - FeaturedCollectionURI string `pg:",unique"` + FeaturedCollectionURI string `bun:",unique,nullzero"` // What type of activitypub actor is this account? ActorType string // This account is associated with x account id @@ -129,15 +129,15 @@ type Account struct { */ // When was this account set to have all its media shown as sensitive? - SensitizedAt time.Time `pg:"type:timestamp"` + SensitizedAt time.Time `bun:",nullzero"` // When was this account silenced (eg., statuses only visible to followers, not public)? - SilencedAt time.Time `pg:"type:timestamp"` + SilencedAt time.Time `bun:",nullzero"` // When was this account suspended (eg., don't allow it to log in/post, don't accept media/posts from this account) - SuspendedAt time.Time `pg:"type:timestamp"` + SuspendedAt time.Time `bun:",nullzero"` // Should we hide this account's collections? HideCollections bool // id of the database entry that caused this account to become suspended -- can be an account ID or a domain block ID - SuspensionOrigin string `pg:"type:CHAR(26)"` + SuspensionOrigin string `bun:"type:CHAR(26),nullzero"` } // Field represents a key value field on an account, for things like pronouns, website, etc. @@ -146,5 +146,5 @@ type Account struct { type Field struct { Name string Value string - VerifiedAt time.Time `pg:"type:timestamp"` + VerifiedAt time.Time `bun:",nullzero"` } diff --git a/internal/gtsmodel/application.go b/internal/gtsmodel/application.go index 91287dff3..a6976eafd 100644 --- a/internal/gtsmodel/application.go +++ b/internal/gtsmodel/application.go @@ -22,7 +22,7 @@ package gtsmodel // It is used to authorize tokens etc, and is associated with an oauth client id in the database. type Application struct { // id of this application in the db - ID string `pg:"type:CHAR(26),pk,notnull"` + ID string `bun:"type:CHAR(26),pk,notnull"` // name of the application given when it was created (eg., 'tusky') Name string // website for the application given when it was created (eg., 'https://tusky.app') @@ -30,7 +30,7 @@ type Application struct { // redirect uri requested by the application for oauth2 flow RedirectURI string // id of the associated oauth client entity in the db - ClientID string `pg:"type:CHAR(26)"` + ClientID string `bun:"type:CHAR(26)"` // secret of the associated oauth client entity in the db ClientSecret string // scopes requested when this app was created diff --git a/internal/gtsmodel/block.go b/internal/gtsmodel/block.go index 32afede55..0c762837d 100644 --- a/internal/gtsmodel/block.go +++ b/internal/gtsmodel/block.go @@ -5,17 +5,17 @@ import "time" // Block refers to the blocking of one account by another. type Block struct { // id of this block in the database - ID string `pg:"type:CHAR(26),pk,notnull"` + ID string `bun:"type:CHAR(26),pk,notnull"` // When was this block created - CreatedAt time.Time `pg:"type:timestamp,notnull,default:now()"` + CreatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"` // When was this block updated - UpdatedAt time.Time `pg:"type:timestamp,notnull,default:now()"` + UpdatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"` // Who created this block? - AccountID string `pg:"type:CHAR(26),notnull"` - Account *Account `pg:"rel:has-one"` + AccountID string `bun:"type:CHAR(26),notnull"` + Account *Account `bun:"rel:belongs-to"` // Who is targeted by this block? - TargetAccountID string `pg:"type:CHAR(26),notnull"` - TargetAccount *Account `pg:"rel:has-one"` + TargetAccountID string `bun:"type:CHAR(26),notnull"` + TargetAccount *Account `bun:"rel:belongs-to"` // Activitypub URI for this block - URI string `pg:",notnull"` + URI string `bun:",notnull"` } diff --git a/internal/gtsmodel/domainblock.go b/internal/gtsmodel/domainblock.go index 1bed86d8f..03d5ab0af 100644 --- a/internal/gtsmodel/domainblock.go +++ b/internal/gtsmodel/domainblock.go @@ -23,16 +23,16 @@ import "time" // DomainBlock represents a federation block against a particular domain type DomainBlock struct { // ID of this block in the database - ID string `pg:"type:CHAR(26),pk,notnull,unique"` + ID string `bun:"type:CHAR(26),pk,notnull,unique"` // blocked domain - Domain string `pg:",pk,notnull,unique"` + Domain string `bun:",pk,notnull,unique"` // When was this block created - CreatedAt time.Time `pg:"type:timestamp,notnull,default:now()"` + CreatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"` // When was this block updated - UpdatedAt time.Time `pg:"type:timestamp,notnull,default:now()"` + UpdatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"` // Account ID of the creator of this block - CreatedByAccountID string `pg:"type:CHAR(26),notnull"` - CreatedByAccount *Account `pg:"rel:belongs-to"` + CreatedByAccountID string `bun:"type:CHAR(26),notnull"` + CreatedByAccount *Account `bun:"rel:belongs-to"` // Private comment on this block, viewable to admins PrivateComment string // Public comment on this block, viewable (optionally) by everyone @@ -40,5 +40,5 @@ type DomainBlock struct { // whether the domain name should appear obfuscated when displaying it publicly Obfuscate bool // if this block was created through a subscription, what's the subscription ID? - SubscriptionID string `pg:"type:CHAR(26)"` + SubscriptionID string `bun:"type:CHAR(26),nullzero"` } diff --git a/internal/gtsmodel/emaildomainblock.go b/internal/gtsmodel/emaildomainblock.go index 374454374..1919172fa 100644 --- a/internal/gtsmodel/emaildomainblock.go +++ b/internal/gtsmodel/emaildomainblock.go @@ -23,14 +23,14 @@ import "time" // EmailDomainBlock represents a domain that the server should automatically reject sign-up requests from. type EmailDomainBlock struct { // ID of this block in the database - ID string `pg:"type:CHAR(26),pk,notnull,unique"` + ID string `bun:"type:CHAR(26),pk,notnull,unique"` // Email domain to block. Eg. 'gmail.com' or 'hotmail.com' - Domain string `pg:",notnull"` + Domain string `bun:",notnull"` // When was this block created - CreatedAt time.Time `pg:"type:timestamp,notnull,default:now()"` + CreatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"` // When was this block updated - UpdatedAt time.Time `pg:"type:timestamp,notnull,default:now()"` + UpdatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"` // Account ID of the creator of this block - CreatedByAccountID string `pg:"type:CHAR(26),notnull"` - CreatedByAccount *Account `pg:"rel:belongs-to"` + CreatedByAccountID string `bun:"type:CHAR(26),notnull"` + CreatedByAccount *Account `bun:"rel:belongs-to"` } diff --git a/internal/gtsmodel/emoji.go b/internal/gtsmodel/emoji.go index f0996d1a3..3b02c14e7 100644 --- a/internal/gtsmodel/emoji.go +++ b/internal/gtsmodel/emoji.go @@ -23,16 +23,16 @@ import "time" // Emoji represents a custom emoji that's been uploaded through the admin UI, and is useable by instance denizens. type Emoji struct { // database ID of this emoji - ID string `pg:"type:CHAR(26),pk,notnull"` + ID string `bun:"type:CHAR(26),pk,notnull"` // String shortcode for this emoji -- the part that's between colons. This should be lowercase a-z_ // eg., 'blob_hug' 'purple_heart' Must be unique with domain. - Shortcode string `pg:",notnull,unique:shortcodedomain"` + Shortcode string `bun:",notnull,unique:shortcodedomain"` // Origin domain of this emoji, eg 'example.org', 'queer.party'. empty string for local emojis. - Domain string `pg:",notnull,default:'',use_zero,unique:shortcodedomain"` + Domain string `bun:",notnull,default:'',unique:shortcodedomain"` // When was this emoji created. Must be unique with shortcode. - CreatedAt time.Time `pg:"type:timestamp,notnull,default:now()"` + CreatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"` // When was this emoji updated - UpdatedAt time.Time `pg:"type:timestamp,notnull,default:now()"` + UpdatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"` // Where can this emoji be retrieved remotely? Null for local emojis. // For remote emojis, it'll be something like: // https://hackers.town/system/custom_emojis/images/000/049/842/original/1b74481204feabfd.png @@ -51,28 +51,27 @@ type Emoji struct { ImageStaticURL string // Path of the emoji image in the server storage system. Will be something like: // '/gotosocial/storage/6339820e-ef65-4166-a262-5a9f46adb1a7/emoji/original/bfa6c9c5-6c25-4ea4-98b4-d78b8126fb52.png' - ImagePath string `pg:",notnull"` + ImagePath string `bun:",notnull"` // Path of a static version of the emoji image in the server storage system. Will be something like: // '/gotosocial/storage/6339820e-ef65-4166-a262-5a9f46adb1a7/emoji/small/bfa6c9c5-6c25-4ea4-98b4-d78b8126fb52.png' - ImageStaticPath string `pg:",notnull"` + ImageStaticPath string `bun:",notnull"` // MIME content type of the emoji image // Probably "image/png" - ImageContentType string `pg:",notnull"` + ImageContentType string `bun:",notnull"` // MIME content type of the static version of the emoji image. - ImageStaticContentType string `pg:",notnull"` + ImageStaticContentType string `bun:",notnull"` // Size of the emoji image file in bytes, for serving purposes. - ImageFileSize int `pg:",notnull"` + ImageFileSize int `bun:",notnull"` // Size of the static version of the emoji image file in bytes, for serving purposes. - ImageStaticFileSize int `pg:",notnull"` + ImageStaticFileSize int `bun:",notnull"` // When was the emoji image last updated? - ImageUpdatedAt time.Time `pg:"type:timestamp,notnull,default:now()"` + ImageUpdatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"` // Has a moderation action disabled this emoji from being shown? - Disabled bool `pg:",notnull,default:false"` + Disabled bool `bun:",notnull,default:false"` // ActivityStreams uri of this emoji. Something like 'https://example.org/emojis/1234' - URI string `pg:",notnull,unique"` + URI string `bun:",notnull,unique"` // Is this emoji visible in the admin emoji picker? - VisibleInPicker bool `pg:",notnull,default:true"` + VisibleInPicker bool `bun:",notnull,default:true"` // In which emoji category is this emoji visible? - CategoryID string `pg:"type:CHAR(26)"` - Status *Status `pg:"rel:belongs-to"` + CategoryID string `bun:"type:CHAR(26),nullzero"` } diff --git a/internal/gtsmodel/follow.go b/internal/gtsmodel/follow.go index 8f169f8c4..3d3eb1f1b 100644 --- a/internal/gtsmodel/follow.go +++ b/internal/gtsmodel/follow.go @@ -23,21 +23,21 @@ import "time" // Follow represents one account following another, and the metadata around that follow. type Follow struct { // id of this follow in the database - ID string `pg:"type:CHAR(26),pk,notnull,unique"` + ID string `bun:"type:CHAR(26),pk,notnull,unique"` // When was this follow created? - CreatedAt time.Time `pg:"type:timestamp,notnull,default:now()"` + CreatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"` // When was this follow last updated? - UpdatedAt time.Time `pg:"type:timestamp,notnull,default:now()"` + UpdatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"` // Who does this follow belong to? - AccountID string `pg:"type:CHAR(26),unique:srctarget,notnull"` - Account *Account `pg:"rel:belongs-to"` + AccountID string `bun:"type:CHAR(26),unique:srctarget,notnull"` + Account *Account `bun:"rel:belongs-to"` // Who does AccountID follow? - TargetAccountID string `pg:"type:CHAR(26),unique:srctarget,notnull"` - TargetAccount *Account `pg:"rel:has-one"` + TargetAccountID string `bun:"type:CHAR(26),unique:srctarget,notnull"` + TargetAccount *Account `bun:"rel:belongs-to"` // Does this follow also want to see reblogs and not just posts? - ShowReblogs bool `pg:"default:true"` + ShowReblogs bool `bun:"default:true"` // What is the activitypub URI of this follow? - URI string `pg:",unique"` + URI string `bun:",unique"` // does the following account want to be notified when the followed account posts? Notify bool } diff --git a/internal/gtsmodel/followrequest.go b/internal/gtsmodel/followrequest.go index 752c7d0a2..5a6cb5e02 100644 --- a/internal/gtsmodel/followrequest.go +++ b/internal/gtsmodel/followrequest.go @@ -23,21 +23,21 @@ import "time" // FollowRequest represents one account requesting to follow another, and the metadata around that request. type FollowRequest struct { // id of this follow request in the database - ID string `pg:"type:CHAR(26),pk,notnull,unique"` + ID string `bun:"type:CHAR(26),pk,notnull,unique"` // When was this follow request created? - CreatedAt time.Time `pg:"type:timestamp,notnull,default:now()"` + CreatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"` // When was this follow request last updated? - UpdatedAt time.Time `pg:"type:timestamp,notnull,default:now()"` + UpdatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"` // Who does this follow request originate from? - AccountID string `pg:"type:CHAR(26),unique:srctarget,notnull"` - Account Account `pg:"rel:has-one"` + AccountID string `bun:"type:CHAR(26),unique:frsrctarget,notnull"` + Account *Account `bun:"rel:belongs-to"` // Who is the target of this follow request? - TargetAccountID string `pg:"type:CHAR(26),unique:srctarget,notnull"` - TargetAccount Account `pg:"rel:has-one"` + TargetAccountID string `bun:"type:CHAR(26),unique:frsrctarget,notnull"` + TargetAccount *Account `bun:"rel:belongs-to"` // Does this follow also want to see reblogs and not just posts? - ShowReblogs bool `pg:"default:true"` + ShowReblogs bool `bun:"default:true"` // What is the activitypub URI of this follow request? - URI string `pg:",unique"` + URI string `bun:",unique,nullzero"` // does the following account want to be notified when the followed account posts? Notify bool } diff --git a/internal/gtsmodel/instance.go b/internal/gtsmodel/instance.go index 7b453a0b3..5bfe942f7 100644 --- a/internal/gtsmodel/instance.go +++ b/internal/gtsmodel/instance.go @@ -5,22 +5,22 @@ import "time" // Instance represents a federated instance, either local or remote. type Instance struct { // ID of this instance in the database - ID string `pg:"type:CHAR(26),pk,notnull,unique"` + ID string `bun:"type:CHAR(26),pk,notnull,unique"` // Instance domain eg example.org - Domain string `pg:",pk,notnull,unique"` + Domain string `bun:",pk,notnull,unique"` // Title of this instance as it would like to be displayed. Title string // base URI of this instance eg https://example.org - URI string `pg:",notnull,unique"` + URI string `bun:",notnull,unique"` // When was this instance created in the db? - CreatedAt time.Time `pg:"type:timestamp,notnull,default:now()"` + CreatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"` // When was this instance last updated in the db? - UpdatedAt time.Time `pg:"type:timestamp,notnull,default:now()"` + UpdatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"` // When was this instance suspended, if at all? - SuspendedAt time.Time + SuspendedAt time.Time `bun:",nullzero"` // ID of any existing domain block for this instance in the database - DomainBlockID string `pg:"type:CHAR(26)"` - DomainBlock *DomainBlock `pg:"rel:has-one"` + DomainBlockID string `bun:"type:CHAR(26),nullzero"` + DomainBlock *DomainBlock `bun:"rel:belongs-to"` // Short description of this instance ShortDescription string // Longer description of this instance @@ -32,10 +32,10 @@ type Instance struct { // Username of the contact account for this instance ContactAccountUsername string // Contact account ID in the database for this instance - ContactAccountID string `pg:"type:CHAR(26)"` - ContactAccount *Account `pg:"rel:has-one"` + ContactAccountID string `bun:"type:CHAR(26),nullzero"` + ContactAccount *Account `bun:"rel:belongs-to"` // Reputation score of this instance - Reputation int64 `pg:",notnull,default:0"` + Reputation int64 `bun:",notnull,default:0"` // Version of the software used on this instance Version string } diff --git a/internal/gtsmodel/mediaattachment.go b/internal/gtsmodel/mediaattachment.go index 0f12caaad..b767e538c 100644 --- a/internal/gtsmodel/mediaattachment.go +++ b/internal/gtsmodel/mediaattachment.go @@ -26,28 +26,28 @@ import ( // somewhere in storage and that can be retrieved and served by the router. type MediaAttachment struct { // ID of the attachment in the database - ID string `pg:"type:CHAR(26),pk,notnull,unique"` + ID string `bun:"type:CHAR(26),pk,notnull,unique"` // ID of the status to which this is attached - StatusID string `pg:"type:CHAR(26)"` + StatusID string `bun:"type:CHAR(26),nullzero"` // Where can the attachment be retrieved on *this* server URL string // Where can the attachment be retrieved on a remote server (empty for local media) RemoteURL string // When was the attachment created - CreatedAt time.Time `pg:"type:timestamp,notnull,default:now()"` + CreatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"` // When was the attachment last updated - UpdatedAt time.Time `pg:"type:timestamp,notnull,default:now()"` + UpdatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"` // Type of file (image/gif/audio/video) - Type FileType `pg:",notnull"` + Type FileType `bun:",notnull"` // Metadata about the file FileMeta FileMeta // To which account does this attachment belong - AccountID string `pg:"type:CHAR(26),notnull"` - Account *Account `pg:"rel:belongs-to"` + AccountID string `bun:"type:CHAR(26),notnull"` + Account *Account `bun:"rel:has-one"` // Description of the attachment (for screenreaders) Description string // To which scheduled status does this attachment belong - ScheduledStatusID string `pg:"type:CHAR(26)"` + ScheduledStatusID string `bun:"type:CHAR(26),nullzero"` // What is the generated blurhash of this attachment Blurhash string // What is the processing status of this attachment @@ -71,7 +71,7 @@ type File struct { // What is the size of the file in bytes. FileSize int // When was the file last updated. - UpdatedAt time.Time `pg:"type:timestamp,notnull,default:now()"` + UpdatedAt time.Time `bun:"type:timestamp,notnull,default:current_timestamp"` } // Thumbnail refers to a small image thumbnail derived from a larger image, video, or audio file. @@ -83,7 +83,7 @@ type Thumbnail struct { // What is the size of the file in bytes FileSize int // When was the file last updated - UpdatedAt time.Time `pg:"type:timestamp,notnull,default:now()"` + UpdatedAt time.Time `bun:"type:timestamp,notnull,default:current_timestamp"` // What is the URL of the thumbnail on the local server URL string // What is the remote URL of the thumbnail (empty for local media) diff --git a/internal/gtsmodel/mention.go b/internal/gtsmodel/mention.go index 931e681db..ce5977659 100644 --- a/internal/gtsmodel/mention.go +++ b/internal/gtsmodel/mention.go @@ -23,22 +23,22 @@ import "time" // Mention refers to the 'tagging' or 'mention' of a user within a status. type Mention struct { // ID of this mention in the database - ID string `pg:"type:CHAR(26),pk,notnull,unique"` + ID string `bun:"type:CHAR(26),pk,notnull,unique"` // ID of the status this mention originates from - StatusID string `pg:"type:CHAR(26),notnull"` - Status *Status `pg:"rel:belongs-to"` + StatusID string `bun:"type:CHAR(26),notnull"` + Status *Status `bun:"rel:belongs-to"` // When was this mention created? - CreatedAt time.Time `pg:"type:timestamp,notnull,default:now()"` + CreatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"` // When was this mention last updated? - UpdatedAt time.Time `pg:"type:timestamp,notnull,default:now()"` + UpdatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"` // What's the internal account ID of the originator of the mention? - OriginAccountID string `pg:"type:CHAR(26),notnull"` - OriginAccount *Account `pg:"rel:has-one"` + OriginAccountID string `bun:"type:CHAR(26),notnull"` + OriginAccount *Account `bun:"rel:belongs-to"` // What's the AP URI of the originator of the mention? - OriginAccountURI string `pg:",notnull"` + OriginAccountURI string `bun:",notnull"` // What's the internal account ID of the mention target? - TargetAccountID string `pg:"type:CHAR(26),notnull"` - TargetAccount *Account `pg:"rel:has-one"` + TargetAccountID string `bun:"type:CHAR(26),notnull"` + TargetAccount *Account `bun:"rel:belongs-to"` // Prevent this mention from generating a notification? Silent bool @@ -54,15 +54,15 @@ type Mention struct { // @whatever_username@example.org // // This will not be put in the database, it's just for convenience. - NameString string `pg:"-"` + NameString string `bun:"-"` // TargetAccountURI is the AP ID (uri) of the user mentioned. // // This will not be put in the database, it's just for convenience. - TargetAccountURI string `pg:"-"` + TargetAccountURI string `bun:"-"` // TargetAccountURL is the web url of the user mentioned. // // This will not be put in the database, it's just for convenience. - TargetAccountURL string `pg:"-"` + TargetAccountURL string `bun:"-"` // A pointer to the gtsmodel account of the mentioned account. } diff --git a/internal/gtsmodel/messages.go b/internal/gtsmodel/messages.go index 910c74898..62beb0adc 100644 --- a/internal/gtsmodel/messages.go +++ b/internal/gtsmodel/messages.go @@ -1,11 +1,22 @@ -package gtsmodel +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. -// // ToClientAPI wraps a message that travels from the processor into the client API -// type ToClientAPI struct { -// APObjectType ActivityStreamsObject -// APActivityType ActivityStreamsActivity -// Activity interface{} -// } + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package gtsmodel // FromClientAPI wraps a message that travels from client API into the processor type FromClientAPI struct { @@ -16,13 +27,6 @@ type FromClientAPI struct { TargetAccount *Account } -// // ToFederator wraps a message that travels from the processor into the federator -// type ToFederator struct { -// APObjectType ActivityStreamsObject -// APActivityType ActivityStreamsActivity -// GTSModel interface{} -// } - // FromFederator wraps a message that travels from the federator into the processor type FromFederator struct { APObjectType string diff --git a/internal/gtsmodel/notification.go b/internal/gtsmodel/notification.go index b85bc969e..14ab90802 100644 --- a/internal/gtsmodel/notification.go +++ b/internal/gtsmodel/notification.go @@ -23,20 +23,20 @@ import "time" // Notification models an alert/notification sent to an account about something like a reblog, like, new follow request, etc. type Notification struct { // ID of this notification in the database - ID string `pg:"type:CHAR(26),pk,notnull"` + ID string `bun:"type:CHAR(26),pk,notnull"` // Type of this notification - NotificationType NotificationType `pg:",notnull"` + NotificationType NotificationType `bun:",notnull"` // Creation time of this notification - CreatedAt time.Time `pg:"type:timestamp,notnull,default:now()"` + CreatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"` // Which account does this notification target (ie., who will receive the notification?) - TargetAccountID string `pg:"type:CHAR(26),notnull"` - TargetAccount *Account `pg:"rel:has-one"` + TargetAccountID string `bun:"type:CHAR(26),notnull"` + TargetAccount *Account `bun:"rel:belongs-to"` // Which account performed the action that created this notification? - OriginAccountID string `pg:"type:CHAR(26),notnull"` - OriginAccount *Account `pg:"rel:has-one"` + OriginAccountID string `bun:"type:CHAR(26),notnull"` + OriginAccount *Account `bun:"rel:belongs-to"` // If the notification pertains to a status, what is the database ID of that status? - StatusID string `pg:"type:CHAR(26)"` - Status *Status `pg:"rel:has-one"` + StatusID string `bun:"type:CHAR(26),nullzero"` + Status *Status `bun:"rel:belongs-to"` // Has this notification been read already? Read bool } diff --git a/internal/gtsmodel/routersession.go b/internal/gtsmodel/routersession.go index c0f8e1f4d..7f3bd85c3 100644 --- a/internal/gtsmodel/routersession.go +++ b/internal/gtsmodel/routersession.go @@ -20,7 +20,7 @@ package gtsmodel // RouterSession is used to store and retrieve settings for a router session. type RouterSession struct { - ID string `pg:"type:CHAR(26),pk,notnull"` - Auth []byte `pg:",notnull"` - Crypt []byte `pg:",notnull"` + ID string `bun:"type:CHAR(26),pk,notnull"` + Auth []byte `bun:"type:bytea,notnull"` + Crypt []byte `bun:"type:bytea,notnull"` } diff --git a/internal/gtsmodel/status.go b/internal/gtsmodel/status.go index 354f37e04..fd33bd788 100644 --- a/internal/gtsmodel/status.go +++ b/internal/gtsmodel/status.go @@ -25,61 +25,61 @@ import ( // Status represents a user-created 'post' or 'status' in the database, either remote or local type Status struct { // id of the status in the database - ID string `pg:"type:CHAR(26),pk,notnull"` + ID string `bun:"type:CHAR(26),pk,notnull"` // uri at which this status is reachable - URI string `pg:",unique"` + URI string `bun:",unique,nullzero"` // web url for viewing this status - URL string `pg:",unique"` + URL string `bun:",unique,nullzero"` // the html-formatted content of this status Content string // Database IDs of any media attachments associated with this status - AttachmentIDs []string `pg:"attachments,array"` - Attachments []*MediaAttachment `pg:"attached_media,rel:has-many"` + AttachmentIDs []string `bun:"attachments,array"` + Attachments []*MediaAttachment `bun:"attached_media,rel:has-many"` // Database IDs of any tags used in this status - TagIDs []string `pg:"tags,array"` - Tags []*Tag `pg:"attached_tags,many2many:status_to_tags"` // https://pg.uptrace.dev/orm/many-to-many-relation/ + TagIDs []string `bun:"tags,array"` + Tags []*Tag `bun:"attached_tags,m2m:status_to_tags"` // https://bun.uptrace.dev/guide/relations.html#many-to-many-relation // Database IDs of any mentions in this status - MentionIDs []string `pg:"mentions,array"` - Mentions []*Mention `pg:"attached_mentions,rel:has-many"` + MentionIDs []string `bun:"mentions,array"` + Mentions []*Mention `bun:"attached_mentions,rel:has-many"` // Database IDs of any emojis used in this status - EmojiIDs []string `pg:"emojis,array"` - Emojis []*Emoji `pg:"attached_emojis,many2many:status_to_emojis"` // https://pg.uptrace.dev/orm/many-to-many-relation/ + EmojiIDs []string `bun:"emojis,array"` + Emojis []*Emoji `bun:"attached_emojis,m2m:status_to_emojis"` // https://bun.uptrace.dev/guide/relations.html#many-to-many-relation // when was this status created? - CreatedAt time.Time `pg:"type:timestamp,notnull,default:now()"` + CreatedAt time.Time `bun:",notnull,nullzero,default:current_timestamp"` // when was this status updated? - UpdatedAt time.Time `pg:"type:timestamp,notnull,default:now()"` + UpdatedAt time.Time `bun:",notnull,nullzero,default:current_timestamp"` // is this status from a local account? Local bool // which account posted this status? - AccountID string `pg:"type:CHAR(26),notnull"` - Account *Account `pg:"rel:has-one"` + AccountID string `bun:"type:CHAR(26),notnull"` + Account *Account `bun:"rel:belongs-to"` // AP uri of the owner of this status AccountURI string // id of the status this status is a reply to - InReplyToID string `pg:"type:CHAR(26)"` - InReplyTo *Status `pg:"rel:has-one"` + InReplyToID string `bun:"type:CHAR(26),nullzero"` + InReplyTo *Status `bun:"-"` // AP uri of the status this status is a reply to InReplyToURI string // id of the account that this status replies to - InReplyToAccountID string `pg:"type:CHAR(26)"` - InReplyToAccount *Account `pg:"rel:has-one"` + InReplyToAccountID string `bun:"type:CHAR(26),nullzero"` + InReplyToAccount *Account `bun:"rel:belongs-to"` // id of the status this status is a boost of - BoostOfID string `pg:"type:CHAR(26)"` - BoostOf *Status `pg:"rel:has-one"` + BoostOfID string `bun:"type:CHAR(26),nullzero"` + BoostOf *Status `bun:"-"` // id of the account that owns the boosted status - BoostOfAccountID string `pg:"type:CHAR(26)"` - BoostOfAccount *Account `pg:"rel:has-one"` + BoostOfAccountID string `bun:"type:CHAR(26),nullzero"` + BoostOfAccount *Account `bun:"rel:belongs-to"` // cw string for this status ContentWarning string // visibility entry for this status - Visibility Visibility `pg:",notnull"` + Visibility Visibility `bun:",notnull"` // mark the status as sensitive? Sensitive bool // what language is this status written in? Language string // Which application was used to create this status? - CreatedWithApplicationID string `pg:"type:CHAR(26)"` - CreatedWithApplication *Application `pg:"rel:has-one"` + CreatedWithApplicationID string `bun:"type:CHAR(26),nullzero"` + CreatedWithApplication *Application `bun:"rel:belongs-to"` // advanced visibility for this status VisibilityAdvanced *VisibilityAdvanced // What is the activitystreams type of this status? See: https://www.w3.org/TR/activitystreams-vocabulary/#object-types @@ -93,14 +93,18 @@ type Status struct { // StatusToTag is an intermediate struct to facilitate the many2many relationship between a status and one or more tags. type StatusToTag struct { - StatusID string `pg:"unique:statustag"` - TagID string `pg:"unique:statustag"` + StatusID string `bun:"type:CHAR(26),unique:statustag,nullzero"` + Status *Status `bun:"rel:belongs-to"` + TagID string `bun:"type:CHAR(26),unique:statustag,nullzero"` + Tag *Tag `bun:"rel:belongs-to"` } // StatusToEmoji is an intermediate struct to facilitate the many2many relationship between a status and one or more emojis. type StatusToEmoji struct { - StatusID string `pg:"unique:statusemoji"` - EmojiID string `pg:"unique:statusemoji"` + StatusID string `bun:"type:CHAR(26),unique:statusemoji,nullzero"` + Status *Status `bun:"rel:belongs-to"` + EmojiID string `bun:"type:CHAR(26),unique:statusemoji,nullzero"` + Emoji *Emoji `bun:"rel:belongs-to"` } // Visibility represents the visibility granularity of a status. @@ -134,11 +138,11 @@ const ( // If DIRECT is selected, boostable will be FALSE, and all other flags will be TRUE. type VisibilityAdvanced struct { // This status will be federated beyond the local timeline(s) - Federated bool `pg:"default:true"` + Federated bool `bun:"default:true"` // This status can be boosted/reblogged - Boostable bool `pg:"default:true"` + Boostable bool `bun:"default:true"` // This status can be replied to - Replyable bool `pg:"default:true"` + Replyable bool `bun:"default:true"` // This status can be liked/faved - Likeable bool `pg:"default:true"` + Likeable bool `bun:"default:true"` } diff --git a/internal/gtsmodel/statusbookmark.go b/internal/gtsmodel/statusbookmark.go index 468939bae..26dafa420 100644 --- a/internal/gtsmodel/statusbookmark.go +++ b/internal/gtsmodel/statusbookmark.go @@ -23,15 +23,15 @@ import "time" // StatusBookmark refers to one account having a 'bookmark' of the status of another account type StatusBookmark struct { // id of this bookmark in the database - ID string `pg:"type:CHAR(26),pk,notnull,unique"` + ID string `bun:"type:CHAR(26),pk,notnull,unique"` // when was this bookmark created - CreatedAt time.Time `pg:"type:timestamp,notnull,default:now()"` + CreatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"` // id of the account that created ('did') the bookmarking - AccountID string `pg:"type:CHAR(26),notnull"` - Account *Account `pg:"rel:belongs-to"` + AccountID string `bun:"type:CHAR(26),notnull"` + Account *Account `bun:"rel:belongs-to"` // id the account owning the bookmarked status - TargetAccountID string `pg:"type:CHAR(26),notnull"` - TargetAccount *Account `pg:"rel:has-one"` + TargetAccountID string `bun:"type:CHAR(26),notnull"` + TargetAccount *Account `bun:"rel:belongs-to"` // database id of the status that has been bookmarked - StatusID string `pg:"type:CHAR(26),notnull"` + StatusID string `bun:"type:CHAR(26),notnull"` } diff --git a/internal/gtsmodel/statusfave.go b/internal/gtsmodel/statusfave.go index 17952673a..3b816af56 100644 --- a/internal/gtsmodel/statusfave.go +++ b/internal/gtsmodel/statusfave.go @@ -23,18 +23,18 @@ import "time" // StatusFave refers to a 'fave' or 'like' in the database, from one account, targeting the status of another account type StatusFave struct { // id of this fave in the database - ID string `pg:"type:CHAR(26),pk,notnull,unique"` + ID string `bun:"type:CHAR(26),pk,notnull,unique"` // when was this fave created - CreatedAt time.Time `pg:"type:timestamp,notnull,default:now()"` + CreatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"` // id of the account that created ('did') the fave - AccountID string `pg:"type:CHAR(26),notnull"` - Account *Account `pg:"rel:has-one"` + AccountID string `bun:"type:CHAR(26),notnull"` + Account *Account `bun:"rel:belongs-to"` // id the account owning the faved status - TargetAccountID string `pg:"type:CHAR(26),notnull"` - TargetAccount *Account `pg:"rel:has-one"` + TargetAccountID string `bun:"type:CHAR(26),notnull"` + TargetAccount *Account `bun:"rel:belongs-to"` // database id of the status that has been 'faved' - StatusID string `pg:"type:CHAR(26),notnull"` - Status *Status `pg:"rel:has-one"` + StatusID string `bun:"type:CHAR(26),notnull"` + Status *Status `bun:"rel:belongs-to"` // ActivityPub URI of this fave - URI string `pg:",notnull"` + URI string `bun:",notnull"` } diff --git a/internal/gtsmodel/statusmute.go b/internal/gtsmodel/statusmute.go index 472a5ec09..56a792ab4 100644 --- a/internal/gtsmodel/statusmute.go +++ b/internal/gtsmodel/statusmute.go @@ -23,16 +23,16 @@ import "time" // StatusMute refers to one account having muted the status of another account or its own type StatusMute struct { // id of this mute in the database - ID string `pg:"type:CHAR(26),pk,notnull,unique"` + ID string `bun:"type:CHAR(26),pk,notnull,unique"` // when was this mute created - CreatedAt time.Time `pg:"type:timestamp,notnull,default:now()"` + CreatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"` // id of the account that created ('did') the mute - AccountID string `pg:"type:CHAR(26),notnull"` - Account *Account `pg:"rel:belongs-to"` + AccountID string `bun:"type:CHAR(26),notnull"` + Account *Account `bun:"rel:belongs-to"` // id the account owning the muted status (can be the same as accountID) - TargetAccountID string `pg:"type:CHAR(26),notnull"` - TargetAccount *Account `pg:"rel:has-one"` + TargetAccountID string `bun:"type:CHAR(26),notnull"` + TargetAccount *Account `bun:"rel:belongs-to"` // database id of the status that has been muted - StatusID string `pg:"type:CHAR(26),notnull"` - Status *Status `pg:"rel:has-one"` + StatusID string `bun:"type:CHAR(26),notnull"` + Status *Status `bun:"rel:belongs-to"` } diff --git a/internal/gtsmodel/tag.go b/internal/gtsmodel/tag.go index 27cce1c8b..5006a36f4 100644 --- a/internal/gtsmodel/tag.go +++ b/internal/gtsmodel/tag.go @@ -23,21 +23,21 @@ import "time" // Tag represents a hashtag for gathering public statuses together type Tag struct { // id of this tag in the database - ID string `pg:",unique,type:CHAR(26),pk,notnull"` + ID string `bun:",unique,type:CHAR(26),pk,notnull"` // Href of this tag, eg https://example.org/tags/somehashtag URL string // name of this tag -- the tag without the hash part - Name string `pg:",unique,notnull"` + Name string `bun:",unique,notnull"` // Which account ID is the first one we saw using this tag? - FirstSeenFromAccountID string `pg:"type:CHAR(26)"` + FirstSeenFromAccountID string `bun:"type:CHAR(26),nullzero"` // when was this tag created - CreatedAt time.Time `pg:"type:timestamp,notnull,default:now()"` + CreatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"` // when was this tag last updated - UpdatedAt time.Time `pg:"type:timestamp,notnull,default:now()"` + UpdatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"` // can our instance users use this tag? - Useable bool `pg:",notnull,default:true"` + Useable bool `bun:",notnull,default:true"` // can our instance users look up this tag? - Listable bool `pg:",notnull,default:true"` + Listable bool `bun:",notnull,default:true"` // when was this tag last used? - LastStatusAt time.Time `pg:"type:timestamp,notnull,default:now()"` + LastStatusAt time.Time `bun:",nullzero"` } diff --git a/internal/gtsmodel/user.go b/internal/gtsmodel/user.go index fe8ebcabe..f439be439 100644 --- a/internal/gtsmodel/user.go +++ b/internal/gtsmodel/user.go @@ -31,37 +31,37 @@ type User struct { */ // id of this user in the local database; the end-user will never need to know this, it's strictly internal - ID string `pg:"type:CHAR(26),pk,notnull,unique"` + ID string `bun:"type:CHAR(26),pk,notnull,unique"` // confirmed email address for this user, this should be unique -- only one email address registered per instance, multiple users per email are not supported - Email string `pg:"default:null,unique"` + Email string `bun:"default:null,unique,nullzero"` // The id of the local gtsmodel.Account entry for this user, if it exists (unconfirmed users don't have an account yet) - AccountID string `pg:"type:CHAR(26),unique"` - Account *Account `pg:"rel:has-one"` + AccountID string `bun:"type:CHAR(26),unique,nullzero"` + Account *Account `bun:"rel:belongs-to"` // The encrypted password of this user, generated using https://pkg.go.dev/golang.org/x/crypto/bcrypt#GenerateFromPassword. A salt is included so we're safe against 🌈 tables - EncryptedPassword string `pg:",notnull"` + EncryptedPassword string `bun:",notnull"` /* USER METADATA */ // When was this user created? - CreatedAt time.Time `pg:"type:timestamp,notnull,default:now()"` + CreatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"` // From what IP was this user created? - SignUpIP net.IP + SignUpIP net.IP `bun:",nullzero"` // When was this user updated (eg., password changed, email address changed)? - UpdatedAt time.Time `pg:"type:timestamp,notnull,default:now()"` + UpdatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"` // When did this user sign in for their current session? - CurrentSignInAt time.Time `pg:"type:timestamp"` + CurrentSignInAt time.Time `bun:",nullzero"` // What's the most recent IP of this user - CurrentSignInIP net.IP + CurrentSignInIP net.IP `bun:",nullzero"` // When did this user last sign in? - LastSignInAt time.Time `pg:"type:timestamp"` + LastSignInAt time.Time `bun:",nullzero"` // What's the previous IP of this user? - LastSignInIP net.IP + LastSignInIP net.IP `bun:",nullzero"` // How many times has this user signed in? SignInCount int // id of the user who invited this user (who let this guy in?) - InviteID string `pg:"type:CHAR(26)"` + InviteID string `bun:"type:CHAR(26),nullzero"` // What languages does this user want to see? ChosenLanguages []string // What languages does this user not want to see? @@ -69,10 +69,10 @@ type User struct { // In what timezone/locale is this user located? Locale string // Which application id created this user? See gtsmodel.Application - CreatedByApplicationID string `pg:"type:CHAR(26)"` - CreatedByApplication *Application `pg:"rel:has-one"` + CreatedByApplicationID string `bun:"type:CHAR(26),nullzero"` + CreatedByApplication *Application `bun:"rel:belongs-to"` // When did we last contact this user - LastEmailedAt time.Time `pg:"type:timestamp"` + LastEmailedAt time.Time `bun:",nullzero"` /* USER CONFIRMATION @@ -81,9 +81,9 @@ type User struct { // What confirmation token did we send this user/what are we expecting back? ConfirmationToken string // When did the user confirm their email address - ConfirmedAt time.Time `pg:"type:timestamp"` + ConfirmedAt time.Time `bun:",nullzero"` // When did we send email confirmation to this user? - ConfirmationSentAt time.Time `pg:"type:timestamp"` + ConfirmationSentAt time.Time `bun:",nullzero"` // Email address that hasn't yet been confirmed UnconfirmedEmail string @@ -107,7 +107,7 @@ type User struct { // The generated token that the user can use to reset their password ResetPasswordToken string // When did we email the user their reset-password email? - ResetPasswordSentAt time.Time `pg:"type:timestamp"` + ResetPasswordSentAt time.Time `bun:",nullzero"` EncryptedOTPSecret string EncryptedOTPSecretIv string @@ -117,6 +117,6 @@ type User struct { ConsumedTimestamp int RememberToken string SignInToken string - SignInTokenSentAt time.Time `pg:"type:timestamp"` + SignInTokenSentAt time.Time `bun:",nullzero"` WebauthnID string } diff --git a/internal/media/handler.go b/internal/media/handler.go index c383a922e..1150f7e87 100644 --- a/internal/media/handler.go +++ b/internal/media/handler.go @@ -67,27 +67,27 @@ type Handler interface { // ProcessHeaderOrAvatar takes a new header image for an account, checks it out, removes exif data from it, // puts it in whatever storage backend we're using, sets the relevant fields in the database for the new image, // and then returns information to the caller about the new header. - ProcessHeaderOrAvatar(attachment []byte, accountID string, mediaType Type, remoteURL string) (*gtsmodel.MediaAttachment, error) + ProcessHeaderOrAvatar(ctx context.Context, attachment []byte, accountID string, mediaType Type, remoteURL string) (*gtsmodel.MediaAttachment, error) // ProcessLocalAttachment takes a new attachment and the requesting account, checks it out, removes exif data from it, // puts it in whatever storage backend we're using, sets the relevant fields in the database for the new media, // and then returns information to the caller about the attachment. It's the caller's responsibility to put the returned struct // in the database. - ProcessAttachment(attachment []byte, accountID string, remoteURL string) (*gtsmodel.MediaAttachment, error) + ProcessAttachment(ctx context.Context, attachment []byte, accountID string, remoteURL string) (*gtsmodel.MediaAttachment, error) // ProcessLocalEmoji takes a new emoji and a shortcode, cleans it up, puts it in storage, and creates a new // *gts.Emoji for it, then returns it to the caller. It's the caller's responsibility to put the returned struct // in the database. - ProcessLocalEmoji(emojiBytes []byte, shortcode string) (*gtsmodel.Emoji, error) + ProcessLocalEmoji(ctx context.Context, emojiBytes []byte, shortcode string) (*gtsmodel.Emoji, error) // ProcessRemoteAttachment takes a transport, a bare-bones current attachment, and an accountID that the attachment belongs to. // It then dereferences the attachment (ie., fetches the attachment bytes from the remote server), ensuring that the bytes are // the correct content type. It stores the attachment in whatever storage backend the Handler has been initalized with, and returns // information to the caller about the new attachment. It's the caller's responsibility to put the returned struct // in the database. - ProcessRemoteAttachment(t transport.Transport, currentAttachment *gtsmodel.MediaAttachment, accountID string) (*gtsmodel.MediaAttachment, error) + ProcessRemoteAttachment(ctx context.Context, t transport.Transport, currentAttachment *gtsmodel.MediaAttachment, accountID string) (*gtsmodel.MediaAttachment, error) - ProcessRemoteHeaderOrAvatar(t transport.Transport, currentAttachment *gtsmodel.MediaAttachment, accountID string) (*gtsmodel.MediaAttachment, error) + ProcessRemoteHeaderOrAvatar(ctx context.Context, t transport.Transport, currentAttachment *gtsmodel.MediaAttachment, accountID string) (*gtsmodel.MediaAttachment, error) } type mediaHandler struct { @@ -114,7 +114,7 @@ func New(config *config.Config, database db.DB, storage blob.Storage, log *logru // ProcessHeaderOrAvatar takes a new header image for an account, checks it out, removes exif data from it, // puts it in whatever storage backend we're using, sets the relevant fields in the database for the new image, // and then returns information to the caller about the new header. -func (mh *mediaHandler) ProcessHeaderOrAvatar(attachment []byte, accountID string, mediaType Type, remoteURL string) (*gtsmodel.MediaAttachment, error) { +func (mh *mediaHandler) ProcessHeaderOrAvatar(ctx context.Context, attachment []byte, accountID string, mediaType Type, remoteURL string) (*gtsmodel.MediaAttachment, error) { l := mh.log.WithField("func", "SetHeaderForAccountID") if mediaType != Header && mediaType != Avatar { @@ -142,7 +142,7 @@ func (mh *mediaHandler) ProcessHeaderOrAvatar(attachment []byte, accountID strin } // set it in the database - if err := mh.db.SetAccountHeaderOrAvatar(ma, accountID); err != nil { + if err := mh.db.SetAccountHeaderOrAvatar(ctx, ma, accountID); err != nil { return nil, fmt.Errorf("error putting %s in database: %s", mediaType, err) } @@ -152,7 +152,7 @@ func (mh *mediaHandler) ProcessHeaderOrAvatar(attachment []byte, accountID strin // ProcessAttachment takes a new attachment and the owning account, checks it out, removes exif data from it, // puts it in whatever storage backend we're using, sets the relevant fields in the database for the new media, // and then returns information to the caller about the attachment. -func (mh *mediaHandler) ProcessAttachment(attachment []byte, accountID string, remoteURL string) (*gtsmodel.MediaAttachment, error) { +func (mh *mediaHandler) ProcessAttachment(ctx context.Context, attachment []byte, accountID string, remoteURL string) (*gtsmodel.MediaAttachment, error) { contentType, err := parseContentType(attachment) if err != nil { return nil, err @@ -184,7 +184,7 @@ func (mh *mediaHandler) ProcessAttachment(attachment []byte, accountID string, r // ProcessLocalEmoji takes a new emoji and a shortcode, cleans it up, puts it in storage, and creates a new // *gts.Emoji for it, then returns it to the caller. It's the caller's responsibility to put the returned struct // in the database. -func (mh *mediaHandler) ProcessLocalEmoji(emojiBytes []byte, shortcode string) (*gtsmodel.Emoji, error) { +func (mh *mediaHandler) ProcessLocalEmoji(ctx context.Context, emojiBytes []byte, shortcode string) (*gtsmodel.Emoji, error) { var clean []byte var err error var original *imageAndMeta @@ -231,7 +231,7 @@ func (mh *mediaHandler) ProcessLocalEmoji(emojiBytes []byte, shortcode string) ( // since emoji aren't 'owned' by an account, but we still want to use the same pattern for serving them through the filserver, // (ie., fileserver/ACCOUNT_ID/etc etc) we need to fetch the INSTANCE ACCOUNT from the database. That is, the account that's created // with the same username as the instance hostname, which doesn't belong to any particular user. - instanceAccount, err := mh.db.GetInstanceAccount("") + instanceAccount, err := mh.db.GetInstanceAccount(ctx, "") if err != nil { return nil, fmt.Errorf("error fetching instance account: %s", err) } @@ -296,7 +296,7 @@ func (mh *mediaHandler) ProcessLocalEmoji(emojiBytes []byte, shortcode string) ( return e, nil } -func (mh *mediaHandler) ProcessRemoteAttachment(t transport.Transport, currentAttachment *gtsmodel.MediaAttachment, accountID string) (*gtsmodel.MediaAttachment, error) { +func (mh *mediaHandler) ProcessRemoteAttachment(ctx context.Context, t transport.Transport, currentAttachment *gtsmodel.MediaAttachment, accountID string) (*gtsmodel.MediaAttachment, error) { if currentAttachment.RemoteURL == "" { return nil, errors.New("no remote URL on media attachment to dereference") } @@ -317,10 +317,10 @@ func (mh *mediaHandler) ProcessRemoteAttachment(t transport.Transport, currentAt return nil, fmt.Errorf("dereferencing remote media with url %s: %s", remoteIRI.String(), err) } - return mh.ProcessAttachment(attachmentBytes, accountID, currentAttachment.RemoteURL) + return mh.ProcessAttachment(ctx, attachmentBytes, accountID, currentAttachment.RemoteURL) } -func (mh *mediaHandler) ProcessRemoteHeaderOrAvatar(t transport.Transport, currentAttachment *gtsmodel.MediaAttachment, accountID string) (*gtsmodel.MediaAttachment, error) { +func (mh *mediaHandler) ProcessRemoteHeaderOrAvatar(ctx context.Context, t transport.Transport, currentAttachment *gtsmodel.MediaAttachment, accountID string) (*gtsmodel.MediaAttachment, error) { if !currentAttachment.Header && !currentAttachment.Avatar { return nil, errors.New("provided attachment was set to neither header nor avatar") @@ -357,5 +357,5 @@ func (mh *mediaHandler) ProcessRemoteHeaderOrAvatar(t transport.Transport, curre return nil, fmt.Errorf("dereferencing remote media with url %s: %s", remoteIRI.String(), err) } - return mh.ProcessHeaderOrAvatar(attachmentBytes, accountID, headerOrAvi, currentAttachment.RemoteURL) + return mh.ProcessHeaderOrAvatar(ctx, attachmentBytes, accountID, headerOrAvi, currentAttachment.RemoteURL) } diff --git a/internal/oauth/clientstore.go b/internal/oauth/clientstore.go index 2e7e0ae88..a642f6cfa 100644 --- a/internal/oauth/clientstore.go +++ b/internal/oauth/clientstore.go @@ -39,10 +39,8 @@ func NewClientStore(db db.Basic) oauth2.ClientStore { } func (cs *clientStore) GetByID(ctx context.Context, clientID string) (oauth2.ClientInfo, error) { - poc := &Client{ - ID: clientID, - } - if err := cs.db.GetByID(clientID, poc); err != nil { + poc := &Client{} + if err := cs.db.GetByID(ctx, clientID, poc); err != nil { return nil, err } return models.New(poc.ID, poc.Secret, poc.Domain, poc.UserID), nil @@ -55,19 +53,19 @@ func (cs *clientStore) Set(ctx context.Context, id string, cli oauth2.ClientInfo Domain: cli.GetDomain(), UserID: cli.GetUserID(), } - return cs.db.UpdateByID(id, poc) + return cs.db.Put(ctx, poc) } func (cs *clientStore) Delete(ctx context.Context, id string) error { poc := &Client{ ID: id, } - return cs.db.DeleteByID(id, poc) + return cs.db.DeleteByID(ctx, id, poc) } // Client is a handy little wrapper for typical oauth client details type Client struct { - ID string `pg:"type:CHAR(26),pk,notnull"` + ID string `bun:"type:CHAR(26),pk,notnull"` Secret string Domain string UserID string diff --git a/internal/oauth/tokenstore.go b/internal/oauth/tokenstore.go index 4fd3183fc..264678ff5 100644 --- a/internal/oauth/tokenstore.go +++ b/internal/oauth/tokenstore.go @@ -43,13 +43,13 @@ type tokenStore struct { // In order to allow tokens to 'expire', it will also set off a goroutine that iterates through // the tokens in the DB once per minute and deletes any that have expired. func newTokenStore(ctx context.Context, db db.Basic, log *logrus.Logger) oauth2.TokenStore { - pts := &tokenStore{ + ts := &tokenStore{ db: db, log: log, } // set the token store to clean out expired tokens once per minute, or return if we're done - go func(ctx context.Context, pts *tokenStore, log *logrus.Logger) { + go func(ctx context.Context, ts *tokenStore, log *logrus.Logger) { cleanloop: for { select { @@ -58,32 +58,32 @@ func newTokenStore(ctx context.Context, db db.Basic, log *logrus.Logger) oauth2. break cleanloop case <-time.After(1 * time.Minute): log.Trace("sweeping out old oauth entries broom broom") - if err := pts.sweep(); err != nil { + if err := ts.sweep(ctx); err != nil { log.Errorf("error while sweeping oauth entries: %s", err) } } } - }(ctx, pts, log) - return pts + }(ctx, ts, log) + return ts } // sweep clears out old tokens that have expired; it should be run on a loop about once per minute or so. -func (pts *tokenStore) sweep() error { +func (ts *tokenStore) sweep(ctx context.Context) error { // select *all* tokens from the db // todo: if this becomes expensive (ie., there are fucking LOADS of tokens) then figure out a better way. tokens := new([]*Token) - if err := pts.db.GetAll(tokens); err != nil { + if err := ts.db.GetAll(ctx, tokens); err != nil { return err } // iterate through and remove expired tokens now := time.Now() - for _, pgt := range *tokens { + for _, dbt := range *tokens { // The zero value of a time.Time is 00:00 january 1 1970, which will always be before now. So: // we only want to check if a token expired before now if the expiry time is *not zero*; // ie., if it's been explicity set. - if !pgt.CodeExpiresAt.IsZero() && pgt.CodeExpiresAt.Before(now) || !pgt.RefreshExpiresAt.IsZero() && pgt.RefreshExpiresAt.Before(now) || !pgt.AccessExpiresAt.IsZero() && pgt.AccessExpiresAt.Before(now) { - if err := pts.db.DeleteByID(pgt.ID, pgt); err != nil { + if !dbt.CodeExpiresAt.IsZero() && dbt.CodeExpiresAt.Before(now) || !dbt.RefreshExpiresAt.IsZero() && dbt.RefreshExpiresAt.Before(now) || !dbt.AccessExpiresAt.IsZero() && dbt.AccessExpiresAt.Before(now) { + if err := ts.db.DeleteByID(ctx, dbt.ID, dbt); err != nil { return err } } @@ -94,92 +94,92 @@ func (pts *tokenStore) sweep() error { // Create creates and store the new token information. // For the original implementation, see https://github.com/superseriousbusiness/oauth2/blob/master/store/token.go#L34 -func (pts *tokenStore) Create(ctx context.Context, info oauth2.TokenInfo) error { +func (ts *tokenStore) Create(ctx context.Context, info oauth2.TokenInfo) error { t, ok := info.(*models.Token) if !ok { return errors.New("info param was not a models.Token") } - pgt := TokenToPGToken(t) - if pgt.ID == "" { - pgtID, err := id.NewRandomULID() + dbt := TokenToDBToken(t) + if dbt.ID == "" { + dbtID, err := id.NewRandomULID() if err != nil { return err } - pgt.ID = pgtID + dbt.ID = dbtID } - if err := pts.db.Put(pgt); err != nil { + if err := ts.db.Put(ctx, dbt); err != nil { return fmt.Errorf("error in tokenstore create: %s", err) } return nil } // RemoveByCode deletes a token from the DB based on the Code field -func (pts *tokenStore) RemoveByCode(ctx context.Context, code string) error { - return pts.db.DeleteWhere([]db.Where{{Key: "code", Value: code}}, &Token{}) +func (ts *tokenStore) RemoveByCode(ctx context.Context, code string) error { + return ts.db.DeleteWhere(ctx, []db.Where{{Key: "code", Value: code}}, &Token{}) } // RemoveByAccess deletes a token from the DB based on the Access field -func (pts *tokenStore) RemoveByAccess(ctx context.Context, access string) error { - return pts.db.DeleteWhere([]db.Where{{Key: "access", Value: access}}, &Token{}) +func (ts *tokenStore) RemoveByAccess(ctx context.Context, access string) error { + return ts.db.DeleteWhere(ctx, []db.Where{{Key: "access", Value: access}}, &Token{}) } // RemoveByRefresh deletes a token from the DB based on the Refresh field -func (pts *tokenStore) RemoveByRefresh(ctx context.Context, refresh string) error { - return pts.db.DeleteWhere([]db.Where{{Key: "refresh", Value: refresh}}, &Token{}) +func (ts *tokenStore) RemoveByRefresh(ctx context.Context, refresh string) error { + return ts.db.DeleteWhere(ctx, []db.Where{{Key: "refresh", Value: refresh}}, &Token{}) } // GetByCode selects a token from the DB based on the Code field -func (pts *tokenStore) GetByCode(ctx context.Context, code string) (oauth2.TokenInfo, error) { +func (ts *tokenStore) GetByCode(ctx context.Context, code string) (oauth2.TokenInfo, error) { if code == "" { return nil, nil } - pgt := &Token{ + dbt := &Token{ Code: code, } - if err := pts.db.GetWhere([]db.Where{{Key: "code", Value: code}}, pgt); err != nil { + if err := ts.db.GetWhere(ctx, []db.Where{{Key: "code", Value: code}}, dbt); err != nil { return nil, err } - return TokenToOauthToken(pgt), nil + return DBTokenToToken(dbt), nil } // GetByAccess selects a token from the DB based on the Access field -func (pts *tokenStore) GetByAccess(ctx context.Context, access string) (oauth2.TokenInfo, error) { +func (ts *tokenStore) GetByAccess(ctx context.Context, access string) (oauth2.TokenInfo, error) { if access == "" { return nil, nil } - pgt := &Token{ + dbt := &Token{ Access: access, } - if err := pts.db.GetWhere([]db.Where{{Key: "access", Value: access}}, pgt); err != nil { + if err := ts.db.GetWhere(ctx, []db.Where{{Key: "access", Value: access}}, dbt); err != nil { return nil, err } - return TokenToOauthToken(pgt), nil + return DBTokenToToken(dbt), nil } // GetByRefresh selects a token from the DB based on the Refresh field -func (pts *tokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2.TokenInfo, error) { +func (ts *tokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2.TokenInfo, error) { if refresh == "" { return nil, nil } - pgt := &Token{ + dbt := &Token{ Refresh: refresh, } - if err := pts.db.GetWhere([]db.Where{{Key: "refresh", Value: refresh}}, pgt); err != nil { + if err := ts.db.GetWhere(ctx, []db.Where{{Key: "refresh", Value: refresh}}, dbt); err != nil { return nil, err } - return TokenToOauthToken(pgt), nil + return DBTokenToToken(dbt), nil } /* - The following models are basically helpers for the postgres token store implementation, they should only be used internally. + The following models are basically helpers for the token store implementation, they should only be used internally. */ // Token is a translation of the gotosocial token with the ExpiresIn fields replaced with ExpiresAt. // // Explanation for this: gotosocial assumes an in-memory or file database of some kind, where a time-to-live parameter (TTL) can be defined, -// and tokens with expired TTLs are automatically removed. Since Postgres doesn't have that feature, it's easier to set an expiry time and +// and tokens with expired TTLs are automatically removed. Since some databases don't have that feature, it's easier to set an expiry time and // then periodically sweep out tokens when that time has passed. // // Note that this struct does *not* satisfy the token interface shown here: https://github.com/superseriousbusiness/oauth2/blob/master/model.go#L22 @@ -187,26 +187,26 @@ func (pts *tokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2 // As such, manual translation is always required between Token and the gotosocial *model.Token. The helper functions oauthTokenToPGToken // and pgTokenToOauthToken can be used for that. type Token struct { - ID string `pg:"type:CHAR(26),pk,notnull"` + ID string `bun:"type:CHAR(26),pk,notnull"` ClientID string UserID string RedirectURI string Scope string - Code string `pg:"default:'',pk"` + Code string `bun:"default:'',pk"` CodeChallenge string CodeChallengeMethod string - CodeCreateAt time.Time `pg:"type:timestamp"` - CodeExpiresAt time.Time `pg:"type:timestamp"` - Access string `pg:"default:'',pk"` - AccessCreateAt time.Time `pg:"type:timestamp"` - AccessExpiresAt time.Time `pg:"type:timestamp"` - Refresh string `pg:"default:'',pk"` - RefreshCreateAt time.Time `pg:"type:timestamp"` - RefreshExpiresAt time.Time `pg:"type:timestamp"` + CodeCreateAt time.Time `bun:",nullzero"` + CodeExpiresAt time.Time `bun:",nullzero"` + Access string `bun:"default:'',pk"` + AccessCreateAt time.Time `bun:",nullzero"` + AccessExpiresAt time.Time `bun:",nullzero"` + Refresh string `bun:"default:'',pk"` + RefreshCreateAt time.Time `bun:",nullzero"` + RefreshExpiresAt time.Time `bun:",nullzero"` } -// TokenToPGToken is a lil util function that takes a gotosocial token and gives back a token for inserting into postgres -func TokenToPGToken(tkn *models.Token) *Token { +// TokenToDBToken is a lil util function that takes a gotosocial token and gives back a token for inserting into a database. +func TokenToDBToken(tkn *models.Token) *Token { now := time.Now() // For the following, we want to make sure we're not adding a time.Now() to an *empty* ExpiresIn, otherwise that's @@ -247,40 +247,40 @@ func TokenToPGToken(tkn *models.Token) *Token { } } -// TokenToOauthToken is a lil util function that takes a postgres token and gives back a gotosocial token -func TokenToOauthToken(pgt *Token) *models.Token { +// DBTokenToToken is a lil util function that takes a database token and gives back a gotosocial token +func DBTokenToToken(dbt *Token) *models.Token { now := time.Now() var codeExpiresIn time.Duration - if !pgt.CodeExpiresAt.IsZero() { - codeExpiresIn = pgt.CodeExpiresAt.Sub(now) + if !dbt.CodeExpiresAt.IsZero() { + codeExpiresIn = dbt.CodeExpiresAt.Sub(now) } var accessExpiresIn time.Duration - if !pgt.AccessExpiresAt.IsZero() { - accessExpiresIn = pgt.AccessExpiresAt.Sub(now) + if !dbt.AccessExpiresAt.IsZero() { + accessExpiresIn = dbt.AccessExpiresAt.Sub(now) } var refreshExpiresIn time.Duration - if !pgt.RefreshExpiresAt.IsZero() { - refreshExpiresIn = pgt.RefreshExpiresAt.Sub(now) + if !dbt.RefreshExpiresAt.IsZero() { + refreshExpiresIn = dbt.RefreshExpiresAt.Sub(now) } return &models.Token{ - ClientID: pgt.ClientID, - UserID: pgt.UserID, - RedirectURI: pgt.RedirectURI, - Scope: pgt.Scope, - Code: pgt.Code, - CodeChallenge: pgt.CodeChallenge, - CodeChallengeMethod: pgt.CodeChallengeMethod, - CodeCreateAt: pgt.CodeCreateAt, + ClientID: dbt.ClientID, + UserID: dbt.UserID, + RedirectURI: dbt.RedirectURI, + Scope: dbt.Scope, + Code: dbt.Code, + CodeChallenge: dbt.CodeChallenge, + CodeChallengeMethod: dbt.CodeChallengeMethod, + CodeCreateAt: dbt.CodeCreateAt, CodeExpiresIn: codeExpiresIn, - Access: pgt.Access, - AccessCreateAt: pgt.AccessCreateAt, + Access: dbt.Access, + AccessCreateAt: dbt.AccessCreateAt, AccessExpiresIn: accessExpiresIn, - Refresh: pgt.Refresh, - RefreshCreateAt: pgt.RefreshCreateAt, + Refresh: dbt.Refresh, + RefreshCreateAt: dbt.RefreshCreateAt, RefreshExpiresIn: refreshExpiresIn, } } diff --git a/internal/processing/account.go b/internal/processing/account.go index f722c88eb..94ba596ac 100644 --- a/internal/processing/account.go +++ b/internal/processing/account.go @@ -19,51 +19,53 @@ package processing import ( + "context" + apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/oauth" ) -func (p *processor) AccountCreate(authed *oauth.Auth, form *apimodel.AccountCreateRequest) (*apimodel.Token, error) { - return p.accountProcessor.Create(authed.Token, authed.Application, form) +func (p *processor) AccountCreate(ctx context.Context, authed *oauth.Auth, form *apimodel.AccountCreateRequest) (*apimodel.Token, error) { + return p.accountProcessor.Create(ctx, authed.Token, authed.Application, form) } -func (p *processor) AccountGet(authed *oauth.Auth, targetAccountID string) (*apimodel.Account, error) { - return p.accountProcessor.Get(authed.Account, targetAccountID) +func (p *processor) AccountGet(ctx context.Context, authed *oauth.Auth, targetAccountID string) (*apimodel.Account, error) { + return p.accountProcessor.Get(ctx, authed.Account, targetAccountID) } -func (p *processor) AccountUpdate(authed *oauth.Auth, form *apimodel.UpdateCredentialsRequest) (*apimodel.Account, error) { - return p.accountProcessor.Update(authed.Account, form) +func (p *processor) AccountUpdate(ctx context.Context, authed *oauth.Auth, form *apimodel.UpdateCredentialsRequest) (*apimodel.Account, error) { + return p.accountProcessor.Update(ctx, authed.Account, form) } -func (p *processor) AccountStatusesGet(authed *oauth.Auth, targetAccountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]apimodel.Status, gtserror.WithCode) { - return p.accountProcessor.StatusesGet(authed.Account, targetAccountID, limit, excludeReplies, maxID, pinnedOnly, mediaOnly) +func (p *processor) AccountStatusesGet(ctx context.Context, authed *oauth.Auth, targetAccountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]apimodel.Status, gtserror.WithCode) { + return p.accountProcessor.StatusesGet(ctx, authed.Account, targetAccountID, limit, excludeReplies, maxID, pinnedOnly, mediaOnly) } -func (p *processor) AccountFollowersGet(authed *oauth.Auth, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) { - return p.accountProcessor.FollowersGet(authed.Account, targetAccountID) +func (p *processor) AccountFollowersGet(ctx context.Context, authed *oauth.Auth, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) { + return p.accountProcessor.FollowersGet(ctx, authed.Account, targetAccountID) } -func (p *processor) AccountFollowingGet(authed *oauth.Auth, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) { - return p.accountProcessor.FollowingGet(authed.Account, targetAccountID) +func (p *processor) AccountFollowingGet(ctx context.Context, authed *oauth.Auth, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) { + return p.accountProcessor.FollowingGet(ctx, authed.Account, targetAccountID) } -func (p *processor) AccountRelationshipGet(authed *oauth.Auth, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) { - return p.accountProcessor.RelationshipGet(authed.Account, targetAccountID) +func (p *processor) AccountRelationshipGet(ctx context.Context, authed *oauth.Auth, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) { + return p.accountProcessor.RelationshipGet(ctx, authed.Account, targetAccountID) } -func (p *processor) AccountFollowCreate(authed *oauth.Auth, form *apimodel.AccountFollowRequest) (*apimodel.Relationship, gtserror.WithCode) { - return p.accountProcessor.FollowCreate(authed.Account, form) +func (p *processor) AccountFollowCreate(ctx context.Context, authed *oauth.Auth, form *apimodel.AccountFollowRequest) (*apimodel.Relationship, gtserror.WithCode) { + return p.accountProcessor.FollowCreate(ctx, authed.Account, form) } -func (p *processor) AccountFollowRemove(authed *oauth.Auth, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) { - return p.accountProcessor.FollowRemove(authed.Account, targetAccountID) +func (p *processor) AccountFollowRemove(ctx context.Context, authed *oauth.Auth, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) { + return p.accountProcessor.FollowRemove(ctx, authed.Account, targetAccountID) } -func (p *processor) AccountBlockCreate(authed *oauth.Auth, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) { - return p.accountProcessor.BlockCreate(authed.Account, targetAccountID) +func (p *processor) AccountBlockCreate(ctx context.Context, authed *oauth.Auth, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) { + return p.accountProcessor.BlockCreate(ctx, authed.Account, targetAccountID) } -func (p *processor) AccountBlockRemove(authed *oauth.Auth, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) { - return p.accountProcessor.BlockRemove(authed.Account, targetAccountID) +func (p *processor) AccountBlockRemove(ctx context.Context, authed *oauth.Auth, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) { + return p.accountProcessor.BlockRemove(ctx, authed.Account, targetAccountID) } diff --git a/internal/processing/account/account.go b/internal/processing/account/account.go index 7b8910149..81701fd7c 100644 --- a/internal/processing/account/account.go +++ b/internal/processing/account/account.go @@ -19,6 +19,7 @@ package account import ( + "context" "mime/multipart" "github.com/sirupsen/logrus" @@ -38,40 +39,40 @@ import ( // Processor wraps a bunch of functions for processing account actions. type Processor interface { // Create processes the given form for creating a new account, returning an oauth token for that account if successful. - Create(applicationToken oauth2.TokenInfo, application *gtsmodel.Application, form *apimodel.AccountCreateRequest) (*apimodel.Token, error) + Create(ctx context.Context, applicationToken oauth2.TokenInfo, application *gtsmodel.Application, form *apimodel.AccountCreateRequest) (*apimodel.Token, error) // Delete deletes an account, and all of that account's statuses, media, follows, notifications, etc etc etc. // The origin passed here should be either the ID of the account doing the delete (can be itself), or the ID of a domain block. - Delete(account *gtsmodel.Account, origin string) error + Delete(ctx context.Context, account *gtsmodel.Account, origin string) error // Get processes the given request for account information. - Get(requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Account, error) + Get(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Account, error) // Update processes the update of an account with the given form - Update(account *gtsmodel.Account, form *apimodel.UpdateCredentialsRequest) (*apimodel.Account, error) + Update(ctx context.Context, account *gtsmodel.Account, form *apimodel.UpdateCredentialsRequest) (*apimodel.Account, error) // StatusesGet fetches a number of statuses (in time descending order) from the given account, filtered by visibility for // the account given in authed. - StatusesGet(requestingAccount *gtsmodel.Account, targetAccountID string, limit int, excludeReplies bool, maxID string, pinned bool, mediaOnly bool) ([]apimodel.Status, gtserror.WithCode) + StatusesGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string, limit int, excludeReplies bool, maxID string, pinned bool, mediaOnly bool) ([]apimodel.Status, gtserror.WithCode) // FollowersGet fetches a list of the target account's followers. - FollowersGet(requestingAccount *gtsmodel.Account, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) + FollowersGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) // FollowingGet fetches a list of the accounts that target account is following. - FollowingGet(requestingAccount *gtsmodel.Account, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) + FollowingGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) // RelationshipGet returns a relationship model describing the relationship of the targetAccount to the Authed account. - RelationshipGet(requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) + RelationshipGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) // FollowCreate handles a follow request to an account, either remote or local. - FollowCreate(requestingAccount *gtsmodel.Account, form *apimodel.AccountFollowRequest) (*apimodel.Relationship, gtserror.WithCode) + FollowCreate(ctx context.Context, requestingAccount *gtsmodel.Account, form *apimodel.AccountFollowRequest) (*apimodel.Relationship, gtserror.WithCode) // FollowRemove handles the removal of a follow/follow request to an account, either remote or local. - FollowRemove(requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) + FollowRemove(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) // BlockCreate handles the creation of a block from requestingAccount to targetAccountID, either remote or local. - BlockCreate(requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) + BlockCreate(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) // BlockRemove handles the removal of a block from requestingAccount to targetAccountID, either remote or local. - BlockRemove(requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) + BlockRemove(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) // UpdateHeader does the dirty work of checking the header part of an account update form, // parsing and checking the image, and doing the necessary updates in the database for this to become // the account's new header image. - UpdateAvatar(avatar *multipart.FileHeader, accountID string) (*gtsmodel.MediaAttachment, error) + UpdateAvatar(ctx context.Context, avatar *multipart.FileHeader, accountID string) (*gtsmodel.MediaAttachment, error) // UpdateAvatar does the dirty work of checking the avatar part of an account update form, // parsing and checking the image, and doing the necessary updates in the database for this to become // the account's new avatar image. - UpdateHeader(header *multipart.FileHeader, accountID string) (*gtsmodel.MediaAttachment, error) + UpdateHeader(ctx context.Context, header *multipart.FileHeader, accountID string) (*gtsmodel.MediaAttachment, error) } type processor struct { diff --git a/internal/processing/account/create.go b/internal/processing/account/create.go index 83e76973d..37c742b45 100644 --- a/internal/processing/account/create.go +++ b/internal/processing/account/create.go @@ -19,6 +19,7 @@ package account import ( + "context" "fmt" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" @@ -27,16 +28,24 @@ import ( "github.com/superseriousbusiness/oauth2/v4" ) -func (p *processor) Create(applicationToken oauth2.TokenInfo, application *gtsmodel.Application, form *apimodel.AccountCreateRequest) (*apimodel.Token, error) { +func (p *processor) Create(ctx context.Context, applicationToken oauth2.TokenInfo, application *gtsmodel.Application, form *apimodel.AccountCreateRequest) (*apimodel.Token, error) { l := p.log.WithField("func", "accountCreate") - if err := p.db.IsEmailAvailable(form.Email); err != nil { + emailAvailable, err := p.db.IsEmailAvailable(ctx, form.Email) + if err != nil { return nil, err } + if !emailAvailable { + return nil, fmt.Errorf("email address %s in use", form.Email) + } - if err := p.db.IsUsernameAvailable(form.Username); err != nil { + usernameAvailable, err := p.db.IsUsernameAvailable(ctx, form.Username) + if err != nil { return nil, err } + if !usernameAvailable { + return nil, fmt.Errorf("username %s in use", form.Username) + } // don't store a reason if we don't require one reason := form.Reason @@ -45,7 +54,7 @@ func (p *processor) Create(applicationToken oauth2.TokenInfo, application *gtsmo } l.Trace("creating new username and account") - user, err := p.db.NewSignup(form.Username, text.RemoveHTML(reason), p.config.AccountsConfig.RequireApproval, form.Email, form.Password, form.IP, form.Locale, application.ID, false, false) + user, err := p.db.NewSignup(ctx, form.Username, text.RemoveHTML(reason), p.config.AccountsConfig.RequireApproval, form.Email, form.Password, form.IP, form.Locale, application.ID, false, false) if err != nil { return nil, fmt.Errorf("error creating new signup in the database: %s", err) } diff --git a/internal/processing/account/createblock.go b/internal/processing/account/createblock.go index f10a2efa3..06f82b37d 100644 --- a/internal/processing/account/createblock.go +++ b/internal/processing/account/createblock.go @@ -19,6 +19,7 @@ package account import ( + "context" "fmt" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" @@ -29,18 +30,18 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/util" ) -func (p *processor) BlockCreate(requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) { +func (p *processor) BlockCreate(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) { // make sure the target account actually exists in our db - targetAccount, err := p.db.GetAccountByID(targetAccountID) + targetAccount, err := p.db.GetAccountByID(ctx, targetAccountID) if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("BlockCreate: error getting account %s from the db: %s", targetAccountID, err)) } // if requestingAccount already blocks target account, we don't need to do anything - if blocked, err := p.db.IsBlocked(requestingAccount.ID, targetAccountID, false); err != nil { + if blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, targetAccountID, false); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("BlockCreate: error checking existence of block: %s", err)) } else if blocked { - return p.RelationshipGet(requestingAccount, targetAccountID) + return p.RelationshipGet(ctx, requestingAccount, targetAccountID) } // make the block @@ -57,18 +58,18 @@ func (p *processor) BlockCreate(requestingAccount *gtsmodel.Account, targetAccou block.URI = util.GenerateURIForBlock(requestingAccount.Username, p.config.Protocol, p.config.Host, newBlockID) // whack it in the database - if err := p.db.Put(block); err != nil { + if err := p.db.Put(ctx, block); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("BlockCreate: error creating block in db: %s", err)) } // clear any follows or follow requests from the blocked account to the target account -- this is a simple delete - if err := p.db.DeleteWhere([]db.Where{ + if err := p.db.DeleteWhere(ctx, []db.Where{ {Key: "account_id", Value: targetAccountID}, {Key: "target_account_id", Value: requestingAccount.ID}, }, >smodel.Follow{}); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("BlockCreate: error removing follow in db: %s", err)) } - if err := p.db.DeleteWhere([]db.Where{ + if err := p.db.DeleteWhere(ctx, []db.Where{ {Key: "account_id", Value: targetAccountID}, {Key: "target_account_id", Value: requestingAccount.ID}, }, >smodel.FollowRequest{}); err != nil { @@ -82,12 +83,12 @@ func (p *processor) BlockCreate(requestingAccount *gtsmodel.Account, targetAccou var frChanged bool var frURI string fr := >smodel.FollowRequest{} - if err := p.db.GetWhere([]db.Where{ + if err := p.db.GetWhere(ctx, []db.Where{ {Key: "account_id", Value: requestingAccount.ID}, {Key: "target_account_id", Value: targetAccountID}, }, fr); err == nil { frURI = fr.URI - if err := p.db.DeleteByID(fr.ID, fr); err != nil { + if err := p.db.DeleteByID(ctx, fr.ID, fr); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("BlockCreate: error removing follow request from db: %s", err)) } frChanged = true @@ -97,12 +98,12 @@ func (p *processor) BlockCreate(requestingAccount *gtsmodel.Account, targetAccou var fChanged bool var fURI string f := >smodel.Follow{} - if err := p.db.GetWhere([]db.Where{ + if err := p.db.GetWhere(ctx, []db.Where{ {Key: "account_id", Value: requestingAccount.ID}, {Key: "target_account_id", Value: targetAccountID}, }, f); err == nil { fURI = f.URI - if err := p.db.DeleteByID(f.ID, f); err != nil { + if err := p.db.DeleteByID(ctx, f.ID, f); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("BlockCreate: error removing follow from db: %s", err)) } fChanged = true @@ -147,5 +148,5 @@ func (p *processor) BlockCreate(requestingAccount *gtsmodel.Account, targetAccou TargetAccount: targetAccount, } - return p.RelationshipGet(requestingAccount, targetAccountID) + return p.RelationshipGet(ctx, requestingAccount, targetAccountID) } diff --git a/internal/processing/account/createfollow.go b/internal/processing/account/createfollow.go index 8c856a50e..a7767afea 100644 --- a/internal/processing/account/createfollow.go +++ b/internal/processing/account/createfollow.go @@ -19,6 +19,7 @@ package account import ( + "context" "fmt" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" @@ -29,16 +30,16 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/util" ) -func (p *processor) FollowCreate(requestingAccount *gtsmodel.Account, form *apimodel.AccountFollowRequest) (*apimodel.Relationship, gtserror.WithCode) { +func (p *processor) FollowCreate(ctx context.Context, requestingAccount *gtsmodel.Account, form *apimodel.AccountFollowRequest) (*apimodel.Relationship, gtserror.WithCode) { // if there's a block between the accounts we shouldn't create the request ofc - if blocked, err := p.db.IsBlocked(requestingAccount.ID, form.ID, true); err != nil { + if blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, form.ID, true); err != nil { return nil, gtserror.NewErrorInternalError(err) } else if blocked { return nil, gtserror.NewErrorNotFound(fmt.Errorf("block exists between accounts")) } // make sure the target account actually exists in our db - targetAcct, err := p.db.GetAccountByID(form.ID) + targetAcct, err := p.db.GetAccountByID(ctx, form.ID) if err != nil { if err == db.ErrNoEntries { return nil, gtserror.NewErrorNotFound(fmt.Errorf("accountfollowcreate: account %s not found in the db: %s", form.ID, err)) @@ -47,19 +48,19 @@ func (p *processor) FollowCreate(requestingAccount *gtsmodel.Account, form *apim } // check if a follow exists already - if follows, err := p.db.IsFollowing(requestingAccount, targetAcct); err != nil { + if follows, err := p.db.IsFollowing(ctx, requestingAccount, targetAcct); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("accountfollowcreate: error checking follow in db: %s", err)) } else if follows { // already follows so just return the relationship - return p.RelationshipGet(requestingAccount, form.ID) + return p.RelationshipGet(ctx, requestingAccount, form.ID) } // check if a follow request exists already - if followRequested, err := p.db.IsFollowRequested(requestingAccount, targetAcct); err != nil { + if followRequested, err := p.db.IsFollowRequested(ctx, requestingAccount, targetAcct); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("accountfollowcreate: error checking follow request in db: %s", err)) } else if followRequested { // already follow requested so just return the relationship - return p.RelationshipGet(requestingAccount, form.ID) + return p.RelationshipGet(ctx, requestingAccount, form.ID) } // make the follow request @@ -84,17 +85,17 @@ func (p *processor) FollowCreate(requestingAccount *gtsmodel.Account, form *apim } // whack it in the database - if err := p.db.Put(fr); err != nil { + if err := p.db.Put(ctx, fr); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("accountfollowcreate: error creating follow request in db: %s", err)) } // if it's a local account that's not locked we can just straight up accept the follow request if !targetAcct.Locked && targetAcct.Domain == "" { - if _, err := p.db.AcceptFollowRequest(requestingAccount.ID, form.ID); err != nil { + if _, err := p.db.AcceptFollowRequest(ctx, requestingAccount.ID, form.ID); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("accountfollowcreate: error accepting folow request for local unlocked account: %s", err)) } // return the new relationship - return p.RelationshipGet(requestingAccount, form.ID) + return p.RelationshipGet(ctx, requestingAccount, form.ID) } // otherwise we leave the follow request as it is and we handle the rest of the process asynchronously @@ -107,5 +108,5 @@ func (p *processor) FollowCreate(requestingAccount *gtsmodel.Account, form *apim } // return whatever relationship results from this - return p.RelationshipGet(requestingAccount, form.ID) + return p.RelationshipGet(ctx, requestingAccount, form.ID) } diff --git a/internal/processing/account/delete.go b/internal/processing/account/delete.go index e8840abae..d97af4d2e 100644 --- a/internal/processing/account/delete.go +++ b/internal/processing/account/delete.go @@ -19,6 +19,7 @@ package account import ( + "context" "time" "github.com/sirupsen/logrus" @@ -48,7 +49,7 @@ import ( // 16. Delete account's user // 17. Delete account's timeline // 18. Delete account itself -func (p *processor) Delete(account *gtsmodel.Account, origin string) error { +func (p *processor) Delete(ctx context.Context, account *gtsmodel.Account, origin string) error { l := p.log.WithFields(logrus.Fields{ "func": "Delete", "username": account.Username, @@ -61,22 +62,22 @@ func (p *processor) Delete(account *gtsmodel.Account, origin string) error { if account.Domain == "" { // see if we can get a user for this account u := >smodel.User{} - if err := p.db.GetWhere([]db.Where{{Key: "account_id", Value: account.ID}}, u); err == nil { + if err := p.db.GetWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, u); err == nil { // we got one! select all tokens with the user's ID tokens := []*oauth.Token{} - if err := p.db.GetWhere([]db.Where{{Key: "user_id", Value: u.ID}}, &tokens); err == nil { + if err := p.db.GetWhere(ctx, []db.Where{{Key: "user_id", Value: u.ID}}, &tokens); err == nil { // we have some tokens to delete for _, t := range tokens { // delete client(s) associated with this token - if err := p.db.DeleteByID(t.ClientID, &oauth.Client{}); err != nil { + if err := p.db.DeleteByID(ctx, t.ClientID, &oauth.Client{}); err != nil { l.Errorf("error deleting oauth client: %s", err) } // delete application(s) associated with this token - if err := p.db.DeleteWhere([]db.Where{{Key: "client_id", Value: t.ClientID}}, >smodel.Application{}); err != nil { + if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "client_id", Value: t.ClientID}}, >smodel.Application{}); err != nil { l.Errorf("error deleting application: %s", err) } // delete the token itself - if err := p.db.DeleteByID(t.ID, t); err != nil { + if err := p.db.DeleteByID(ctx, t.ID, t); err != nil { l.Errorf("error deleting oauth token: %s", err) } } @@ -87,12 +88,12 @@ func (p *processor) Delete(account *gtsmodel.Account, origin string) error { // 2. Delete account's blocks l.Debug("deleting account blocks") // first delete any blocks that this account created - if err := p.db.DeleteWhere([]db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.Block{}); err != nil { + if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.Block{}); err != nil { l.Errorf("error deleting blocks created by account: %s", err) } // now delete any blocks that target this account - if err := p.db.DeleteWhere([]db.Where{{Key: "target_account_id", Value: account.ID}}, &[]*gtsmodel.Block{}); err != nil { + if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "target_account_id", Value: account.ID}}, &[]*gtsmodel.Block{}); err != nil { l.Errorf("error deleting blocks targeting account: %s", err) } @@ -103,12 +104,12 @@ func (p *processor) Delete(account *gtsmodel.Account, origin string) error { // TODO: federate these if necessary l.Debug("deleting account follow requests") // first delete any follow requests that this account created - if err := p.db.DeleteWhere([]db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.FollowRequest{}); err != nil { + if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.FollowRequest{}); err != nil { l.Errorf("error deleting follow requests created by account: %s", err) } // now delete any follow requests that target this account - if err := p.db.DeleteWhere([]db.Where{{Key: "target_account_id", Value: account.ID}}, &[]*gtsmodel.FollowRequest{}); err != nil { + if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "target_account_id", Value: account.ID}}, &[]*gtsmodel.FollowRequest{}); err != nil { l.Errorf("error deleting follow requests targeting account: %s", err) } @@ -116,12 +117,12 @@ func (p *processor) Delete(account *gtsmodel.Account, origin string) error { // TODO: federate these if necessary l.Debug("deleting account follows") // first delete any follows that this account created - if err := p.db.DeleteWhere([]db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.Follow{}); err != nil { + if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.Follow{}); err != nil { l.Errorf("error deleting follows created by account: %s", err) } // now delete any follows that target this account - if err := p.db.DeleteWhere([]db.Where{{Key: "target_account_id", Value: account.ID}}, &[]*gtsmodel.Follow{}); err != nil { + if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "target_account_id", Value: account.ID}}, &[]*gtsmodel.Follow{}); err != nil { l.Errorf("error deleting follows targeting account: %s", err) } @@ -133,7 +134,7 @@ func (p *processor) Delete(account *gtsmodel.Account, origin string) error { var maxID string selectStatusesLoop: for { - statuses, err := p.db.GetAccountStatuses(account.ID, 20, false, maxID, false, false) + statuses, err := p.db.GetAccountStatuses(ctx, account.ID, 20, false, maxID, false, false) if err != nil { if err == db.ErrNoEntries { // no statuses left for this instance so we're done @@ -157,7 +158,7 @@ selectStatusesLoop: TargetAccount: account, } - if err := p.db.DeleteByID(s.ID, s); err != nil { + if err := p.db.DeleteByID(ctx, s.ID, s); err != nil { if err != db.ErrNoEntries { // actual error has occurred l.Errorf("Delete: db error status %s for account %s: %s", s.ID, account.Username, err) @@ -167,7 +168,7 @@ selectStatusesLoop: // if there are any boosts of this status, delete them as well boosts := []*gtsmodel.Status{} - if err := p.db.GetWhere([]db.Where{{Key: "boost_of_id", Value: s.ID}}, &boosts); err != nil { + if err := p.db.GetWhere(ctx, []db.Where{{Key: "boost_of_id", Value: s.ID}}, &boosts); err != nil { if err != db.ErrNoEntries { // an actual error has occurred l.Errorf("Delete: db error selecting boosts of status %s for account %s: %s", s.ID, account.Username, err) @@ -176,20 +177,24 @@ selectStatusesLoop: } for _, b := range boosts { - oa := >smodel.Account{} - if err := p.db.GetByID(b.AccountID, oa); err == nil { - - l.Debug("putting boost undo in the client api channel") - p.fromClientAPI <- gtsmodel.FromClientAPI{ - APObjectType: gtsmodel.ActivityStreamsAnnounce, - APActivityType: gtsmodel.ActivityStreamsUndo, - GTSModel: s, - OriginAccount: oa, - TargetAccount: account, + if b.Account == nil { + bAccount, err := p.db.GetAccountByID(ctx, b.AccountID) + if err != nil { + continue } + b.Account = bAccount } - if err := p.db.DeleteByID(b.ID, b); err != nil { + l.Debug("putting boost undo in the client api channel") + p.fromClientAPI <- gtsmodel.FromClientAPI{ + APObjectType: gtsmodel.ActivityStreamsAnnounce, + APActivityType: gtsmodel.ActivityStreamsUndo, + GTSModel: s, + OriginAccount: b.Account, + TargetAccount: account, + } + + if err := p.db.DeleteByID(ctx, b.ID, b); err != nil { if err != db.ErrNoEntries { // actual error has occurred l.Errorf("Delete: db error deleting boost with id %s: %s", b.ID, err) @@ -208,26 +213,26 @@ selectStatusesLoop: // 10. Delete account's notifications l.Debug("deleting account notifications") - if err := p.db.DeleteWhere([]db.Where{{Key: "origin_account_id", Value: account.ID}}, &[]*gtsmodel.Notification{}); err != nil { + if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "origin_account_id", Value: account.ID}}, &[]*gtsmodel.Notification{}); err != nil { l.Errorf("error deleting notifications created by account: %s", err) } // 11. Delete account's bookmarks l.Debug("deleting account bookmarks") - if err := p.db.DeleteWhere([]db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.StatusBookmark{}); err != nil { + if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.StatusBookmark{}); err != nil { l.Errorf("error deleting bookmarks created by account: %s", err) } // 12. Delete account's faves // TODO: federate these if necessary l.Debug("deleting account faves") - if err := p.db.DeleteWhere([]db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.StatusFave{}); err != nil { + if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.StatusFave{}); err != nil { l.Errorf("error deleting faves created by account: %s", err) } // 13. Delete account's mutes l.Debug("deleting account mutes") - if err := p.db.DeleteWhere([]db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.StatusMute{}); err != nil { + if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.StatusMute{}); err != nil { l.Errorf("error deleting status mutes created by account: %s", err) } @@ -239,7 +244,7 @@ selectStatusesLoop: // 16. Delete account's user l.Debug("deleting account user") - if err := p.db.DeleteWhere([]db.Where{{Key: "account_id", Value: account.ID}}, >smodel.User{}); err != nil { + if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, >smodel.User{}); err != nil { return err } @@ -266,7 +271,8 @@ selectStatusesLoop: account.SuspendedAt = time.Now() account.SuspensionOrigin = origin - if err := p.db.UpdateByID(account.ID, account); err != nil { + account, err := p.db.UpdateAccount(ctx, account) + if err != nil { return err } diff --git a/internal/processing/account/get.go b/internal/processing/account/get.go index 3dfc54b51..5f039127c 100644 --- a/internal/processing/account/get.go +++ b/internal/processing/account/get.go @@ -19,6 +19,7 @@ package account import ( + "context" "errors" "fmt" @@ -27,9 +28,9 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) -func (p *processor) Get(requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Account, error) { - targetAccount := >smodel.Account{} - if err := p.db.GetByID(targetAccountID, targetAccount); err != nil { +func (p *processor) Get(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Account, error) { + targetAccount, err := p.db.GetAccountByID(ctx, targetAccountID) + if err != nil { if err == db.ErrNoEntries { return nil, errors.New("account not found") } @@ -37,9 +38,8 @@ func (p *processor) Get(requestingAccount *gtsmodel.Account, targetAccountID str } var blocked bool - var err error if requestingAccount != nil { - blocked, err = p.db.IsBlocked(requestingAccount.ID, targetAccountID, true) + blocked, err = p.db.IsBlocked(ctx, requestingAccount.ID, targetAccountID, true) if err != nil { return nil, fmt.Errorf("error checking account block: %s", err) } @@ -47,7 +47,7 @@ func (p *processor) Get(requestingAccount *gtsmodel.Account, targetAccountID str var mastoAccount *apimodel.Account if blocked { - mastoAccount, err = p.tc.AccountToMastoBlocked(targetAccount) + mastoAccount, err = p.tc.AccountToMastoBlocked(ctx, targetAccount) if err != nil { return nil, fmt.Errorf("error converting account: %s", err) } @@ -56,16 +56,16 @@ func (p *processor) Get(requestingAccount *gtsmodel.Account, targetAccountID str // last-minute check to make sure we have remote account header/avi cached if targetAccount.Domain != "" { - a, err := p.federator.EnrichRemoteAccount(requestingAccount.Username, targetAccount) + a, err := p.federator.EnrichRemoteAccount(ctx, requestingAccount.Username, targetAccount) if err == nil { targetAccount = a } } if requestingAccount != nil && targetAccount.ID == requestingAccount.ID { - mastoAccount, err = p.tc.AccountToMastoSensitive(targetAccount) + mastoAccount, err = p.tc.AccountToMastoSensitive(ctx, targetAccount) } else { - mastoAccount, err = p.tc.AccountToMastoPublic(targetAccount) + mastoAccount, err = p.tc.AccountToMastoPublic(ctx, targetAccount) } if err != nil { return nil, fmt.Errorf("error converting account: %s", err) diff --git a/internal/processing/account/getfollowers.go b/internal/processing/account/getfollowers.go index 4f66b40ee..517467085 100644 --- a/internal/processing/account/getfollowers.go +++ b/internal/processing/account/getfollowers.go @@ -19,6 +19,7 @@ package account import ( + "context" "fmt" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" @@ -27,15 +28,15 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) -func (p *processor) FollowersGet(requestingAccount *gtsmodel.Account, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) { - if blocked, err := p.db.IsBlocked(requestingAccount.ID, targetAccountID, true); err != nil { +func (p *processor) FollowersGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) { + if blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, targetAccountID, true); err != nil { return nil, gtserror.NewErrorInternalError(err) } else if blocked { return nil, gtserror.NewErrorNotFound(fmt.Errorf("block exists between accounts")) } accounts := []apimodel.Account{} - follows, err := p.db.GetAccountFollowedBy(targetAccountID, false) + follows, err := p.db.GetAccountFollowedBy(ctx, targetAccountID, false) if err != nil { if err == db.ErrNoEntries { return accounts, nil @@ -44,7 +45,7 @@ func (p *processor) FollowersGet(requestingAccount *gtsmodel.Account, targetAcco } for _, f := range follows { - blocked, err := p.db.IsBlocked(requestingAccount.ID, f.AccountID, true) + blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, f.AccountID, true) if err != nil { return nil, gtserror.NewErrorInternalError(err) } @@ -53,7 +54,7 @@ func (p *processor) FollowersGet(requestingAccount *gtsmodel.Account, targetAcco } if f.Account == nil { - a, err := p.db.GetAccountByID(f.AccountID) + a, err := p.db.GetAccountByID(ctx, f.AccountID) if err != nil { if err == db.ErrNoEntries { continue @@ -63,7 +64,7 @@ func (p *processor) FollowersGet(requestingAccount *gtsmodel.Account, targetAcco f.Account = a } - account, err := p.tc.AccountToMastoPublic(f.Account) + account, err := p.tc.AccountToMastoPublic(ctx, f.Account) if err != nil { return nil, gtserror.NewErrorInternalError(err) } diff --git a/internal/processing/account/getfollowing.go b/internal/processing/account/getfollowing.go index c7fb426f9..543213f90 100644 --- a/internal/processing/account/getfollowing.go +++ b/internal/processing/account/getfollowing.go @@ -19,6 +19,7 @@ package account import ( + "context" "fmt" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" @@ -27,15 +28,15 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) -func (p *processor) FollowingGet(requestingAccount *gtsmodel.Account, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) { - if blocked, err := p.db.IsBlocked(requestingAccount.ID, targetAccountID, true); err != nil { +func (p *processor) FollowingGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) { + if blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, targetAccountID, true); err != nil { return nil, gtserror.NewErrorInternalError(err) } else if blocked { return nil, gtserror.NewErrorNotFound(fmt.Errorf("block exists between accounts")) } accounts := []apimodel.Account{} - follows, err := p.db.GetAccountFollows(targetAccountID) + follows, err := p.db.GetAccountFollows(ctx, targetAccountID) if err != nil { if err == db.ErrNoEntries { return accounts, nil @@ -44,7 +45,7 @@ func (p *processor) FollowingGet(requestingAccount *gtsmodel.Account, targetAcco } for _, f := range follows { - blocked, err := p.db.IsBlocked(requestingAccount.ID, f.AccountID, true) + blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, f.AccountID, true) if err != nil { return nil, gtserror.NewErrorInternalError(err) } @@ -53,7 +54,7 @@ func (p *processor) FollowingGet(requestingAccount *gtsmodel.Account, targetAcco } if f.TargetAccount == nil { - a, err := p.db.GetAccountByID(f.TargetAccountID) + a, err := p.db.GetAccountByID(ctx, f.TargetAccountID) if err != nil { if err == db.ErrNoEntries { continue @@ -63,7 +64,7 @@ func (p *processor) FollowingGet(requestingAccount *gtsmodel.Account, targetAcco f.TargetAccount = a } - account, err := p.tc.AccountToMastoPublic(f.TargetAccount) + account, err := p.tc.AccountToMastoPublic(ctx, f.TargetAccount) if err != nil { return nil, gtserror.NewErrorInternalError(err) } diff --git a/internal/processing/account/getrelationship.go b/internal/processing/account/getrelationship.go index a0a93a4c2..ebfd9b479 100644 --- a/internal/processing/account/getrelationship.go +++ b/internal/processing/account/getrelationship.go @@ -19,6 +19,7 @@ package account import ( + "context" "errors" "fmt" @@ -27,17 +28,17 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) -func (p *processor) RelationshipGet(requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) { +func (p *processor) RelationshipGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) { if requestingAccount == nil { return nil, gtserror.NewErrorForbidden(errors.New("not authed")) } - gtsR, err := p.db.GetRelationship(requestingAccount.ID, targetAccountID) + gtsR, err := p.db.GetRelationship(ctx, requestingAccount.ID, targetAccountID) if err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("error getting relationship: %s", err)) } - r, err := p.tc.RelationshipToMasto(gtsR) + r, err := p.tc.RelationshipToMasto(ctx, gtsR) if err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("error converting relationship: %s", err)) } diff --git a/internal/processing/account/getstatuses.go b/internal/processing/account/getstatuses.go index dc21e7006..dc157e43c 100644 --- a/internal/processing/account/getstatuses.go +++ b/internal/processing/account/getstatuses.go @@ -19,6 +19,7 @@ package account import ( + "context" "fmt" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" @@ -27,8 +28,8 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) -func (p *processor) StatusesGet(requestingAccount *gtsmodel.Account, targetAccountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]apimodel.Status, gtserror.WithCode) { - if blocked, err := p.db.IsBlocked(requestingAccount.ID, targetAccountID, true); err != nil { +func (p *processor) StatusesGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]apimodel.Status, gtserror.WithCode) { + if blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, targetAccountID, true); err != nil { return nil, gtserror.NewErrorInternalError(err) } else if blocked { return nil, gtserror.NewErrorNotFound(fmt.Errorf("block exists between accounts")) @@ -36,7 +37,7 @@ func (p *processor) StatusesGet(requestingAccount *gtsmodel.Account, targetAccou apiStatuses := []apimodel.Status{} - statuses, err := p.db.GetAccountStatuses(targetAccountID, limit, excludeReplies, maxID, pinnedOnly, mediaOnly) + statuses, err := p.db.GetAccountStatuses(ctx, targetAccountID, limit, excludeReplies, maxID, pinnedOnly, mediaOnly) if err != nil { if err == db.ErrNoEntries { return apiStatuses, nil @@ -45,12 +46,12 @@ func (p *processor) StatusesGet(requestingAccount *gtsmodel.Account, targetAccou } for _, s := range statuses { - visible, err := p.filter.StatusVisible(s, requestingAccount) + visible, err := p.filter.StatusVisible(ctx, s, requestingAccount) if err != nil || !visible { continue } - apiStatus, err := p.tc.StatusToMasto(s, requestingAccount) + apiStatus, err := p.tc.StatusToMasto(ctx, s, requestingAccount) if err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("error converting status to masto: %s", err)) } diff --git a/internal/processing/account/removeblock.go b/internal/processing/account/removeblock.go index 7c1f2bc17..7e3d78076 100644 --- a/internal/processing/account/removeblock.go +++ b/internal/processing/account/removeblock.go @@ -19,6 +19,7 @@ package account import ( + "context" "fmt" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" @@ -27,9 +28,9 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) -func (p *processor) BlockRemove(requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) { +func (p *processor) BlockRemove(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) { // make sure the target account actually exists in our db - targetAccount, err := p.db.GetAccountByID(targetAccountID) + targetAccount, err := p.db.GetAccountByID(ctx, targetAccountID) if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("BlockCreate: error getting account %s from the db: %s", targetAccountID, err)) } @@ -37,13 +38,13 @@ func (p *processor) BlockRemove(requestingAccount *gtsmodel.Account, targetAccou // check if a block exists, and remove it if it does (storing the URI for later) var blockChanged bool block := >smodel.Block{} - if err := p.db.GetWhere([]db.Where{ + if err := p.db.GetWhere(ctx, []db.Where{ {Key: "account_id", Value: requestingAccount.ID}, {Key: "target_account_id", Value: targetAccountID}, }, block); err == nil { block.Account = requestingAccount block.TargetAccount = targetAccount - if err := p.db.DeleteByID(block.ID, >smodel.Block{}); err != nil { + if err := p.db.DeleteByID(ctx, block.ID, >smodel.Block{}); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("BlockRemove: error removing block from db: %s", err)) } blockChanged = true @@ -61,5 +62,5 @@ func (p *processor) BlockRemove(requestingAccount *gtsmodel.Account, targetAccou } // return whatever relationship results from all this - return p.RelationshipGet(requestingAccount, targetAccountID) + return p.RelationshipGet(ctx, requestingAccount, targetAccountID) } diff --git a/internal/processing/account/removefollow.go b/internal/processing/account/removefollow.go index 6646d694e..6186c550f 100644 --- a/internal/processing/account/removefollow.go +++ b/internal/processing/account/removefollow.go @@ -19,6 +19,7 @@ package account import ( + "context" "fmt" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" @@ -27,9 +28,9 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) -func (p *processor) FollowRemove(requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) { +func (p *processor) FollowRemove(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) { // if there's a block between the accounts we shouldn't do anything - blocked, err := p.db.IsBlocked(requestingAccount.ID, targetAccountID, true) + blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, targetAccountID, true) if err != nil { return nil, gtserror.NewErrorInternalError(err) } @@ -38,8 +39,8 @@ func (p *processor) FollowRemove(requestingAccount *gtsmodel.Account, targetAcco } // make sure the target account actually exists in our db - targetAcct := >smodel.Account{} - if err := p.db.GetByID(targetAccountID, targetAcct); err != nil { + targetAcct, err := p.db.GetAccountByID(ctx, targetAccountID) + if err != nil { if err == db.ErrNoEntries { return nil, gtserror.NewErrorNotFound(fmt.Errorf("AccountFollowRemove: account %s not found in the db: %s", targetAccountID, err)) } @@ -49,12 +50,12 @@ func (p *processor) FollowRemove(requestingAccount *gtsmodel.Account, targetAcco var frChanged bool var frURI string fr := >smodel.FollowRequest{} - if err := p.db.GetWhere([]db.Where{ + if err := p.db.GetWhere(ctx, []db.Where{ {Key: "account_id", Value: requestingAccount.ID}, {Key: "target_account_id", Value: targetAccountID}, }, fr); err == nil { frURI = fr.URI - if err := p.db.DeleteByID(fr.ID, fr); err != nil { + if err := p.db.DeleteByID(ctx, fr.ID, fr); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("AccountFollowRemove: error removing follow request from db: %s", err)) } frChanged = true @@ -64,12 +65,12 @@ func (p *processor) FollowRemove(requestingAccount *gtsmodel.Account, targetAcco var fChanged bool var fURI string f := >smodel.Follow{} - if err := p.db.GetWhere([]db.Where{ + if err := p.db.GetWhere(ctx, []db.Where{ {Key: "account_id", Value: requestingAccount.ID}, {Key: "target_account_id", Value: targetAccountID}, }, f); err == nil { fURI = f.URI - if err := p.db.DeleteByID(f.ID, f); err != nil { + if err := p.db.DeleteByID(ctx, f.ID, f); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("AccountFollowRemove: error removing follow from db: %s", err)) } fChanged = true @@ -106,5 +107,5 @@ func (p *processor) FollowRemove(requestingAccount *gtsmodel.Account, targetAcco } // return whatever relationship results from all this - return p.RelationshipGet(requestingAccount, targetAccountID) + return p.RelationshipGet(ctx, requestingAccount, targetAccountID) } diff --git a/internal/processing/account/update.go b/internal/processing/account/update.go index df842bacd..99ccbf5a0 100644 --- a/internal/processing/account/update.go +++ b/internal/processing/account/update.go @@ -20,6 +20,7 @@ package account import ( "bytes" + "context" "errors" "fmt" "io" @@ -32,17 +33,17 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/util" ) -func (p *processor) Update(account *gtsmodel.Account, form *apimodel.UpdateCredentialsRequest) (*apimodel.Account, error) { +func (p *processor) Update(ctx context.Context, account *gtsmodel.Account, form *apimodel.UpdateCredentialsRequest) (*apimodel.Account, error) { l := p.log.WithField("func", "AccountUpdate") if form.Discoverable != nil { - if err := p.db.UpdateOneByID(account.ID, "discoverable", *form.Discoverable, >smodel.Account{}); err != nil { + if err := p.db.UpdateOneByID(ctx, account.ID, "discoverable", *form.Discoverable, >smodel.Account{}); err != nil { return nil, fmt.Errorf("error updating discoverable: %s", err) } } if form.Bot != nil { - if err := p.db.UpdateOneByID(account.ID, "bot", *form.Bot, >smodel.Account{}); err != nil { + if err := p.db.UpdateOneByID(ctx, account.ID, "bot", *form.Bot, >smodel.Account{}); err != nil { return nil, fmt.Errorf("error updating bot: %s", err) } } @@ -52,7 +53,7 @@ func (p *processor) Update(account *gtsmodel.Account, form *apimodel.UpdateCrede return nil, err } displayName := text.RemoveHTML(*form.DisplayName) // no html allowed in display name - if err := p.db.UpdateOneByID(account.ID, "display_name", displayName, >smodel.Account{}); err != nil { + if err := p.db.UpdateOneByID(ctx, account.ID, "display_name", displayName, >smodel.Account{}); err != nil { return nil, err } } @@ -62,13 +63,13 @@ func (p *processor) Update(account *gtsmodel.Account, form *apimodel.UpdateCrede return nil, err } note := text.SanitizeHTML(*form.Note) // html OK in note but sanitize it - if err := p.db.UpdateOneByID(account.ID, "note", note, >smodel.Account{}); err != nil { + if err := p.db.UpdateOneByID(ctx, account.ID, "note", note, >smodel.Account{}); err != nil { return nil, err } } if form.Avatar != nil && form.Avatar.Size != 0 { - avatarInfo, err := p.UpdateAvatar(form.Avatar, account.ID) + avatarInfo, err := p.UpdateAvatar(ctx, form.Avatar, account.ID) if err != nil { return nil, err } @@ -76,7 +77,7 @@ func (p *processor) Update(account *gtsmodel.Account, form *apimodel.UpdateCrede } if form.Header != nil && form.Header.Size != 0 { - headerInfo, err := p.UpdateHeader(form.Header, account.ID) + headerInfo, err := p.UpdateHeader(ctx, form.Header, account.ID) if err != nil { return nil, err } @@ -84,7 +85,7 @@ func (p *processor) Update(account *gtsmodel.Account, form *apimodel.UpdateCrede } if form.Locked != nil { - if err := p.db.UpdateOneByID(account.ID, "locked", *form.Locked, >smodel.Account{}); err != nil { + if err := p.db.UpdateOneByID(ctx, account.ID, "locked", *form.Locked, >smodel.Account{}); err != nil { return nil, err } } @@ -94,13 +95,13 @@ func (p *processor) Update(account *gtsmodel.Account, form *apimodel.UpdateCrede if err := util.ValidateLanguage(*form.Source.Language); err != nil { return nil, err } - if err := p.db.UpdateOneByID(account.ID, "language", *form.Source.Language, >smodel.Account{}); err != nil { + if err := p.db.UpdateOneByID(ctx, account.ID, "language", *form.Source.Language, >smodel.Account{}); err != nil { return nil, err } } if form.Source.Sensitive != nil { - if err := p.db.UpdateOneByID(account.ID, "locked", *form.Locked, >smodel.Account{}); err != nil { + if err := p.db.UpdateOneByID(ctx, account.ID, "locked", *form.Locked, >smodel.Account{}); err != nil { return nil, err } } @@ -109,15 +110,15 @@ func (p *processor) Update(account *gtsmodel.Account, form *apimodel.UpdateCrede if err := util.ValidatePrivacy(*form.Source.Privacy); err != nil { return nil, err } - if err := p.db.UpdateOneByID(account.ID, "privacy", *form.Source.Privacy, >smodel.Account{}); err != nil { + if err := p.db.UpdateOneByID(ctx, account.ID, "privacy", *form.Source.Privacy, >smodel.Account{}); err != nil { return nil, err } } } // fetch the account with all updated values set - updatedAccount := >smodel.Account{} - if err := p.db.GetByID(account.ID, updatedAccount); err != nil { + updatedAccount, err := p.db.GetAccountByID(ctx, account.ID) + if err != nil { return nil, fmt.Errorf("could not fetch updated account %s: %s", account.ID, err) } @@ -128,7 +129,7 @@ func (p *processor) Update(account *gtsmodel.Account, form *apimodel.UpdateCrede OriginAccount: updatedAccount, } - acctSensitive, err := p.tc.AccountToMastoSensitive(updatedAccount) + acctSensitive, err := p.tc.AccountToMastoSensitive(ctx, updatedAccount) if err != nil { return nil, fmt.Errorf("could not convert account into mastosensitive account: %s", err) } @@ -138,7 +139,7 @@ func (p *processor) Update(account *gtsmodel.Account, form *apimodel.UpdateCrede // UpdateAvatar does the dirty work of checking the avatar part of an account update form, // parsing and checking the image, and doing the necessary updates in the database for this to become // the account's new avatar image. -func (p *processor) UpdateAvatar(avatar *multipart.FileHeader, accountID string) (*gtsmodel.MediaAttachment, error) { +func (p *processor) UpdateAvatar(ctx context.Context, avatar *multipart.FileHeader, accountID string) (*gtsmodel.MediaAttachment, error) { var err error if int(avatar.Size) > p.config.MediaConfig.MaxImageSize { err = fmt.Errorf("avatar with size %d exceeded max image size of %d bytes", avatar.Size, p.config.MediaConfig.MaxImageSize) @@ -160,7 +161,7 @@ func (p *processor) UpdateAvatar(avatar *multipart.FileHeader, accountID string) } // do the setting - avatarInfo, err := p.mediaHandler.ProcessHeaderOrAvatar(buf.Bytes(), accountID, media.Avatar, "") + avatarInfo, err := p.mediaHandler.ProcessHeaderOrAvatar(ctx, buf.Bytes(), accountID, media.Avatar, "") if err != nil { return nil, fmt.Errorf("error processing avatar: %s", err) } @@ -171,7 +172,7 @@ func (p *processor) UpdateAvatar(avatar *multipart.FileHeader, accountID string) // UpdateHeader does the dirty work of checking the header part of an account update form, // parsing and checking the image, and doing the necessary updates in the database for this to become // the account's new header image. -func (p *processor) UpdateHeader(header *multipart.FileHeader, accountID string) (*gtsmodel.MediaAttachment, error) { +func (p *processor) UpdateHeader(ctx context.Context, header *multipart.FileHeader, accountID string) (*gtsmodel.MediaAttachment, error) { var err error if int(header.Size) > p.config.MediaConfig.MaxImageSize { err = fmt.Errorf("header with size %d exceeded max image size of %d bytes", header.Size, p.config.MediaConfig.MaxImageSize) @@ -193,7 +194,7 @@ func (p *processor) UpdateHeader(header *multipart.FileHeader, accountID string) } // do the setting - headerInfo, err := p.mediaHandler.ProcessHeaderOrAvatar(buf.Bytes(), accountID, media.Header, "") + headerInfo, err := p.mediaHandler.ProcessHeaderOrAvatar(ctx, buf.Bytes(), accountID, media.Header, "") if err != nil { return nil, fmt.Errorf("error processing header: %s", err) } diff --git a/internal/processing/admin.go b/internal/processing/admin.go index 9a38f5ec1..48faee986 100644 --- a/internal/processing/admin.go +++ b/internal/processing/admin.go @@ -19,31 +19,33 @@ package processing import ( + "context" + apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/oauth" ) -func (p *processor) AdminEmojiCreate(authed *oauth.Auth, form *apimodel.EmojiCreateRequest) (*apimodel.Emoji, error) { - return p.adminProcessor.EmojiCreate(authed.Account, authed.User, form) +func (p *processor) AdminEmojiCreate(ctx context.Context, authed *oauth.Auth, form *apimodel.EmojiCreateRequest) (*apimodel.Emoji, error) { + return p.adminProcessor.EmojiCreate(ctx, authed.Account, authed.User, form) } -func (p *processor) AdminDomainBlockCreate(authed *oauth.Auth, form *apimodel.DomainBlockCreateRequest) (*apimodel.DomainBlock, gtserror.WithCode) { - return p.adminProcessor.DomainBlockCreate(authed.Account, form.Domain, form.Obfuscate, form.PublicComment, form.PrivateComment, "") +func (p *processor) AdminDomainBlockCreate(ctx context.Context, authed *oauth.Auth, form *apimodel.DomainBlockCreateRequest) (*apimodel.DomainBlock, gtserror.WithCode) { + return p.adminProcessor.DomainBlockCreate(ctx, authed.Account, form.Domain, form.Obfuscate, form.PublicComment, form.PrivateComment, "") } -func (p *processor) AdminDomainBlocksImport(authed *oauth.Auth, form *apimodel.DomainBlockCreateRequest) ([]*apimodel.DomainBlock, gtserror.WithCode) { - return p.adminProcessor.DomainBlocksImport(authed.Account, form.Domains) +func (p *processor) AdminDomainBlocksImport(ctx context.Context, authed *oauth.Auth, form *apimodel.DomainBlockCreateRequest) ([]*apimodel.DomainBlock, gtserror.WithCode) { + return p.adminProcessor.DomainBlocksImport(ctx, authed.Account, form.Domains) } -func (p *processor) AdminDomainBlocksGet(authed *oauth.Auth, export bool) ([]*apimodel.DomainBlock, gtserror.WithCode) { - return p.adminProcessor.DomainBlocksGet(authed.Account, export) +func (p *processor) AdminDomainBlocksGet(ctx context.Context, authed *oauth.Auth, export bool) ([]*apimodel.DomainBlock, gtserror.WithCode) { + return p.adminProcessor.DomainBlocksGet(ctx, authed.Account, export) } -func (p *processor) AdminDomainBlockGet(authed *oauth.Auth, id string, export bool) (*apimodel.DomainBlock, gtserror.WithCode) { - return p.adminProcessor.DomainBlockGet(authed.Account, id, export) +func (p *processor) AdminDomainBlockGet(ctx context.Context, authed *oauth.Auth, id string, export bool) (*apimodel.DomainBlock, gtserror.WithCode) { + return p.adminProcessor.DomainBlockGet(ctx, authed.Account, id, export) } -func (p *processor) AdminDomainBlockDelete(authed *oauth.Auth, id string) (*apimodel.DomainBlock, gtserror.WithCode) { - return p.adminProcessor.DomainBlockDelete(authed.Account, id) +func (p *processor) AdminDomainBlockDelete(ctx context.Context, authed *oauth.Auth, id string) (*apimodel.DomainBlock, gtserror.WithCode) { + return p.adminProcessor.DomainBlockDelete(ctx, authed.Account, id) } diff --git a/internal/processing/admin/admin.go b/internal/processing/admin/admin.go index fd63d8a10..de288811b 100644 --- a/internal/processing/admin/admin.go +++ b/internal/processing/admin/admin.go @@ -19,6 +19,7 @@ package admin import ( + "context" "mime/multipart" "github.com/sirupsen/logrus" @@ -33,12 +34,12 @@ import ( // Processor wraps a bunch of functions for processing admin actions. type Processor interface { - DomainBlockCreate(account *gtsmodel.Account, domain string, obfuscate bool, publicComment string, privateComment string, subscriptionID string) (*apimodel.DomainBlock, gtserror.WithCode) - DomainBlocksImport(account *gtsmodel.Account, domains *multipart.FileHeader) ([]*apimodel.DomainBlock, gtserror.WithCode) - DomainBlocksGet(account *gtsmodel.Account, export bool) ([]*apimodel.DomainBlock, gtserror.WithCode) - DomainBlockGet(account *gtsmodel.Account, id string, export bool) (*apimodel.DomainBlock, gtserror.WithCode) - DomainBlockDelete(account *gtsmodel.Account, id string) (*apimodel.DomainBlock, gtserror.WithCode) - EmojiCreate(account *gtsmodel.Account, user *gtsmodel.User, form *apimodel.EmojiCreateRequest) (*apimodel.Emoji, error) + DomainBlockCreate(ctx context.Context, account *gtsmodel.Account, domain string, obfuscate bool, publicComment string, privateComment string, subscriptionID string) (*apimodel.DomainBlock, gtserror.WithCode) + DomainBlocksImport(ctx context.Context, account *gtsmodel.Account, domains *multipart.FileHeader) ([]*apimodel.DomainBlock, gtserror.WithCode) + DomainBlocksGet(ctx context.Context, account *gtsmodel.Account, export bool) ([]*apimodel.DomainBlock, gtserror.WithCode) + DomainBlockGet(ctx context.Context, account *gtsmodel.Account, id string, export bool) (*apimodel.DomainBlock, gtserror.WithCode) + DomainBlockDelete(ctx context.Context, account *gtsmodel.Account, id string) (*apimodel.DomainBlock, gtserror.WithCode) + EmojiCreate(ctx context.Context, account *gtsmodel.Account, user *gtsmodel.User, form *apimodel.EmojiCreateRequest) (*apimodel.Emoji, error) } type processor struct { diff --git a/internal/processing/admin/createdomainblock.go b/internal/processing/admin/createdomainblock.go index 624f632dc..a34c03a44 100644 --- a/internal/processing/admin/createdomainblock.go +++ b/internal/processing/admin/createdomainblock.go @@ -19,6 +19,7 @@ package admin import ( + "context" "fmt" "time" @@ -31,10 +32,10 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/text" ) -func (p *processor) DomainBlockCreate(account *gtsmodel.Account, domain string, obfuscate bool, publicComment string, privateComment string, subscriptionID string) (*apimodel.DomainBlock, gtserror.WithCode) { +func (p *processor) DomainBlockCreate(ctx context.Context, account *gtsmodel.Account, domain string, obfuscate bool, publicComment string, privateComment string, subscriptionID string) (*apimodel.DomainBlock, gtserror.WithCode) { // first check if we already have a block -- if err == nil we already had a block so we can skip a whole lot of work domainBlock := >smodel.DomainBlock{} - err := p.db.GetWhere([]db.Where{{Key: "domain", Value: domain, CaseInsensitive: true}}, domainBlock) + err := p.db.GetWhere(ctx, []db.Where{{Key: "domain", Value: domain, CaseInsensitive: true}}, domainBlock) if err != nil { if err != db.ErrNoEntries { // something went wrong in the DB @@ -59,7 +60,7 @@ func (p *processor) DomainBlockCreate(account *gtsmodel.Account, domain string, } // put the new block in the database - if err := p.db.Put(domainBlock); err != nil { + if err := p.db.Put(ctx, domainBlock); err != nil { if err != db.ErrNoEntries { // there's a real error creating the block return nil, gtserror.NewErrorInternalError(fmt.Errorf("DomainBlockCreate: db error putting new domain block %s: %s", domain, err)) @@ -67,10 +68,10 @@ func (p *processor) DomainBlockCreate(account *gtsmodel.Account, domain string, } // process the side effects of the domain block asynchronously since it might take a while - go p.initiateDomainBlockSideEffects(account, domainBlock) // TODO: add this to a queuing system so it can retry/resume + go p.initiateDomainBlockSideEffects(ctx, account, domainBlock) // TODO: add this to a queuing system so it can retry/resume } - mastoDomainBlock, err := p.tc.DomainBlockToMasto(domainBlock, false) + mastoDomainBlock, err := p.tc.DomainBlockToMasto(ctx, domainBlock, false) if err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("DomainBlockCreate: error converting domain block to frontend/masto representation %s: %s", domain, err)) } @@ -83,7 +84,7 @@ func (p *processor) DomainBlockCreate(account *gtsmodel.Account, domain string, // 1. Strip most info away from the instance entry for the domain. // 2. Delete the instance account for that instance if it exists. // 3. Select all accounts from this instance and pass them through the delete functionality of the processor. -func (p *processor) initiateDomainBlockSideEffects(account *gtsmodel.Account, block *gtsmodel.DomainBlock) { +func (p *processor) initiateDomainBlockSideEffects(ctx context.Context, account *gtsmodel.Account, block *gtsmodel.DomainBlock) { l := p.log.WithFields(logrus.Fields{ "func": "domainBlockProcessSideEffects", "domain": block.Domain, @@ -93,7 +94,7 @@ func (p *processor) initiateDomainBlockSideEffects(account *gtsmodel.Account, bl // if we have an instance entry for this domain, update it with the new block ID and clear all fields instance := >smodel.Instance{} - if err := p.db.GetWhere([]db.Where{{Key: "domain", Value: block.Domain, CaseInsensitive: true}}, instance); err == nil { + if err := p.db.GetWhere(ctx, []db.Where{{Key: "domain", Value: block.Domain, CaseInsensitive: true}}, instance); err == nil { instance.Title = "" instance.UpdatedAt = time.Now() instance.SuspendedAt = time.Now() @@ -105,14 +106,14 @@ func (p *processor) initiateDomainBlockSideEffects(account *gtsmodel.Account, bl instance.ContactAccountUsername = "" instance.ContactAccountID = "" instance.Version = "" - if err := p.db.UpdateByID(instance.ID, instance); err != nil { + if err := p.db.UpdateByID(ctx, instance.ID, instance); err != nil { l.Errorf("domainBlockProcessSideEffects: db error updating instance: %s", err) } l.Debug("domainBlockProcessSideEffects: instance entry updated") } // if we have an instance account for this instance, delete it - if err := p.db.DeleteWhere([]db.Where{{Key: "username", Value: block.Domain, CaseInsensitive: true}}, >smodel.Account{}); err != nil { + if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "username", Value: block.Domain, CaseInsensitive: true}}, >smodel.Account{}); err != nil { l.Errorf("domainBlockProcessSideEffects: db error removing instance account: %s", err) } @@ -123,7 +124,7 @@ func (p *processor) initiateDomainBlockSideEffects(account *gtsmodel.Account, bl selectAccountsLoop: for { - accounts, err := p.db.GetInstanceAccounts(block.Domain, maxID, limit) + accounts, err := p.db.GetInstanceAccounts(ctx, block.Domain, maxID, limit) if err != nil { if err == db.ErrNoEntries { // no accounts left for this instance so we're done diff --git a/internal/processing/admin/deletedomainblock.go b/internal/processing/admin/deletedomainblock.go index edb0a58f9..2563b557d 100644 --- a/internal/processing/admin/deletedomainblock.go +++ b/internal/processing/admin/deletedomainblock.go @@ -19,6 +19,7 @@ package admin import ( + "context" "fmt" "time" @@ -28,10 +29,10 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) -func (p *processor) DomainBlockDelete(account *gtsmodel.Account, id string) (*apimodel.DomainBlock, gtserror.WithCode) { +func (p *processor) DomainBlockDelete(ctx context.Context, account *gtsmodel.Account, id string) (*apimodel.DomainBlock, gtserror.WithCode) { domainBlock := >smodel.DomainBlock{} - if err := p.db.GetByID(id, domainBlock); err != nil { + if err := p.db.GetByID(ctx, id, domainBlock); err != nil { if err != db.ErrNoEntries { // something has gone really wrong return nil, gtserror.NewErrorInternalError(err) @@ -41,39 +42,39 @@ func (p *processor) DomainBlockDelete(account *gtsmodel.Account, id string) (*ap } // prepare the domain block to return - mastoDomainBlock, err := p.tc.DomainBlockToMasto(domainBlock, false) + mastoDomainBlock, err := p.tc.DomainBlockToMasto(ctx, domainBlock, false) if err != nil { return nil, gtserror.NewErrorInternalError(err) } // delete the domain block - if err := p.db.DeleteByID(id, domainBlock); err != nil { + if err := p.db.DeleteByID(ctx, id, domainBlock); err != nil { return nil, gtserror.NewErrorInternalError(err) } // remove the domain block reference from the instance, if we have an entry for it i := >smodel.Instance{} - if err := p.db.GetWhere([]db.Where{ + if err := p.db.GetWhere(ctx, []db.Where{ {Key: "domain", Value: domainBlock.Domain, CaseInsensitive: true}, {Key: "domain_block_id", Value: id}, }, i); err == nil { i.SuspendedAt = time.Time{} i.DomainBlockID = "" - if err := p.db.UpdateByID(i.ID, i); err != nil { + if err := p.db.UpdateByID(ctx, i.ID, i); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("couldn't update database entry for instance %s: %s", domainBlock.Domain, err)) } } // unsuspend all accounts whose suspension origin was this domain block // 1. remove the 'suspended_at' entry from their accounts - if err := p.db.UpdateWhere([]db.Where{ + if err := p.db.UpdateWhere(ctx, []db.Where{ {Key: "suspension_origin", Value: domainBlock.ID}, }, "suspended_at", nil, &[]*gtsmodel.Account{}); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("database error removing suspended_at from accounts: %s", err)) } // 2. remove the 'suspension_origin' entry from their accounts - if err := p.db.UpdateWhere([]db.Where{ + if err := p.db.UpdateWhere(ctx, []db.Where{ {Key: "suspension_origin", Value: domainBlock.ID}, }, "suspension_origin", nil, &[]*gtsmodel.Account{}); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("database error removing suspension_origin from accounts: %s", err)) diff --git a/internal/processing/admin/emoji.go b/internal/processing/admin/emoji.go index f19e173b5..f56bde8e0 100644 --- a/internal/processing/admin/emoji.go +++ b/internal/processing/admin/emoji.go @@ -20,6 +20,7 @@ package admin import ( "bytes" + "context" "errors" "fmt" "io" @@ -29,7 +30,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/id" ) -func (p *processor) EmojiCreate(account *gtsmodel.Account, user *gtsmodel.User, form *apimodel.EmojiCreateRequest) (*apimodel.Emoji, error) { +func (p *processor) EmojiCreate(ctx context.Context, account *gtsmodel.Account, user *gtsmodel.User, form *apimodel.EmojiCreateRequest) (*apimodel.Emoji, error) { if user.Admin { return nil, fmt.Errorf("user %s not an admin", user.ID) } @@ -49,7 +50,7 @@ func (p *processor) EmojiCreate(account *gtsmodel.Account, user *gtsmodel.User, } // allow the mediaHandler to work its magic of processing the emoji bytes, and putting them in whatever storage backend we're using - emoji, err := p.mediaHandler.ProcessLocalEmoji(buf.Bytes(), form.Shortcode) + emoji, err := p.mediaHandler.ProcessLocalEmoji(ctx, buf.Bytes(), form.Shortcode) if err != nil { return nil, fmt.Errorf("error reading emoji: %s", err) } @@ -60,12 +61,12 @@ func (p *processor) EmojiCreate(account *gtsmodel.Account, user *gtsmodel.User, } emoji.ID = emojiID - mastoEmoji, err := p.tc.EmojiToMasto(emoji) + mastoEmoji, err := p.tc.EmojiToMasto(ctx, emoji) if err != nil { return nil, fmt.Errorf("error converting emoji to mastotype: %s", err) } - if err := p.db.Put(emoji); err != nil { + if err := p.db.Put(ctx, emoji); err != nil { return nil, fmt.Errorf("database error while processing emoji: %s", err) } diff --git a/internal/processing/admin/getdomainblock.go b/internal/processing/admin/getdomainblock.go index f74010627..19bc9fe09 100644 --- a/internal/processing/admin/getdomainblock.go +++ b/internal/processing/admin/getdomainblock.go @@ -19,6 +19,7 @@ package admin import ( + "context" "fmt" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" @@ -27,10 +28,10 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) -func (p *processor) DomainBlockGet(account *gtsmodel.Account, id string, export bool) (*apimodel.DomainBlock, gtserror.WithCode) { +func (p *processor) DomainBlockGet(ctx context.Context, account *gtsmodel.Account, id string, export bool) (*apimodel.DomainBlock, gtserror.WithCode) { domainBlock := >smodel.DomainBlock{} - if err := p.db.GetByID(id, domainBlock); err != nil { + if err := p.db.GetByID(ctx, id, domainBlock); err != nil { if err != db.ErrNoEntries { // something has gone really wrong return nil, gtserror.NewErrorInternalError(err) @@ -39,7 +40,7 @@ func (p *processor) DomainBlockGet(account *gtsmodel.Account, id string, export return nil, gtserror.NewErrorNotFound(fmt.Errorf("no entry for ID %s", id)) } - mastoDomainBlock, err := p.tc.DomainBlockToMasto(domainBlock, export) + mastoDomainBlock, err := p.tc.DomainBlockToMasto(ctx, domainBlock, export) if err != nil { return nil, gtserror.NewErrorInternalError(err) } diff --git a/internal/processing/admin/getdomainblocks.go b/internal/processing/admin/getdomainblocks.go index f827d03fc..0ec33cfff 100644 --- a/internal/processing/admin/getdomainblocks.go +++ b/internal/processing/admin/getdomainblocks.go @@ -19,16 +19,18 @@ package admin import ( + "context" + apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) -func (p *processor) DomainBlocksGet(account *gtsmodel.Account, export bool) ([]*apimodel.DomainBlock, gtserror.WithCode) { +func (p *processor) DomainBlocksGet(ctx context.Context, account *gtsmodel.Account, export bool) ([]*apimodel.DomainBlock, gtserror.WithCode) { domainBlocks := []*gtsmodel.DomainBlock{} - if err := p.db.GetAll(&domainBlocks); err != nil { + if err := p.db.GetAll(ctx, &domainBlocks); err != nil { if err != db.ErrNoEntries { // something has gone really wrong return nil, gtserror.NewErrorInternalError(err) @@ -37,7 +39,7 @@ func (p *processor) DomainBlocksGet(account *gtsmodel.Account, export bool) ([]* mastoDomainBlocks := []*apimodel.DomainBlock{} for _, b := range domainBlocks { - mastoDomainBlock, err := p.tc.DomainBlockToMasto(b, export) + mastoDomainBlock, err := p.tc.DomainBlockToMasto(ctx, b, export) if err != nil { return nil, gtserror.NewErrorInternalError(err) } diff --git a/internal/processing/admin/importdomainblocks.go b/internal/processing/admin/importdomainblocks.go index ab171b712..66326bd62 100644 --- a/internal/processing/admin/importdomainblocks.go +++ b/internal/processing/admin/importdomainblocks.go @@ -20,6 +20,7 @@ package admin import ( "bytes" + "context" "encoding/json" "errors" "fmt" @@ -32,7 +33,7 @@ import ( ) // DomainBlocksImport handles the import of a bunch of domain blocks at once, by calling the DomainBlockCreate function for each domain in the provided file. -func (p *processor) DomainBlocksImport(account *gtsmodel.Account, domains *multipart.FileHeader) ([]*apimodel.DomainBlock, gtserror.WithCode) { +func (p *processor) DomainBlocksImport(ctx context.Context, account *gtsmodel.Account, domains *multipart.FileHeader) ([]*apimodel.DomainBlock, gtserror.WithCode) { f, err := domains.Open() if err != nil { @@ -54,7 +55,7 @@ func (p *processor) DomainBlocksImport(account *gtsmodel.Account, domains *multi blocks := []*apimodel.DomainBlock{} for _, d := range d { - block, err := p.DomainBlockCreate(account, d.Domain, false, d.PublicComment, "", "") + block, err := p.DomainBlockCreate(ctx, account, d.Domain, false, d.PublicComment, "", "") if err != nil { return nil, err diff --git a/internal/processing/app.go b/internal/processing/app.go index 7da5344ac..4f805572b 100644 --- a/internal/processing/app.go +++ b/internal/processing/app.go @@ -19,6 +19,8 @@ package processing import ( + "context" + "github.com/google/uuid" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" @@ -26,7 +28,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/oauth" ) -func (p *processor) AppCreate(authed *oauth.Auth, form *apimodel.ApplicationCreateRequest) (*apimodel.Application, error) { +func (p *processor) AppCreate(ctx context.Context, authed *oauth.Auth, form *apimodel.ApplicationCreateRequest) (*apimodel.Application, error) { // set default 'read' for scopes if it's not set, this follows the default of the mastodon api https://docs.joinmastodon.org/methods/apps/ var scopes string if form.Scopes == "" { @@ -61,7 +63,7 @@ func (p *processor) AppCreate(authed *oauth.Auth, form *apimodel.ApplicationCrea } // chuck it in the db - if err := p.db.Put(app); err != nil { + if err := p.db.Put(ctx, app); err != nil { return nil, err } @@ -74,11 +76,11 @@ func (p *processor) AppCreate(authed *oauth.Auth, form *apimodel.ApplicationCrea } // chuck it in the db - if err := p.db.Put(oc); err != nil { + if err := p.db.Put(ctx, oc); err != nil { return nil, err } - mastoApp, err := p.tc.AppToMastoSensitive(app) + mastoApp, err := p.tc.AppToMastoSensitive(ctx, app) if err != nil { return nil, err } diff --git a/internal/processing/blocks.go b/internal/processing/blocks.go index 809cbde8e..7c8371989 100644 --- a/internal/processing/blocks.go +++ b/internal/processing/blocks.go @@ -19,6 +19,7 @@ package processing import ( + "context" "fmt" "net/url" @@ -28,8 +29,8 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/oauth" ) -func (p *processor) BlocksGet(authed *oauth.Auth, maxID string, sinceID string, limit int) (*apimodel.BlocksResponse, gtserror.WithCode) { - accounts, nextMaxID, prevMinID, err := p.db.GetAccountBlocks(authed.Account.ID, maxID, sinceID, limit) +func (p *processor) BlocksGet(ctx context.Context, authed *oauth.Auth, maxID string, sinceID string, limit int) (*apimodel.BlocksResponse, gtserror.WithCode) { + accounts, nextMaxID, prevMinID, err := p.db.GetAccountBlocks(ctx, authed.Account.ID, maxID, sinceID, limit) if err != nil { if err == db.ErrNoEntries { // there are just no entries @@ -43,7 +44,7 @@ func (p *processor) BlocksGet(authed *oauth.Auth, maxID string, sinceID string, apiAccounts := []*apimodel.Account{} for _, a := range accounts { - apiAccount, err := p.tc.AccountToMastoBlocked(a) + apiAccount, err := p.tc.AccountToMastoBlocked(ctx, a) if err != nil { continue } diff --git a/internal/processing/federation.go b/internal/processing/federation.go index cea14b4de..352a6ddc2 100644 --- a/internal/processing/federation.go +++ b/internal/processing/federation.go @@ -36,7 +36,7 @@ import ( func (p *processor) GetFediUser(ctx context.Context, requestedUsername string, requestURL *url.URL) (interface{}, gtserror.WithCode) { // get the account the request is referring to - requestedAccount, err := p.db.GetLocalAccountByUsername(requestedUsername) + requestedAccount, err := p.db.GetLocalAccountByUsername(ctx, requestedUsername) if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("database error getting account with username %s: %s", requestedUsername, err)) } @@ -44,7 +44,7 @@ func (p *processor) GetFediUser(ctx context.Context, requestedUsername string, r var requestedPerson vocab.ActivityStreamsPerson if util.IsPublicKeyPath(requestURL) { // if it's a public key path, we don't need to authenticate but we'll only serve the bare minimum user profile needed for the public key - requestedPerson, err = p.tc.AccountToASMinimal(requestedAccount) + requestedPerson, err = p.tc.AccountToASMinimal(ctx, requestedAccount) if err != nil { return nil, gtserror.NewErrorInternalError(err) } @@ -56,13 +56,13 @@ func (p *processor) GetFediUser(ctx context.Context, requestedUsername string, r } // if we're not already handshaking/dereferencing a remote account, dereference it now - if !p.federator.Handshaking(requestedUsername, requestingAccountURI) { - requestingAccount, _, err := p.federator.GetRemoteAccount(requestedUsername, requestingAccountURI, false) + if !p.federator.Handshaking(ctx, requestedUsername, requestingAccountURI) { + requestingAccount, _, err := p.federator.GetRemoteAccount(ctx, requestedUsername, requestingAccountURI, false) if err != nil { return nil, gtserror.NewErrorNotAuthorized(err) } - blocked, err := p.db.IsBlocked(requestedAccount.ID, requestingAccount.ID, true) + blocked, err := p.db.IsBlocked(ctx, requestedAccount.ID, requestingAccount.ID, true) if err != nil { return nil, gtserror.NewErrorInternalError(err) } @@ -72,7 +72,7 @@ func (p *processor) GetFediUser(ctx context.Context, requestedUsername string, r } } - requestedPerson, err = p.tc.AccountToAS(requestedAccount) + requestedPerson, err = p.tc.AccountToAS(ctx, requestedAccount) if err != nil { return nil, gtserror.NewErrorInternalError(err) } @@ -90,7 +90,7 @@ func (p *processor) GetFediUser(ctx context.Context, requestedUsername string, r func (p *processor) GetFediFollowers(ctx context.Context, requestedUsername string, requestURL *url.URL) (interface{}, gtserror.WithCode) { // get the account the request is referring to - requestedAccount, err := p.db.GetLocalAccountByUsername(requestedUsername) + requestedAccount, err := p.db.GetLocalAccountByUsername(ctx, requestedUsername) if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("database error getting account with username %s: %s", requestedUsername, err)) } @@ -101,12 +101,12 @@ func (p *processor) GetFediFollowers(ctx context.Context, requestedUsername stri return nil, gtserror.NewErrorNotAuthorized(errors.New("not authorized"), "not authorized") } - requestingAccount, _, err := p.federator.GetRemoteAccount(requestedUsername, requestingAccountURI, false) + requestingAccount, _, err := p.federator.GetRemoteAccount(ctx, requestedUsername, requestingAccountURI, false) if err != nil { return nil, gtserror.NewErrorNotAuthorized(err) } - blocked, err := p.db.IsBlocked(requestedAccount.ID, requestingAccount.ID, true) + blocked, err := p.db.IsBlocked(ctx, requestedAccount.ID, requestingAccount.ID, true) if err != nil { return nil, gtserror.NewErrorInternalError(err) } @@ -135,7 +135,7 @@ func (p *processor) GetFediFollowers(ctx context.Context, requestedUsername stri func (p *processor) GetFediFollowing(ctx context.Context, requestedUsername string, requestURL *url.URL) (interface{}, gtserror.WithCode) { // get the account the request is referring to - requestedAccount, err := p.db.GetLocalAccountByUsername(requestedUsername) + requestedAccount, err := p.db.GetLocalAccountByUsername(ctx, requestedUsername) if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("database error getting account with username %s: %s", requestedUsername, err)) } @@ -146,12 +146,12 @@ func (p *processor) GetFediFollowing(ctx context.Context, requestedUsername stri return nil, gtserror.NewErrorNotAuthorized(errors.New("not authorized"), "not authorized") } - requestingAccount, _, err := p.federator.GetRemoteAccount(requestedUsername, requestingAccountURI, false) + requestingAccount, _, err := p.federator.GetRemoteAccount(ctx, requestedUsername, requestingAccountURI, false) if err != nil { return nil, gtserror.NewErrorNotAuthorized(err) } - blocked, err := p.db.IsBlocked(requestedAccount.ID, requestingAccount.ID, true) + blocked, err := p.db.IsBlocked(ctx, requestedAccount.ID, requestingAccount.ID, true) if err != nil { return nil, gtserror.NewErrorInternalError(err) } @@ -180,7 +180,7 @@ func (p *processor) GetFediFollowing(ctx context.Context, requestedUsername stri func (p *processor) GetFediStatus(ctx context.Context, requestedUsername string, requestedStatusID string, requestURL *url.URL) (interface{}, gtserror.WithCode) { // get the account the request is referring to - requestedAccount, err := p.db.GetLocalAccountByUsername(requestedUsername) + requestedAccount, err := p.db.GetLocalAccountByUsername(ctx, requestedUsername) if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("database error getting account with username %s: %s", requestedUsername, err)) } @@ -191,14 +191,14 @@ func (p *processor) GetFediStatus(ctx context.Context, requestedUsername string, return nil, gtserror.NewErrorNotAuthorized(errors.New("not authorized"), "not authorized") } - requestingAccount, _, err := p.federator.GetRemoteAccount(requestedUsername, requestingAccountURI, false) + requestingAccount, _, err := p.federator.GetRemoteAccount(ctx, requestedUsername, requestingAccountURI, false) if err != nil { return nil, gtserror.NewErrorNotAuthorized(err) } // authorize the request: // 1. check if a block exists between the requester and the requestee - blocked, err := p.db.IsBlocked(requestedAccount.ID, requestingAccount.ID, true) + blocked, err := p.db.IsBlocked(ctx, requestedAccount.ID, requestingAccount.ID, true) if err != nil { return nil, gtserror.NewErrorInternalError(err) } @@ -209,14 +209,14 @@ func (p *processor) GetFediStatus(ctx context.Context, requestedUsername string, // get the status out of the database here s := >smodel.Status{} - if err := p.db.GetWhere([]db.Where{ + if err := p.db.GetWhere(ctx, []db.Where{ {Key: "id", Value: requestedStatusID}, {Key: "account_id", Value: requestedAccount.ID}, }, s); err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("database error getting status with id %s and account id %s: %s", requestedStatusID, requestedAccount.ID, err)) } - visible, err := p.filter.StatusVisible(s, requestingAccount) + visible, err := p.filter.StatusVisible(ctx, s, requestingAccount) if err != nil { return nil, gtserror.NewErrorInternalError(err) } @@ -225,7 +225,7 @@ func (p *processor) GetFediStatus(ctx context.Context, requestedUsername string, } // requester is authorized to view the status, so convert it to AP representation and serialize it - asStatus, err := p.tc.StatusToAS(s) + asStatus, err := p.tc.StatusToAS(ctx, s) if err != nil { return nil, gtserror.NewErrorInternalError(err) } @@ -240,7 +240,7 @@ func (p *processor) GetFediStatus(ctx context.Context, requestedUsername string, func (p *processor) GetFediStatusReplies(ctx context.Context, requestedUsername string, requestedStatusID string, page bool, onlyOtherAccounts bool, minID string, requestURL *url.URL) (interface{}, gtserror.WithCode) { // get the account the request is referring to - requestedAccount, err := p.db.GetLocalAccountByUsername(requestedUsername) + requestedAccount, err := p.db.GetLocalAccountByUsername(ctx, requestedUsername) if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("database error getting account with username %s: %s", requestedUsername, err)) } @@ -251,14 +251,14 @@ func (p *processor) GetFediStatusReplies(ctx context.Context, requestedUsername return nil, gtserror.NewErrorNotAuthorized(errors.New("not authorized"), "not authorized") } - requestingAccount, _, err := p.federator.GetRemoteAccount(requestedUsername, requestingAccountURI, false) + requestingAccount, _, err := p.federator.GetRemoteAccount(ctx, requestedUsername, requestingAccountURI, false) if err != nil { return nil, gtserror.NewErrorNotAuthorized(err) } // authorize the request: // 1. check if a block exists between the requester and the requestee - blocked, err := p.db.IsBlocked(requestedAccount.ID, requestingAccount.ID, true) + blocked, err := p.db.IsBlocked(ctx, requestedAccount.ID, requestingAccount.ID, true) if err != nil { return nil, gtserror.NewErrorInternalError(err) } @@ -269,14 +269,14 @@ func (p *processor) GetFediStatusReplies(ctx context.Context, requestedUsername // get the status out of the database here s := >smodel.Status{} - if err := p.db.GetWhere([]db.Where{ + if err := p.db.GetWhere(ctx, []db.Where{ {Key: "id", Value: requestedStatusID}, {Key: "account_id", Value: requestedAccount.ID}, }, s); err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("database error getting status with id %s and account id %s: %s", requestedStatusID, requestedAccount.ID, err)) } - visible, err := p.filter.StatusVisible(s, requestingAccount) + visible, err := p.filter.StatusVisible(ctx, s, requestingAccount) if err != nil { return nil, gtserror.NewErrorInternalError(err) } @@ -295,7 +295,7 @@ func (p *processor) GetFediStatusReplies(ctx context.Context, requestedUsername // scenario 1 // get the collection - collection, err := p.tc.StatusToASRepliesCollection(s, onlyOtherAccounts) + collection, err := p.tc.StatusToASRepliesCollection(ctx, s, onlyOtherAccounts) if err != nil { return nil, gtserror.NewErrorInternalError(err) } @@ -308,7 +308,7 @@ func (p *processor) GetFediStatusReplies(ctx context.Context, requestedUsername // scenario 2 // get the collection - collection, err := p.tc.StatusToASRepliesCollection(s, onlyOtherAccounts) + collection, err := p.tc.StatusToASRepliesCollection(ctx, s, onlyOtherAccounts) if err != nil { return nil, gtserror.NewErrorInternalError(err) } @@ -320,7 +320,7 @@ func (p *processor) GetFediStatusReplies(ctx context.Context, requestedUsername } else { // scenario 3 // get immediate children - replies, err := p.db.GetStatusChildren(s, true, minID) + replies, err := p.db.GetStatusChildren(ctx, s, true, minID) if err != nil { return nil, gtserror.NewErrorInternalError(err) } @@ -339,13 +339,13 @@ func (p *processor) GetFediStatusReplies(ctx context.Context, requestedUsername } // only show replies that the status owner can see - visibleToStatusOwner, err := p.filter.StatusVisible(r, requestedAccount) + visibleToStatusOwner, err := p.filter.StatusVisible(ctx, r, requestedAccount) if err != nil || !visibleToStatusOwner { continue } // only show replies that the requester can see - visibleToRequester, err := p.filter.StatusVisible(r, requestingAccount) + visibleToRequester, err := p.filter.StatusVisible(ctx, r, requestingAccount) if err != nil || !visibleToRequester { continue } @@ -358,7 +358,7 @@ func (p *processor) GetFediStatusReplies(ctx context.Context, requestedUsername replyURIs[r.ID] = rURI } - repliesPage, err := p.tc.StatusURIsToASRepliesPage(s, onlyOtherAccounts, minID, replyURIs) + repliesPage, err := p.tc.StatusURIsToASRepliesPage(ctx, s, onlyOtherAccounts, minID, replyURIs) if err != nil { return nil, gtserror.NewErrorInternalError(err) } @@ -373,7 +373,7 @@ func (p *processor) GetFediStatusReplies(ctx context.Context, requestedUsername func (p *processor) GetWebfingerAccount(ctx context.Context, requestedUsername string, requestURL *url.URL) (*apimodel.WellKnownResponse, gtserror.WithCode) { // get the account the request is referring to - requestedAccount, err := p.db.GetLocalAccountByUsername(requestedUsername) + requestedAccount, err := p.db.GetLocalAccountByUsername(ctx, requestedUsername) if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("database error getting account with username %s: %s", requestedUsername, err)) } @@ -400,7 +400,7 @@ func (p *processor) GetWebfingerAccount(ctx context.Context, requestedUsername s }, nil } -func (p *processor) GetNodeInfoRel(request *http.Request) (*apimodel.WellKnownResponse, gtserror.WithCode) { +func (p *processor) GetNodeInfoRel(ctx context.Context, request *http.Request) (*apimodel.WellKnownResponse, gtserror.WithCode) { return &apimodel.WellKnownResponse{ Links: []apimodel.Link{ { @@ -411,7 +411,7 @@ func (p *processor) GetNodeInfoRel(request *http.Request) (*apimodel.WellKnownRe }, nil } -func (p *processor) GetNodeInfo(request *http.Request) (*apimodel.Nodeinfo, gtserror.WithCode) { +func (p *processor) GetNodeInfo(ctx context.Context, request *http.Request) (*apimodel.Nodeinfo, gtserror.WithCode) { return &apimodel.Nodeinfo{ Version: "2.0", Software: apimodel.NodeInfoSoftware{ diff --git a/internal/processing/followrequest.go b/internal/processing/followrequest.go index 867725023..3dd6432e2 100644 --- a/internal/processing/followrequest.go +++ b/internal/processing/followrequest.go @@ -19,6 +19,8 @@ package processing import ( + "context" + apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtserror" @@ -26,8 +28,8 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/oauth" ) -func (p *processor) FollowRequestsGet(auth *oauth.Auth) ([]apimodel.Account, gtserror.WithCode) { - frs, err := p.db.GetAccountFollowRequests(auth.Account.ID) +func (p *processor) FollowRequestsGet(ctx context.Context, auth *oauth.Auth) ([]apimodel.Account, gtserror.WithCode) { + frs, err := p.db.GetAccountFollowRequests(ctx, auth.Account.ID) if err != nil { if err != db.ErrNoEntries { return nil, gtserror.NewErrorInternalError(err) @@ -36,11 +38,15 @@ func (p *processor) FollowRequestsGet(auth *oauth.Auth) ([]apimodel.Account, gts accts := []apimodel.Account{} for _, fr := range frs { - acct := >smodel.Account{} - if err := p.db.GetByID(fr.AccountID, acct); err != nil { - return nil, gtserror.NewErrorInternalError(err) + if fr.Account == nil { + frAcct, err := p.db.GetAccountByID(ctx, fr.AccountID) + if err != nil { + return nil, gtserror.NewErrorInternalError(err) + } + fr.Account = frAcct } - mastoAcct, err := p.tc.AccountToMastoPublic(acct) + + mastoAcct, err := p.tc.AccountToMastoPublic(ctx, fr.Account) if err != nil { return nil, gtserror.NewErrorInternalError(err) } @@ -49,36 +55,42 @@ func (p *processor) FollowRequestsGet(auth *oauth.Auth) ([]apimodel.Account, gts return accts, nil } -func (p *processor) FollowRequestAccept(auth *oauth.Auth, accountID string) (*apimodel.Relationship, gtserror.WithCode) { - follow, err := p.db.AcceptFollowRequest(accountID, auth.Account.ID) +func (p *processor) FollowRequestAccept(ctx context.Context, auth *oauth.Auth, accountID string) (*apimodel.Relationship, gtserror.WithCode) { + follow, err := p.db.AcceptFollowRequest(ctx, accountID, auth.Account.ID) if err != nil { return nil, gtserror.NewErrorNotFound(err) } - originAccount := >smodel.Account{} - if err := p.db.GetByID(follow.AccountID, originAccount); err != nil { - return nil, gtserror.NewErrorInternalError(err) + if follow.Account == nil { + followAccount, err := p.db.GetAccountByID(ctx, follow.AccountID) + if err != nil { + return nil, gtserror.NewErrorInternalError(err) + } + follow.Account = followAccount } - targetAccount := >smodel.Account{} - if err := p.db.GetByID(follow.TargetAccountID, targetAccount); err != nil { - return nil, gtserror.NewErrorInternalError(err) + if follow.TargetAccount == nil { + followTargetAccount, err := p.db.GetAccountByID(ctx, follow.TargetAccountID) + if err != nil { + return nil, gtserror.NewErrorInternalError(err) + } + follow.TargetAccount = followTargetAccount } p.fromClientAPI <- gtsmodel.FromClientAPI{ APObjectType: gtsmodel.ActivityStreamsFollow, APActivityType: gtsmodel.ActivityStreamsAccept, GTSModel: follow, - OriginAccount: originAccount, - TargetAccount: targetAccount, + OriginAccount: follow.Account, + TargetAccount: follow.TargetAccount, } - gtsR, err := p.db.GetRelationship(auth.Account.ID, accountID) + gtsR, err := p.db.GetRelationship(ctx, auth.Account.ID, accountID) if err != nil { return nil, gtserror.NewErrorInternalError(err) } - r, err := p.tc.RelationshipToMasto(gtsR) + r, err := p.tc.RelationshipToMasto(ctx, gtsR) if err != nil { return nil, gtserror.NewErrorInternalError(err) } @@ -86,6 +98,6 @@ func (p *processor) FollowRequestAccept(auth *oauth.Auth, accountID string) (*ap return r, nil } -func (p *processor) FollowRequestDeny(auth *oauth.Auth) gtserror.WithCode { +func (p *processor) FollowRequestDeny(ctx context.Context, auth *oauth.Auth) gtserror.WithCode { return nil } diff --git a/internal/processing/fromclientapi.go b/internal/processing/fromclientapi.go index beed283c1..a6ea0068b 100644 --- a/internal/processing/fromclientapi.go +++ b/internal/processing/fromclientapi.go @@ -29,7 +29,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) -func (p *processor) processFromClientAPI(clientMsg gtsmodel.FromClientAPI) error { +func (p *processor) processFromClientAPI(ctx context.Context, clientMsg gtsmodel.FromClientAPI) error { switch clientMsg.APActivityType { case gtsmodel.ActivityStreamsCreate: // CREATE @@ -41,16 +41,16 @@ func (p *processor) processFromClientAPI(clientMsg gtsmodel.FromClientAPI) error return errors.New("note was not parseable as *gtsmodel.Status") } - if err := p.timelineStatus(status); err != nil { + if err := p.timelineStatus(ctx, status); err != nil { return err } - if err := p.notifyStatus(status); err != nil { + if err := p.notifyStatus(ctx, status); err != nil { return err } if status.VisibilityAdvanced != nil && status.VisibilityAdvanced.Federated { - return p.federateStatus(status) + return p.federateStatus(ctx, status) } case gtsmodel.ActivityStreamsFollow: // CREATE FOLLOW REQUEST @@ -59,11 +59,11 @@ func (p *processor) processFromClientAPI(clientMsg gtsmodel.FromClientAPI) error return errors.New("followrequest was not parseable as *gtsmodel.FollowRequest") } - if err := p.notifyFollowRequest(followRequest, clientMsg.TargetAccount); err != nil { + if err := p.notifyFollowRequest(ctx, followRequest, clientMsg.TargetAccount); err != nil { return err } - return p.federateFollow(followRequest, clientMsg.OriginAccount, clientMsg.TargetAccount) + return p.federateFollow(ctx, followRequest, clientMsg.OriginAccount, clientMsg.TargetAccount) case gtsmodel.ActivityStreamsLike: // CREATE LIKE/FAVE fave, ok := clientMsg.GTSModel.(*gtsmodel.StatusFave) @@ -71,11 +71,11 @@ func (p *processor) processFromClientAPI(clientMsg gtsmodel.FromClientAPI) error return errors.New("fave was not parseable as *gtsmodel.StatusFave") } - if err := p.notifyFave(fave, clientMsg.TargetAccount); err != nil { + if err := p.notifyFave(ctx, fave, clientMsg.TargetAccount); err != nil { return err } - return p.federateFave(fave, clientMsg.OriginAccount, clientMsg.TargetAccount) + return p.federateFave(ctx, fave, clientMsg.OriginAccount, clientMsg.TargetAccount) case gtsmodel.ActivityStreamsAnnounce: // CREATE BOOST/ANNOUNCE boostWrapperStatus, ok := clientMsg.GTSModel.(*gtsmodel.Status) @@ -83,15 +83,15 @@ func (p *processor) processFromClientAPI(clientMsg gtsmodel.FromClientAPI) error return errors.New("boost was not parseable as *gtsmodel.Status") } - if err := p.timelineStatus(boostWrapperStatus); err != nil { + if err := p.timelineStatus(ctx, boostWrapperStatus); err != nil { return err } - if err := p.notifyAnnounce(boostWrapperStatus); err != nil { + if err := p.notifyAnnounce(ctx, boostWrapperStatus); err != nil { return err } - return p.federateAnnounce(boostWrapperStatus, clientMsg.OriginAccount, clientMsg.TargetAccount) + return p.federateAnnounce(ctx, boostWrapperStatus, clientMsg.OriginAccount, clientMsg.TargetAccount) case gtsmodel.ActivityStreamsBlock: // CREATE BLOCK block, ok := clientMsg.GTSModel.(*gtsmodel.Block) @@ -100,17 +100,17 @@ func (p *processor) processFromClientAPI(clientMsg gtsmodel.FromClientAPI) error } // remove any of the blocking account's statuses from the blocked account's timeline, and vice versa - if err := p.timelineManager.WipeStatusesFromAccountID(block.AccountID, block.TargetAccountID); err != nil { + if err := p.timelineManager.WipeStatusesFromAccountID(ctx, block.AccountID, block.TargetAccountID); err != nil { return err } - if err := p.timelineManager.WipeStatusesFromAccountID(block.TargetAccountID, block.AccountID); err != nil { + if err := p.timelineManager.WipeStatusesFromAccountID(ctx, block.TargetAccountID, block.AccountID); err != nil { return err } // TODO: same with notifications // TODO: same with bookmarks - return p.federateBlock(block) + return p.federateBlock(ctx, block) } case gtsmodel.ActivityStreamsUpdate: // UPDATE @@ -122,7 +122,7 @@ func (p *processor) processFromClientAPI(clientMsg gtsmodel.FromClientAPI) error return errors.New("account was not parseable as *gtsmodel.Account") } - return p.federateAccountUpdate(account, clientMsg.OriginAccount) + return p.federateAccountUpdate(ctx, account, clientMsg.OriginAccount) } case gtsmodel.ActivityStreamsAccept: // ACCEPT @@ -134,11 +134,11 @@ func (p *processor) processFromClientAPI(clientMsg gtsmodel.FromClientAPI) error return errors.New("accept was not parseable as *gtsmodel.Follow") } - if err := p.notifyFollow(follow, clientMsg.TargetAccount); err != nil { + if err := p.notifyFollow(ctx, follow, clientMsg.TargetAccount); err != nil { return err } - return p.federateAcceptFollowRequest(follow, clientMsg.OriginAccount, clientMsg.TargetAccount) + return p.federateAcceptFollowRequest(ctx, follow, clientMsg.OriginAccount, clientMsg.TargetAccount) } case gtsmodel.ActivityStreamsUndo: // UNDO @@ -149,21 +149,21 @@ func (p *processor) processFromClientAPI(clientMsg gtsmodel.FromClientAPI) error if !ok { return errors.New("undo was not parseable as *gtsmodel.Follow") } - return p.federateUnfollow(follow, clientMsg.OriginAccount, clientMsg.TargetAccount) + return p.federateUnfollow(ctx, follow, clientMsg.OriginAccount, clientMsg.TargetAccount) case gtsmodel.ActivityStreamsBlock: // UNDO BLOCK block, ok := clientMsg.GTSModel.(*gtsmodel.Block) if !ok { return errors.New("undo was not parseable as *gtsmodel.Block") } - return p.federateUnblock(block) + return p.federateUnblock(ctx, block) case gtsmodel.ActivityStreamsLike: // UNDO LIKE/FAVE fave, ok := clientMsg.GTSModel.(*gtsmodel.StatusFave) if !ok { return errors.New("undo was not parseable as *gtsmodel.StatusFave") } - return p.federateUnfave(fave, clientMsg.OriginAccount, clientMsg.TargetAccount) + return p.federateUnfave(ctx, fave, clientMsg.OriginAccount, clientMsg.TargetAccount) case gtsmodel.ActivityStreamsAnnounce: // UNDO ANNOUNCE/BOOST boost, ok := clientMsg.GTSModel.(*gtsmodel.Status) @@ -171,11 +171,11 @@ func (p *processor) processFromClientAPI(clientMsg gtsmodel.FromClientAPI) error return errors.New("undo was not parseable as *gtsmodel.Status") } - if err := p.deleteStatusFromTimelines(boost); err != nil { + if err := p.deleteStatusFromTimelines(ctx, boost); err != nil { return err } - return p.federateUnannounce(boost, clientMsg.OriginAccount, clientMsg.TargetAccount) + return p.federateUnannounce(ctx, boost, clientMsg.OriginAccount, clientMsg.TargetAccount) } case gtsmodel.ActivityStreamsDelete: // DELETE @@ -193,29 +193,29 @@ func (p *processor) processFromClientAPI(clientMsg gtsmodel.FromClientAPI) error // delete all attachments for this status for _, a := range statusToDelete.AttachmentIDs { - if err := p.mediaProcessor.Delete(a); err != nil { + if err := p.mediaProcessor.Delete(ctx, a); err != nil { return err } } // delete all mentions for this status for _, m := range statusToDelete.MentionIDs { - if err := p.db.DeleteByID(m, >smodel.Mention{}); err != nil { + if err := p.db.DeleteByID(ctx, m, >smodel.Mention{}); err != nil { return err } } // delete all notifications for this status - if err := p.db.DeleteWhere([]db.Where{{Key: "status_id", Value: statusToDelete.ID}}, &[]*gtsmodel.Notification{}); err != nil { + if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "status_id", Value: statusToDelete.ID}}, &[]*gtsmodel.Notification{}); err != nil { return err } // delete this status from any and all timelines - if err := p.deleteStatusFromTimelines(statusToDelete); err != nil { + if err := p.deleteStatusFromTimelines(ctx, statusToDelete); err != nil { return err } - return p.federateStatusDelete(statusToDelete) + return p.federateStatusDelete(ctx, statusToDelete) case gtsmodel.ActivityStreamsProfile, gtsmodel.ActivityStreamsPerson: // DELETE ACCOUNT/PROFILE @@ -228,7 +228,7 @@ func (p *processor) processFromClientAPI(clientMsg gtsmodel.FromClientAPI) error // origin is whichever account caused this message origin = clientMsg.OriginAccount.ID } - return p.accountProcessor.Delete(clientMsg.TargetAccount, origin) + return p.accountProcessor.Delete(ctx, clientMsg.TargetAccount, origin) } } return nil @@ -236,13 +236,13 @@ func (p *processor) processFromClientAPI(clientMsg gtsmodel.FromClientAPI) error // TODO: move all the below functions into federation.Federator -func (p *processor) federateStatus(status *gtsmodel.Status) error { +func (p *processor) federateStatus(ctx context.Context, status *gtsmodel.Status) error { if status.Account == nil { - a := >smodel.Account{} - if err := p.db.GetByID(status.AccountID, a); err != nil { + statusAccount, err := p.db.GetAccountByID(ctx, status.AccountID) + if err != nil { return fmt.Errorf("federateStatus: error fetching status author account: %s", err) } - status.Account = a + status.Account = statusAccount } // do nothing if this isn't our status @@ -250,7 +250,7 @@ func (p *processor) federateStatus(status *gtsmodel.Status) error { return nil } - asStatus, err := p.tc.StatusToAS(status) + asStatus, err := p.tc.StatusToAS(ctx, status) if err != nil { return fmt.Errorf("federateStatus: error converting status to as format: %s", err) } @@ -260,17 +260,17 @@ func (p *processor) federateStatus(status *gtsmodel.Status) error { return fmt.Errorf("federateStatus: error parsing outboxURI %s: %s", status.Account.OutboxURI, err) } - _, err = p.federator.FederatingActor().Send(context.Background(), outboxIRI, asStatus) + _, err = p.federator.FederatingActor().Send(ctx, outboxIRI, asStatus) return err } -func (p *processor) federateStatusDelete(status *gtsmodel.Status) error { +func (p *processor) federateStatusDelete(ctx context.Context, status *gtsmodel.Status) error { if status.Account == nil { - a := >smodel.Account{} - if err := p.db.GetByID(status.AccountID, a); err != nil { - return fmt.Errorf("federateStatus: error fetching status author account: %s", err) + statusAccount, err := p.db.GetAccountByID(ctx, status.AccountID) + if err != nil { + return fmt.Errorf("federateStatusDelete: error fetching status author account: %s", err) } - status.Account = a + status.Account = statusAccount } // do nothing if this isn't our status @@ -278,7 +278,7 @@ func (p *processor) federateStatusDelete(status *gtsmodel.Status) error { return nil } - asStatus, err := p.tc.StatusToAS(status) + asStatus, err := p.tc.StatusToAS(ctx, status) if err != nil { return fmt.Errorf("federateStatusDelete: error converting status to as format: %s", err) } @@ -310,19 +310,19 @@ func (p *processor) federateStatusDelete(status *gtsmodel.Status) error { delete.SetActivityStreamsTo(asStatus.GetActivityStreamsTo()) delete.SetActivityStreamsCc(asStatus.GetActivityStreamsCc()) - _, err = p.federator.FederatingActor().Send(context.Background(), outboxIRI, delete) + _, err = p.federator.FederatingActor().Send(ctx, outboxIRI, delete) return err } -func (p *processor) federateFollow(followRequest *gtsmodel.FollowRequest, originAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) error { +func (p *processor) federateFollow(ctx context.Context, followRequest *gtsmodel.FollowRequest, originAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) error { // if both accounts are local there's nothing to do here if originAccount.Domain == "" && targetAccount.Domain == "" { return nil } - follow := p.tc.FollowRequestToFollow(followRequest) + follow := p.tc.FollowRequestToFollow(ctx, followRequest) - asFollow, err := p.tc.FollowToAS(follow, originAccount, targetAccount) + asFollow, err := p.tc.FollowToAS(ctx, follow, originAccount, targetAccount) if err != nil { return fmt.Errorf("federateFollow: error converting follow to as format: %s", err) } @@ -332,18 +332,18 @@ func (p *processor) federateFollow(followRequest *gtsmodel.FollowRequest, origin return fmt.Errorf("federateFollow: error parsing outboxURI %s: %s", originAccount.OutboxURI, err) } - _, err = p.federator.FederatingActor().Send(context.Background(), outboxIRI, asFollow) + _, err = p.federator.FederatingActor().Send(ctx, outboxIRI, asFollow) return err } -func (p *processor) federateUnfollow(follow *gtsmodel.Follow, originAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) error { +func (p *processor) federateUnfollow(ctx context.Context, follow *gtsmodel.Follow, originAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) error { // if both accounts are local there's nothing to do here if originAccount.Domain == "" && targetAccount.Domain == "" { return nil } // recreate the follow - asFollow, err := p.tc.FollowToAS(follow, originAccount, targetAccount) + asFollow, err := p.tc.FollowToAS(ctx, follow, originAccount, targetAccount) if err != nil { return fmt.Errorf("federateUnfollow: error converting follow to as format: %s", err) } @@ -373,18 +373,18 @@ func (p *processor) federateUnfollow(follow *gtsmodel.Follow, originAccount *gts } // send off the Undo - _, err = p.federator.FederatingActor().Send(context.Background(), outboxIRI, undo) + _, err = p.federator.FederatingActor().Send(ctx, outboxIRI, undo) return err } -func (p *processor) federateUnfave(fave *gtsmodel.StatusFave, originAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) error { +func (p *processor) federateUnfave(ctx context.Context, fave *gtsmodel.StatusFave, originAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) error { // if both accounts are local there's nothing to do here if originAccount.Domain == "" && targetAccount.Domain == "" { return nil } // create the AS fave - asFave, err := p.tc.FaveToAS(fave) + asFave, err := p.tc.FaveToAS(ctx, fave) if err != nil { return fmt.Errorf("federateFave: error converting fave to as format: %s", err) } @@ -412,17 +412,17 @@ func (p *processor) federateUnfave(fave *gtsmodel.StatusFave, originAccount *gts if err != nil { return fmt.Errorf("federateFave: error parsing outboxURI %s: %s", originAccount.OutboxURI, err) } - _, err = p.federator.FederatingActor().Send(context.Background(), outboxIRI, undo) + _, err = p.federator.FederatingActor().Send(ctx, outboxIRI, undo) return err } -func (p *processor) federateUnannounce(boost *gtsmodel.Status, originAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) error { +func (p *processor) federateUnannounce(ctx context.Context, boost *gtsmodel.Status, originAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) error { if originAccount.Domain != "" { // nothing to do here return nil } - asAnnounce, err := p.tc.BoostToAS(boost, originAccount, targetAccount) + asAnnounce, err := p.tc.BoostToAS(ctx, boost, originAccount, targetAccount) if err != nil { return fmt.Errorf("federateUnannounce: error converting status to announce: %s", err) } @@ -447,18 +447,18 @@ func (p *processor) federateUnannounce(boost *gtsmodel.Status, originAccount *gt return fmt.Errorf("federateUnannounce: error parsing outboxURI %s: %s", originAccount.OutboxURI, err) } - _, err = p.federator.FederatingActor().Send(context.Background(), outboxIRI, undo) + _, err = p.federator.FederatingActor().Send(ctx, outboxIRI, undo) return err } -func (p *processor) federateAcceptFollowRequest(follow *gtsmodel.Follow, originAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) error { +func (p *processor) federateAcceptFollowRequest(ctx context.Context, follow *gtsmodel.Follow, originAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) error { // if both accounts are local there's nothing to do here if originAccount.Domain == "" && targetAccount.Domain == "" { return nil } // recreate the AS follow - asFollow, err := p.tc.FollowToAS(follow, originAccount, targetAccount) + asFollow, err := p.tc.FollowToAS(ctx, follow, originAccount, targetAccount) if err != nil { return fmt.Errorf("federateUnfollow: error converting follow to as format: %s", err) } @@ -497,18 +497,18 @@ func (p *processor) federateAcceptFollowRequest(follow *gtsmodel.Follow, originA } // send off the accept using the accepter's outbox - _, err = p.federator.FederatingActor().Send(context.Background(), outboxIRI, accept) + _, err = p.federator.FederatingActor().Send(ctx, outboxIRI, accept) return err } -func (p *processor) federateFave(fave *gtsmodel.StatusFave, originAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) error { +func (p *processor) federateFave(ctx context.Context, fave *gtsmodel.StatusFave, originAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) error { // if both accounts are local there's nothing to do here if originAccount.Domain == "" && targetAccount.Domain == "" { return nil } // create the AS fave - asFave, err := p.tc.FaveToAS(fave) + asFave, err := p.tc.FaveToAS(ctx, fave) if err != nil { return fmt.Errorf("federateFave: error converting fave to as format: %s", err) } @@ -517,12 +517,12 @@ func (p *processor) federateFave(fave *gtsmodel.StatusFave, originAccount *gtsmo if err != nil { return fmt.Errorf("federateFave: error parsing outboxURI %s: %s", originAccount.OutboxURI, err) } - _, err = p.federator.FederatingActor().Send(context.Background(), outboxIRI, asFave) + _, err = p.federator.FederatingActor().Send(ctx, outboxIRI, asFave) return err } -func (p *processor) federateAnnounce(boostWrapperStatus *gtsmodel.Status, boostingAccount *gtsmodel.Account, boostedAccount *gtsmodel.Account) error { - announce, err := p.tc.BoostToAS(boostWrapperStatus, boostingAccount, boostedAccount) +func (p *processor) federateAnnounce(ctx context.Context, boostWrapperStatus *gtsmodel.Status, boostingAccount *gtsmodel.Account, boostedAccount *gtsmodel.Account) error { + announce, err := p.tc.BoostToAS(ctx, boostWrapperStatus, boostingAccount, boostedAccount) if err != nil { return fmt.Errorf("federateAnnounce: error converting status to announce: %s", err) } @@ -532,12 +532,12 @@ func (p *processor) federateAnnounce(boostWrapperStatus *gtsmodel.Status, boosti return fmt.Errorf("federateAnnounce: error parsing outboxURI %s: %s", boostingAccount.OutboxURI, err) } - _, err = p.federator.FederatingActor().Send(context.Background(), outboxIRI, announce) + _, err = p.federator.FederatingActor().Send(ctx, outboxIRI, announce) return err } -func (p *processor) federateAccountUpdate(updatedAccount *gtsmodel.Account, originAccount *gtsmodel.Account) error { - person, err := p.tc.AccountToAS(updatedAccount) +func (p *processor) federateAccountUpdate(ctx context.Context, updatedAccount *gtsmodel.Account, originAccount *gtsmodel.Account) error { + person, err := p.tc.AccountToAS(ctx, updatedAccount) if err != nil { return fmt.Errorf("federateAccountUpdate: error converting account to person: %s", err) } @@ -552,25 +552,25 @@ func (p *processor) federateAccountUpdate(updatedAccount *gtsmodel.Account, orig return fmt.Errorf("federateAnnounce: error parsing outboxURI %s: %s", originAccount.OutboxURI, err) } - _, err = p.federator.FederatingActor().Send(context.Background(), outboxIRI, update) + _, err = p.federator.FederatingActor().Send(ctx, outboxIRI, update) return err } -func (p *processor) federateBlock(block *gtsmodel.Block) error { +func (p *processor) federateBlock(ctx context.Context, block *gtsmodel.Block) error { if block.Account == nil { - a := >smodel.Account{} - if err := p.db.GetByID(block.AccountID, a); err != nil { + blockAccount, err := p.db.GetAccountByID(ctx, block.AccountID) + if err != nil { return fmt.Errorf("federateBlock: error getting block account from database: %s", err) } - block.Account = a + block.Account = blockAccount } if block.TargetAccount == nil { - a := >smodel.Account{} - if err := p.db.GetByID(block.TargetAccountID, a); err != nil { + blockTargetAccount, err := p.db.GetAccountByID(ctx, block.TargetAccountID) + if err != nil { return fmt.Errorf("federateBlock: error getting block target account from database: %s", err) } - block.TargetAccount = a + block.TargetAccount = blockTargetAccount } // if both accounts are local there's nothing to do here @@ -578,7 +578,7 @@ func (p *processor) federateBlock(block *gtsmodel.Block) error { return nil } - asBlock, err := p.tc.BlockToAS(block) + asBlock, err := p.tc.BlockToAS(ctx, block) if err != nil { return fmt.Errorf("federateBlock: error converting block to AS format: %s", err) } @@ -588,25 +588,25 @@ func (p *processor) federateBlock(block *gtsmodel.Block) error { return fmt.Errorf("federateBlock: error parsing outboxURI %s: %s", block.Account.OutboxURI, err) } - _, err = p.federator.FederatingActor().Send(context.Background(), outboxIRI, asBlock) + _, err = p.federator.FederatingActor().Send(ctx, outboxIRI, asBlock) return err } -func (p *processor) federateUnblock(block *gtsmodel.Block) error { +func (p *processor) federateUnblock(ctx context.Context, block *gtsmodel.Block) error { if block.Account == nil { - a := >smodel.Account{} - if err := p.db.GetByID(block.AccountID, a); err != nil { + blockAccount, err := p.db.GetAccountByID(ctx, block.AccountID) + if err != nil { return fmt.Errorf("federateUnblock: error getting block account from database: %s", err) } - block.Account = a + block.Account = blockAccount } if block.TargetAccount == nil { - a := >smodel.Account{} - if err := p.db.GetByID(block.TargetAccountID, a); err != nil { + blockTargetAccount, err := p.db.GetAccountByID(ctx, block.TargetAccountID) + if err != nil { return fmt.Errorf("federateUnblock: error getting block target account from database: %s", err) } - block.TargetAccount = a + block.TargetAccount = blockTargetAccount } // if both accounts are local there's nothing to do here @@ -614,7 +614,7 @@ func (p *processor) federateUnblock(block *gtsmodel.Block) error { return nil } - asBlock, err := p.tc.BlockToAS(block) + asBlock, err := p.tc.BlockToAS(ctx, block) if err != nil { return fmt.Errorf("federateUnblock: error converting block to AS format: %s", err) } @@ -642,6 +642,6 @@ func (p *processor) federateUnblock(block *gtsmodel.Block) error { if err != nil { return fmt.Errorf("federateUnblock: error parsing outboxURI %s: %s", block.Account.OutboxURI, err) } - _, err = p.federator.FederatingActor().Send(context.Background(), outboxIRI, undo) + _, err = p.federator.FederatingActor().Send(ctx, outboxIRI, undo) return err } diff --git a/internal/processing/fromcommon.go b/internal/processing/fromcommon.go index 2c2635175..b7a6defc3 100644 --- a/internal/processing/fromcommon.go +++ b/internal/processing/fromcommon.go @@ -19,6 +19,7 @@ package processing import ( + "context" "fmt" "strings" "sync" @@ -28,7 +29,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/id" ) -func (p *processor) notifyStatus(status *gtsmodel.Status) error { +func (p *processor) notifyStatus(ctx context.Context, status *gtsmodel.Status) error { // if there are no mentions in this status then just bail if len(status.MentionIDs) == 0 { return nil @@ -36,7 +37,7 @@ func (p *processor) notifyStatus(status *gtsmodel.Status) error { if status.Mentions == nil { // there are mentions but they're not fully populated on the status yet so do this - menchies, err := p.db.GetMentions(status.MentionIDs) + menchies, err := p.db.GetMentions(ctx, status.MentionIDs) if err != nil { return fmt.Errorf("notifyStatus: error getting mentions for status %s from the db: %s", status.ID, err) } @@ -47,7 +48,7 @@ func (p *processor) notifyStatus(status *gtsmodel.Status) error { for _, m := range status.Mentions { // make sure this is a local account, otherwise we don't need to create a notification for it if m.TargetAccount == nil { - a, err := p.db.GetAccountByID(m.TargetAccountID) + a, err := p.db.GetAccountByID(ctx, m.TargetAccountID) if err != nil { // we don't have the account or there's been an error return fmt.Errorf("notifyStatus: error getting account with id %s from the db: %s", m.TargetAccountID, err) @@ -60,7 +61,7 @@ func (p *processor) notifyStatus(status *gtsmodel.Status) error { } // make sure a notif doesn't already exist for this mention - err := p.db.GetWhere([]db.Where{ + err := p.db.GetWhere(ctx, []db.Where{ {Key: "notification_type", Value: gtsmodel.NotificationMention}, {Key: "target_account_id", Value: m.TargetAccountID}, {Key: "origin_account_id", Value: status.AccountID}, @@ -92,12 +93,12 @@ func (p *processor) notifyStatus(status *gtsmodel.Status) error { Status: status, } - if err := p.db.Put(notif); err != nil { + if err := p.db.Put(ctx, notif); err != nil { return fmt.Errorf("notifyStatus: error putting notification in database: %s", err) } // now stream the notification to the user - mastoNotif, err := p.tc.NotificationToMasto(notif) + mastoNotif, err := p.tc.NotificationToMasto(ctx, notif) if err != nil { return fmt.Errorf("notifyStatus: error converting notification to masto representation: %s", err) } @@ -110,7 +111,7 @@ func (p *processor) notifyStatus(status *gtsmodel.Status) error { return nil } -func (p *processor) notifyFollowRequest(followRequest *gtsmodel.FollowRequest, receivingAccount *gtsmodel.Account) error { +func (p *processor) notifyFollowRequest(ctx context.Context, followRequest *gtsmodel.FollowRequest, receivingAccount *gtsmodel.Account) error { // return if this isn't a local account if receivingAccount.Domain != "" { return nil @@ -128,12 +129,12 @@ func (p *processor) notifyFollowRequest(followRequest *gtsmodel.FollowRequest, r OriginAccountID: followRequest.AccountID, } - if err := p.db.Put(notif); err != nil { + if err := p.db.Put(ctx, notif); err != nil { return fmt.Errorf("notifyFollowRequest: error putting notification in database: %s", err) } // now stream the notification to the user - mastoNotif, err := p.tc.NotificationToMasto(notif) + mastoNotif, err := p.tc.NotificationToMasto(ctx, notif) if err != nil { return fmt.Errorf("notifyStatus: error converting notification to masto representation: %s", err) } @@ -145,14 +146,14 @@ func (p *processor) notifyFollowRequest(followRequest *gtsmodel.FollowRequest, r return nil } -func (p *processor) notifyFollow(follow *gtsmodel.Follow, targetAccount *gtsmodel.Account) error { +func (p *processor) notifyFollow(ctx context.Context, follow *gtsmodel.Follow, targetAccount *gtsmodel.Account) error { // return if this isn't a local account if targetAccount.Domain != "" { return nil } // first remove the follow request notification - if err := p.db.DeleteWhere([]db.Where{ + if err := p.db.DeleteWhere(ctx, []db.Where{ {Key: "notification_type", Value: gtsmodel.NotificationFollowRequest}, {Key: "target_account_id", Value: follow.TargetAccountID}, {Key: "origin_account_id", Value: follow.AccountID}, @@ -174,12 +175,12 @@ func (p *processor) notifyFollow(follow *gtsmodel.Follow, targetAccount *gtsmode OriginAccountID: follow.AccountID, OriginAccount: follow.Account, } - if err := p.db.Put(notif); err != nil { + if err := p.db.Put(ctx, notif); err != nil { return fmt.Errorf("notifyFollow: error putting notification in database: %s", err) } // now stream the notification to the user - mastoNotif, err := p.tc.NotificationToMasto(notif) + mastoNotif, err := p.tc.NotificationToMasto(ctx, notif) if err != nil { return fmt.Errorf("notifyStatus: error converting notification to masto representation: %s", err) } @@ -191,7 +192,7 @@ func (p *processor) notifyFollow(follow *gtsmodel.Follow, targetAccount *gtsmode return nil } -func (p *processor) notifyFave(fave *gtsmodel.StatusFave, targetAccount *gtsmodel.Account) error { +func (p *processor) notifyFave(ctx context.Context, fave *gtsmodel.StatusFave, targetAccount *gtsmodel.Account) error { // return if this isn't a local account if targetAccount.Domain != "" { return nil @@ -213,12 +214,12 @@ func (p *processor) notifyFave(fave *gtsmodel.StatusFave, targetAccount *gtsmode Status: fave.Status, } - if err := p.db.Put(notif); err != nil { + if err := p.db.Put(ctx, notif); err != nil { return fmt.Errorf("notifyFave: error putting notification in database: %s", err) } // now stream the notification to the user - mastoNotif, err := p.tc.NotificationToMasto(notif) + mastoNotif, err := p.tc.NotificationToMasto(ctx, notif) if err != nil { return fmt.Errorf("notifyStatus: error converting notification to masto representation: %s", err) } @@ -230,14 +231,14 @@ func (p *processor) notifyFave(fave *gtsmodel.StatusFave, targetAccount *gtsmode return nil } -func (p *processor) notifyAnnounce(status *gtsmodel.Status) error { +func (p *processor) notifyAnnounce(ctx context.Context, status *gtsmodel.Status) error { if status.BoostOfID == "" { // not a boost, nothing to do return nil } if status.BoostOf == nil { - boostedStatus, err := p.db.GetStatusByID(status.BoostOfID) + boostedStatus, err := p.db.GetStatusByID(ctx, status.BoostOfID) if err != nil { return fmt.Errorf("notifyAnnounce: error getting status with id %s: %s", status.BoostOfID, err) } @@ -245,7 +246,7 @@ func (p *processor) notifyAnnounce(status *gtsmodel.Status) error { } if status.BoostOfAccount == nil { - boostedAcct, err := p.db.GetAccountByID(status.BoostOfAccountID) + boostedAcct, err := p.db.GetAccountByID(ctx, status.BoostOfAccountID) if err != nil { return fmt.Errorf("notifyAnnounce: error getting account with id %s: %s", status.BoostOfAccountID, err) } @@ -264,7 +265,7 @@ func (p *processor) notifyAnnounce(status *gtsmodel.Status) error { } // make sure a notif doesn't already exist for this announce - err := p.db.GetWhere([]db.Where{ + err := p.db.GetWhere(ctx, []db.Where{ {Key: "notification_type", Value: gtsmodel.NotificationReblog}, {Key: "target_account_id", Value: status.BoostOfAccountID}, {Key: "origin_account_id", Value: status.AccountID}, @@ -292,12 +293,12 @@ func (p *processor) notifyAnnounce(status *gtsmodel.Status) error { Status: status, } - if err := p.db.Put(notif); err != nil { + if err := p.db.Put(ctx, notif); err != nil { return fmt.Errorf("notifyAnnounce: error putting notification in database: %s", err) } // now stream the notification to the user - mastoNotif, err := p.tc.NotificationToMasto(notif) + mastoNotif, err := p.tc.NotificationToMasto(ctx, notif) if err != nil { return fmt.Errorf("notifyStatus: error converting notification to masto representation: %s", err) } @@ -309,10 +310,10 @@ func (p *processor) notifyAnnounce(status *gtsmodel.Status) error { return nil } -func (p *processor) timelineStatus(status *gtsmodel.Status) error { +func (p *processor) timelineStatus(ctx context.Context, status *gtsmodel.Status) error { // make sure the author account is pinned onto the status if status.Account == nil { - a, err := p.db.GetAccountByID(status.AccountID) + a, err := p.db.GetAccountByID(ctx, status.AccountID) if err != nil { return fmt.Errorf("timelineStatus: error getting author account with id %s: %s", status.AccountID, err) } @@ -320,7 +321,7 @@ func (p *processor) timelineStatus(status *gtsmodel.Status) error { } // get local followers of the account that posted the status - follows, err := p.db.GetAccountFollowedBy(status.AccountID, true) + follows, err := p.db.GetAccountFollowedBy(ctx, status.AccountID, true) if err != nil { return fmt.Errorf("timelineStatus: error getting followers for account id %s: %s", status.AccountID, err) } @@ -338,7 +339,7 @@ func (p *processor) timelineStatus(status *gtsmodel.Status) error { errors := make(chan error, len(follows)) for _, f := range follows { - go p.timelineStatusForAccount(status, f.AccountID, errors, &wg) + go p.timelineStatusForAccount(ctx, status, f.AccountID, errors, &wg) } // read any errors that come in from the async functions @@ -365,18 +366,18 @@ func (p *processor) timelineStatus(status *gtsmodel.Status) error { return nil } -func (p *processor) timelineStatusForAccount(status *gtsmodel.Status, accountID string, errors chan error, wg *sync.WaitGroup) { +func (p *processor) timelineStatusForAccount(ctx context.Context, status *gtsmodel.Status, accountID string, errors chan error, wg *sync.WaitGroup) { defer wg.Done() // get the timeline owner account - timelineAccount, err := p.db.GetAccountByID(accountID) + timelineAccount, err := p.db.GetAccountByID(ctx, accountID) if err != nil { errors <- fmt.Errorf("timelineStatusForAccount: error getting account for timeline with id %s: %s", accountID, err) return } // make sure the status is timelineable - timelineable, err := p.filter.StatusHometimelineable(status, timelineAccount) + timelineable, err := p.filter.StatusHometimelineable(ctx, status, timelineAccount) if err != nil { errors <- fmt.Errorf("timelineStatusForAccount: error getting timelineability for status for timeline with id %s: %s", accountID, err) return @@ -387,7 +388,7 @@ func (p *processor) timelineStatusForAccount(status *gtsmodel.Status, accountID } // stick the status in the timeline for the account and then immediately prepare it so they can see it right away - inserted, err := p.timelineManager.IngestAndPrepare(status, timelineAccount.ID) + inserted, err := p.timelineManager.IngestAndPrepare(ctx, status, timelineAccount.ID) if err != nil { errors <- fmt.Errorf("timelineStatusForAccount: error ingesting status %s: %s", status.ID, err) return @@ -395,7 +396,7 @@ func (p *processor) timelineStatusForAccount(status *gtsmodel.Status, accountID // the status was inserted to stream it to the user if inserted { - mastoStatus, err := p.tc.StatusToMasto(status, timelineAccount) + mastoStatus, err := p.tc.StatusToMasto(ctx, status, timelineAccount) if err != nil { errors <- fmt.Errorf("timelineStatusForAccount: error converting status %s to frontend representation: %s", status.ID, err) } else { @@ -405,7 +406,7 @@ func (p *processor) timelineStatusForAccount(status *gtsmodel.Status, accountID } } - mastoStatus, err := p.tc.StatusToMasto(status, timelineAccount) + mastoStatus, err := p.tc.StatusToMasto(ctx, status, timelineAccount) if err != nil { errors <- fmt.Errorf("timelineStatusForAccount: error converting status %s to frontend representation: %s", status.ID, err) } else { @@ -415,8 +416,8 @@ func (p *processor) timelineStatusForAccount(status *gtsmodel.Status, accountID } } -func (p *processor) deleteStatusFromTimelines(status *gtsmodel.Status) error { - if err := p.timelineManager.WipeStatusFromAllTimelines(status.ID); err != nil { +func (p *processor) deleteStatusFromTimelines(ctx context.Context, status *gtsmodel.Status) error { + if err := p.timelineManager.WipeStatusFromAllTimelines(ctx, status.ID); err != nil { return err } diff --git a/internal/processing/fromfederator.go b/internal/processing/fromfederator.go index c95c27778..2bb74db34 100644 --- a/internal/processing/fromfederator.go +++ b/internal/processing/fromfederator.go @@ -19,6 +19,7 @@ package processing import ( + "context" "errors" "fmt" "net/url" @@ -29,7 +30,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/id" ) -func (p *processor) processFromFederator(federatorMsg gtsmodel.FromFederator) error { +func (p *processor) processFromFederator(ctx context.Context, federatorMsg gtsmodel.FromFederator) error { l := p.log.WithFields(logrus.Fields{ "func": "processFromFederator", "federatorMsg": fmt.Sprintf("%+v", federatorMsg), @@ -48,16 +49,16 @@ func (p *processor) processFromFederator(federatorMsg gtsmodel.FromFederator) er return errors.New("note was not parseable as *gtsmodel.Status") } - status, err := p.federator.EnrichRemoteStatus(federatorMsg.ReceivingAccount.Username, incomingStatus) + status, err := p.federator.EnrichRemoteStatus(ctx, federatorMsg.ReceivingAccount.Username, incomingStatus) if err != nil { return err } - if err := p.timelineStatus(status); err != nil { + if err := p.timelineStatus(ctx, status); err != nil { return err } - if err := p.notifyStatus(status); err != nil { + if err := p.notifyStatus(ctx, status); err != nil { return err } case gtsmodel.ActivityStreamsProfile: @@ -70,7 +71,7 @@ func (p *processor) processFromFederator(federatorMsg gtsmodel.FromFederator) er return errors.New("like was not parseable as *gtsmodel.StatusFave") } - if err := p.notifyFave(incomingFave, federatorMsg.ReceivingAccount); err != nil { + if err := p.notifyFave(ctx, incomingFave, federatorMsg.ReceivingAccount); err != nil { return err } case gtsmodel.ActivityStreamsFollow: @@ -80,7 +81,7 @@ func (p *processor) processFromFederator(federatorMsg gtsmodel.FromFederator) er return errors.New("incomingFollowRequest was not parseable as *gtsmodel.FollowRequest") } - if err := p.notifyFollowRequest(incomingFollowRequest, federatorMsg.ReceivingAccount); err != nil { + if err := p.notifyFollowRequest(ctx, incomingFollowRequest, federatorMsg.ReceivingAccount); err != nil { return err } case gtsmodel.ActivityStreamsAnnounce: @@ -90,7 +91,7 @@ func (p *processor) processFromFederator(federatorMsg gtsmodel.FromFederator) er return errors.New("announce was not parseable as *gtsmodel.Status") } - if err := p.federator.DereferenceAnnounce(incomingAnnounce, federatorMsg.ReceivingAccount.Username); err != nil { + if err := p.federator.DereferenceAnnounce(ctx, incomingAnnounce, federatorMsg.ReceivingAccount.Username); err != nil { return fmt.Errorf("error dereferencing announce from federator: %s", err) } @@ -100,17 +101,17 @@ func (p *processor) processFromFederator(federatorMsg gtsmodel.FromFederator) er } incomingAnnounce.ID = incomingAnnounceID - if err := p.db.PutStatus(incomingAnnounce); err != nil { + if err := p.db.PutStatus(ctx, incomingAnnounce); err != nil { if err != db.ErrNoEntries { return fmt.Errorf("error adding dereferenced announce to the db: %s", err) } } - if err := p.timelineStatus(incomingAnnounce); err != nil { + if err := p.timelineStatus(ctx, incomingAnnounce); err != nil { return err } - if err := p.notifyAnnounce(incomingAnnounce); err != nil { + if err := p.notifyAnnounce(ctx, incomingAnnounce); err != nil { return err } case gtsmodel.ActivityStreamsBlock: @@ -121,10 +122,10 @@ func (p *processor) processFromFederator(federatorMsg gtsmodel.FromFederator) er } // remove any of the blocking account's statuses from the blocked account's timeline, and vice versa - if err := p.timelineManager.WipeStatusesFromAccountID(block.AccountID, block.TargetAccountID); err != nil { + if err := p.timelineManager.WipeStatusesFromAccountID(ctx, block.AccountID, block.TargetAccountID); err != nil { return err } - if err := p.timelineManager.WipeStatusesFromAccountID(block.TargetAccountID, block.AccountID); err != nil { + if err := p.timelineManager.WipeStatusesFromAccountID(ctx, block.TargetAccountID, block.AccountID); err != nil { return err } // TODO: same with notifications @@ -145,7 +146,7 @@ func (p *processor) processFromFederator(federatorMsg gtsmodel.FromFederator) er return err } - if _, _, err := p.federator.GetRemoteAccount(federatorMsg.ReceivingAccount.Username, incomingAccountURI, true); err != nil { + if _, _, err := p.federator.GetRemoteAccount(ctx, federatorMsg.ReceivingAccount.Username, incomingAccountURI, true); err != nil { return fmt.Errorf("error dereferencing account from federator: %s", err) } } @@ -165,25 +166,25 @@ func (p *processor) processFromFederator(federatorMsg gtsmodel.FromFederator) er // delete all attachments for this status for _, a := range statusToDelete.AttachmentIDs { - if err := p.mediaProcessor.Delete(a); err != nil { + if err := p.mediaProcessor.Delete(ctx, a); err != nil { return err } } // delete all mentions for this status for _, m := range statusToDelete.MentionIDs { - if err := p.db.DeleteByID(m, >smodel.Mention{}); err != nil { + if err := p.db.DeleteByID(ctx, m, >smodel.Mention{}); err != nil { return err } } // delete all notifications for this status - if err := p.db.DeleteWhere([]db.Where{{Key: "status_id", Value: statusToDelete.ID}}, &[]*gtsmodel.Notification{}); err != nil { + if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "status_id", Value: statusToDelete.ID}}, &[]*gtsmodel.Notification{}); err != nil { return err } // remove this status from any and all timelines - return p.deleteStatusFromTimelines(statusToDelete) + return p.deleteStatusFromTimelines(ctx, statusToDelete) case gtsmodel.ActivityStreamsProfile: // DELETE A PROFILE/ACCOUNT // TODO: handle side effects of account deletion here: delete all objects, statuses, media etc associated with account @@ -198,7 +199,7 @@ func (p *processor) processFromFederator(federatorMsg gtsmodel.FromFederator) er return errors.New("follow was not parseable as *gtsmodel.Follow") } - if err := p.notifyFollow(follow, federatorMsg.ReceivingAccount); err != nil { + if err := p.notifyFollow(ctx, follow, federatorMsg.ReceivingAccount); err != nil { return err } } diff --git a/internal/processing/instance.go b/internal/processing/instance.go index b151744ef..ced798c2e 100644 --- a/internal/processing/instance.go +++ b/internal/processing/instance.go @@ -19,6 +19,7 @@ package processing import ( + "context" "fmt" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" @@ -29,13 +30,13 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/util" ) -func (p *processor) InstanceGet(domain string) (*apimodel.Instance, gtserror.WithCode) { +func (p *processor) InstanceGet(ctx context.Context, domain string) (*apimodel.Instance, gtserror.WithCode) { i := >smodel.Instance{} - if err := p.db.GetWhere([]db.Where{{Key: "domain", Value: domain}}, i); err != nil { + if err := p.db.GetWhere(ctx, []db.Where{{Key: "domain", Value: domain}}, i); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error fetching instance %s: %s", p.config.Host, err)) } - ai, err := p.tc.InstanceToMasto(i) + ai, err := p.tc.InstanceToMasto(ctx, i) if err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("error converting instance to api representation: %s", err)) } @@ -43,15 +44,15 @@ func (p *processor) InstanceGet(domain string) (*apimodel.Instance, gtserror.Wit return ai, nil } -func (p *processor) InstancePatch(form *apimodel.InstanceSettingsUpdateRequest) (*apimodel.Instance, gtserror.WithCode) { +func (p *processor) InstancePatch(ctx context.Context, form *apimodel.InstanceSettingsUpdateRequest) (*apimodel.Instance, gtserror.WithCode) { // fetch the instance entry from the db for processing i := >smodel.Instance{} - if err := p.db.GetWhere([]db.Where{{Key: "domain", Value: p.config.Host}}, i); err != nil { + if err := p.db.GetWhere(ctx, []db.Where{{Key: "domain", Value: p.config.Host}}, i); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error fetching instance %s: %s", p.config.Host, err)) } // fetch the instance account from the db for processing - ia, err := p.db.GetInstanceAccount("") + ia, err := p.db.GetInstanceAccount(ctx, "") if err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error fetching instance account %s: %s", p.config.Host, err)) } @@ -67,13 +68,13 @@ func (p *processor) InstancePatch(form *apimodel.InstanceSettingsUpdateRequest) // validate & update site contact account if it's set on the form if form.ContactUsername != nil { // make sure the account with the given username exists in the db - contactAccount, err := p.db.GetLocalAccountByUsername(*form.ContactUsername) + contactAccount, err := p.db.GetLocalAccountByUsername(ctx, *form.ContactUsername) if err != nil { return nil, gtserror.NewErrorBadRequest(err, fmt.Sprintf("account with username %s not retrievable", *form.ContactUsername)) } // make sure it has a user associated with it contactUser := >smodel.User{} - if err := p.db.GetWhere([]db.Where{{Key: "account_id", Value: contactAccount.ID}}, contactUser); err != nil { + if err := p.db.GetWhere(ctx, []db.Where{{Key: "account_id", Value: contactAccount.ID}}, contactUser); err != nil { return nil, gtserror.NewErrorBadRequest(err, fmt.Sprintf("user for account with username %s not retrievable", *form.ContactUsername)) } // suspended accounts cannot be contact accounts @@ -132,7 +133,7 @@ func (p *processor) InstancePatch(form *apimodel.InstanceSettingsUpdateRequest) // process avatar if provided if form.Avatar != nil && form.Avatar.Size != 0 { - _, err := p.accountProcessor.UpdateAvatar(form.Avatar, ia.ID) + _, err := p.accountProcessor.UpdateAvatar(ctx, form.Avatar, ia.ID) if err != nil { return nil, gtserror.NewErrorBadRequest(err, "error processing avatar") } @@ -140,17 +141,17 @@ func (p *processor) InstancePatch(form *apimodel.InstanceSettingsUpdateRequest) // process header if provided if form.Header != nil && form.Header.Size != 0 { - _, err := p.accountProcessor.UpdateHeader(form.Header, ia.ID) + _, err := p.accountProcessor.UpdateHeader(ctx, form.Header, ia.ID) if err != nil { return nil, gtserror.NewErrorBadRequest(err, "error processing header") } } - if err := p.db.UpdateByID(i.ID, i); err != nil { + if err := p.db.UpdateByID(ctx, i.ID, i); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error updating instance %s: %s", p.config.Host, err)) } - ai, err := p.tc.InstanceToMasto(i) + ai, err := p.tc.InstanceToMasto(ctx, i) if err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("error converting instance to api representation: %s", err)) } diff --git a/internal/processing/media.go b/internal/processing/media.go index 6ca0eda5b..0b2443893 100644 --- a/internal/processing/media.go +++ b/internal/processing/media.go @@ -19,23 +19,25 @@ package processing import ( + "context" + apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/oauth" ) -func (p *processor) MediaCreate(authed *oauth.Auth, form *apimodel.AttachmentRequest) (*apimodel.Attachment, error) { - return p.mediaProcessor.Create(authed.Account, form) +func (p *processor) MediaCreate(ctx context.Context, authed *oauth.Auth, form *apimodel.AttachmentRequest) (*apimodel.Attachment, error) { + return p.mediaProcessor.Create(ctx, authed.Account, form) } -func (p *processor) MediaGet(authed *oauth.Auth, mediaAttachmentID string) (*apimodel.Attachment, gtserror.WithCode) { - return p.mediaProcessor.GetMedia(authed.Account, mediaAttachmentID) +func (p *processor) MediaGet(ctx context.Context, authed *oauth.Auth, mediaAttachmentID string) (*apimodel.Attachment, gtserror.WithCode) { + return p.mediaProcessor.GetMedia(ctx, authed.Account, mediaAttachmentID) } -func (p *processor) MediaUpdate(authed *oauth.Auth, mediaAttachmentID string, form *apimodel.AttachmentUpdateRequest) (*apimodel.Attachment, gtserror.WithCode) { - return p.mediaProcessor.Update(authed.Account, mediaAttachmentID, form) +func (p *processor) MediaUpdate(ctx context.Context, authed *oauth.Auth, mediaAttachmentID string, form *apimodel.AttachmentUpdateRequest) (*apimodel.Attachment, gtserror.WithCode) { + return p.mediaProcessor.Update(ctx, authed.Account, mediaAttachmentID, form) } -func (p *processor) FileGet(authed *oauth.Auth, form *apimodel.GetContentRequestForm) (*apimodel.Content, error) { - return p.mediaProcessor.GetFile(authed.Account, form) +func (p *processor) FileGet(ctx context.Context, authed *oauth.Auth, form *apimodel.GetContentRequestForm) (*apimodel.Content, error) { + return p.mediaProcessor.GetFile(ctx, authed.Account, form) } diff --git a/internal/processing/media/create.go b/internal/processing/media/create.go index 5b8cdf604..648e4d46a 100644 --- a/internal/processing/media/create.go +++ b/internal/processing/media/create.go @@ -20,6 +20,7 @@ package media import ( "bytes" + "context" "errors" "fmt" "io" @@ -29,7 +30,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/text" ) -func (p *processor) Create(account *gtsmodel.Account, form *apimodel.AttachmentRequest) (*apimodel.Attachment, error) { +func (p *processor) Create(ctx context.Context, account *gtsmodel.Account, form *apimodel.AttachmentRequest) (*apimodel.Attachment, error) { // open the attachment and extract the bytes from it f, err := form.File.Open() if err != nil { @@ -45,7 +46,7 @@ func (p *processor) Create(account *gtsmodel.Account, form *apimodel.AttachmentR } // allow the mediaHandler to work its magic of processing the attachment bytes, and putting them in whatever storage backend we're using - attachment, err := p.mediaHandler.ProcessAttachment(buf.Bytes(), account.ID, "") + attachment, err := p.mediaHandler.ProcessAttachment(ctx, buf.Bytes(), account.ID, "") if err != nil { return nil, fmt.Errorf("error reading attachment: %s", err) } @@ -66,13 +67,13 @@ func (p *processor) Create(account *gtsmodel.Account, form *apimodel.AttachmentR // prepare the frontend representation now -- if there are any errors here at least we can bail without // having already put something in the database and then having to clean it up again (eugh) - mastoAttachment, err := p.tc.AttachmentToMasto(attachment) + mastoAttachment, err := p.tc.AttachmentToMasto(ctx, attachment) if err != nil { return nil, fmt.Errorf("error parsing media attachment to frontend type: %s", err) } // now we can confidently put the attachment in the database - if err := p.db.Put(attachment); err != nil { + if err := p.db.Put(ctx, attachment); err != nil { return nil, fmt.Errorf("error storing media attachment in db: %s", err) } diff --git a/internal/processing/media/delete.go b/internal/processing/media/delete.go index b5ea8c806..281ddba03 100644 --- a/internal/processing/media/delete.go +++ b/internal/processing/media/delete.go @@ -1,17 +1,17 @@ package media import ( + "context" "fmt" "strings" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtserror" - "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) -func (p *processor) Delete(mediaAttachmentID string) gtserror.WithCode { - a := >smodel.MediaAttachment{} - if err := p.db.GetByID(mediaAttachmentID, a); err != nil { +func (p *processor) Delete(ctx context.Context, mediaAttachmentID string) gtserror.WithCode { + attachment, err := p.db.GetAttachmentByID(ctx, mediaAttachmentID) + if err != nil { if err == db.ErrNoEntries { // attachment already gone return nil @@ -23,21 +23,21 @@ func (p *processor) Delete(mediaAttachmentID string) gtserror.WithCode { errs := []string{} // delete the thumbnail from storage - if a.Thumbnail.Path != "" { - if err := p.storage.RemoveFileAt(a.Thumbnail.Path); err != nil { - errs = append(errs, fmt.Sprintf("remove thumbnail at path %s: %s", a.Thumbnail.Path, err)) + if attachment.Thumbnail.Path != "" { + if err := p.storage.RemoveFileAt(attachment.Thumbnail.Path); err != nil { + errs = append(errs, fmt.Sprintf("remove thumbnail at path %s: %s", attachment.Thumbnail.Path, err)) } } // delete the file from storage - if a.File.Path != "" { - if err := p.storage.RemoveFileAt(a.File.Path); err != nil { - errs = append(errs, fmt.Sprintf("remove file at path %s: %s", a.File.Path, err)) + if attachment.File.Path != "" { + if err := p.storage.RemoveFileAt(attachment.File.Path); err != nil { + errs = append(errs, fmt.Sprintf("remove file at path %s: %s", attachment.File.Path, err)) } } // delete the attachment - if err := p.db.DeleteByID(mediaAttachmentID, a); err != nil { + if err := p.db.DeleteByID(ctx, mediaAttachmentID, attachment); err != nil { if err != db.ErrNoEntries { errs = append(errs, fmt.Sprintf("remove attachment: %s", err)) } diff --git a/internal/processing/media/getfile.go b/internal/processing/media/getfile.go index 01288c56d..c9c9b556d 100644 --- a/internal/processing/media/getfile.go +++ b/internal/processing/media/getfile.go @@ -19,6 +19,7 @@ package media import ( + "context" "fmt" "strings" @@ -28,7 +29,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/media" ) -func (p *processor) GetFile(account *gtsmodel.Account, form *apimodel.GetContentRequestForm) (*apimodel.Content, error) { +func (p *processor) GetFile(ctx context.Context, account *gtsmodel.Account, form *apimodel.GetContentRequestForm) (*apimodel.Content, error) { // parse the form fields mediaSize, err := media.ParseMediaSize(form.MediaSize) if err != nil { @@ -47,8 +48,8 @@ func (p *processor) GetFile(account *gtsmodel.Account, form *apimodel.GetContent wantedMediaID := spl[0] // get the account that owns the media and make sure it's not suspended - acct := >smodel.Account{} - if err := p.db.GetByID(form.AccountID, acct); err != nil { + acct, err := p.db.GetAccountByID(ctx, form.AccountID) + if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("account with id %s could not be selected from the db: %s", form.AccountID, err)) } if !acct.SuspendedAt.IsZero() { @@ -57,7 +58,7 @@ func (p *processor) GetFile(account *gtsmodel.Account, form *apimodel.GetContent // make sure the requesting account and the media account don't block each other if account != nil { - blocked, err := p.db.IsBlocked(account.ID, form.AccountID, true) + blocked, err := p.db.IsBlocked(ctx, account.ID, form.AccountID, true) if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("block status could not be established between accounts %s and %s: %s", form.AccountID, account.ID, err)) } @@ -73,7 +74,7 @@ func (p *processor) GetFile(account *gtsmodel.Account, form *apimodel.GetContent switch mediaType { case media.Emoji: e := >smodel.Emoji{} - if err := p.db.GetByID(wantedMediaID, e); err != nil { + if err := p.db.GetByID(ctx, wantedMediaID, e); err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("emoji %s could not be taken from the db: %s", wantedMediaID, err)) } if e.Disabled { @@ -90,8 +91,8 @@ func (p *processor) GetFile(account *gtsmodel.Account, form *apimodel.GetContent return nil, gtserror.NewErrorNotFound(fmt.Errorf("media size %s not recognized for emoji", mediaSize)) } case media.Attachment, media.Header, media.Avatar: - a := >smodel.MediaAttachment{} - if err := p.db.GetByID(wantedMediaID, a); err != nil { + a, err := p.db.GetAttachmentByID(ctx, wantedMediaID) + if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("attachment %s could not be taken from the db: %s", wantedMediaID, err)) } if a.AccountID != form.AccountID { diff --git a/internal/processing/media/getmedia.go b/internal/processing/media/getmedia.go index 380a54cc2..91608e90d 100644 --- a/internal/processing/media/getmedia.go +++ b/internal/processing/media/getmedia.go @@ -19,6 +19,7 @@ package media import ( + "context" "errors" "fmt" @@ -28,9 +29,9 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) -func (p *processor) GetMedia(account *gtsmodel.Account, mediaAttachmentID string) (*apimodel.Attachment, gtserror.WithCode) { - attachment := >smodel.MediaAttachment{} - if err := p.db.GetByID(mediaAttachmentID, attachment); err != nil { +func (p *processor) GetMedia(ctx context.Context, account *gtsmodel.Account, mediaAttachmentID string) (*apimodel.Attachment, gtserror.WithCode) { + attachment, err := p.db.GetAttachmentByID(ctx, mediaAttachmentID) + if err != nil { if err == db.ErrNoEntries { // attachment doesn't exist return nil, gtserror.NewErrorNotFound(errors.New("attachment doesn't exist in the db")) @@ -42,7 +43,7 @@ func (p *processor) GetMedia(account *gtsmodel.Account, mediaAttachmentID string return nil, gtserror.NewErrorNotFound(errors.New("attachment not owned by requesting account")) } - a, err := p.tc.AttachmentToMasto(attachment) + a, err := p.tc.AttachmentToMasto(ctx, attachment) if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("error converting attachment: %s", err)) } diff --git a/internal/processing/media/media.go b/internal/processing/media/media.go index 79c9a7e18..6b88143e2 100644 --- a/internal/processing/media/media.go +++ b/internal/processing/media/media.go @@ -19,6 +19,8 @@ package media import ( + "context" + "github.com/sirupsen/logrus" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" "github.com/superseriousbusiness/gotosocial/internal/blob" @@ -33,12 +35,12 @@ import ( // Processor wraps a bunch of functions for processing media actions. type Processor interface { // Create creates a new media attachment belonging to the given account, using the request form. - Create(account *gtsmodel.Account, form *apimodel.AttachmentRequest) (*apimodel.Attachment, error) + Create(ctx context.Context, account *gtsmodel.Account, form *apimodel.AttachmentRequest) (*apimodel.Attachment, error) // Delete deletes the media attachment with the given ID, including all files pertaining to that attachment. - Delete(mediaAttachmentID string) gtserror.WithCode - GetFile(account *gtsmodel.Account, form *apimodel.GetContentRequestForm) (*apimodel.Content, error) - GetMedia(account *gtsmodel.Account, mediaAttachmentID string) (*apimodel.Attachment, gtserror.WithCode) - Update(account *gtsmodel.Account, mediaAttachmentID string, form *apimodel.AttachmentUpdateRequest) (*apimodel.Attachment, gtserror.WithCode) + Delete(ctx context.Context, mediaAttachmentID string) gtserror.WithCode + GetFile(ctx context.Context, account *gtsmodel.Account, form *apimodel.GetContentRequestForm) (*apimodel.Content, error) + GetMedia(ctx context.Context, account *gtsmodel.Account, mediaAttachmentID string) (*apimodel.Attachment, gtserror.WithCode) + Update(ctx context.Context, account *gtsmodel.Account, mediaAttachmentID string, form *apimodel.AttachmentUpdateRequest) (*apimodel.Attachment, gtserror.WithCode) } type processor struct { diff --git a/internal/processing/media/update.go b/internal/processing/media/update.go index 89ed08ac1..6f15f2ace 100644 --- a/internal/processing/media/update.go +++ b/internal/processing/media/update.go @@ -19,6 +19,7 @@ package media import ( + "context" "errors" "fmt" @@ -29,9 +30,9 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/text" ) -func (p *processor) Update(account *gtsmodel.Account, mediaAttachmentID string, form *apimodel.AttachmentUpdateRequest) (*apimodel.Attachment, gtserror.WithCode) { - attachment := >smodel.MediaAttachment{} - if err := p.db.GetByID(mediaAttachmentID, attachment); err != nil { +func (p *processor) Update(ctx context.Context, account *gtsmodel.Account, mediaAttachmentID string, form *apimodel.AttachmentUpdateRequest) (*apimodel.Attachment, gtserror.WithCode) { + attachment, err := p.db.GetAttachmentByID(ctx, mediaAttachmentID) + if err != nil { if err == db.ErrNoEntries { // attachment doesn't exist return nil, gtserror.NewErrorNotFound(errors.New("attachment doesn't exist in the db")) @@ -45,7 +46,7 @@ func (p *processor) Update(account *gtsmodel.Account, mediaAttachmentID string, if form.Description != nil { attachment.Description = text.RemoveHTML(*form.Description) - if err := p.db.UpdateByID(mediaAttachmentID, attachment); err != nil { + if err := p.db.UpdateByID(ctx, mediaAttachmentID, attachment); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("database error updating description: %s", err)) } } @@ -57,12 +58,12 @@ func (p *processor) Update(account *gtsmodel.Account, mediaAttachmentID string, } attachment.FileMeta.Focus.X = focusx attachment.FileMeta.Focus.Y = focusy - if err := p.db.UpdateByID(mediaAttachmentID, attachment); err != nil { + if err := p.db.UpdateByID(ctx, mediaAttachmentID, attachment); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("database error updating focus: %s", err)) } } - a, err := p.tc.AttachmentToMasto(attachment) + a, err := p.tc.AttachmentToMasto(ctx, attachment) if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("error converting attachment: %s", err)) } diff --git a/internal/processing/notification.go b/internal/processing/notification.go index 7af74b04f..f91d2f2cd 100644 --- a/internal/processing/notification.go +++ b/internal/processing/notification.go @@ -19,22 +19,24 @@ package processing import ( + "context" + apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/oauth" ) -func (p *processor) NotificationsGet(authed *oauth.Auth, limit int, maxID string, sinceID string) ([]*apimodel.Notification, gtserror.WithCode) { +func (p *processor) NotificationsGet(ctx context.Context, authed *oauth.Auth, limit int, maxID string, sinceID string) ([]*apimodel.Notification, gtserror.WithCode) { l := p.log.WithField("func", "NotificationsGet") - notifs, err := p.db.GetNotifications(authed.Account.ID, limit, maxID, sinceID) + notifs, err := p.db.GetNotifications(ctx, authed.Account.ID, limit, maxID, sinceID) if err != nil { return nil, gtserror.NewErrorInternalError(err) } mastoNotifs := []*apimodel.Notification{} for _, n := range notifs { - mastoNotif, err := p.tc.NotificationToMasto(n) + mastoNotif, err := p.tc.NotificationToMasto(ctx, n) if err != nil { l.Debugf("got an error converting a notification to masto, will skip it: %s", err) continue diff --git a/internal/processing/processor.go b/internal/processing/processor.go index 48ed2a35f..8df464ce0 100644 --- a/internal/processing/processor.go +++ b/internal/processing/processor.go @@ -51,7 +51,7 @@ import ( // for clean distribution of messages without slowing down the client API and harming the user experience. type Processor interface { // Start starts the Processor, reading from its channels and passing messages back and forth. - Start() error + Start(ctx context.Context) error // Stop stops the processor cleanly, finishing handling any remaining messages before closing down. Stop() error @@ -64,108 +64,108 @@ type Processor interface { */ // AccountCreate processes the given form for creating a new account, returning an oauth token for that account if successful. - AccountCreate(authed *oauth.Auth, form *apimodel.AccountCreateRequest) (*apimodel.Token, error) + AccountCreate(ctx context.Context, authed *oauth.Auth, form *apimodel.AccountCreateRequest) (*apimodel.Token, error) // AccountGet processes the given request for account information. - AccountGet(authed *oauth.Auth, targetAccountID string) (*apimodel.Account, error) + AccountGet(ctx context.Context, authed *oauth.Auth, targetAccountID string) (*apimodel.Account, error) // AccountUpdate processes the update of an account with the given form - AccountUpdate(authed *oauth.Auth, form *apimodel.UpdateCredentialsRequest) (*apimodel.Account, error) + AccountUpdate(ctx context.Context, authed *oauth.Auth, form *apimodel.UpdateCredentialsRequest) (*apimodel.Account, error) // AccountStatusesGet fetches a number of statuses (in time descending order) from the given account, filtered by visibility for // the account given in authed. - AccountStatusesGet(authed *oauth.Auth, targetAccountID string, limit int, excludeReplies bool, maxID string, pinned bool, mediaOnly bool) ([]apimodel.Status, gtserror.WithCode) + AccountStatusesGet(ctx context.Context, authed *oauth.Auth, targetAccountID string, limit int, excludeReplies bool, maxID string, pinned bool, mediaOnly bool) ([]apimodel.Status, gtserror.WithCode) // AccountFollowersGet fetches a list of the target account's followers. - AccountFollowersGet(authed *oauth.Auth, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) + AccountFollowersGet(ctx context.Context, authed *oauth.Auth, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) // AccountFollowingGet fetches a list of the accounts that target account is following. - AccountFollowingGet(authed *oauth.Auth, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) + AccountFollowingGet(ctx context.Context, authed *oauth.Auth, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) // AccountRelationshipGet returns a relationship model describing the relationship of the targetAccount to the Authed account. - AccountRelationshipGet(authed *oauth.Auth, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) + AccountRelationshipGet(ctx context.Context, authed *oauth.Auth, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) // AccountFollowCreate handles a follow request to an account, either remote or local. - AccountFollowCreate(authed *oauth.Auth, form *apimodel.AccountFollowRequest) (*apimodel.Relationship, gtserror.WithCode) + AccountFollowCreate(ctx context.Context, authed *oauth.Auth, form *apimodel.AccountFollowRequest) (*apimodel.Relationship, gtserror.WithCode) // AccountFollowRemove handles the removal of a follow/follow request to an account, either remote or local. - AccountFollowRemove(authed *oauth.Auth, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) + AccountFollowRemove(ctx context.Context, authed *oauth.Auth, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) // AccountBlockCreate handles the creation of a block from authed account to target account, either remote or local. - AccountBlockCreate(authed *oauth.Auth, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) + AccountBlockCreate(ctx context.Context, authed *oauth.Auth, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) // AccountBlockRemove handles the removal of a block from authed account to target account, either remote or local. - AccountBlockRemove(authed *oauth.Auth, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) + AccountBlockRemove(ctx context.Context, authed *oauth.Auth, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) // AdminEmojiCreate handles the creation of a new instance emoji by an admin, using the given form. - AdminEmojiCreate(authed *oauth.Auth, form *apimodel.EmojiCreateRequest) (*apimodel.Emoji, error) + AdminEmojiCreate(ctx context.Context, authed *oauth.Auth, form *apimodel.EmojiCreateRequest) (*apimodel.Emoji, error) // AdminDomainBlockCreate handles the creation of a new domain block by an admin, using the given form. - AdminDomainBlockCreate(authed *oauth.Auth, form *apimodel.DomainBlockCreateRequest) (*apimodel.DomainBlock, gtserror.WithCode) + AdminDomainBlockCreate(ctx context.Context, authed *oauth.Auth, form *apimodel.DomainBlockCreateRequest) (*apimodel.DomainBlock, gtserror.WithCode) // AdminDomainBlocksImport handles the import of multiple domain blocks by an admin, using the given form. - AdminDomainBlocksImport(authed *oauth.Auth, form *apimodel.DomainBlockCreateRequest) ([]*apimodel.DomainBlock, gtserror.WithCode) + AdminDomainBlocksImport(ctx context.Context, authed *oauth.Auth, form *apimodel.DomainBlockCreateRequest) ([]*apimodel.DomainBlock, gtserror.WithCode) // AdminDomainBlocksGet returns a list of currently blocked domains. - AdminDomainBlocksGet(authed *oauth.Auth, export bool) ([]*apimodel.DomainBlock, gtserror.WithCode) + AdminDomainBlocksGet(ctx context.Context, authed *oauth.Auth, export bool) ([]*apimodel.DomainBlock, gtserror.WithCode) // AdminDomainBlockGet returns one domain block, specified by ID. - AdminDomainBlockGet(authed *oauth.Auth, id string, export bool) (*apimodel.DomainBlock, gtserror.WithCode) + AdminDomainBlockGet(ctx context.Context, authed *oauth.Auth, id string, export bool) (*apimodel.DomainBlock, gtserror.WithCode) // AdminDomainBlockDelete deletes one domain block, specified by ID, returning the deleted domain block. - AdminDomainBlockDelete(authed *oauth.Auth, id string) (*apimodel.DomainBlock, gtserror.WithCode) + AdminDomainBlockDelete(ctx context.Context, authed *oauth.Auth, id string) (*apimodel.DomainBlock, gtserror.WithCode) // AppCreate processes the creation of a new API application - AppCreate(authed *oauth.Auth, form *apimodel.ApplicationCreateRequest) (*apimodel.Application, error) + AppCreate(ctx context.Context, authed *oauth.Auth, form *apimodel.ApplicationCreateRequest) (*apimodel.Application, error) // BlocksGet returns a list of accounts blocked by the requesting account. - BlocksGet(authed *oauth.Auth, maxID string, sinceID string, limit int) (*apimodel.BlocksResponse, gtserror.WithCode) + BlocksGet(ctx context.Context, authed *oauth.Auth, maxID string, sinceID string, limit int) (*apimodel.BlocksResponse, gtserror.WithCode) // FileGet handles the fetching of a media attachment file via the fileserver. - FileGet(authed *oauth.Auth, form *apimodel.GetContentRequestForm) (*apimodel.Content, error) + FileGet(ctx context.Context, authed *oauth.Auth, form *apimodel.GetContentRequestForm) (*apimodel.Content, error) // FollowRequestsGet handles the getting of the authed account's incoming follow requests - FollowRequestsGet(auth *oauth.Auth) ([]apimodel.Account, gtserror.WithCode) + FollowRequestsGet(ctx context.Context, auth *oauth.Auth) ([]apimodel.Account, gtserror.WithCode) // FollowRequestAccept handles the acceptance of a follow request from the given account ID - FollowRequestAccept(auth *oauth.Auth, accountID string) (*apimodel.Relationship, gtserror.WithCode) + FollowRequestAccept(ctx context.Context, auth *oauth.Auth, accountID string) (*apimodel.Relationship, gtserror.WithCode) // InstanceGet retrieves instance information for serving at api/v1/instance - InstanceGet(domain string) (*apimodel.Instance, gtserror.WithCode) + InstanceGet(ctx context.Context, domain string) (*apimodel.Instance, gtserror.WithCode) // InstancePatch updates this instance according to the given form. // // It should already be ascertained that the requesting account is authenticated and an admin. - InstancePatch(form *apimodel.InstanceSettingsUpdateRequest) (*apimodel.Instance, gtserror.WithCode) + InstancePatch(ctx context.Context, form *apimodel.InstanceSettingsUpdateRequest) (*apimodel.Instance, gtserror.WithCode) // MediaCreate handles the creation of a media attachment, using the given form. - MediaCreate(authed *oauth.Auth, form *apimodel.AttachmentRequest) (*apimodel.Attachment, error) + MediaCreate(ctx context.Context, authed *oauth.Auth, form *apimodel.AttachmentRequest) (*apimodel.Attachment, error) // MediaGet handles the GET of a media attachment with the given ID - MediaGet(authed *oauth.Auth, attachmentID string) (*apimodel.Attachment, gtserror.WithCode) + MediaGet(ctx context.Context, authed *oauth.Auth, attachmentID string) (*apimodel.Attachment, gtserror.WithCode) // MediaUpdate handles the PUT of a media attachment with the given ID and form - MediaUpdate(authed *oauth.Auth, attachmentID string, form *apimodel.AttachmentUpdateRequest) (*apimodel.Attachment, gtserror.WithCode) + MediaUpdate(ctx context.Context, authed *oauth.Auth, attachmentID string, form *apimodel.AttachmentUpdateRequest) (*apimodel.Attachment, gtserror.WithCode) // NotificationsGet - NotificationsGet(authed *oauth.Auth, limit int, maxID string, sinceID string) ([]*apimodel.Notification, gtserror.WithCode) + NotificationsGet(ctx context.Context, authed *oauth.Auth, limit int, maxID string, sinceID string) ([]*apimodel.Notification, gtserror.WithCode) // SearchGet performs a search with the given params, resolving/dereferencing remotely as desired - SearchGet(authed *oauth.Auth, searchQuery *apimodel.SearchQuery) (*apimodel.SearchResult, gtserror.WithCode) + SearchGet(ctx context.Context, authed *oauth.Auth, searchQuery *apimodel.SearchQuery) (*apimodel.SearchResult, gtserror.WithCode) // StatusCreate processes the given form to create a new status, returning the api model representation of that status if it's OK. - StatusCreate(authed *oauth.Auth, form *apimodel.AdvancedStatusCreateForm) (*apimodel.Status, error) + StatusCreate(ctx context.Context, authed *oauth.Auth, form *apimodel.AdvancedStatusCreateForm) (*apimodel.Status, error) // StatusDelete processes the delete of a given status, returning the deleted status if the delete goes through. - StatusDelete(authed *oauth.Auth, targetStatusID string) (*apimodel.Status, error) + StatusDelete(ctx context.Context, authed *oauth.Auth, targetStatusID string) (*apimodel.Status, error) // StatusFave processes the faving of a given status, returning the updated status if the fave goes through. - StatusFave(authed *oauth.Auth, targetStatusID string) (*apimodel.Status, error) + StatusFave(ctx context.Context, authed *oauth.Auth, targetStatusID string) (*apimodel.Status, error) // StatusBoost processes the boost/reblog of a given status, returning the newly-created boost if all is well. - StatusBoost(authed *oauth.Auth, targetStatusID string) (*apimodel.Status, gtserror.WithCode) + StatusBoost(ctx context.Context, authed *oauth.Auth, targetStatusID string) (*apimodel.Status, gtserror.WithCode) // StatusUnboost processes the unboost/unreblog of a given status, returning the status if all is well. - StatusUnboost(authed *oauth.Auth, targetStatusID string) (*apimodel.Status, gtserror.WithCode) + StatusUnboost(ctx context.Context, authed *oauth.Auth, targetStatusID string) (*apimodel.Status, gtserror.WithCode) // StatusBoostedBy returns a slice of accounts that have boosted the given status, filtered according to privacy settings. - StatusBoostedBy(authed *oauth.Auth, targetStatusID string) ([]*apimodel.Account, gtserror.WithCode) + StatusBoostedBy(ctx context.Context, authed *oauth.Auth, targetStatusID string) ([]*apimodel.Account, gtserror.WithCode) // StatusFavedBy returns a slice of accounts that have liked the given status, filtered according to privacy settings. - StatusFavedBy(authed *oauth.Auth, targetStatusID string) ([]*apimodel.Account, error) + StatusFavedBy(ctx context.Context, authed *oauth.Auth, targetStatusID string) ([]*apimodel.Account, error) // StatusGet gets the given status, taking account of privacy settings and blocks etc. - StatusGet(authed *oauth.Auth, targetStatusID string) (*apimodel.Status, error) + StatusGet(ctx context.Context, authed *oauth.Auth, targetStatusID string) (*apimodel.Status, error) // StatusUnfave processes the unfaving of a given status, returning the updated status if the fave goes through. - StatusUnfave(authed *oauth.Auth, targetStatusID string) (*apimodel.Status, error) + StatusUnfave(ctx context.Context, authed *oauth.Auth, targetStatusID string) (*apimodel.Status, error) // StatusGetContext returns the context (previous and following posts) from the given status ID - StatusGetContext(authed *oauth.Auth, targetStatusID string) (*apimodel.Context, gtserror.WithCode) + StatusGetContext(ctx context.Context, authed *oauth.Auth, targetStatusID string) (*apimodel.Context, gtserror.WithCode) // HomeTimelineGet returns statuses from the home timeline, with the given filters/parameters. - HomeTimelineGet(authed *oauth.Auth, maxID string, sinceID string, minID string, limit int, local bool) (*apimodel.StatusTimelineResponse, gtserror.WithCode) + HomeTimelineGet(ctx context.Context, authed *oauth.Auth, maxID string, sinceID string, minID string, limit int, local bool) (*apimodel.StatusTimelineResponse, gtserror.WithCode) // PublicTimelineGet returns statuses from the public/local timeline, with the given filters/parameters. - PublicTimelineGet(authed *oauth.Auth, maxID string, sinceID string, minID string, limit int, local bool) (*apimodel.StatusTimelineResponse, gtserror.WithCode) + PublicTimelineGet(ctx context.Context, authed *oauth.Auth, maxID string, sinceID string, minID string, limit int, local bool) (*apimodel.StatusTimelineResponse, gtserror.WithCode) // FavedTimelineGet returns faved statuses, with the given filters/parameters. - FavedTimelineGet(authed *oauth.Auth, maxID string, minID string, limit int) (*apimodel.StatusTimelineResponse, gtserror.WithCode) + FavedTimelineGet(ctx context.Context, authed *oauth.Auth, maxID string, minID string, limit int) (*apimodel.StatusTimelineResponse, gtserror.WithCode) // AuthorizeStreamingRequest returns a gotosocial account in exchange for an access token, or an error if the given token is not valid. - AuthorizeStreamingRequest(accessToken string) (*gtsmodel.Account, error) + AuthorizeStreamingRequest(ctx context.Context, accessToken string) (*gtsmodel.Account, error) // OpenStreamForAccount opens a new stream for the given account, with the given stream type. - OpenStreamForAccount(account *gtsmodel.Account, streamType string) (*gtsmodel.Stream, gtserror.WithCode) + OpenStreamForAccount(ctx context.Context, account *gtsmodel.Account, streamType string) (*gtsmodel.Stream, gtserror.WithCode) /* FEDERATION API-FACING PROCESSING FUNCTIONS @@ -199,10 +199,10 @@ type Processor interface { GetWebfingerAccount(ctx context.Context, requestedUsername string, requestURL *url.URL) (*apimodel.WellKnownResponse, gtserror.WithCode) // GetNodeInfoRel returns a well known response giving the path to node info. - GetNodeInfoRel(request *http.Request) (*apimodel.WellKnownResponse, gtserror.WithCode) + GetNodeInfoRel(ctx context.Context, request *http.Request) (*apimodel.WellKnownResponse, gtserror.WithCode) // GetNodeInfo returns a node info struct in response to a node info request. - GetNodeInfo(request *http.Request) (*apimodel.Nodeinfo, gtserror.WithCode) + GetNodeInfo(ctx context.Context, request *http.Request) (*apimodel.Nodeinfo, gtserror.WithCode) // InboxPost handles POST requests to a user's inbox for new activitypub messages. // @@ -280,7 +280,7 @@ func NewProcessor(config *config.Config, tc typeutils.TypeConverter, federator f } // Start starts the Processor, reading from its channels and passing messages back and forth. -func (p *processor) Start() error { +func (p *processor) Start(ctx context.Context) error { go func() { DistLoop: for { @@ -288,14 +288,14 @@ func (p *processor) Start() error { case clientMsg := <-p.fromClientAPI: p.log.Tracef("received message FROM client API: %+v", clientMsg) go func() { - if err := p.processFromClientAPI(clientMsg); err != nil { + if err := p.processFromClientAPI(ctx, clientMsg); err != nil { p.log.Error(err) } }() case federatorMsg := <-p.fromFederator: p.log.Tracef("received message FROM federator: %+v", federatorMsg) go func() { - if err := p.processFromFederator(federatorMsg); err != nil { + if err := p.processFromFederator(ctx, federatorMsg); err != nil { p.log.Error(err) } }() diff --git a/internal/processing/search.go b/internal/processing/search.go index f2ae721ae..768fceacd 100644 --- a/internal/processing/search.go +++ b/internal/processing/search.go @@ -19,6 +19,7 @@ package processing import ( + "context" "fmt" "net/url" "strings" @@ -32,7 +33,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/util" ) -func (p *processor) SearchGet(authed *oauth.Auth, searchQuery *apimodel.SearchQuery) (*apimodel.SearchResult, gtserror.WithCode) { +func (p *processor) SearchGet(ctx context.Context, authed *oauth.Auth, searchQuery *apimodel.SearchQuery) (*apimodel.SearchResult, gtserror.WithCode) { l := p.log.WithFields(logrus.Fields{ "func": "SearchGet", "query": searchQuery.Query, @@ -54,7 +55,7 @@ func (p *processor) SearchGet(authed *oauth.Auth, searchQuery *apimodel.SearchQu // check if the query is something like @whatever_username@example.org -- this means it's a remote account if !foundOne && util.IsMention(searchQuery.Query) { l.Debug("search term is a mention, looking it up...") - foundAccount, err := p.searchAccountByMention(authed, searchQuery.Query, searchQuery.Resolve) + foundAccount, err := p.searchAccountByMention(ctx, authed, searchQuery.Query, searchQuery.Resolve) if err == nil && foundAccount != nil { foundAccounts = append(foundAccounts, foundAccount) foundOne = true @@ -65,14 +66,14 @@ func (p *processor) SearchGet(authed *oauth.Auth, searchQuery *apimodel.SearchQu // check if the query is a URI and just do a lookup for that, straight up if uri, err := url.Parse(query); err == nil && !foundOne { // 1. check if it's a status - if foundStatus, err := p.searchStatusByURI(authed, uri, searchQuery.Resolve); err == nil && foundStatus != nil { + if foundStatus, err := p.searchStatusByURI(ctx, authed, uri, searchQuery.Resolve); err == nil && foundStatus != nil { foundStatuses = append(foundStatuses, foundStatus) foundOne = true l.Debug("got a status by searching by URI") } // 2. check if it's an account - if foundAccount, err := p.searchAccountByURI(authed, uri, searchQuery.Resolve); err == nil && foundAccount != nil { + if foundAccount, err := p.searchAccountByURI(ctx, authed, uri, searchQuery.Resolve); err == nil && foundAccount != nil { foundAccounts = append(foundAccounts, foundAccount) foundOne = true l.Debug("got an account by searching by URI") @@ -90,20 +91,20 @@ func (p *processor) SearchGet(authed *oauth.Auth, searchQuery *apimodel.SearchQu */ for _, foundAccount := range foundAccounts { // make sure there's no block in either direction between the account and the requester - if blocked, err := p.db.IsBlocked(authed.Account.ID, foundAccount.ID, true); err == nil && !blocked { + if blocked, err := p.db.IsBlocked(ctx, authed.Account.ID, foundAccount.ID, true); err == nil && !blocked { // all good, convert it and add it to the results - if acctMasto, err := p.tc.AccountToMastoPublic(foundAccount); err == nil && acctMasto != nil { + if acctMasto, err := p.tc.AccountToMastoPublic(ctx, foundAccount); err == nil && acctMasto != nil { results.Accounts = append(results.Accounts, *acctMasto) } } } for _, foundStatus := range foundStatuses { - if visible, err := p.filter.StatusVisible(foundStatus, authed.Account); !visible || err != nil { + if visible, err := p.filter.StatusVisible(ctx, foundStatus, authed.Account); !visible || err != nil { continue } - statusMasto, err := p.tc.StatusToMasto(foundStatus, authed.Account) + statusMasto, err := p.tc.StatusToMasto(ctx, foundStatus, authed.Account) if err != nil { continue } @@ -114,24 +115,24 @@ func (p *processor) SearchGet(authed *oauth.Auth, searchQuery *apimodel.SearchQu return results, nil } -func (p *processor) searchStatusByURI(authed *oauth.Auth, uri *url.URL, resolve bool) (*gtsmodel.Status, error) { +func (p *processor) searchStatusByURI(ctx context.Context, authed *oauth.Auth, uri *url.URL, resolve bool) (*gtsmodel.Status, error) { l := p.log.WithFields(logrus.Fields{ "func": "searchStatusByURI", "uri": uri.String(), "resolve": resolve, }) - if maybeStatus, err := p.db.GetStatusByURI(uri.String()); err == nil { + if maybeStatus, err := p.db.GetStatusByURI(ctx, uri.String()); err == nil { return maybeStatus, nil - } else if maybeStatus, err := p.db.GetStatusByURL(uri.String()); err == nil { + } else if maybeStatus, err := p.db.GetStatusByURL(ctx, uri.String()); err == nil { return maybeStatus, nil } // we don't have it locally so dereference it if we're allowed to if resolve { - status, _, _, err := p.federator.GetRemoteStatus(authed.Account.Username, uri, true) + status, _, _, err := p.federator.GetRemoteStatus(ctx, authed.Account.Username, uri, true) if err == nil { - if err := p.federator.DereferenceRemoteThread(authed.Account.Username, uri); err != nil { + if err := p.federator.DereferenceRemoteThread(ctx, authed.Account.Username, uri); err != nil { // try to deref the thread while we're here l.Debugf("searchStatusByURI: error dereferencing remote thread: %s", err) } @@ -141,16 +142,16 @@ func (p *processor) searchStatusByURI(authed *oauth.Auth, uri *url.URL, resolve return nil, nil } -func (p *processor) searchAccountByURI(authed *oauth.Auth, uri *url.URL, resolve bool) (*gtsmodel.Account, error) { - if maybeAccount, err := p.db.GetAccountByURI(uri.String()); err == nil { +func (p *processor) searchAccountByURI(ctx context.Context, authed *oauth.Auth, uri *url.URL, resolve bool) (*gtsmodel.Account, error) { + if maybeAccount, err := p.db.GetAccountByURI(ctx, uri.String()); err == nil { return maybeAccount, nil - } else if maybeAccount, err := p.db.GetAccountByURL(uri.String()); err == nil { + } else if maybeAccount, err := p.db.GetAccountByURL(ctx, uri.String()); err == nil { return maybeAccount, nil } if resolve { // we don't have it locally so try and dereference it - account, _, err := p.federator.GetRemoteAccount(authed.Account.Username, uri, true) + account, _, err := p.federator.GetRemoteAccount(ctx, authed.Account.Username, uri, true) if err != nil { return nil, fmt.Errorf("searchAccountByURI: error dereferencing account with uri %s: %s", uri.String(), err) } @@ -159,7 +160,7 @@ func (p *processor) searchAccountByURI(authed *oauth.Auth, uri *url.URL, resolve return nil, nil } -func (p *processor) searchAccountByMention(authed *oauth.Auth, mention string, resolve bool) (*gtsmodel.Account, error) { +func (p *processor) searchAccountByMention(ctx context.Context, authed *oauth.Auth, mention string, resolve bool) (*gtsmodel.Account, error) { // query is for a remote account username, domain, err := util.ExtractMentionParts(mention) if err != nil { @@ -169,7 +170,7 @@ func (p *processor) searchAccountByMention(authed *oauth.Auth, mention string, r // if it's a local account we can skip a whole bunch of stuff maybeAcct := >smodel.Account{} if domain == p.config.Host { - maybeAcct, err = p.db.GetLocalAccountByUsername(username) + maybeAcct, err = p.db.GetLocalAccountByUsername(ctx, username) if err != nil { return nil, fmt.Errorf("searchAccountByMention: error getting local account by username: %s", err) } @@ -181,7 +182,7 @@ func (p *processor) searchAccountByMention(authed *oauth.Auth, mention string, r {Key: "username", Value: username, CaseInsensitive: true}, {Key: "domain", Value: domain, CaseInsensitive: true}, } - err = p.db.GetWhere(where, maybeAcct) + err = p.db.GetWhere(ctx, where, maybeAcct) if err == nil { // we've got it stored locally already! return maybeAcct, nil @@ -197,14 +198,14 @@ func (p *processor) searchAccountByMention(authed *oauth.Auth, mention string, r // we're allowed to resolve it so let's try // first we need to webfinger the remote account to convert the username and domain into the activitypub URI for the account - acctURI, err := p.federator.FingerRemoteAccount(authed.Account.Username, username, domain) + acctURI, err := p.federator.FingerRemoteAccount(ctx, authed.Account.Username, username, domain) if err != nil { // something went wrong doing the webfinger lookup so we can't process the request return nil, fmt.Errorf("searchAccountByMention: error fingering remote account with username %s and domain %s: %s", username, domain, err) } // we don't have it locally so try and dereference it - account, _, err := p.federator.GetRemoteAccount(authed.Account.Username, acctURI, true) + account, _, err := p.federator.GetRemoteAccount(ctx, authed.Account.Username, acctURI, true) if err != nil { return nil, fmt.Errorf("searchAccountByMention: error dereferencing account with uri %s: %s", acctURI.String(), err) } diff --git a/internal/processing/status.go b/internal/processing/status.go index ab3843ded..c31c20628 100644 --- a/internal/processing/status.go +++ b/internal/processing/status.go @@ -19,47 +19,49 @@ package processing import ( + "context" + apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/oauth" ) -func (p *processor) StatusCreate(authed *oauth.Auth, form *apimodel.AdvancedStatusCreateForm) (*apimodel.Status, error) { - return p.statusProcessor.Create(authed.Account, authed.Application, form) +func (p *processor) StatusCreate(ctx context.Context, authed *oauth.Auth, form *apimodel.AdvancedStatusCreateForm) (*apimodel.Status, error) { + return p.statusProcessor.Create(ctx, authed.Account, authed.Application, form) } -func (p *processor) StatusDelete(authed *oauth.Auth, targetStatusID string) (*apimodel.Status, error) { - return p.statusProcessor.Delete(authed.Account, targetStatusID) +func (p *processor) StatusDelete(ctx context.Context, authed *oauth.Auth, targetStatusID string) (*apimodel.Status, error) { + return p.statusProcessor.Delete(ctx, authed.Account, targetStatusID) } -func (p *processor) StatusFave(authed *oauth.Auth, targetStatusID string) (*apimodel.Status, error) { - return p.statusProcessor.Fave(authed.Account, targetStatusID) +func (p *processor) StatusFave(ctx context.Context, authed *oauth.Auth, targetStatusID string) (*apimodel.Status, error) { + return p.statusProcessor.Fave(ctx, authed.Account, targetStatusID) } -func (p *processor) StatusBoost(authed *oauth.Auth, targetStatusID string) (*apimodel.Status, gtserror.WithCode) { - return p.statusProcessor.Boost(authed.Account, authed.Application, targetStatusID) +func (p *processor) StatusBoost(ctx context.Context, authed *oauth.Auth, targetStatusID string) (*apimodel.Status, gtserror.WithCode) { + return p.statusProcessor.Boost(ctx, authed.Account, authed.Application, targetStatusID) } -func (p *processor) StatusUnboost(authed *oauth.Auth, targetStatusID string) (*apimodel.Status, gtserror.WithCode) { - return p.statusProcessor.Unboost(authed.Account, authed.Application, targetStatusID) +func (p *processor) StatusUnboost(ctx context.Context, authed *oauth.Auth, targetStatusID string) (*apimodel.Status, gtserror.WithCode) { + return p.statusProcessor.Unboost(ctx, authed.Account, authed.Application, targetStatusID) } -func (p *processor) StatusBoostedBy(authed *oauth.Auth, targetStatusID string) ([]*apimodel.Account, gtserror.WithCode) { - return p.statusProcessor.BoostedBy(authed.Account, targetStatusID) +func (p *processor) StatusBoostedBy(ctx context.Context, authed *oauth.Auth, targetStatusID string) ([]*apimodel.Account, gtserror.WithCode) { + return p.statusProcessor.BoostedBy(ctx, authed.Account, targetStatusID) } -func (p *processor) StatusFavedBy(authed *oauth.Auth, targetStatusID string) ([]*apimodel.Account, error) { - return p.statusProcessor.FavedBy(authed.Account, targetStatusID) +func (p *processor) StatusFavedBy(ctx context.Context, authed *oauth.Auth, targetStatusID string) ([]*apimodel.Account, error) { + return p.statusProcessor.FavedBy(ctx, authed.Account, targetStatusID) } -func (p *processor) StatusGet(authed *oauth.Auth, targetStatusID string) (*apimodel.Status, error) { - return p.statusProcessor.Get(authed.Account, targetStatusID) +func (p *processor) StatusGet(ctx context.Context, authed *oauth.Auth, targetStatusID string) (*apimodel.Status, error) { + return p.statusProcessor.Get(ctx, authed.Account, targetStatusID) } -func (p *processor) StatusUnfave(authed *oauth.Auth, targetStatusID string) (*apimodel.Status, error) { - return p.statusProcessor.Unfave(authed.Account, targetStatusID) +func (p *processor) StatusUnfave(ctx context.Context, authed *oauth.Auth, targetStatusID string) (*apimodel.Status, error) { + return p.statusProcessor.Unfave(ctx, authed.Account, targetStatusID) } -func (p *processor) StatusGetContext(authed *oauth.Auth, targetStatusID string) (*apimodel.Context, gtserror.WithCode) { - return p.statusProcessor.Context(authed.Account, targetStatusID) +func (p *processor) StatusGetContext(ctx context.Context, authed *oauth.Auth, targetStatusID string) (*apimodel.Context, gtserror.WithCode) { + return p.statusProcessor.Context(ctx, authed.Account, targetStatusID) } diff --git a/internal/processing/status/boost.go b/internal/processing/status/boost.go index d7a62beb1..948d57a48 100644 --- a/internal/processing/status/boost.go +++ b/internal/processing/status/boost.go @@ -1,6 +1,25 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + package status import ( + "context" "errors" "fmt" @@ -9,8 +28,8 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) -func (p *processor) Boost(requestingAccount *gtsmodel.Account, application *gtsmodel.Application, targetStatusID string) (*apimodel.Status, gtserror.WithCode) { - targetStatus, err := p.db.GetStatusByID(targetStatusID) +func (p *processor) Boost(ctx context.Context, requestingAccount *gtsmodel.Account, application *gtsmodel.Application, targetStatusID string) (*apimodel.Status, gtserror.WithCode) { + targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID) if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err)) } @@ -18,7 +37,7 @@ func (p *processor) Boost(requestingAccount *gtsmodel.Account, application *gtsm return nil, gtserror.NewErrorNotFound(fmt.Errorf("no status owner for status %s", targetStatusID)) } - visible, err := p.filter.StatusVisible(targetStatus, requestingAccount) + visible, err := p.filter.StatusVisible(ctx, targetStatus, requestingAccount) if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("error seeing if status %s is visible: %s", targetStatus.ID, err)) } @@ -32,7 +51,7 @@ func (p *processor) Boost(requestingAccount *gtsmodel.Account, application *gtsm } // it's visible! it's boostable! so let's boost the FUCK out of it - boostWrapperStatus, err := p.tc.StatusToBoost(targetStatus, requestingAccount) + boostWrapperStatus, err := p.tc.StatusToBoost(ctx, targetStatus, requestingAccount) if err != nil { return nil, gtserror.NewErrorInternalError(err) } @@ -41,7 +60,7 @@ func (p *processor) Boost(requestingAccount *gtsmodel.Account, application *gtsm boostWrapperStatus.BoostOfAccount = targetStatus.Account // put the boost in the database - if err := p.db.PutStatus(boostWrapperStatus); err != nil { + if err := p.db.PutStatus(ctx, boostWrapperStatus); err != nil { return nil, gtserror.NewErrorInternalError(err) } @@ -55,7 +74,7 @@ func (p *processor) Boost(requestingAccount *gtsmodel.Account, application *gtsm } // return the frontend representation of the new status to the submitter - mastoStatus, err := p.tc.StatusToMasto(boostWrapperStatus, requestingAccount) + mastoStatus, err := p.tc.StatusToMasto(ctx, boostWrapperStatus, requestingAccount) if err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("error converting status %s to frontend representation: %s", targetStatus.ID, err)) } diff --git a/internal/processing/status/boostedby.go b/internal/processing/status/boostedby.go index 1bde6b5ae..46f41039f 100644 --- a/internal/processing/status/boostedby.go +++ b/internal/processing/status/boostedby.go @@ -1,6 +1,25 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + package status import ( + "context" "errors" "fmt" @@ -9,8 +28,8 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) -func (p *processor) BoostedBy(requestingAccount *gtsmodel.Account, targetStatusID string) ([]*apimodel.Account, gtserror.WithCode) { - targetStatus, err := p.db.GetStatusByID(targetStatusID) +func (p *processor) BoostedBy(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) ([]*apimodel.Account, gtserror.WithCode) { + targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID) if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err)) } @@ -18,7 +37,7 @@ func (p *processor) BoostedBy(requestingAccount *gtsmodel.Account, targetStatusI return nil, gtserror.NewErrorNotFound(fmt.Errorf("no status owner for status %s", targetStatusID)) } - visible, err := p.filter.StatusVisible(targetStatus, requestingAccount) + visible, err := p.filter.StatusVisible(ctx, targetStatus, requestingAccount) if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("error seeing if status %s is visible: %s", targetStatus.ID, err)) } @@ -26,7 +45,7 @@ func (p *processor) BoostedBy(requestingAccount *gtsmodel.Account, targetStatusI return nil, gtserror.NewErrorNotFound(errors.New("status is not visible")) } - statusReblogs, err := p.db.GetStatusReblogs(targetStatus) + statusReblogs, err := p.db.GetStatusReblogs(ctx, targetStatus) if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("StatusBoostedBy: error seeing who boosted status: %s", err)) } @@ -34,7 +53,7 @@ func (p *processor) BoostedBy(requestingAccount *gtsmodel.Account, targetStatusI // filter the list so the user doesn't see accounts they blocked or which blocked them filteredAccounts := []*gtsmodel.Account{} for _, s := range statusReblogs { - blocked, err := p.db.IsBlocked(requestingAccount.ID, s.AccountID, true) + blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, s.AccountID, true) if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("StatusBoostedBy: error checking blocks: %s", err)) } @@ -48,7 +67,7 @@ func (p *processor) BoostedBy(requestingAccount *gtsmodel.Account, targetStatusI // now we can return the masto representation of those accounts mastoAccounts := []*apimodel.Account{} for _, acc := range filteredAccounts { - mastoAccount, err := p.tc.AccountToMastoPublic(acc) + mastoAccount, err := p.tc.AccountToMastoPublic(ctx, acc) if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("StatusFavedBy: error converting account to api model: %s", err)) } diff --git a/internal/processing/status/context.go b/internal/processing/status/context.go index 43002545e..3e8e93d09 100644 --- a/internal/processing/status/context.go +++ b/internal/processing/status/context.go @@ -1,6 +1,25 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + package status import ( + "context" "errors" "fmt" "sort" @@ -10,8 +29,8 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) -func (p *processor) Context(requestingAccount *gtsmodel.Account, targetStatusID string) (*apimodel.Context, gtserror.WithCode) { - targetStatus, err := p.db.GetStatusByID(targetStatusID) +func (p *processor) Context(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) (*apimodel.Context, gtserror.WithCode) { + targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID) if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err)) } @@ -19,7 +38,7 @@ func (p *processor) Context(requestingAccount *gtsmodel.Account, targetStatusID return nil, gtserror.NewErrorNotFound(fmt.Errorf("no status owner for status %s", targetStatusID)) } - visible, err := p.filter.StatusVisible(targetStatus, requestingAccount) + visible, err := p.filter.StatusVisible(ctx, targetStatus, requestingAccount) if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("error seeing if status %s is visible: %s", targetStatus.ID, err)) } @@ -32,14 +51,14 @@ func (p *processor) Context(requestingAccount *gtsmodel.Account, targetStatusID Descendants: []apimodel.Status{}, } - parents, err := p.db.GetStatusParents(targetStatus, false) + parents, err := p.db.GetStatusParents(ctx, targetStatus, false) if err != nil { return nil, gtserror.NewErrorInternalError(err) } for _, status := range parents { - if v, err := p.filter.StatusVisible(status, requestingAccount); err == nil && v { - mastoStatus, err := p.tc.StatusToMasto(status, requestingAccount) + if v, err := p.filter.StatusVisible(ctx, status, requestingAccount); err == nil && v { + mastoStatus, err := p.tc.StatusToMasto(ctx, status, requestingAccount) if err == nil { context.Ancestors = append(context.Ancestors, *mastoStatus) } @@ -50,14 +69,14 @@ func (p *processor) Context(requestingAccount *gtsmodel.Account, targetStatusID return context.Ancestors[i].ID < context.Ancestors[j].ID }) - children, err := p.db.GetStatusChildren(targetStatus, false, "") + children, err := p.db.GetStatusChildren(ctx, targetStatus, false, "") if err != nil { return nil, gtserror.NewErrorInternalError(err) } for _, status := range children { - if v, err := p.filter.StatusVisible(status, requestingAccount); err == nil && v { - mastoStatus, err := p.tc.StatusToMasto(status, requestingAccount) + if v, err := p.filter.StatusVisible(ctx, status, requestingAccount); err == nil && v { + mastoStatus, err := p.tc.StatusToMasto(ctx, status, requestingAccount) if err == nil { context.Descendants = append(context.Descendants, *mastoStatus) } diff --git a/internal/processing/status/create.go b/internal/processing/status/create.go index fc112ed8b..2e0b30ad8 100644 --- a/internal/processing/status/create.go +++ b/internal/processing/status/create.go @@ -1,6 +1,25 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + package status import ( + "context" "fmt" "time" @@ -12,7 +31,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/util" ) -func (p *processor) Create(account *gtsmodel.Account, application *gtsmodel.Application, form *apimodel.AdvancedStatusCreateForm) (*apimodel.Status, gtserror.WithCode) { +func (p *processor) Create(ctx context.Context, account *gtsmodel.Account, application *gtsmodel.Application, form *apimodel.AdvancedStatusCreateForm) (*apimodel.Status, gtserror.WithCode) { uris := util.GenerateURIsForAccount(account.Username, p.config.Protocol, p.config.Host) thisStatusID, err := id.NewULID() if err != nil { @@ -38,40 +57,40 @@ func (p *processor) Create(account *gtsmodel.Account, application *gtsmodel.Appl Text: form.Status, } - if err := p.ProcessReplyToID(form, account.ID, newStatus); err != nil { + if err := p.ProcessReplyToID(ctx, form, account.ID, newStatus); err != nil { return nil, gtserror.NewErrorInternalError(err) } - if err := p.ProcessMediaIDs(form, account.ID, newStatus); err != nil { + if err := p.ProcessMediaIDs(ctx, form, account.ID, newStatus); err != nil { return nil, gtserror.NewErrorInternalError(err) } - if err := p.ProcessVisibility(form, account.Privacy, newStatus); err != nil { + if err := p.ProcessVisibility(ctx, form, account.Privacy, newStatus); err != nil { return nil, gtserror.NewErrorInternalError(err) } - if err := p.ProcessLanguage(form, account.Language, newStatus); err != nil { + if err := p.ProcessLanguage(ctx, form, account.Language, newStatus); err != nil { return nil, gtserror.NewErrorInternalError(err) } - if err := p.ProcessMentions(form, account.ID, newStatus); err != nil { + if err := p.ProcessMentions(ctx, form, account.ID, newStatus); err != nil { return nil, gtserror.NewErrorInternalError(err) } - if err := p.ProcessTags(form, account.ID, newStatus); err != nil { + if err := p.ProcessTags(ctx, form, account.ID, newStatus); err != nil { return nil, gtserror.NewErrorInternalError(err) } - if err := p.ProcessEmojis(form, account.ID, newStatus); err != nil { + if err := p.ProcessEmojis(ctx, form, account.ID, newStatus); err != nil { return nil, gtserror.NewErrorInternalError(err) } - if err := p.ProcessContent(form, account.ID, newStatus); err != nil { + if err := p.ProcessContent(ctx, form, account.ID, newStatus); err != nil { return nil, gtserror.NewErrorInternalError(err) } // put the new status in the database - if err := p.db.PutStatus(newStatus); err != nil { + if err := p.db.PutStatus(ctx, newStatus); err != nil { return nil, gtserror.NewErrorInternalError(err) } @@ -84,7 +103,7 @@ func (p *processor) Create(account *gtsmodel.Account, application *gtsmodel.Appl } // return the frontend representation of the new status to the submitter - mastoStatus, err := p.tc.StatusToMasto(newStatus, account) + mastoStatus, err := p.tc.StatusToMasto(ctx, newStatus, account) if err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("error converting status %s to frontend representation: %s", newStatus.ID, err)) } diff --git a/internal/processing/status/delete.go b/internal/processing/status/delete.go index 4c5dfd744..daa7a934f 100644 --- a/internal/processing/status/delete.go +++ b/internal/processing/status/delete.go @@ -1,6 +1,25 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + package status import ( + "context" "errors" "fmt" @@ -9,8 +28,8 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) -func (p *processor) Delete(requestingAccount *gtsmodel.Account, targetStatusID string) (*apimodel.Status, gtserror.WithCode) { - targetStatus, err := p.db.GetStatusByID(targetStatusID) +func (p *processor) Delete(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) (*apimodel.Status, gtserror.WithCode) { + targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID) if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err)) } @@ -22,12 +41,12 @@ func (p *processor) Delete(requestingAccount *gtsmodel.Account, targetStatusID s return nil, gtserror.NewErrorForbidden(errors.New("status doesn't belong to requesting account")) } - mastoStatus, err := p.tc.StatusToMasto(targetStatus, requestingAccount) + mastoStatus, err := p.tc.StatusToMasto(ctx, targetStatus, requestingAccount) if err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("error converting status %s to frontend representation: %s", targetStatus.ID, err)) } - if err := p.db.DeleteByID(targetStatus.ID, >smodel.Status{}); err != nil { + if err := p.db.DeleteByID(ctx, targetStatus.ID, >smodel.Status{}); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("error deleting status from the database: %s", err)) } diff --git a/internal/processing/status/fave.go b/internal/processing/status/fave.go index 7ba8c8fe8..2badf83b3 100644 --- a/internal/processing/status/fave.go +++ b/internal/processing/status/fave.go @@ -1,6 +1,25 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + package status import ( + "context" "errors" "fmt" @@ -12,8 +31,8 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/util" ) -func (p *processor) Fave(requestingAccount *gtsmodel.Account, targetStatusID string) (*apimodel.Status, gtserror.WithCode) { - targetStatus, err := p.db.GetStatusByID(targetStatusID) +func (p *processor) Fave(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) (*apimodel.Status, gtserror.WithCode) { + targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID) if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err)) } @@ -21,7 +40,7 @@ func (p *processor) Fave(requestingAccount *gtsmodel.Account, targetStatusID str return nil, gtserror.NewErrorNotFound(fmt.Errorf("no status owner for status %s", targetStatusID)) } - visible, err := p.filter.StatusVisible(targetStatus, requestingAccount) + visible, err := p.filter.StatusVisible(ctx, targetStatus, requestingAccount) if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("error seeing if status %s is visible: %s", targetStatus.ID, err)) } @@ -37,7 +56,7 @@ func (p *processor) Fave(requestingAccount *gtsmodel.Account, targetStatusID str // first check if the status is already faved, if so we don't need to do anything newFave := true gtsFave := >smodel.StatusFave{} - if err := p.db.GetWhere([]db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsFave); err == nil { + if err := p.db.GetWhere(ctx, []db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsFave); err == nil { // we already have a fave for this status newFave = false } @@ -60,7 +79,7 @@ func (p *processor) Fave(requestingAccount *gtsmodel.Account, targetStatusID str URI: util.GenerateURIForLike(requestingAccount.Username, p.config.Protocol, p.config.Host, thisFaveID), } - if err := p.db.Put(gtsFave); err != nil { + if err := p.db.Put(ctx, gtsFave); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("error putting fave in database: %s", err)) } @@ -75,7 +94,7 @@ func (p *processor) Fave(requestingAccount *gtsmodel.Account, targetStatusID str } // return the mastodon representation of the target status - mastoStatus, err := p.tc.StatusToMasto(targetStatus, requestingAccount) + mastoStatus, err := p.tc.StatusToMasto(ctx, targetStatus, requestingAccount) if err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("error converting status %s to frontend representation: %s", targetStatus.ID, err)) } diff --git a/internal/processing/status/favedby.go b/internal/processing/status/favedby.go index dffe6bba9..227fb669d 100644 --- a/internal/processing/status/favedby.go +++ b/internal/processing/status/favedby.go @@ -1,6 +1,25 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + package status import ( + "context" "errors" "fmt" @@ -9,8 +28,8 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) -func (p *processor) FavedBy(requestingAccount *gtsmodel.Account, targetStatusID string) ([]*apimodel.Account, gtserror.WithCode) { - targetStatus, err := p.db.GetStatusByID(targetStatusID) +func (p *processor) FavedBy(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) ([]*apimodel.Account, gtserror.WithCode) { + targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID) if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err)) } @@ -18,7 +37,7 @@ func (p *processor) FavedBy(requestingAccount *gtsmodel.Account, targetStatusID return nil, gtserror.NewErrorNotFound(fmt.Errorf("no status owner for status %s", targetStatusID)) } - visible, err := p.filter.StatusVisible(targetStatus, requestingAccount) + visible, err := p.filter.StatusVisible(ctx, targetStatus, requestingAccount) if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("error seeing if status %s is visible: %s", targetStatus.ID, err)) } @@ -26,7 +45,7 @@ func (p *processor) FavedBy(requestingAccount *gtsmodel.Account, targetStatusID return nil, gtserror.NewErrorNotFound(errors.New("status is not visible")) } - statusFaves, err := p.db.GetStatusFaves(targetStatus) + statusFaves, err := p.db.GetStatusFaves(ctx, targetStatus) if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("error seeing who faved status: %s", err)) } @@ -34,7 +53,7 @@ func (p *processor) FavedBy(requestingAccount *gtsmodel.Account, targetStatusID // filter the list so the user doesn't see accounts they blocked or which blocked them filteredAccounts := []*gtsmodel.Account{} for _, fave := range statusFaves { - blocked, err := p.db.IsBlocked(requestingAccount.ID, fave.AccountID, true) + blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, fave.AccountID, true) if err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("error checking blocks: %s", err)) } @@ -46,7 +65,7 @@ func (p *processor) FavedBy(requestingAccount *gtsmodel.Account, targetStatusID // now we can return the masto representation of those accounts mastoAccounts := []*apimodel.Account{} for _, acc := range filteredAccounts { - mastoAccount, err := p.tc.AccountToMastoPublic(acc) + mastoAccount, err := p.tc.AccountToMastoPublic(ctx, acc) if err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("error converting status %s to frontend representation: %s", targetStatus.ID, err)) } diff --git a/internal/processing/status/get.go b/internal/processing/status/get.go index 9d403b901..258210faf 100644 --- a/internal/processing/status/get.go +++ b/internal/processing/status/get.go @@ -1,6 +1,25 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + package status import ( + "context" "errors" "fmt" @@ -9,8 +28,8 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) -func (p *processor) Get(requestingAccount *gtsmodel.Account, targetStatusID string) (*apimodel.Status, gtserror.WithCode) { - targetStatus, err := p.db.GetStatusByID(targetStatusID) +func (p *processor) Get(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) (*apimodel.Status, gtserror.WithCode) { + targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID) if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err)) } @@ -18,7 +37,7 @@ func (p *processor) Get(requestingAccount *gtsmodel.Account, targetStatusID stri return nil, gtserror.NewErrorNotFound(fmt.Errorf("no status owner for status %s", targetStatusID)) } - visible, err := p.filter.StatusVisible(targetStatus, requestingAccount) + visible, err := p.filter.StatusVisible(ctx, targetStatus, requestingAccount) if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("error seeing if status %s is visible: %s", targetStatus.ID, err)) } @@ -26,7 +45,7 @@ func (p *processor) Get(requestingAccount *gtsmodel.Account, targetStatusID stri return nil, gtserror.NewErrorNotFound(errors.New("status is not visible")) } - mastoStatus, err := p.tc.StatusToMasto(targetStatus, requestingAccount) + mastoStatus, err := p.tc.StatusToMasto(ctx, targetStatus, requestingAccount) if err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("error converting status %s to frontend representation: %s", targetStatus.ID, err)) } diff --git a/internal/processing/status/status.go b/internal/processing/status/status.go index 038ca005e..37790d062 100644 --- a/internal/processing/status/status.go +++ b/internal/processing/status/status.go @@ -1,6 +1,26 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + package status import ( + "context" + "github.com/sirupsen/logrus" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" "github.com/superseriousbusiness/gotosocial/internal/config" @@ -15,38 +35,38 @@ import ( // Processor wraps a bunch of functions for processing statuses. type Processor interface { // Create processes the given form to create a new status, returning the api model representation of that status if it's OK. - Create(account *gtsmodel.Account, application *gtsmodel.Application, form *apimodel.AdvancedStatusCreateForm) (*apimodel.Status, gtserror.WithCode) + Create(ctx context.Context, account *gtsmodel.Account, application *gtsmodel.Application, form *apimodel.AdvancedStatusCreateForm) (*apimodel.Status, gtserror.WithCode) // Delete processes the delete of a given status, returning the deleted status if the delete goes through. - Delete(account *gtsmodel.Account, targetStatusID string) (*apimodel.Status, gtserror.WithCode) + Delete(ctx context.Context, account *gtsmodel.Account, targetStatusID string) (*apimodel.Status, gtserror.WithCode) // Fave processes the faving of a given status, returning the updated status if the fave goes through. - Fave(account *gtsmodel.Account, targetStatusID string) (*apimodel.Status, gtserror.WithCode) + Fave(ctx context.Context, account *gtsmodel.Account, targetStatusID string) (*apimodel.Status, gtserror.WithCode) // Boost processes the boost/reblog of a given status, returning the newly-created boost if all is well. - Boost(account *gtsmodel.Account, application *gtsmodel.Application, targetStatusID string) (*apimodel.Status, gtserror.WithCode) + Boost(ctx context.Context, account *gtsmodel.Account, application *gtsmodel.Application, targetStatusID string) (*apimodel.Status, gtserror.WithCode) // Unboost processes the unboost/unreblog of a given status, returning the status if all is well. - Unboost(account *gtsmodel.Account, application *gtsmodel.Application, targetStatusID string) (*apimodel.Status, gtserror.WithCode) + Unboost(ctx context.Context, account *gtsmodel.Account, application *gtsmodel.Application, targetStatusID string) (*apimodel.Status, gtserror.WithCode) // BoostedBy returns a slice of accounts that have boosted the given status, filtered according to privacy settings. - BoostedBy(account *gtsmodel.Account, targetStatusID string) ([]*apimodel.Account, gtserror.WithCode) + BoostedBy(ctx context.Context, account *gtsmodel.Account, targetStatusID string) ([]*apimodel.Account, gtserror.WithCode) // FavedBy returns a slice of accounts that have liked the given status, filtered according to privacy settings. - FavedBy(account *gtsmodel.Account, targetStatusID string) ([]*apimodel.Account, gtserror.WithCode) + FavedBy(ctx context.Context, account *gtsmodel.Account, targetStatusID string) ([]*apimodel.Account, gtserror.WithCode) // Get gets the given status, taking account of privacy settings and blocks etc. - Get(account *gtsmodel.Account, targetStatusID string) (*apimodel.Status, gtserror.WithCode) + Get(ctx context.Context, account *gtsmodel.Account, targetStatusID string) (*apimodel.Status, gtserror.WithCode) // Unfave processes the unfaving of a given status, returning the updated status if the fave goes through. - Unfave(account *gtsmodel.Account, targetStatusID string) (*apimodel.Status, gtserror.WithCode) + Unfave(ctx context.Context, account *gtsmodel.Account, targetStatusID string) (*apimodel.Status, gtserror.WithCode) // Context returns the context (previous and following posts) from the given status ID - Context(account *gtsmodel.Account, targetStatusID string) (*apimodel.Context, gtserror.WithCode) + Context(ctx context.Context, account *gtsmodel.Account, targetStatusID string) (*apimodel.Context, gtserror.WithCode) /* PROCESSING UTILS */ - ProcessVisibility(form *apimodel.AdvancedStatusCreateForm, accountDefaultVis gtsmodel.Visibility, status *gtsmodel.Status) error - ProcessReplyToID(form *apimodel.AdvancedStatusCreateForm, thisAccountID string, status *gtsmodel.Status) error - ProcessMediaIDs(form *apimodel.AdvancedStatusCreateForm, thisAccountID string, status *gtsmodel.Status) error - ProcessLanguage(form *apimodel.AdvancedStatusCreateForm, accountDefaultLanguage string, status *gtsmodel.Status) error - ProcessMentions(form *apimodel.AdvancedStatusCreateForm, accountID string, status *gtsmodel.Status) error - ProcessTags(form *apimodel.AdvancedStatusCreateForm, accountID string, status *gtsmodel.Status) error - ProcessEmojis(form *apimodel.AdvancedStatusCreateForm, accountID string, status *gtsmodel.Status) error - ProcessContent(form *apimodel.AdvancedStatusCreateForm, accountID string, status *gtsmodel.Status) error + ProcessVisibility(ctx context.Context, form *apimodel.AdvancedStatusCreateForm, accountDefaultVis gtsmodel.Visibility, status *gtsmodel.Status) error + ProcessReplyToID(ctx context.Context, form *apimodel.AdvancedStatusCreateForm, thisAccountID string, status *gtsmodel.Status) error + ProcessMediaIDs(ctx context.Context, form *apimodel.AdvancedStatusCreateForm, thisAccountID string, status *gtsmodel.Status) error + ProcessLanguage(ctx context.Context, form *apimodel.AdvancedStatusCreateForm, accountDefaultLanguage string, status *gtsmodel.Status) error + ProcessMentions(ctx context.Context, form *apimodel.AdvancedStatusCreateForm, accountID string, status *gtsmodel.Status) error + ProcessTags(ctx context.Context, form *apimodel.AdvancedStatusCreateForm, accountID string, status *gtsmodel.Status) error + ProcessEmojis(ctx context.Context, form *apimodel.AdvancedStatusCreateForm, accountID string, status *gtsmodel.Status) error + ProcessContent(ctx context.Context, form *apimodel.AdvancedStatusCreateForm, accountID string, status *gtsmodel.Status) error } type processor struct { diff --git a/internal/processing/status/unboost.go b/internal/processing/status/unboost.go index 254cfe11f..c3c667a71 100644 --- a/internal/processing/status/unboost.go +++ b/internal/processing/status/unboost.go @@ -1,6 +1,25 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + package status import ( + "context" "errors" "fmt" @@ -10,8 +29,8 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) -func (p *processor) Unboost(requestingAccount *gtsmodel.Account, application *gtsmodel.Application, targetStatusID string) (*apimodel.Status, gtserror.WithCode) { - targetStatus, err := p.db.GetStatusByID(targetStatusID) +func (p *processor) Unboost(ctx context.Context, requestingAccount *gtsmodel.Account, application *gtsmodel.Application, targetStatusID string) (*apimodel.Status, gtserror.WithCode) { + targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID) if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err)) } @@ -19,7 +38,7 @@ func (p *processor) Unboost(requestingAccount *gtsmodel.Account, application *gt return nil, gtserror.NewErrorNotFound(fmt.Errorf("no status owner for status %s", targetStatusID)) } - visible, err := p.filter.StatusVisible(targetStatus, requestingAccount) + visible, err := p.filter.StatusVisible(ctx, targetStatus, requestingAccount) if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("error seeing if status %s is visible: %s", targetStatus.ID, err)) } @@ -41,7 +60,7 @@ func (p *processor) Unboost(requestingAccount *gtsmodel.Account, application *gt Value: requestingAccount.ID, }, } - err = p.db.GetWhere(where, gtsBoost) + err = p.db.GetWhere(ctx, where, gtsBoost) if err == nil { // we have a boost toUnboost = true @@ -58,7 +77,7 @@ func (p *processor) Unboost(requestingAccount *gtsmodel.Account, application *gt if toUnboost { // we had a boost, so take some action to get rid of it - if err := p.db.DeleteWhere(where, >smodel.Status{}); err != nil { + if err := p.db.DeleteWhere(ctx, where, >smodel.Status{}); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("error unboosting status: %s", err)) } @@ -79,7 +98,7 @@ func (p *processor) Unboost(requestingAccount *gtsmodel.Account, application *gt } } - mastoStatus, err := p.tc.StatusToMasto(targetStatus, requestingAccount) + mastoStatus, err := p.tc.StatusToMasto(ctx, targetStatus, requestingAccount) if err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("error converting status %s to frontend representation: %s", targetStatus.ID, err)) } diff --git a/internal/processing/status/unfave.go b/internal/processing/status/unfave.go index d6e5320db..3d079e2ff 100644 --- a/internal/processing/status/unfave.go +++ b/internal/processing/status/unfave.go @@ -1,6 +1,25 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + package status import ( + "context" "errors" "fmt" @@ -10,8 +29,8 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) -func (p *processor) Unfave(requestingAccount *gtsmodel.Account, targetStatusID string) (*apimodel.Status, gtserror.WithCode) { - targetStatus, err := p.db.GetStatusByID(targetStatusID) +func (p *processor) Unfave(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) (*apimodel.Status, gtserror.WithCode) { + targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID) if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err)) } @@ -19,7 +38,7 @@ func (p *processor) Unfave(requestingAccount *gtsmodel.Account, targetStatusID s return nil, gtserror.NewErrorNotFound(fmt.Errorf("no status owner for status %s", targetStatusID)) } - visible, err := p.filter.StatusVisible(targetStatus, requestingAccount) + visible, err := p.filter.StatusVisible(ctx, targetStatus, requestingAccount) if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("error seeing if status %s is visible: %s", targetStatus.ID, err)) } @@ -31,7 +50,7 @@ func (p *processor) Unfave(requestingAccount *gtsmodel.Account, targetStatusID s var toUnfave bool gtsFave := >smodel.StatusFave{} - err = p.db.GetWhere([]db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsFave) + err = p.db.GetWhere(ctx, []db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsFave) if err == nil { // we have a fave toUnfave = true @@ -47,7 +66,7 @@ func (p *processor) Unfave(requestingAccount *gtsmodel.Account, targetStatusID s if toUnfave { // we had a fave, so take some action to get rid of it - if err := p.db.DeleteWhere([]db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsFave); err != nil { + if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsFave); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("error unfaveing status: %s", err)) } @@ -61,7 +80,7 @@ func (p *processor) Unfave(requestingAccount *gtsmodel.Account, targetStatusID s } } - mastoStatus, err := p.tc.StatusToMasto(targetStatus, requestingAccount) + mastoStatus, err := p.tc.StatusToMasto(ctx, targetStatus, requestingAccount) if err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("error converting status %s to frontend representation: %s", targetStatus.ID, err)) } diff --git a/internal/processing/status/util.go b/internal/processing/status/util.go index 025607f4a..26ee5d4f7 100644 --- a/internal/processing/status/util.go +++ b/internal/processing/status/util.go @@ -1,6 +1,25 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + package status import ( + "context" "errors" "fmt" @@ -12,7 +31,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/util" ) -func (p *processor) ProcessVisibility(form *apimodel.AdvancedStatusCreateForm, accountDefaultVis gtsmodel.Visibility, status *gtsmodel.Status) error { +func (p *processor) ProcessVisibility(ctx context.Context, form *apimodel.AdvancedStatusCreateForm, accountDefaultVis gtsmodel.Visibility, status *gtsmodel.Status) error { // by default all flags are set to true gtsAdvancedVis := >smodel.VisibilityAdvanced{ Federated: true, @@ -83,7 +102,7 @@ func (p *processor) ProcessVisibility(form *apimodel.AdvancedStatusCreateForm, a return nil } -func (p *processor) ProcessReplyToID(form *apimodel.AdvancedStatusCreateForm, thisAccountID string, status *gtsmodel.Status) error { +func (p *processor) ProcessReplyToID(ctx context.Context, form *apimodel.AdvancedStatusCreateForm, thisAccountID string, status *gtsmodel.Status) error { if form.InReplyToID == "" { return nil } @@ -98,7 +117,7 @@ func (p *processor) ProcessReplyToID(form *apimodel.AdvancedStatusCreateForm, th repliedStatus := >smodel.Status{} repliedAccount := >smodel.Account{} // check replied status exists + is replyable - if err := p.db.GetByID(form.InReplyToID, repliedStatus); err != nil { + if err := p.db.GetByID(ctx, form.InReplyToID, repliedStatus); err != nil { if err == db.ErrNoEntries { return fmt.Errorf("status with id %s not replyable because it doesn't exist", form.InReplyToID) } @@ -112,14 +131,14 @@ func (p *processor) ProcessReplyToID(form *apimodel.AdvancedStatusCreateForm, th } // check replied account is known to us - if err := p.db.GetByID(repliedStatus.AccountID, repliedAccount); err != nil { + if err := p.db.GetByID(ctx, repliedStatus.AccountID, repliedAccount); err != nil { if err == db.ErrNoEntries { return fmt.Errorf("status with id %s not replyable because account id %s is not known", form.InReplyToID, repliedStatus.AccountID) } return fmt.Errorf("status with id %s not replyable: %s", form.InReplyToID, err) } // check if a block exists - if blocked, err := p.db.IsBlocked(thisAccountID, repliedAccount.ID, true); err != nil { + if blocked, err := p.db.IsBlocked(ctx, thisAccountID, repliedAccount.ID, true); err != nil { if err != db.ErrNoEntries { return fmt.Errorf("status with id %s not replyable: %s", form.InReplyToID, err) } @@ -132,7 +151,7 @@ func (p *processor) ProcessReplyToID(form *apimodel.AdvancedStatusCreateForm, th return nil } -func (p *processor) ProcessMediaIDs(form *apimodel.AdvancedStatusCreateForm, thisAccountID string, status *gtsmodel.Status) error { +func (p *processor) ProcessMediaIDs(ctx context.Context, form *apimodel.AdvancedStatusCreateForm, thisAccountID string, status *gtsmodel.Status) error { if form.MediaIDs == nil { return nil } @@ -142,7 +161,7 @@ func (p *processor) ProcessMediaIDs(form *apimodel.AdvancedStatusCreateForm, thi for _, mediaID := range form.MediaIDs { // check these attachments exist a := >smodel.MediaAttachment{} - if err := p.db.GetByID(mediaID, a); err != nil { + if err := p.db.GetByID(ctx, mediaID, a); err != nil { return fmt.Errorf("invalid media type or media not found for media id %s", mediaID) } // check they belong to the requesting account id @@ -161,7 +180,7 @@ func (p *processor) ProcessMediaIDs(form *apimodel.AdvancedStatusCreateForm, thi return nil } -func (p *processor) ProcessLanguage(form *apimodel.AdvancedStatusCreateForm, accountDefaultLanguage string, status *gtsmodel.Status) error { +func (p *processor) ProcessLanguage(ctx context.Context, form *apimodel.AdvancedStatusCreateForm, accountDefaultLanguage string, status *gtsmodel.Status) error { if form.Language != "" { status.Language = form.Language } else { @@ -173,9 +192,9 @@ func (p *processor) ProcessLanguage(form *apimodel.AdvancedStatusCreateForm, acc return nil } -func (p *processor) ProcessMentions(form *apimodel.AdvancedStatusCreateForm, accountID string, status *gtsmodel.Status) error { +func (p *processor) ProcessMentions(ctx context.Context, form *apimodel.AdvancedStatusCreateForm, accountID string, status *gtsmodel.Status) error { menchies := []string{} - gtsMenchies, err := p.db.MentionStringsToMentions(util.DeriveMentionsFromStatus(form.Status), accountID, status.ID) + gtsMenchies, err := p.db.MentionStringsToMentions(ctx, util.DeriveMentionsFromStatus(form.Status), accountID, status.ID) if err != nil { return fmt.Errorf("error generating mentions from status: %s", err) } @@ -186,7 +205,7 @@ func (p *processor) ProcessMentions(form *apimodel.AdvancedStatusCreateForm, acc } menchie.ID = menchieID - if err := p.db.Put(menchie); err != nil { + if err := p.db.Put(ctx, menchie); err != nil { return fmt.Errorf("error putting mentions in db: %s", err) } menchies = append(menchies, menchie.ID) @@ -198,14 +217,14 @@ func (p *processor) ProcessMentions(form *apimodel.AdvancedStatusCreateForm, acc return nil } -func (p *processor) ProcessTags(form *apimodel.AdvancedStatusCreateForm, accountID string, status *gtsmodel.Status) error { +func (p *processor) ProcessTags(ctx context.Context, form *apimodel.AdvancedStatusCreateForm, accountID string, status *gtsmodel.Status) error { tags := []string{} - gtsTags, err := p.db.TagStringsToTags(util.DeriveHashtagsFromStatus(form.Status), accountID, status.ID) + gtsTags, err := p.db.TagStringsToTags(ctx, util.DeriveHashtagsFromStatus(form.Status), accountID, status.ID) if err != nil { return fmt.Errorf("error generating hashtags from status: %s", err) } for _, tag := range gtsTags { - if err := p.db.Upsert(tag, "name"); err != nil { + if err := p.db.Put(ctx, tag); err != nil && err != db.ErrAlreadyExists { return fmt.Errorf("error putting tags in db: %s", err) } tags = append(tags, tag.ID) @@ -217,9 +236,9 @@ func (p *processor) ProcessTags(form *apimodel.AdvancedStatusCreateForm, account return nil } -func (p *processor) ProcessEmojis(form *apimodel.AdvancedStatusCreateForm, accountID string, status *gtsmodel.Status) error { +func (p *processor) ProcessEmojis(ctx context.Context, form *apimodel.AdvancedStatusCreateForm, accountID string, status *gtsmodel.Status) error { emojis := []string{} - gtsEmojis, err := p.db.EmojiStringsToEmojis(util.DeriveEmojisFromStatus(form.Status), accountID, status.ID) + gtsEmojis, err := p.db.EmojiStringsToEmojis(ctx, util.DeriveEmojisFromStatus(form.Status), accountID, status.ID) if err != nil { return fmt.Errorf("error generating emojis from status: %s", err) } @@ -233,7 +252,7 @@ func (p *processor) ProcessEmojis(form *apimodel.AdvancedStatusCreateForm, accou return nil } -func (p *processor) ProcessContent(form *apimodel.AdvancedStatusCreateForm, accountID string, status *gtsmodel.Status) error { +func (p *processor) ProcessContent(ctx context.Context, form *apimodel.AdvancedStatusCreateForm, accountID string, status *gtsmodel.Status) error { // if there's nothing in the status at all we can just return early if form.Status == "" { status.Content = "" @@ -252,9 +271,9 @@ func (p *processor) ProcessContent(form *apimodel.AdvancedStatusCreateForm, acco var formatted string switch form.Format { case apimodel.StatusFormatPlain: - formatted = p.formatter.FromPlain(content, status.Mentions, status.Tags) + formatted = p.formatter.FromPlain(ctx, content, status.Mentions, status.Tags) case apimodel.StatusFormatMarkdown: - formatted = p.formatter.FromMarkdown(content, status.Mentions, status.Tags) + formatted = p.formatter.FromMarkdown(ctx, content, status.Mentions, status.Tags) default: return fmt.Errorf("format %s not recognised as a valid status format", form.Format) } diff --git a/internal/processing/status/util_test.go b/internal/processing/status/util_test.go index 9c282eb52..1ec2076b1 100644 --- a/internal/processing/status/util_test.go +++ b/internal/processing/status/util_test.go @@ -1,6 +1,25 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + package status_test import ( + "context" "fmt" "testing" @@ -88,7 +107,7 @@ func (suite *UtilTestSuite) TestProcessMentions1() { ID: "01FCTDD78JJMX3K9KPXQ7ZQ8BJ", } - err := suite.status.ProcessMentions(form, creatingAccount.ID, status) + err := suite.status.ProcessMentions(context.Background(), form, creatingAccount.ID, status) assert.NoError(suite.T(), err) assert.Len(suite.T(), status.Mentions, 1) @@ -138,11 +157,11 @@ func (suite *UtilTestSuite) TestProcessContentFull1() { ID: "01FCTDD78JJMX3K9KPXQ7ZQ8BJ", } - err := suite.status.ProcessMentions(form, creatingAccount.ID, status) + err := suite.status.ProcessMentions(context.Background(), form, creatingAccount.ID, status) assert.NoError(suite.T(), err) assert.Empty(suite.T(), status.Content) // shouldn't be set yet - err = suite.status.ProcessTags(form, creatingAccount.ID, status) + err = suite.status.ProcessTags(context.Background(), form, creatingAccount.ID, status) assert.NoError(suite.T(), err) assert.Empty(suite.T(), status.Content) // shouldn't be set yet @@ -150,7 +169,7 @@ func (suite *UtilTestSuite) TestProcessContentFull1() { ACTUAL TEST */ - err = suite.status.ProcessContent(form, creatingAccount.ID, status) + err = suite.status.ProcessContent(context.Background(), form, creatingAccount.ID, status) assert.NoError(suite.T(), err) assert.Equal(suite.T(), statusText1ExpectedFull, status.Content) } @@ -187,7 +206,7 @@ func (suite *UtilTestSuite) TestProcessContentPartial1() { ID: "01FCTDD78JJMX3K9KPXQ7ZQ8BJ", } - err := suite.status.ProcessMentions(form, creatingAccount.ID, status) + err := suite.status.ProcessMentions(context.Background(), form, creatingAccount.ID, status) assert.NoError(suite.T(), err) assert.Empty(suite.T(), status.Content) // shouldn't be set yet @@ -195,7 +214,7 @@ func (suite *UtilTestSuite) TestProcessContentPartial1() { ACTUAL TEST */ - err = suite.status.ProcessContent(form, creatingAccount.ID, status) + err = suite.status.ProcessContent(context.Background(), form, creatingAccount.ID, status) assert.NoError(suite.T(), err) assert.Equal(suite.T(), statusText1ExpectedPartial, status.Content) } @@ -229,7 +248,7 @@ func (suite *UtilTestSuite) TestProcessMentions2() { ID: "01FCTDD78JJMX3K9KPXQ7ZQ8BJ", } - err := suite.status.ProcessMentions(form, creatingAccount.ID, status) + err := suite.status.ProcessMentions(context.Background(), form, creatingAccount.ID, status) assert.NoError(suite.T(), err) assert.Len(suite.T(), status.Mentions, 1) @@ -279,11 +298,11 @@ func (suite *UtilTestSuite) TestProcessContentFull2() { ID: "01FCTDD78JJMX3K9KPXQ7ZQ8BJ", } - err := suite.status.ProcessMentions(form, creatingAccount.ID, status) + err := suite.status.ProcessMentions(context.Background(), form, creatingAccount.ID, status) assert.NoError(suite.T(), err) assert.Empty(suite.T(), status.Content) // shouldn't be set yet - err = suite.status.ProcessTags(form, creatingAccount.ID, status) + err = suite.status.ProcessTags(context.Background(), form, creatingAccount.ID, status) assert.NoError(suite.T(), err) assert.Empty(suite.T(), status.Content) // shouldn't be set yet @@ -291,7 +310,7 @@ func (suite *UtilTestSuite) TestProcessContentFull2() { ACTUAL TEST */ - err = suite.status.ProcessContent(form, creatingAccount.ID, status) + err = suite.status.ProcessContent(context.Background(), form, creatingAccount.ID, status) assert.NoError(suite.T(), err) assert.Equal(suite.T(), status2TextExpectedFull, status.Content) @@ -329,7 +348,7 @@ func (suite *UtilTestSuite) TestProcessContentPartial2() { ID: "01FCTDD78JJMX3K9KPXQ7ZQ8BJ", } - err := suite.status.ProcessMentions(form, creatingAccount.ID, status) + err := suite.status.ProcessMentions(context.Background(), form, creatingAccount.ID, status) assert.NoError(suite.T(), err) assert.Empty(suite.T(), status.Content) // shouldn't be set yet @@ -337,7 +356,7 @@ func (suite *UtilTestSuite) TestProcessContentPartial2() { ACTUAL TEST */ - err = suite.status.ProcessContent(form, creatingAccount.ID, status) + err = suite.status.ProcessContent(context.Background(), form, creatingAccount.ID, status) assert.NoError(suite.T(), err) fmt.Println(status.Content) diff --git a/internal/processing/streaming.go b/internal/processing/streaming.go index 457db0576..e1c134d00 100644 --- a/internal/processing/streaming.go +++ b/internal/processing/streaming.go @@ -19,14 +19,16 @@ package processing import ( + "context" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) -func (p *processor) AuthorizeStreamingRequest(accessToken string) (*gtsmodel.Account, error) { - return p.streamingProcessor.AuthorizeStreamingRequest(accessToken) +func (p *processor) AuthorizeStreamingRequest(ctx context.Context, accessToken string) (*gtsmodel.Account, error) { + return p.streamingProcessor.AuthorizeStreamingRequest(ctx, accessToken) } -func (p *processor) OpenStreamForAccount(account *gtsmodel.Account, streamType string) (*gtsmodel.Stream, gtserror.WithCode) { - return p.streamingProcessor.OpenStreamForAccount(account, streamType) +func (p *processor) OpenStreamForAccount(ctx context.Context, account *gtsmodel.Account, streamType string) (*gtsmodel.Stream, gtserror.WithCode) { + return p.streamingProcessor.OpenStreamForAccount(ctx, account, streamType) } diff --git a/internal/processing/streaming/authorize.go b/internal/processing/streaming/authorize.go index 8bbf1856d..f938a0c0c 100644 --- a/internal/processing/streaming/authorize.go +++ b/internal/processing/streaming/authorize.go @@ -7,7 +7,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) -func (p *processor) AuthorizeStreamingRequest(accessToken string) (*gtsmodel.Account, error) { +func (p *processor) AuthorizeStreamingRequest(ctx context.Context, accessToken string) (*gtsmodel.Account, error) { ti, err := p.oauthServer.LoadAccessToken(context.Background(), accessToken) if err != nil { return nil, fmt.Errorf("AuthorizeStreamingRequest: error loading access token: %s", err) @@ -20,12 +20,12 @@ func (p *processor) AuthorizeStreamingRequest(accessToken string) (*gtsmodel.Acc // fetch user's and account for this user id user := >smodel.User{} - if err := p.db.GetByID(uid, user); err != nil || user == nil { + if err := p.db.GetByID(ctx, uid, user); err != nil || user == nil { return nil, fmt.Errorf("AuthorizeStreamingRequest: no user found for validated uid %s", uid) } - acct := >smodel.Account{} - if err := p.db.GetByID(user.AccountID, acct); err != nil || acct == nil { + acct, err := p.db.GetAccountByID(ctx, user.AccountID) + if err != nil || acct == nil { return nil, fmt.Errorf("AuthorizeStreamingRequest: no account retrieved for user with id %s", uid) } diff --git a/internal/processing/streaming/openstream.go b/internal/processing/streaming/openstream.go index 68446bac6..dfad5398e 100644 --- a/internal/processing/streaming/openstream.go +++ b/internal/processing/streaming/openstream.go @@ -1,6 +1,7 @@ package streaming import ( + "context" "errors" "fmt" @@ -10,7 +11,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/id" ) -func (p *processor) OpenStreamForAccount(account *gtsmodel.Account, streamType string) (*gtsmodel.Stream, gtserror.WithCode) { +func (p *processor) OpenStreamForAccount(ctx context.Context, account *gtsmodel.Account, streamType string) (*gtsmodel.Stream, gtserror.WithCode) { l := p.log.WithFields(logrus.Fields{ "func": "OpenStreamForAccount", "account": account.ID, diff --git a/internal/processing/streaming/streaming.go b/internal/processing/streaming/streaming.go index de75b8f27..f349a655a 100644 --- a/internal/processing/streaming/streaming.go +++ b/internal/processing/streaming/streaming.go @@ -1,6 +1,7 @@ package streaming import ( + "context" "sync" "github.com/sirupsen/logrus" @@ -17,9 +18,9 @@ import ( // Processor wraps a bunch of functions for processing streaming. type Processor interface { // AuthorizeStreamingRequest returns an oauth2 token info in response to an access token query from the streaming API - AuthorizeStreamingRequest(accessToken string) (*gtsmodel.Account, error) + AuthorizeStreamingRequest(ctx context.Context, accessToken string) (*gtsmodel.Account, error) // OpenStreamForAccount returns a new Stream for the given account, which will contain a channel for passing messages back to the caller. - OpenStreamForAccount(account *gtsmodel.Account, streamType string) (*gtsmodel.Stream, gtserror.WithCode) + OpenStreamForAccount(ctx context.Context, account *gtsmodel.Account, streamType string) (*gtsmodel.Stream, gtserror.WithCode) // StreamStatusToAccount streams the given status to any open, appropriate streams belonging to the given account. StreamStatusToAccount(s *apimodel.Status, account *gtsmodel.Account) error // StreamNotificationToAccount streams the given notification to any open, appropriate streams belonging to the given account. diff --git a/internal/processing/timeline.go b/internal/processing/timeline.go index afddd3e6c..6a409a6cc 100644 --- a/internal/processing/timeline.go +++ b/internal/processing/timeline.go @@ -19,6 +19,7 @@ package processing import ( + "context" "fmt" "net/url" @@ -58,8 +59,8 @@ func (p *processor) packageStatusResponse(statuses []*apimodel.Status, path stri return resp, nil } -func (p *processor) HomeTimelineGet(authed *oauth.Auth, maxID string, sinceID string, minID string, limit int, local bool) (*apimodel.StatusTimelineResponse, gtserror.WithCode) { - statuses, err := p.timelineManager.HomeTimeline(authed.Account.ID, maxID, sinceID, minID, limit, local) +func (p *processor) HomeTimelineGet(ctx context.Context, authed *oauth.Auth, maxID string, sinceID string, minID string, limit int, local bool) (*apimodel.StatusTimelineResponse, gtserror.WithCode) { + statuses, err := p.timelineManager.HomeTimeline(ctx, authed.Account.ID, maxID, sinceID, minID, limit, local) if err != nil { return nil, gtserror.NewErrorInternalError(err) } @@ -73,8 +74,8 @@ func (p *processor) HomeTimelineGet(authed *oauth.Auth, maxID string, sinceID st return p.packageStatusResponse(statuses, "api/v1/timelines/home", statuses[len(statuses)-1].ID, statuses[0].ID, limit) } -func (p *processor) PublicTimelineGet(authed *oauth.Auth, maxID string, sinceID string, minID string, limit int, local bool) (*apimodel.StatusTimelineResponse, gtserror.WithCode) { - statuses, err := p.db.GetPublicTimeline(authed.Account.ID, maxID, sinceID, minID, limit, local) +func (p *processor) PublicTimelineGet(ctx context.Context, authed *oauth.Auth, maxID string, sinceID string, minID string, limit int, local bool) (*apimodel.StatusTimelineResponse, gtserror.WithCode) { + statuses, err := p.db.GetPublicTimeline(ctx, authed.Account.ID, maxID, sinceID, minID, limit, local) if err != nil { if err == db.ErrNoEntries { // there are just no entries left @@ -86,16 +87,22 @@ func (p *processor) PublicTimelineGet(authed *oauth.Auth, maxID string, sinceID return nil, gtserror.NewErrorInternalError(err) } - s, err := p.filterPublicStatuses(authed, statuses) + s, err := p.filterPublicStatuses(ctx, authed, statuses) if err != nil { return nil, gtserror.NewErrorInternalError(err) } + if len(s) == 0 { + return &apimodel.StatusTimelineResponse{ + Statuses: []*apimodel.Status{}, + }, nil + } + return p.packageStatusResponse(s, "api/v1/timelines/public", s[len(s)-1].ID, s[0].ID, limit) } -func (p *processor) FavedTimelineGet(authed *oauth.Auth, maxID string, minID string, limit int) (*apimodel.StatusTimelineResponse, gtserror.WithCode) { - statuses, nextMaxID, prevMinID, err := p.db.GetFavedTimeline(authed.Account.ID, maxID, minID, limit) +func (p *processor) FavedTimelineGet(ctx context.Context, authed *oauth.Auth, maxID string, minID string, limit int) (*apimodel.StatusTimelineResponse, gtserror.WithCode) { + statuses, nextMaxID, prevMinID, err := p.db.GetFavedTimeline(ctx, authed.Account.ID, maxID, minID, limit) if err != nil { if err == db.ErrNoEntries { // there are just no entries left @@ -107,21 +114,27 @@ func (p *processor) FavedTimelineGet(authed *oauth.Auth, maxID string, minID str return nil, gtserror.NewErrorInternalError(err) } - s, err := p.filterFavedStatuses(authed, statuses) + s, err := p.filterFavedStatuses(ctx, authed, statuses) if err != nil { return nil, gtserror.NewErrorInternalError(err) } + if len(s) == 0 { + return &apimodel.StatusTimelineResponse{ + Statuses: []*apimodel.Status{}, + }, nil + } + return p.packageStatusResponse(s, "api/v1/favourites", nextMaxID, prevMinID, limit) } -func (p *processor) filterPublicStatuses(authed *oauth.Auth, statuses []*gtsmodel.Status) ([]*apimodel.Status, error) { +func (p *processor) filterPublicStatuses(ctx context.Context, authed *oauth.Auth, statuses []*gtsmodel.Status) ([]*apimodel.Status, error) { l := p.log.WithField("func", "filterPublicStatuses") apiStatuses := []*apimodel.Status{} for _, s := range statuses { targetAccount := >smodel.Account{} - if err := p.db.GetByID(s.AccountID, targetAccount); err != nil { + if err := p.db.GetByID(ctx, s.AccountID, targetAccount); err != nil { if err == db.ErrNoEntries { l.Debugf("filterPublicStatuses: skipping status %s because account %s can't be found in the db", s.ID, s.AccountID) continue @@ -129,7 +142,7 @@ func (p *processor) filterPublicStatuses(authed *oauth.Auth, statuses []*gtsmode return nil, gtserror.NewErrorInternalError(fmt.Errorf("filterPublicStatuses: error getting status author: %s", err)) } - timelineable, err := p.filter.StatusPublictimelineable(s, authed.Account) + timelineable, err := p.filter.StatusPublictimelineable(ctx, s, authed.Account) if err != nil { l.Debugf("filterPublicStatuses: skipping status %s because of an error checking status visibility: %s", s.ID, err) continue @@ -138,7 +151,7 @@ func (p *processor) filterPublicStatuses(authed *oauth.Auth, statuses []*gtsmode continue } - apiStatus, err := p.tc.StatusToMasto(s, authed.Account) + apiStatus, err := p.tc.StatusToMasto(ctx, s, authed.Account) if err != nil { l.Debugf("filterPublicStatuses: skipping status %s because it couldn't be converted to its mastodon representation: %s", s.ID, err) continue @@ -150,13 +163,13 @@ func (p *processor) filterPublicStatuses(authed *oauth.Auth, statuses []*gtsmode return apiStatuses, nil } -func (p *processor) filterFavedStatuses(authed *oauth.Auth, statuses []*gtsmodel.Status) ([]*apimodel.Status, error) { +func (p *processor) filterFavedStatuses(ctx context.Context, authed *oauth.Auth, statuses []*gtsmodel.Status) ([]*apimodel.Status, error) { l := p.log.WithField("func", "filterFavedStatuses") apiStatuses := []*apimodel.Status{} for _, s := range statuses { targetAccount := >smodel.Account{} - if err := p.db.GetByID(s.AccountID, targetAccount); err != nil { + if err := p.db.GetByID(ctx, s.AccountID, targetAccount); err != nil { if err == db.ErrNoEntries { l.Debugf("filterFavedStatuses: skipping status %s because account %s can't be found in the db", s.ID, s.AccountID) continue @@ -164,7 +177,7 @@ func (p *processor) filterFavedStatuses(authed *oauth.Auth, statuses []*gtsmodel return nil, gtserror.NewErrorInternalError(fmt.Errorf("filterPublicStatuses: error getting status author: %s", err)) } - timelineable, err := p.filter.StatusVisible(s, authed.Account) + timelineable, err := p.filter.StatusVisible(ctx, s, authed.Account) if err != nil { l.Debugf("filterFavedStatuses: skipping status %s because of an error checking status visibility: %s", s.ID, err) continue @@ -173,7 +186,7 @@ func (p *processor) filterFavedStatuses(authed *oauth.Auth, statuses []*gtsmodel continue } - apiStatus, err := p.tc.StatusToMasto(s, authed.Account) + apiStatus, err := p.tc.StatusToMasto(ctx, s, authed.Account) if err != nil { l.Debugf("filterFavedStatuses: skipping status %s because it couldn't be converted to its mastodon representation: %s", s.ID, err) continue diff --git a/internal/router/router.go b/internal/router/router.go index c5f105448..621d93ff5 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -103,7 +103,7 @@ func (r *router) Stop(ctx context.Context) error { // // The given DB is only used in the New function for parsing config values, and is not otherwise // pinned to the router. -func New(cfg *config.Config, db db.DB, logger *logrus.Logger) (Router, error) { +func New(ctx context.Context, cfg *config.Config, db db.DB, logger *logrus.Logger) (Router, error) { // gin has different log modes; for convenience, we match the gin log mode to // whatever log mode has been set for logrus @@ -141,7 +141,7 @@ func New(cfg *config.Config, db db.DB, logger *logrus.Logger) (Router, error) { } // enable session store middleware on the engine - if err := useSession(cfg, db, engine); err != nil { + if err := useSession(ctx, cfg, db, engine); err != nil { return nil, err } diff --git a/internal/router/session.go b/internal/router/session.go index 38810572f..4359a8a60 100644 --- a/internal/router/session.go +++ b/internal/router/session.go @@ -19,7 +19,7 @@ package router import ( - "crypto/rand" + "context" "errors" "fmt" "net/http" @@ -29,8 +29,6 @@ import ( "github.com/gin-gonic/gin" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" - "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" - "github.com/superseriousbusiness/gotosocial/internal/id" ) // SessionOptions returns the standard set of options to use for each session. @@ -45,34 +43,23 @@ func SessionOptions(cfg *config.Config) sessions.Options { } } -func useSession(cfg *config.Config, dbService db.DB, engine *gin.Engine) error { +func useSession(ctx context.Context, cfg *config.Config, sessionDB db.Session, engine *gin.Engine) error { // check if we have a saved router session already - routerSessions := []*gtsmodel.RouterSession{} - if err := dbService.GetAll(&routerSessions); err != nil { + rs, err := sessionDB.GetSession(ctx) + if err != nil { if err != db.ErrNoEntries { // proper error occurred return err } - } - - var rs *gtsmodel.RouterSession - if len(routerSessions) == 1 { - // we have a router session stored - rs = routerSessions[0] - } else if len(routerSessions) == 0 { - // we have no router sessions so we need to create a new one - var err error - rs, err = routerSession(dbService) + // no session saved so create a new one + rs, err = sessionDB.CreateSession(ctx) if err != nil { - return fmt.Errorf("error creating new router session: %s", err) + return err } - } else { - // we should only have one router session stored ever - return errors.New("we had more than one router session in the db") } if rs == nil { - return errors.New("error getting or creating router session: router session was nil") + return errors.New("router session was nil") } store := memstore.NewStore(rs.Auth, rs.Crypt) @@ -81,34 +68,3 @@ func useSession(cfg *config.Config, dbService db.DB, engine *gin.Engine) error { engine.Use(sessions.Sessions(sessionName, store)) return nil } - -// routerSession generates a new router session with random auth and crypt bytes, -// puts it in the database for persistence, and returns it for use. -func routerSession(dbService db.DB) (*gtsmodel.RouterSession, error) { - auth := make([]byte, 32) - crypt := make([]byte, 32) - - if _, err := rand.Read(auth); err != nil { - return nil, err - } - if _, err := rand.Read(crypt); err != nil { - return nil, err - } - - rid, err := id.NewULID() - if err != nil { - return nil, err - } - - rs := >smodel.RouterSession{ - ID: rid, - Auth: auth, - Crypt: crypt, - } - - if err := dbService.Put(rs); err != nil { - return nil, err - } - - return rs, nil -} diff --git a/internal/text/common.go b/internal/text/common.go index af77521dd..a8d585a09 100644 --- a/internal/text/common.go +++ b/internal/text/common.go @@ -19,6 +19,7 @@ package text import ( + "context" "fmt" "html" "strings" @@ -59,7 +60,7 @@ func postformat(in string) string { return mini } -func (f *formatter) ReplaceTags(in string, tags []*gtsmodel.Tag) string { +func (f *formatter) ReplaceTags(ctx context.Context, in string, tags []*gtsmodel.Tag) string { return util.HashtagFinderRegex.ReplaceAllStringFunc(in, func(match string) string { // we have a match matchTrimmed := strings.TrimSpace(match) @@ -88,7 +89,7 @@ func (f *formatter) ReplaceTags(in string, tags []*gtsmodel.Tag) string { }) } -func (f *formatter) ReplaceMentions(in string, mentions []*gtsmodel.Mention) string { +func (f *formatter) ReplaceMentions(ctx context.Context, in string, mentions []*gtsmodel.Mention) string { for _, menchie := range mentions { // make sure we have a target account, either by getting one pinned on the mention, // or by pulling it from the database @@ -97,8 +98,8 @@ func (f *formatter) ReplaceMentions(in string, mentions []*gtsmodel.Mention) str // got it from the mention targetAccount = menchie.OriginAccount } else { - a := >smodel.Account{} - if err := f.db.GetByID(menchie.TargetAccountID, a); err == nil { + a, err := f.db.GetAccountByID(ctx, menchie.TargetAccountID) + if err == nil { // got it from the db targetAccount = a } else { diff --git a/internal/text/common_test.go b/internal/text/common_test.go index 69fe7d446..174b79177 100644 --- a/internal/text/common_test.go +++ b/internal/text/common_test.go @@ -19,6 +19,7 @@ package text_test import ( + "context" "testing" "github.com/stretchr/testify/assert" @@ -87,7 +88,7 @@ func (suite *CommonTestSuite) TestReplaceMentions() { suite.testMentions["zork_mention_foss_satan"], } - f := suite.formatter.ReplaceMentions(replaceMentionsString, foundMentions) + f := suite.formatter.ReplaceMentions(context.Background(), replaceMentionsString, foundMentions) assert.Equal(suite.T(), replaceMentionsExpected, f) } @@ -96,7 +97,7 @@ func (suite *CommonTestSuite) TestReplaceHashtags() { suite.testTags["Hashtag"], } - f := suite.formatter.ReplaceTags(replaceMentionsString, foundTags) + f := suite.formatter.ReplaceTags(context.Background(), replaceMentionsString, foundTags) assert.Equal(suite.T(), replaceHashtagsExpected, f) } @@ -106,7 +107,7 @@ func (suite *CommonTestSuite) TestReplaceHashtagsAfterReplaceMentions() { suite.testTags["Hashtag"], } - f := suite.formatter.ReplaceTags(replaceMentionsExpected, foundTags) + f := suite.formatter.ReplaceTags(context.Background(), replaceMentionsExpected, foundTags) assert.Equal(suite.T(), replaceHashtagsAfterMentionsExpected, f) } diff --git a/internal/text/formatter.go b/internal/text/formatter.go index 39aaae559..769ecafbb 100644 --- a/internal/text/formatter.go +++ b/internal/text/formatter.go @@ -19,6 +19,8 @@ package text import ( + "context" + "github.com/sirupsen/logrus" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" @@ -28,16 +30,16 @@ import ( // Formatter wraps some logic and functions for parsing statuses and other text input into nice html. type Formatter interface { // FromMarkdown parses an HTML text from a markdown-formatted text. - FromMarkdown(md string, mentions []*gtsmodel.Mention, tags []*gtsmodel.Tag) string + FromMarkdown(ctx context.Context, md string, mentions []*gtsmodel.Mention, tags []*gtsmodel.Tag) string // FromPlain parses an HTML text from a plaintext. - FromPlain(plain string, mentions []*gtsmodel.Mention, tags []*gtsmodel.Tag) string + FromPlain(ctx context.Context, plain string, mentions []*gtsmodel.Mention, tags []*gtsmodel.Tag) string // ReplaceTags takes a piece of text and a slice of tags, and returns the same text with the tags nicely formatted as hrefs. - ReplaceTags(in string, tags []*gtsmodel.Tag) string + ReplaceTags(ctx context.Context, in string, tags []*gtsmodel.Tag) string // ReplaceMentions takes a piece of text and a slice of mentions, and returns the same text with the mentions nicely formatted as hrefs. - ReplaceMentions(in string, mentions []*gtsmodel.Mention) string + ReplaceMentions(ctx context.Context, in string, mentions []*gtsmodel.Mention) string // ReplaceLinks takes a piece of text, finds all recognizable links in that text, and replaces them with hrefs. - ReplaceLinks(in string) string + ReplaceLinks(ctx context.Context, in string) string } type formatter struct { diff --git a/internal/text/link.go b/internal/text/link.go index d42cc3b68..0a0f0c60d 100644 --- a/internal/text/link.go +++ b/internal/text/link.go @@ -19,6 +19,7 @@ package text import ( + "context" "fmt" "net/url" @@ -82,7 +83,7 @@ func contains(urls []*url.URL, url *url.URL) bool { // Note: because Go doesn't allow negative lookbehinds in regex, it's possible that an already-formatted // href will end up double-formatted, if the text you pass here contains one or more hrefs already. // To avoid this, you should sanitize any HTML out of text before you pass it into this function. -func (f *formatter) ReplaceLinks(in string) string { +func (f *formatter) ReplaceLinks(ctx context.Context, in string) string { rxStrict, err := xurls.StrictMatchingScheme(schemes) if err != nil { panic(err) diff --git a/internal/text/link_test.go b/internal/text/link_test.go index 83c42f045..f8d6a1adc 100644 --- a/internal/text/link_test.go +++ b/internal/text/link_test.go @@ -19,6 +19,7 @@ package text_test import ( + "context" "testing" "github.com/stretchr/testify/assert" @@ -94,7 +95,7 @@ func (suite *LinkTestSuite) TearDownTest() { } func (suite *LinkTestSuite) TestParseSimple() { - f := suite.formatter.FromPlain(simple, nil, nil) + f := suite.formatter.FromPlain(context.Background(), simple, nil, nil) assert.Equal(suite.T(), simpleExpected, f) } @@ -126,7 +127,7 @@ func (suite *LinkTestSuite) TestParseURLsFromText3() { } func (suite *LinkTestSuite) TestReplaceLinksFromText1() { - replaced := suite.formatter.ReplaceLinks(text1) + replaced := suite.formatter.ReplaceLinks(context.Background(), text1) assert.Equal(suite.T(), ` This is a text with some links in it. Here's link number one: <a href="https://example.org/link/to/something#fragment" rel="noopener">example.org/link/to/something#fragment</a> @@ -141,7 +142,7 @@ really.cool.website <-- this one shouldn't be parsed as a link because it doesn' } func (suite *LinkTestSuite) TestReplaceLinksFromText2() { - replaced := suite.formatter.ReplaceLinks(text2) + replaced := suite.formatter.ReplaceLinks(context.Background(), text2) assert.Equal(suite.T(), ` this is one link: <a href="https://example.org" rel="noopener">example.org</a> @@ -153,14 +154,14 @@ these should be deduplicated func (suite *LinkTestSuite) TestReplaceLinksFromText3() { // we know mailto links won't be replaced with hrefs -- we only accept https and http - replaced := suite.formatter.ReplaceLinks(text3) + replaced := suite.formatter.ReplaceLinks(context.Background(), text3) assert.Equal(suite.T(), ` here's a mailto link: mailto:whatever@test.org `, replaced) } func (suite *LinkTestSuite) TestReplaceLinksFromText4() { - replaced := suite.formatter.ReplaceLinks(text4) + replaced := suite.formatter.ReplaceLinks(context.Background(), text4) assert.Equal(suite.T(), ` two similar links: @@ -172,7 +173,7 @@ two similar links: func (suite *LinkTestSuite) TestReplaceLinksFromText5() { // we know this one doesn't work properly, which is why html should always be sanitized before being passed into the ReplaceLinks function - replaced := suite.formatter.ReplaceLinks(text5) + replaced := suite.formatter.ReplaceLinks(context.Background(), text5) assert.Equal(suite.T(), ` what happens when we already have a link within an href? diff --git a/internal/text/markdown.go b/internal/text/markdown.go index 5a7603615..eeeae0edf 100644 --- a/internal/text/markdown.go +++ b/internal/text/markdown.go @@ -19,21 +19,23 @@ package text import ( + "context" + "github.com/russross/blackfriday/v2" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) -func (f *formatter) FromMarkdown(md string, mentions []*gtsmodel.Mention, tags []*gtsmodel.Tag) string { +func (f *formatter) FromMarkdown(ctx context.Context, md string, mentions []*gtsmodel.Mention, tags []*gtsmodel.Tag) string { content := preformat(md) // do the markdown parsing *first* contentBytes := blackfriday.Run([]byte(content)) // format tags nicely - content = f.ReplaceTags(string(contentBytes), tags) + content = f.ReplaceTags(ctx, string(contentBytes), tags) // format mentions nicely - content = f.ReplaceMentions(content, mentions) + content = f.ReplaceMentions(ctx, content, mentions) return postformat(content) } diff --git a/internal/text/markdown_test.go b/internal/text/markdown_test.go index 432e9a4ec..085f211d2 100644 --- a/internal/text/markdown_test.go +++ b/internal/text/markdown_test.go @@ -19,6 +19,7 @@ package text_test import ( + "context" "fmt" "testing" @@ -92,13 +93,13 @@ func (suite *MarkdownTestSuite) TearDownTest() { } func (suite *MarkdownTestSuite) TestParseSimple() { - s := suite.formatter.FromMarkdown(simpleMarkdown, nil, nil) + s := suite.formatter.FromMarkdown(context.Background(), simpleMarkdown, nil, nil) suite.Equal(simpleMarkdownExpected, s) } func (suite *MarkdownTestSuite) TestParseWithCodeBlock() { fmt.Println(withCodeBlock) - s := suite.formatter.FromMarkdown(withCodeBlock, nil, nil) + s := suite.formatter.FromMarkdown(context.Background(), withCodeBlock, nil, nil) suite.Equal(withCodeBlockExpected, s) } @@ -107,7 +108,7 @@ func (suite *MarkdownTestSuite) TestParseWithHashtag() { suite.testTags["Hashtag"], } - s := suite.formatter.FromMarkdown(withHashtag, nil, foundTags) + s := suite.formatter.FromMarkdown(context.Background(), withHashtag, nil, foundTags) suite.Equal(withHashtagExpected, s) } diff --git a/internal/text/plain.go b/internal/text/plain.go index a44e02c80..34cc3fa06 100644 --- a/internal/text/plain.go +++ b/internal/text/plain.go @@ -19,26 +19,27 @@ package text import ( + "context" "fmt" "strings" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) -func (f *formatter) FromPlain(plain string, mentions []*gtsmodel.Mention, tags []*gtsmodel.Tag) string { +func (f *formatter) FromPlain(ctx context.Context, plain string, mentions []*gtsmodel.Mention, tags []*gtsmodel.Tag) string { content := preformat(plain) // sanitize any html elements content = RemoveHTML(content) // format links nicely - content = f.ReplaceLinks(content) + content = f.ReplaceLinks(ctx, content) // format tags nicely - content = f.ReplaceTags(content, tags) + content = f.ReplaceTags(ctx, content, tags) // format mentions nicely - content = f.ReplaceMentions(content, mentions) + content = f.ReplaceMentions(ctx, content, mentions) // replace newlines with breaks content = strings.ReplaceAll(content, "\n", "<br />") diff --git a/internal/text/plain_test.go b/internal/text/plain_test.go index 33c95234c..62c43406d 100644 --- a/internal/text/plain_test.go +++ b/internal/text/plain_test.go @@ -19,6 +19,7 @@ package text_test import ( + "context" "fmt" "testing" @@ -74,7 +75,7 @@ func (suite *PlainTestSuite) TearDownTest() { } func (suite *PlainTestSuite) TestParseSimple() { - f := suite.formatter.FromPlain(simple, nil, nil) + f := suite.formatter.FromPlain(context.Background(), simple, nil, nil) assert.Equal(suite.T(), simpleExpected, f) } @@ -84,7 +85,7 @@ func (suite *PlainTestSuite) TestParseWithTag() { suite.testTags["welcome"], } - f := suite.formatter.FromPlain(withTag, nil, foundTags) + f := suite.formatter.FromPlain(context.Background(), withTag, nil, foundTags) assert.Equal(suite.T(), withTagExpected, f) } @@ -98,7 +99,7 @@ func (suite *PlainTestSuite) TestParseMoreComplex() { suite.testMentions["zork_mention_foss_satan"], } - f := suite.formatter.FromPlain(moreComplex, foundMentions, foundTags) + f := suite.formatter.FromPlain(context.Background(), moreComplex, foundMentions, foundTags) fmt.Println(f) diff --git a/internal/timeline/get.go b/internal/timeline/get.go index d800da4e3..a00613dc0 100644 --- a/internal/timeline/get.go +++ b/internal/timeline/get.go @@ -20,6 +20,7 @@ package timeline import ( "container/list" + "context" "errors" "fmt" @@ -29,7 +30,7 @@ import ( const retries = 5 -func (t *timeline) Get(amount int, maxID string, sinceID string, minID string, prepareNext bool) ([]*apimodel.Status, error) { +func (t *timeline) Get(ctx context.Context, amount int, maxID string, sinceID string, minID string, prepareNext bool) ([]*apimodel.Status, error) { l := t.log.WithFields(logrus.Fields{ "func": "Get", "accountID": t.accountID, @@ -46,14 +47,15 @@ func (t *timeline) Get(amount int, maxID string, sinceID string, minID string, p // no params are defined to just fetch from the top // this is equivalent to a user asking for the top x posts from their timeline if maxID == "" && sinceID == "" && minID == "" { - statuses, err = t.GetXFromTop(amount) + statuses, err = t.GetXFromTop(ctx, amount) // aysnchronously prepare the next predicted query so it's ready when the user asks for it if len(statuses) != 0 { nextMaxID := statuses[len(statuses)-1].ID if prepareNext { // already cache the next query to speed up scrolling go func() { - if err := t.prepareNextQuery(amount, nextMaxID, "", ""); err != nil { + // use context.Background() because we don't want the query to abort when the request finishes + if err := t.prepareNextQuery(context.Background(), amount, nextMaxID, "", ""); err != nil { l.Errorf("error preparing next query: %s", err) } }() @@ -65,14 +67,15 @@ func (t *timeline) Get(amount int, maxID string, sinceID string, minID string, p // this is equivalent to a user asking for the next x posts from their timeline, starting from maxID if maxID != "" && sinceID == "" { attempts := 0 - statuses, err = t.GetXBehindID(amount, maxID, &attempts) + statuses, err = t.GetXBehindID(ctx, amount, maxID, &attempts) // aysnchronously prepare the next predicted query so it's ready when the user asks for it if len(statuses) != 0 { nextMaxID := statuses[len(statuses)-1].ID if prepareNext { // already cache the next query to speed up scrolling go func() { - if err := t.prepareNextQuery(amount, nextMaxID, "", ""); err != nil { + // use context.Background() because we don't want the query to abort when the request finishes + if err := t.prepareNextQuery(context.Background(), amount, nextMaxID, "", ""); err != nil { l.Errorf("error preparing next query: %s", err) } }() @@ -83,25 +86,25 @@ func (t *timeline) Get(amount int, maxID string, sinceID string, minID string, p // maxID is defined and sinceID || minID are as well, so take a slice between them // this is equivalent to a user asking for posts older than x but newer than y if maxID != "" && sinceID != "" { - statuses, err = t.GetXBetweenID(amount, maxID, minID) + statuses, err = t.GetXBetweenID(ctx, amount, maxID, minID) } if maxID != "" && minID != "" { - statuses, err = t.GetXBetweenID(amount, maxID, minID) + statuses, err = t.GetXBetweenID(ctx, amount, maxID, minID) } // maxID isn't defined, but sinceID || minID are, so take x before // this is equivalent to a user asking for posts newer than x (eg., refreshing the top of their timeline) if maxID == "" && sinceID != "" { - statuses, err = t.GetXBeforeID(amount, sinceID, true) + statuses, err = t.GetXBeforeID(ctx, amount, sinceID, true) } if maxID == "" && minID != "" { - statuses, err = t.GetXBeforeID(amount, minID, true) + statuses, err = t.GetXBeforeID(ctx, amount, minID, true) } return statuses, err } -func (t *timeline) GetXFromTop(amount int) ([]*apimodel.Status, error) { +func (t *timeline) GetXFromTop(ctx context.Context, amount int) ([]*apimodel.Status, error) { // make a slice of statuses with the length we need to return statuses := make([]*apimodel.Status, 0, amount) @@ -111,7 +114,7 @@ func (t *timeline) GetXFromTop(amount int) ([]*apimodel.Status, error) { // make sure we have enough posts prepared to return if t.preparedPosts.data.Len() < amount { - if err := t.PrepareFromTop(amount); err != nil { + if err := t.PrepareFromTop(ctx, amount); err != nil { return nil, err } } @@ -133,7 +136,7 @@ func (t *timeline) GetXFromTop(amount int) ([]*apimodel.Status, error) { return statuses, nil } -func (t *timeline) GetXBehindID(amount int, behindID string, attempts *int) ([]*apimodel.Status, error) { +func (t *timeline) GetXBehindID(ctx context.Context, amount int, behindID string, attempts *int) ([]*apimodel.Status, error) { l := t.log.WithFields(logrus.Fields{ "func": "GetXBehindID", "amount": amount, @@ -174,10 +177,10 @@ findMarkLoop: // we didn't find it, so we need to make sure it's indexed and prepared and then try again // this can happen when a user asks for really old posts if behindIDMark == nil { - if err := t.PrepareBehind(behindID, amount); err != nil { + if err := t.PrepareBehind(ctx, behindID, amount); err != nil { return nil, fmt.Errorf("GetXBehindID: error preparing behind and including ID %s", behindID) } - oldestID, err := t.OldestPreparedPostID() + oldestID, err := t.OldestPreparedPostID(ctx) if err != nil { return nil, err } @@ -194,12 +197,12 @@ findMarkLoop: return statuses, nil } l.Trace("trying GetXBehindID again") - return t.GetXBehindID(amount, behindID, attempts) + return t.GetXBehindID(ctx, amount, behindID, attempts) } // make sure we have enough posts prepared behind it to return what we're being asked for if t.preparedPosts.data.Len() < amount+position { - if err := t.PrepareBehind(behindID, amount); err != nil { + if err := t.PrepareBehind(ctx, behindID, amount); err != nil { return nil, err } } @@ -224,7 +227,7 @@ serveloop: return statuses, nil } -func (t *timeline) GetXBeforeID(amount int, beforeID string, startFromTop bool) ([]*apimodel.Status, error) { +func (t *timeline) GetXBeforeID(ctx context.Context, amount int, beforeID string, startFromTop bool) ([]*apimodel.Status, error) { // make a slice of statuses with the length we need to return statuses := make([]*apimodel.Status, 0, amount) @@ -295,7 +298,7 @@ findMarkLoop: return statuses, nil } -func (t *timeline) GetXBetweenID(amount int, behindID string, beforeID string) ([]*apimodel.Status, error) { +func (t *timeline) GetXBetweenID(ctx context.Context, amount int, behindID string, beforeID string) ([]*apimodel.Status, error) { // make a slice of statuses with the length we need to return statuses := make([]*apimodel.Status, 0, amount) @@ -327,7 +330,7 @@ findMarkLoop: // make sure we have enough posts prepared behind it to return what we're being asked for if t.preparedPosts.data.Len() < amount+position { - if err := t.PrepareBehind(behindID, amount); err != nil { + if err := t.PrepareBehind(ctx, behindID, amount); err != nil { return nil, err } } diff --git a/internal/timeline/get_test.go b/internal/timeline/get_test.go index 0866f3bdd..96c333c5f 100644 --- a/internal/timeline/get_test.go +++ b/internal/timeline/get_test.go @@ -19,6 +19,7 @@ package timeline_test import ( + "context" "testing" "time" @@ -45,14 +46,14 @@ func (suite *GetTestSuite) SetupTest() { testrig.StandardDBSetup(suite.db, nil) // let's take local_account_1 as the timeline owner - tl, err := timeline.NewTimeline(suite.testAccounts["local_account_1"].ID, suite.db, suite.tc, suite.log) + tl, err := timeline.NewTimeline(context.Background(), suite.testAccounts["local_account_1"].ID, suite.db, suite.tc, suite.log) if err != nil { suite.FailNow(err.Error()) } // prepare the timeline by just shoving all test statuses in it -- let's not be fussy about who sees what for _, s := range suite.testStatuses { - _, err := tl.IndexAndPrepareOne(s.CreatedAt, s.ID, s.BoostOfID, s.AccountID, s.BoostOfAccountID) + _, err := tl.IndexAndPrepareOne(context.Background(), s.CreatedAt, s.ID, s.BoostOfID, s.AccountID, s.BoostOfAccountID) if err != nil { suite.FailNow(err.Error()) } @@ -67,7 +68,7 @@ func (suite *GetTestSuite) TearDownTest() { func (suite *GetTestSuite) TestGetDefault() { // get 10 20 the top and don't prepare the next query - statuses, err := suite.timeline.Get(20, "", "", "", false) + statuses, err := suite.timeline.Get(context.Background(), 20, "", "", "", false) if err != nil { suite.FailNow(err.Error()) } @@ -89,7 +90,7 @@ func (suite *GetTestSuite) TestGetDefault() { func (suite *GetTestSuite) TestGetDefaultPrepareNext() { // get 10 from the top and prepare the next query - statuses, err := suite.timeline.Get(10, "", "", "", true) + statuses, err := suite.timeline.Get(context.Background(), 10, "", "", "", true) if err != nil { suite.FailNow(err.Error()) } @@ -113,7 +114,7 @@ func (suite *GetTestSuite) TestGetDefaultPrepareNext() { func (suite *GetTestSuite) TestGetMaxID() { // ask for 10 with a max ID somewhere in the middle of the stack - statuses, err := suite.timeline.Get(10, "01F8MHBQCBTDKN6X5VHGMMN4MA", "", "", false) + statuses, err := suite.timeline.Get(context.Background(), 10, "01F8MHBQCBTDKN6X5VHGMMN4MA", "", "", false) if err != nil { suite.FailNow(err.Error()) } @@ -135,7 +136,7 @@ func (suite *GetTestSuite) TestGetMaxID() { func (suite *GetTestSuite) TestGetMaxIDPrepareNext() { // ask for 10 with a max ID somewhere in the middle of the stack - statuses, err := suite.timeline.Get(10, "01F8MHBQCBTDKN6X5VHGMMN4MA", "", "", true) + statuses, err := suite.timeline.Get(context.Background(), 10, "01F8MHBQCBTDKN6X5VHGMMN4MA", "", "", true) if err != nil { suite.FailNow(err.Error()) } @@ -160,7 +161,7 @@ func (suite *GetTestSuite) TestGetMaxIDPrepareNext() { func (suite *GetTestSuite) TestGetMinID() { // ask for 10 with a min ID somewhere in the middle of the stack - statuses, err := suite.timeline.Get(10, "", "01F8MHBQCBTDKN6X5VHGMMN4MA", "", false) + statuses, err := suite.timeline.Get(context.Background(), 10, "", "01F8MHBQCBTDKN6X5VHGMMN4MA", "", false) if err != nil { suite.FailNow(err.Error()) } @@ -182,7 +183,7 @@ func (suite *GetTestSuite) TestGetMinID() { func (suite *GetTestSuite) TestGetSinceID() { // ask for 10 with a since ID somewhere in the middle of the stack - statuses, err := suite.timeline.Get(10, "", "", "01F8MHBQCBTDKN6X5VHGMMN4MA", false) + statuses, err := suite.timeline.Get(context.Background(), 10, "", "", "01F8MHBQCBTDKN6X5VHGMMN4MA", false) if err != nil { suite.FailNow(err.Error()) } @@ -204,7 +205,7 @@ func (suite *GetTestSuite) TestGetSinceID() { func (suite *GetTestSuite) TestGetSinceIDPrepareNext() { // ask for 10 with a since ID somewhere in the middle of the stack - statuses, err := suite.timeline.Get(10, "", "", "01F8MHBQCBTDKN6X5VHGMMN4MA", true) + statuses, err := suite.timeline.Get(context.Background(), 10, "", "", "01F8MHBQCBTDKN6X5VHGMMN4MA", true) if err != nil { suite.FailNow(err.Error()) } @@ -229,7 +230,7 @@ func (suite *GetTestSuite) TestGetSinceIDPrepareNext() { func (suite *GetTestSuite) TestGetBetweenID() { // ask for 10 between these two IDs - statuses, err := suite.timeline.Get(10, "01F8MHCP5P2NWYQ416SBA0XSEV", "", "01F8MHBQCBTDKN6X5VHGMMN4MA", false) + statuses, err := suite.timeline.Get(context.Background(), 10, "01F8MHCP5P2NWYQ416SBA0XSEV", "", "01F8MHBQCBTDKN6X5VHGMMN4MA", false) if err != nil { suite.FailNow(err.Error()) } @@ -251,7 +252,7 @@ func (suite *GetTestSuite) TestGetBetweenID() { func (suite *GetTestSuite) TestGetBetweenIDPrepareNext() { // ask for 10 between these two IDs - statuses, err := suite.timeline.Get(10, "01F8MHCP5P2NWYQ416SBA0XSEV", "", "01F8MHBQCBTDKN6X5VHGMMN4MA", true) + statuses, err := suite.timeline.Get(context.Background(), 10, "01F8MHCP5P2NWYQ416SBA0XSEV", "", "01F8MHBQCBTDKN6X5VHGMMN4MA", true) if err != nil { suite.FailNow(err.Error()) } @@ -276,7 +277,7 @@ func (suite *GetTestSuite) TestGetBetweenIDPrepareNext() { func (suite *GetTestSuite) TestGetXFromTop() { // get 5 from the top - statuses, err := suite.timeline.GetXFromTop(5) + statuses, err := suite.timeline.GetXFromTop(context.Background(), 5) if err != nil { suite.FailNow(err.Error()) } @@ -300,7 +301,7 @@ func (suite *GetTestSuite) TestGetXBehindID() { var attempts *int a := 0 attempts = &a - statuses, err := suite.timeline.GetXBehindID(3, "01F8MHBQCBTDKN6X5VHGMMN4MA", attempts) + statuses, err := suite.timeline.GetXBehindID(context.Background(), 3, "01F8MHBQCBTDKN6X5VHGMMN4MA", attempts) if err != nil { suite.FailNow(err.Error()) } @@ -326,7 +327,7 @@ func (suite *GetTestSuite) TestGetXBehindID0() { var attempts *int a := 0 attempts = &a - statuses, err := suite.timeline.GetXBehindID(3, "0", attempts) + statuses, err := suite.timeline.GetXBehindID(context.Background(), 3, "0", attempts) if err != nil { suite.FailNow(err.Error()) } @@ -340,7 +341,7 @@ func (suite *GetTestSuite) TestGetXBehindNonexistentReasonableID() { var attempts *int a := 0 attempts = &a - statuses, err := suite.timeline.GetXBehindID(3, "01F8MHBQCBTDKN6X5VHGMMN4MB", attempts) // change the last A to a B + statuses, err := suite.timeline.GetXBehindID(context.Background(), 3, "01F8MHBQCBTDKN6X5VHGMMN4MB", attempts) // change the last A to a B if err != nil { suite.FailNow(err.Error()) } @@ -365,7 +366,7 @@ func (suite *GetTestSuite) TestGetXBehindVeryHighID() { var attempts *int a := 0 attempts = &a - statuses, err := suite.timeline.GetXBehindID(7, "9998MHBQCBTDKN6X5VHGMMN4MA", attempts) + statuses, err := suite.timeline.GetXBehindID(context.Background(), 7, "9998MHBQCBTDKN6X5VHGMMN4MA", attempts) if err != nil { suite.FailNow(err.Error()) } @@ -389,7 +390,7 @@ func (suite *GetTestSuite) TestGetXBehindVeryHighID() { func (suite *GetTestSuite) TestGetXBeforeID() { // get 3 before the 'middle' id - statuses, err := suite.timeline.GetXBeforeID(3, "01F8MHBQCBTDKN6X5VHGMMN4MA", true) + statuses, err := suite.timeline.GetXBeforeID(context.Background(), 3, "01F8MHBQCBTDKN6X5VHGMMN4MA", true) if err != nil { suite.FailNow(err.Error()) } @@ -412,7 +413,7 @@ func (suite *GetTestSuite) TestGetXBeforeID() { func (suite *GetTestSuite) TestGetXBeforeIDNoStartFromTop() { // get 3 before the 'middle' id - statuses, err := suite.timeline.GetXBeforeID(3, "01F8MHBQCBTDKN6X5VHGMMN4MA", false) + statuses, err := suite.timeline.GetXBeforeID(context.Background(), 3, "01F8MHBQCBTDKN6X5VHGMMN4MA", false) if err != nil { suite.FailNow(err.Error()) } diff --git a/internal/timeline/index.go b/internal/timeline/index.go index 7cffe7ab9..7d7dc8873 100644 --- a/internal/timeline/index.go +++ b/internal/timeline/index.go @@ -20,6 +20,7 @@ package timeline import ( "container/list" + "context" "errors" "fmt" "time" @@ -29,7 +30,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) -func (t *timeline) IndexBefore(statusID string, include bool, amount int) error { +func (t *timeline) IndexBefore(ctx context.Context, statusID string, include bool, amount int) error { // lazily initialize index if it hasn't been done already if t.postIndex.data == nil { t.postIndex.data = &list.List{} @@ -42,7 +43,7 @@ func (t *timeline) IndexBefore(statusID string, include bool, amount int) error if include { // if we have the status with given statusID in the database, include it in the results set as well s := >smodel.Status{} - if err := t.db.GetByID(statusID, s); err == nil { + if err := t.db.GetByID(ctx, statusID, s); err == nil { filtered = append(filtered, s) } } @@ -50,7 +51,7 @@ func (t *timeline) IndexBefore(statusID string, include bool, amount int) error i := 0 grabloop: for ; len(filtered) < amount && i < 5; i = i + 1 { // try the grabloop 5 times only - statuses, err := t.db.GetHomeTimeline(t.accountID, "", "", offsetStatus, amount, false) + statuses, err := t.db.GetHomeTimeline(ctx, t.accountID, "", "", offsetStatus, amount, false) if err != nil { if err == db.ErrNoEntries { break grabloop // we just don't have enough statuses left in the db so index what we've got and then bail @@ -59,7 +60,7 @@ grabloop: } for _, s := range statuses { - timelineable, err := t.filter.StatusHometimelineable(s, t.account) + timelineable, err := t.filter.StatusHometimelineable(ctx, s, t.account) if err != nil { continue } @@ -71,7 +72,7 @@ grabloop: } for _, s := range filtered { - if _, err := t.IndexOne(s.CreatedAt, s.ID, s.BoostOfID, s.AccountID, s.BoostOfAccountID); err != nil { + if _, err := t.IndexOne(ctx, s.CreatedAt, s.ID, s.BoostOfID, s.AccountID, s.BoostOfAccountID); err != nil { return fmt.Errorf("IndexBefore: error indexing status with id %s: %s", s.ID, err) } } @@ -79,7 +80,7 @@ grabloop: return nil } -func (t *timeline) IndexBehind(statusID string, include bool, amount int) error { +func (t *timeline) IndexBehind(ctx context.Context, statusID string, include bool, amount int) error { l := t.log.WithFields(logrus.Fields{ "func": "IndexBehind", "include": include, @@ -121,7 +122,7 @@ positionLoop: if include { // if we have the status with given statusID in the database, include it in the results set as well s := >smodel.Status{} - if err := t.db.GetByID(statusID, s); err == nil { + if err := t.db.GetByID(ctx, statusID, s); err == nil { filtered = append(filtered, s) } } @@ -130,7 +131,7 @@ positionLoop: grabloop: for ; len(filtered) < amount && i < 5; i = i + 1 { // try the grabloop 5 times only l.Tracef("entering grabloop; i is %d; len(filtered) is %d", i, len(filtered)) - statuses, err := t.db.GetHomeTimeline(t.accountID, offsetStatus, "", "", amount, false) + statuses, err := t.db.GetHomeTimeline(ctx, t.accountID, offsetStatus, "", "", amount, false) if err != nil { if err == db.ErrNoEntries { break grabloop // we just don't have enough statuses left in the db so index what we've got and then bail @@ -140,7 +141,7 @@ grabloop: l.Tracef("got %d statuses", len(statuses)) for _, s := range statuses { - timelineable, err := t.filter.StatusHometimelineable(s, t.account) + timelineable, err := t.filter.StatusHometimelineable(ctx, s, t.account) if err != nil { l.Tracef("status was not hometimelineable: %s", err) continue @@ -154,7 +155,7 @@ grabloop: l.Trace("left grabloop") for _, s := range filtered { - if _, err := t.IndexOne(s.CreatedAt, s.ID, s.BoostOfID, s.AccountID, s.BoostOfAccountID); err != nil { + if _, err := t.IndexOne(ctx, s.CreatedAt, s.ID, s.BoostOfID, s.AccountID, s.BoostOfAccountID); err != nil { return fmt.Errorf("IndexBehind: error indexing status with id %s: %s", s.ID, err) } } @@ -163,7 +164,7 @@ grabloop: return nil } -func (t *timeline) IndexOne(statusCreatedAt time.Time, statusID string, boostOfID string, accountID string, boostOfAccountID string) (bool, error) { +func (t *timeline) IndexOne(ctx context.Context, statusCreatedAt time.Time, statusID string, boostOfID string, accountID string, boostOfAccountID string) (bool, error) { t.Lock() defer t.Unlock() @@ -177,7 +178,7 @@ func (t *timeline) IndexOne(statusCreatedAt time.Time, statusID string, boostOfI return t.postIndex.insertIndexed(postIndexEntry) } -func (t *timeline) IndexAndPrepareOne(statusCreatedAt time.Time, statusID string, boostOfID string, accountID string, boostOfAccountID string) (bool, error) { +func (t *timeline) IndexAndPrepareOne(ctx context.Context, statusCreatedAt time.Time, statusID string, boostOfID string, accountID string, boostOfAccountID string) (bool, error) { t.Lock() defer t.Unlock() @@ -194,7 +195,7 @@ func (t *timeline) IndexAndPrepareOne(statusCreatedAt time.Time, statusID string } if inserted { - if err := t.prepare(statusID); err != nil { + if err := t.prepare(ctx, statusID); err != nil { return inserted, fmt.Errorf("IndexAndPrepareOne: error preparing: %s", err) } } @@ -202,7 +203,7 @@ func (t *timeline) IndexAndPrepareOne(statusCreatedAt time.Time, statusID string return inserted, nil } -func (t *timeline) OldestIndexedPostID() (string, error) { +func (t *timeline) OldestIndexedPostID(ctx context.Context) (string, error) { var id string if t.postIndex == nil || t.postIndex.data == nil || t.postIndex.data.Back() == nil { // return an empty string if postindex hasn't been initialized yet @@ -217,7 +218,7 @@ func (t *timeline) OldestIndexedPostID() (string, error) { return entry.statusID, nil } -func (t *timeline) NewestIndexedPostID() (string, error) { +func (t *timeline) NewestIndexedPostID(ctx context.Context) (string, error) { var id string if t.postIndex == nil || t.postIndex.data == nil || t.postIndex.data.Front() == nil { // return an empty string if postindex hasn't been initialized yet diff --git a/internal/timeline/index_test.go b/internal/timeline/index_test.go index 4201a27dd..25565a1de 100644 --- a/internal/timeline/index_test.go +++ b/internal/timeline/index_test.go @@ -19,6 +19,7 @@ package timeline_test import ( + "context" "testing" "time" @@ -46,7 +47,7 @@ func (suite *IndexTestSuite) SetupTest() { testrig.StandardDBSetup(suite.db, nil) // let's take local_account_1 as the timeline owner, and start with an empty timeline - tl, err := timeline.NewTimeline(suite.testAccounts["local_account_1"].ID, suite.db, suite.tc, suite.log) + tl, err := timeline.NewTimeline(context.Background(), suite.testAccounts["local_account_1"].ID, suite.db, suite.tc, suite.log) if err != nil { suite.FailNow(err.Error()) } @@ -59,82 +60,82 @@ func (suite *IndexTestSuite) TearDownTest() { func (suite *IndexTestSuite) TestIndexBeforeLowID() { // index 10 before the lowest status ID possible - err := suite.timeline.IndexBefore("00000000000000000000000000", true, 10) + err := suite.timeline.IndexBefore(context.Background(), "00000000000000000000000000", true, 10) suite.NoError(err) // the oldest indexed post should be the lowest one we have in our testrig - postID, err := suite.timeline.OldestIndexedPostID() + postID, err := suite.timeline.OldestIndexedPostID(context.Background()) suite.NoError(err) suite.Equal("01F8MHAAY43M6RJ473VQFCVH37", postID) - indexLength := suite.timeline.PostIndexLength() + indexLength := suite.timeline.PostIndexLength(context.Background()) suite.Equal(10, indexLength) } func (suite *IndexTestSuite) TestIndexBeforeHighID() { // index 10 before the highest status ID possible - err := suite.timeline.IndexBefore("ZZZZZZZZZZZZZZZZZZZZZZZZZZ", true, 10) + err := suite.timeline.IndexBefore(context.Background(), "ZZZZZZZZZZZZZZZZZZZZZZZZZZ", true, 10) suite.NoError(err) // the oldest indexed post should be empty - postID, err := suite.timeline.OldestIndexedPostID() + postID, err := suite.timeline.OldestIndexedPostID(context.Background()) suite.NoError(err) suite.Empty(postID) // indexLength should be 0 - indexLength := suite.timeline.PostIndexLength() + indexLength := suite.timeline.PostIndexLength(context.Background()) suite.Equal(0, indexLength) } func (suite *IndexTestSuite) TestIndexBehindHighID() { // index 10 behind the highest status ID possible - err := suite.timeline.IndexBehind("ZZZZZZZZZZZZZZZZZZZZZZZZZZ", true, 10) + err := suite.timeline.IndexBehind(context.Background(), "ZZZZZZZZZZZZZZZZZZZZZZZZZZ", true, 10) suite.NoError(err) // the newest indexed post should be the highest one we have in our testrig - postID, err := suite.timeline.NewestIndexedPostID() + postID, err := suite.timeline.NewestIndexedPostID(context.Background()) suite.NoError(err) suite.Equal("01FCTA44PW9H1TB328S9AQXKDS", postID) // indexLength should be 10 because that's all this user has hometimelineable - indexLength := suite.timeline.PostIndexLength() + indexLength := suite.timeline.PostIndexLength(context.Background()) suite.Equal(10, indexLength) } func (suite *IndexTestSuite) TestIndexBehindLowID() { // index 10 behind the lowest status ID possible - err := suite.timeline.IndexBehind("00000000000000000000000000", true, 10) + err := suite.timeline.IndexBehind(context.Background(), "00000000000000000000000000", true, 10) suite.NoError(err) // the newest indexed post should be empty - postID, err := suite.timeline.NewestIndexedPostID() + postID, err := suite.timeline.NewestIndexedPostID(context.Background()) suite.NoError(err) suite.Empty(postID) // indexLength should be 0 - indexLength := suite.timeline.PostIndexLength() + indexLength := suite.timeline.PostIndexLength(context.Background()) suite.Equal(0, indexLength) } func (suite *IndexTestSuite) TestOldestIndexedPostIDEmpty() { // the oldest indexed post should be an empty string since there's nothing indexed yet - postID, err := suite.timeline.OldestIndexedPostID() + postID, err := suite.timeline.OldestIndexedPostID(context.Background()) suite.NoError(err) suite.Empty(postID) // indexLength should be 0 - indexLength := suite.timeline.PostIndexLength() + indexLength := suite.timeline.PostIndexLength(context.Background()) suite.Equal(0, indexLength) } func (suite *IndexTestSuite) TestNewestIndexedPostIDEmpty() { // the newest indexed post should be an empty string since there's nothing indexed yet - postID, err := suite.timeline.NewestIndexedPostID() + postID, err := suite.timeline.NewestIndexedPostID(context.Background()) suite.NoError(err) suite.Empty(postID) // indexLength should be 0 - indexLength := suite.timeline.PostIndexLength() + indexLength := suite.timeline.PostIndexLength(context.Background()) suite.Equal(0, indexLength) } @@ -142,12 +143,12 @@ func (suite *IndexTestSuite) TestIndexAlreadyIndexed() { testStatus := suite.testStatuses["local_account_1_status_1"] // index one post -- it should be indexed - indexed, err := suite.timeline.IndexOne(testStatus.CreatedAt, testStatus.ID, testStatus.BoostOfID, testStatus.AccountID, testStatus.BoostOfAccountID) + indexed, err := suite.timeline.IndexOne(context.Background(), testStatus.CreatedAt, testStatus.ID, testStatus.BoostOfID, testStatus.AccountID, testStatus.BoostOfAccountID) suite.NoError(err) suite.True(indexed) // try to index the same post again -- it should not be indexed - indexed, err = suite.timeline.IndexOne(testStatus.CreatedAt, testStatus.ID, testStatus.BoostOfID, testStatus.AccountID, testStatus.BoostOfAccountID) + indexed, err = suite.timeline.IndexOne(context.Background(), testStatus.CreatedAt, testStatus.ID, testStatus.BoostOfID, testStatus.AccountID, testStatus.BoostOfAccountID) suite.NoError(err) suite.False(indexed) } @@ -156,12 +157,12 @@ func (suite *IndexTestSuite) TestIndexAndPrepareAlreadyIndexedAndPrepared() { testStatus := suite.testStatuses["local_account_1_status_1"] // index and prepare one post -- it should be indexed - indexed, err := suite.timeline.IndexAndPrepareOne(testStatus.CreatedAt, testStatus.ID, testStatus.BoostOfID, testStatus.AccountID, testStatus.BoostOfAccountID) + indexed, err := suite.timeline.IndexAndPrepareOne(context.Background(), testStatus.CreatedAt, testStatus.ID, testStatus.BoostOfID, testStatus.AccountID, testStatus.BoostOfAccountID) suite.NoError(err) suite.True(indexed) // try to index and prepare the same post again -- it should not be indexed - indexed, err = suite.timeline.IndexAndPrepareOne(testStatus.CreatedAt, testStatus.ID, testStatus.BoostOfID, testStatus.AccountID, testStatus.BoostOfAccountID) + indexed, err = suite.timeline.IndexAndPrepareOne(context.Background(), testStatus.CreatedAt, testStatus.ID, testStatus.BoostOfID, testStatus.AccountID, testStatus.BoostOfAccountID) suite.NoError(err) suite.False(indexed) } @@ -177,12 +178,12 @@ func (suite *IndexTestSuite) TestIndexBoostOfAlreadyIndexed() { } // index one post -- it should be indexed - indexed, err := suite.timeline.IndexOne(testStatus.CreatedAt, testStatus.ID, testStatus.BoostOfID, testStatus.AccountID, testStatus.BoostOfAccountID) + indexed, err := suite.timeline.IndexOne(context.Background(), testStatus.CreatedAt, testStatus.ID, testStatus.BoostOfID, testStatus.AccountID, testStatus.BoostOfAccountID) suite.NoError(err) suite.True(indexed) // try to index the a boost of that post -- it should not be indexed - indexed, err = suite.timeline.IndexOne(boostOfTestStatus.CreatedAt, boostOfTestStatus.ID, boostOfTestStatus.BoostOfID, boostOfTestStatus.AccountID, boostOfTestStatus.BoostOfAccountID) + indexed, err = suite.timeline.IndexOne(context.Background(), boostOfTestStatus.CreatedAt, boostOfTestStatus.ID, boostOfTestStatus.BoostOfID, boostOfTestStatus.AccountID, boostOfTestStatus.BoostOfAccountID) suite.NoError(err) suite.False(indexed) } diff --git a/internal/timeline/manager.go b/internal/timeline/manager.go index a592670a8..7f42e2f51 100644 --- a/internal/timeline/manager.go +++ b/internal/timeline/manager.go @@ -19,6 +19,7 @@ package timeline import ( + "context" "fmt" "strings" "sync" @@ -54,7 +55,7 @@ type Manager interface { // // The returned bool indicates whether the status was actually put in the timeline. This could be false in cases where // the status is a boost, but a boost of the original post or the post itself already exists recently in the timeline. - Ingest(status *gtsmodel.Status, timelineAccountID string) (bool, error) + Ingest(ctx context.Context, status *gtsmodel.Status, timelineAccountID string) (bool, error) // IngestAndPrepare takes one status and indexes it into the timeline for the given account ID, and then immediately prepares it for serving. // This is useful in cases where we know the status will need to be shown at the top of a user's timeline immediately (eg., a new status is created). // @@ -62,24 +63,24 @@ type Manager interface { // // The returned bool indicates whether the status was actually put in the timeline. This could be false in cases where // the status is a boost, but a boost of the original post or the post itself already exists recently in the timeline. - IngestAndPrepare(status *gtsmodel.Status, timelineAccountID string) (bool, error) + IngestAndPrepare(ctx context.Context, status *gtsmodel.Status, timelineAccountID string) (bool, error) // HomeTimeline returns limit n amount of entries from the home timeline of the given account ID, in descending chronological order. // If maxID is provided, it will return entries from that maxID onwards, inclusive. - HomeTimeline(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*apimodel.Status, error) + HomeTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*apimodel.Status, error) // GetIndexedLength returns the amount of posts/statuses that have been *indexed* for the given account ID. - GetIndexedLength(timelineAccountID string) int + GetIndexedLength(ctx context.Context, timelineAccountID string) int // GetDesiredIndexLength returns the amount of posts that we, ideally, index for each user. - GetDesiredIndexLength() int + GetDesiredIndexLength(ctx context.Context) int // GetOldestIndexedID returns the status ID for the oldest post that we have indexed for the given account. - GetOldestIndexedID(timelineAccountID string) (string, error) + GetOldestIndexedID(ctx context.Context, timelineAccountID string) (string, error) // PrepareXFromTop prepares limit n amount of posts, based on their indexed representations, from the top of the index. - PrepareXFromTop(timelineAccountID string, limit int) error + PrepareXFromTop(ctx context.Context, timelineAccountID string, limit int) error // Remove removes one status from the timeline of the given timelineAccountID - Remove(timelineAccountID string, statusID string) (int, error) + Remove(ctx context.Context, timelineAccountID string, statusID string) (int, error) // WipeStatusFromAllTimelines removes one status from the index and prepared posts of all timelines - WipeStatusFromAllTimelines(statusID string) error + WipeStatusFromAllTimelines(ctx context.Context, statusID string) error // WipeStatusesFromAccountID removes all statuses by the given accountID from the timelineAccountID's timelines. - WipeStatusesFromAccountID(timelineAccountID string, accountID string) error + WipeStatusesFromAccountID(ctx context.Context, timelineAccountID string, accountID string) error } // NewManager returns a new timeline manager with the given database, typeconverter, config, and log. @@ -101,104 +102,104 @@ type manager struct { log *logrus.Logger } -func (m *manager) Ingest(status *gtsmodel.Status, timelineAccountID string) (bool, error) { +func (m *manager) Ingest(ctx context.Context, status *gtsmodel.Status, timelineAccountID string) (bool, error) { l := m.log.WithFields(logrus.Fields{ "func": "Ingest", "timelineAccountID": timelineAccountID, "statusID": status.ID, }) - t, err := m.getOrCreateTimeline(timelineAccountID) + t, err := m.getOrCreateTimeline(ctx, timelineAccountID) if err != nil { return false, err } l.Trace("ingesting status") - return t.IndexOne(status.CreatedAt, status.ID, status.BoostOfID, status.AccountID, status.BoostOfAccountID) + return t.IndexOne(ctx, status.CreatedAt, status.ID, status.BoostOfID, status.AccountID, status.BoostOfAccountID) } -func (m *manager) IngestAndPrepare(status *gtsmodel.Status, timelineAccountID string) (bool, error) { +func (m *manager) IngestAndPrepare(ctx context.Context, status *gtsmodel.Status, timelineAccountID string) (bool, error) { l := m.log.WithFields(logrus.Fields{ "func": "IngestAndPrepare", "timelineAccountID": timelineAccountID, "statusID": status.ID, }) - t, err := m.getOrCreateTimeline(timelineAccountID) + t, err := m.getOrCreateTimeline(ctx, timelineAccountID) if err != nil { return false, err } l.Trace("ingesting status") - return t.IndexAndPrepareOne(status.CreatedAt, status.ID, status.BoostOfID, status.AccountID, status.BoostOfAccountID) + return t.IndexAndPrepareOne(ctx, status.CreatedAt, status.ID, status.BoostOfID, status.AccountID, status.BoostOfAccountID) } -func (m *manager) Remove(timelineAccountID string, statusID string) (int, error) { +func (m *manager) Remove(ctx context.Context, timelineAccountID string, statusID string) (int, error) { l := m.log.WithFields(logrus.Fields{ "func": "Remove", "timelineAccountID": timelineAccountID, "statusID": statusID, }) - t, err := m.getOrCreateTimeline(timelineAccountID) + t, err := m.getOrCreateTimeline(ctx, timelineAccountID) if err != nil { return 0, err } l.Trace("removing status") - return t.Remove(statusID) + return t.Remove(ctx, statusID) } -func (m *manager) HomeTimeline(timelineAccountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*apimodel.Status, error) { +func (m *manager) HomeTimeline(ctx context.Context, timelineAccountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*apimodel.Status, error) { l := m.log.WithFields(logrus.Fields{ "func": "HomeTimelineGet", "timelineAccountID": timelineAccountID, }) - t, err := m.getOrCreateTimeline(timelineAccountID) + t, err := m.getOrCreateTimeline(ctx, timelineAccountID) if err != nil { return nil, err } - statuses, err := t.Get(limit, maxID, sinceID, minID, true) + statuses, err := t.Get(ctx, limit, maxID, sinceID, minID, true) if err != nil { l.Errorf("error getting statuses: %s", err) } return statuses, nil } -func (m *manager) GetIndexedLength(timelineAccountID string) int { - t, err := m.getOrCreateTimeline(timelineAccountID) +func (m *manager) GetIndexedLength(ctx context.Context, timelineAccountID string) int { + t, err := m.getOrCreateTimeline(ctx, timelineAccountID) if err != nil { return 0 } - return t.PostIndexLength() + return t.PostIndexLength(ctx) } -func (m *manager) GetDesiredIndexLength() int { +func (m *manager) GetDesiredIndexLength(ctx context.Context) int { return desiredPostIndexLength } -func (m *manager) GetOldestIndexedID(timelineAccountID string) (string, error) { - t, err := m.getOrCreateTimeline(timelineAccountID) +func (m *manager) GetOldestIndexedID(ctx context.Context, timelineAccountID string) (string, error) { + t, err := m.getOrCreateTimeline(ctx, timelineAccountID) if err != nil { return "", err } - return t.OldestIndexedPostID() + return t.OldestIndexedPostID(ctx) } -func (m *manager) PrepareXFromTop(timelineAccountID string, limit int) error { - t, err := m.getOrCreateTimeline(timelineAccountID) +func (m *manager) PrepareXFromTop(ctx context.Context, timelineAccountID string, limit int) error { + t, err := m.getOrCreateTimeline(ctx, timelineAccountID) if err != nil { return err } - return t.PrepareFromTop(limit) + return t.PrepareFromTop(ctx, limit) } -func (m *manager) WipeStatusFromAllTimelines(statusID string) error { +func (m *manager) WipeStatusFromAllTimelines(ctx context.Context, statusID string) error { errors := []string{} m.accountTimelines.Range(func(k interface{}, i interface{}) bool { t, ok := i.(Timeline) @@ -206,7 +207,7 @@ func (m *manager) WipeStatusFromAllTimelines(statusID string) error { panic("couldn't parse entry as Timeline, this should never happen so panic") } - if _, err := t.Remove(statusID); err != nil { + if _, err := t.Remove(ctx, statusID); err != nil { errors = append(errors, err.Error()) } @@ -221,22 +222,22 @@ func (m *manager) WipeStatusFromAllTimelines(statusID string) error { return err } -func (m *manager) WipeStatusesFromAccountID(timelineAccountID string, accountID string) error { - t, err := m.getOrCreateTimeline(timelineAccountID) +func (m *manager) WipeStatusesFromAccountID(ctx context.Context, timelineAccountID string, accountID string) error { + t, err := m.getOrCreateTimeline(ctx, timelineAccountID) if err != nil { return err } - _, err = t.RemoveAllBy(accountID) + _, err = t.RemoveAllBy(ctx, accountID) return err } -func (m *manager) getOrCreateTimeline(timelineAccountID string) (Timeline, error) { +func (m *manager) getOrCreateTimeline(ctx context.Context, timelineAccountID string) (Timeline, error) { var t Timeline i, ok := m.accountTimelines.Load(timelineAccountID) if !ok { var err error - t, err = NewTimeline(timelineAccountID, m.db, m.tc, m.log) + t, err = NewTimeline(ctx, timelineAccountID, m.db, m.tc, m.log) if err != nil { return nil, err } diff --git a/internal/timeline/manager_test.go b/internal/timeline/manager_test.go index 00c6dcb4a..ea4dc4c12 100644 --- a/internal/timeline/manager_test.go +++ b/internal/timeline/manager_test.go @@ -19,6 +19,7 @@ package timeline_test import ( + "context" "testing" "github.com/stretchr/testify/suite" @@ -54,85 +55,85 @@ func (suite *ManagerTestSuite) TestManagerIntegration() { testAccount := suite.testAccounts["local_account_1"] // should start at 0 - indexedLen := suite.manager.GetIndexedLength(testAccount.ID) + indexedLen := suite.manager.GetIndexedLength(context.Background(), testAccount.ID) suite.Equal(0, indexedLen) // oldestIndexed should be empty string since there's nothing indexed - oldestIndexed, err := suite.manager.GetOldestIndexedID(testAccount.ID) + oldestIndexed, err := suite.manager.GetOldestIndexedID(context.Background(), testAccount.ID) suite.NoError(err) suite.Empty(oldestIndexed) // trigger status preparation - err = suite.manager.PrepareXFromTop(testAccount.ID, 20) + err = suite.manager.PrepareXFromTop(context.Background(), testAccount.ID, 20) suite.NoError(err) // local_account_1 can see 12 statuses out of the testrig statuses in its home timeline - indexedLen = suite.manager.GetIndexedLength(testAccount.ID) + indexedLen = suite.manager.GetIndexedLength(context.Background(), testAccount.ID) suite.Equal(12, indexedLen) // oldest should now be set - oldestIndexed, err = suite.manager.GetOldestIndexedID(testAccount.ID) + oldestIndexed, err = suite.manager.GetOldestIndexedID(context.Background(), testAccount.ID) suite.NoError(err) suite.Equal("01F8MH75CBF9JFX4ZAD54N0W0R", oldestIndexed) // get hometimeline - statuses, err := suite.manager.HomeTimeline(testAccount.ID, "", "", "", 20, false) + statuses, err := suite.manager.HomeTimeline(context.Background(), testAccount.ID, "", "", "", 20, false) suite.NoError(err) suite.Len(statuses, 12) // now wipe the last status from all timelines, as though it had been deleted by the owner - err = suite.manager.WipeStatusFromAllTimelines("01F8MH75CBF9JFX4ZAD54N0W0R") + err = suite.manager.WipeStatusFromAllTimelines(context.Background(), "01F8MH75CBF9JFX4ZAD54N0W0R") suite.NoError(err) // timeline should be shorter - indexedLen = suite.manager.GetIndexedLength(testAccount.ID) + indexedLen = suite.manager.GetIndexedLength(context.Background(), testAccount.ID) suite.Equal(11, indexedLen) // oldest should now be different - oldestIndexed, err = suite.manager.GetOldestIndexedID(testAccount.ID) + oldestIndexed, err = suite.manager.GetOldestIndexedID(context.Background(), testAccount.ID) suite.NoError(err) suite.Equal("01F8MH82FYRXD2RC6108DAJ5HB", oldestIndexed) // delete the new oldest status specifically from this timeline, as though local_account_1 had muted or blocked it - removed, err := suite.manager.Remove(testAccount.ID, "01F8MH82FYRXD2RC6108DAJ5HB") + removed, err := suite.manager.Remove(context.Background(), testAccount.ID, "01F8MH82FYRXD2RC6108DAJ5HB") suite.NoError(err) suite.Equal(2, removed) // 1 status should be removed, but from both indexed and prepared, so 2 removals total // timeline should be shorter - indexedLen = suite.manager.GetIndexedLength(testAccount.ID) + indexedLen = suite.manager.GetIndexedLength(context.Background(), testAccount.ID) suite.Equal(10, indexedLen) // oldest should now be different - oldestIndexed, err = suite.manager.GetOldestIndexedID(testAccount.ID) + oldestIndexed, err = suite.manager.GetOldestIndexedID(context.Background(), testAccount.ID) suite.NoError(err) suite.Equal("01F8MHAAY43M6RJ473VQFCVH37", oldestIndexed) // now remove all entries by local_account_2 from the timeline - err = suite.manager.WipeStatusesFromAccountID(testAccount.ID, suite.testAccounts["local_account_2"].ID) + err = suite.manager.WipeStatusesFromAccountID(context.Background(), testAccount.ID, suite.testAccounts["local_account_2"].ID) suite.NoError(err) // timeline should be empty now - indexedLen = suite.manager.GetIndexedLength(testAccount.ID) + indexedLen = suite.manager.GetIndexedLength(context.Background(), testAccount.ID) suite.Equal(5, indexedLen) // ingest 1 into the timeline status1 := suite.testStatuses["admin_account_status_1"] - ingested, err := suite.manager.Ingest(status1, testAccount.ID) + ingested, err := suite.manager.Ingest(context.Background(), status1, testAccount.ID) suite.NoError(err) suite.True(ingested) // ingest and prepare another one into the timeline status2 := suite.testStatuses["local_account_2_status_1"] - ingested, err = suite.manager.IngestAndPrepare(status2, testAccount.ID) + ingested, err = suite.manager.IngestAndPrepare(context.Background(), status2, testAccount.ID) suite.NoError(err) suite.True(ingested) // timeline should be longer now - indexedLen = suite.manager.GetIndexedLength(testAccount.ID) + indexedLen = suite.manager.GetIndexedLength(context.Background(), testAccount.ID) suite.Equal(7, indexedLen) // try to ingest status 2 again - ingested, err = suite.manager.IngestAndPrepare(status2, testAccount.ID) + ingested, err = suite.manager.IngestAndPrepare(context.Background(), status2, testAccount.ID) suite.NoError(err) suite.False(ingested) // should be false since it's a duplicate } diff --git a/internal/timeline/prepare.go b/internal/timeline/prepare.go index 20000b4e9..d57222ee8 100644 --- a/internal/timeline/prepare.go +++ b/internal/timeline/prepare.go @@ -20,6 +20,7 @@ package timeline import ( "container/list" + "context" "errors" "fmt" @@ -28,7 +29,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) -func (t *timeline) prepareNextQuery(amount int, maxID string, sinceID string, minID string) error { +func (t *timeline) prepareNextQuery(ctx context.Context, amount int, maxID string, sinceID string, minID string) error { l := t.log.WithFields(logrus.Fields{ "func": "prepareNextQuery", "amount": amount, @@ -42,30 +43,30 @@ func (t *timeline) prepareNextQuery(amount int, maxID string, sinceID string, mi // maxID is defined but sinceID isn't so take from behind if maxID != "" && sinceID == "" { l.Debug("preparing behind maxID") - err = t.PrepareBehind(maxID, amount) + err = t.PrepareBehind(ctx, maxID, amount) } // maxID isn't defined, but sinceID || minID are, so take x before if maxID == "" && sinceID != "" { l.Debug("preparing before sinceID") - err = t.PrepareBefore(sinceID, false, amount) + err = t.PrepareBefore(ctx, sinceID, false, amount) } if maxID == "" && minID != "" { l.Debug("preparing before minID") - err = t.PrepareBefore(minID, false, amount) + err = t.PrepareBefore(ctx, minID, false, amount) } return err } -func (t *timeline) PrepareBehind(statusID string, amount int) error { +func (t *timeline) PrepareBehind(ctx context.Context, statusID string, amount int) error { // lazily initialize prepared posts if it hasn't been done already if t.preparedPosts.data == nil { t.preparedPosts.data = &list.List{} t.preparedPosts.data.Init() } - if err := t.IndexBehind(statusID, true, amount); err != nil { + if err := t.IndexBehind(ctx, statusID, true, amount); err != nil { return fmt.Errorf("PrepareBehind: error indexing behind id %s: %s", statusID, err) } @@ -93,7 +94,7 @@ prepareloop: } if preparing { - if err := t.prepare(entry.statusID); err != nil { + if err := t.prepare(ctx, entry.statusID); err != nil { // there's been an error if err != db.ErrNoEntries { // it's a real error @@ -113,7 +114,7 @@ prepareloop: return nil } -func (t *timeline) PrepareBefore(statusID string, include bool, amount int) error { +func (t *timeline) PrepareBefore(ctx context.Context, statusID string, include bool, amount int) error { t.Lock() defer t.Unlock() @@ -148,7 +149,7 @@ prepareloop: } if preparing { - if err := t.prepare(entry.statusID); err != nil { + if err := t.prepare(ctx, entry.statusID); err != nil { // there's been an error if err != db.ErrNoEntries { // it's a real error @@ -168,7 +169,7 @@ prepareloop: return nil } -func (t *timeline) PrepareFromTop(amount int) error { +func (t *timeline) PrepareFromTop(ctx context.Context, amount int) error { l := t.log.WithFields(logrus.Fields{ "func": "PrepareFromTop", "amount": amount, @@ -183,7 +184,7 @@ func (t *timeline) PrepareFromTop(amount int) error { // if the postindex is nil, nothing has been indexed yet so index from the highest ID possible if t.postIndex.data == nil { l.Debug("postindex.data was nil, indexing behind highest possible ID") - if err := t.IndexBehind("ZZZZZZZZZZZZZZZZZZZZZZZZZZ", false, amount); err != nil { + if err := t.IndexBehind(ctx, "ZZZZZZZZZZZZZZZZZZZZZZZZZZ", false, amount); err != nil { return fmt.Errorf("PrepareFromTop: error indexing behind id %s: %s", "ZZZZZZZZZZZZZZZZZZZZZZZZZZ", err) } } @@ -203,7 +204,7 @@ prepareloop: return errors.New("PrepareFromTop: could not parse e as a postIndexEntry") } - if err := t.prepare(entry.statusID); err != nil { + if err := t.prepare(ctx, entry.statusID); err != nil { // there's been an error if err != db.ErrNoEntries { // it's a real error @@ -225,25 +226,25 @@ prepareloop: return nil } -func (t *timeline) prepare(statusID string) error { +func (t *timeline) prepare(ctx context.Context, statusID string) error { // start by getting the status out of the database according to its indexed ID gtsStatus := >smodel.Status{} - if err := t.db.GetByID(statusID, gtsStatus); err != nil { + if err := t.db.GetByID(ctx, statusID, gtsStatus); err != nil { return err } // if the account pointer hasn't been set on this timeline already, set it lazily here if t.account == nil { timelineOwnerAccount := >smodel.Account{} - if err := t.db.GetByID(t.accountID, timelineOwnerAccount); err != nil { + if err := t.db.GetByID(ctx, t.accountID, timelineOwnerAccount); err != nil { return err } t.account = timelineOwnerAccount } // serialize the status (or, at least, convert it to a form that's ready to be serialized) - apiModelStatus, err := t.tc.StatusToMasto(gtsStatus, t.account) + apiModelStatus, err := t.tc.StatusToMasto(ctx, gtsStatus, t.account) if err != nil { return err } @@ -260,7 +261,7 @@ func (t *timeline) prepare(statusID string) error { return t.preparedPosts.insertPrepared(preparedPostsEntry) } -func (t *timeline) OldestPreparedPostID() (string, error) { +func (t *timeline) OldestPreparedPostID(ctx context.Context) (string, error) { var id string if t.preparedPosts == nil || t.preparedPosts.data == nil { // return an empty string if prepared posts hasn't been initialized yet diff --git a/internal/timeline/remove.go b/internal/timeline/remove.go index cf0b0b617..031dace1f 100644 --- a/internal/timeline/remove.go +++ b/internal/timeline/remove.go @@ -20,12 +20,13 @@ package timeline import ( "container/list" + "context" "errors" "github.com/sirupsen/logrus" ) -func (t *timeline) Remove(statusID string) (int, error) { +func (t *timeline) Remove(ctx context.Context, statusID string) (int, error) { l := t.log.WithFields(logrus.Fields{ "func": "Remove", "accountTimeline": t.accountID, @@ -77,7 +78,7 @@ func (t *timeline) Remove(statusID string) (int, error) { return removed, nil } -func (t *timeline) RemoveAllBy(accountID string) (int, error) { +func (t *timeline) RemoveAllBy(ctx context.Context, accountID string) (int, error) { l := t.log.WithFields(logrus.Fields{ "func": "RemoveAllBy", "accountTimeline": t.accountID, diff --git a/internal/timeline/timeline.go b/internal/timeline/timeline.go index 6274a86ac..5f5fa1b4f 100644 --- a/internal/timeline/timeline.go +++ b/internal/timeline/timeline.go @@ -19,6 +19,7 @@ package timeline import ( + "context" "sync" "time" @@ -41,24 +42,24 @@ type Timeline interface { // Get returns an amount of statuses with the given parameters. // If prepareNext is true, then the next predicted query will be prepared already in a goroutine, // to make the next call to Get faster. - Get(amount int, maxID string, sinceID string, minID string, prepareNext bool) ([]*apimodel.Status, error) + Get(ctx context.Context, amount int, maxID string, sinceID string, minID string, prepareNext bool) ([]*apimodel.Status, error) // GetXFromTop returns x amount of posts from the top of the timeline, from newest to oldest. - GetXFromTop(amount int) ([]*apimodel.Status, error) + GetXFromTop(ctx context.Context, amount int) ([]*apimodel.Status, error) // GetXBehindID returns x amount of posts from the given id onwards, from newest to oldest. // This will NOT include the status with the given ID. // // This corresponds to an api call to /timelines/home?max_id=WHATEVER - GetXBehindID(amount int, fromID string, attempts *int) ([]*apimodel.Status, error) + GetXBehindID(ctx context.Context, amount int, fromID string, attempts *int) ([]*apimodel.Status, error) // GetXBeforeID returns x amount of posts up to the given id, from newest to oldest. // This will NOT include the status with the given ID. // // This corresponds to an api call to /timelines/home?since_id=WHATEVER - GetXBeforeID(amount int, sinceID string, startFromTop bool) ([]*apimodel.Status, error) + GetXBeforeID(ctx context.Context, amount int, sinceID string, startFromTop bool) ([]*apimodel.Status, error) // GetXBetweenID returns x amount of posts from the given maxID, up to the given id, from newest to oldest. // This will NOT include the status with the given IDs. // // This corresponds to an api call to /timelines/home?since_id=WHATEVER&max_id=WHATEVER_ELSE - GetXBetweenID(amount int, maxID string, sinceID string) ([]*apimodel.Status, error) + GetXBetweenID(ctx context.Context, amount int, maxID string, sinceID string) ([]*apimodel.Status, error) /* INDEXING FUNCTIONS @@ -68,43 +69,43 @@ type Timeline interface { // // The returned bool indicates whether or not the status was actually inserted into the timeline. This will be false // if the status is a boost and the original post or another boost of it already exists < boostReinsertionDepth back in the timeline. - IndexOne(statusCreatedAt time.Time, statusID string, boostOfID string, accountID string, boostOfAccountID string) (bool, error) + IndexOne(ctx context.Context, statusCreatedAt time.Time, statusID string, boostOfID string, accountID string, boostOfAccountID string) (bool, error) // OldestIndexedPostID returns the id of the rearmost (ie., the oldest) indexed post, or an error if something goes wrong. // If nothing goes wrong but there's no oldest post, an empty string will be returned so make sure to check for this. - OldestIndexedPostID() (string, error) + OldestIndexedPostID(ctx context.Context) (string, error) // NewestIndexedPostID returns the id of the frontmost (ie., the newest) indexed post, or an error if something goes wrong. // If nothing goes wrong but there's no newest post, an empty string will be returned so make sure to check for this. - NewestIndexedPostID() (string, error) + NewestIndexedPostID(ctx context.Context) (string, error) - IndexBefore(statusID string, include bool, amount int) error - IndexBehind(statusID string, include bool, amount int) error + IndexBefore(ctx context.Context, statusID string, include bool, amount int) error + IndexBehind(ctx context.Context, statusID string, include bool, amount int) error /* PREPARATION FUNCTIONS */ // PrepareXFromTop instructs the timeline to prepare x amount of posts from the top of the timeline. - PrepareFromTop(amount int) error + PrepareFromTop(ctx context.Context, amount int) error // PrepareBehind instructs the timeline to prepare the next amount of entries for serialization, from position onwards. // If include is true, then the given status ID will also be prepared, otherwise only entries behind it will be prepared. - PrepareBehind(statusID string, amount int) error + PrepareBehind(ctx context.Context, statusID string, amount int) error // IndexOne puts a status into the timeline at the appropriate place according to its 'createdAt' property, // and then immediately prepares it. // // The returned bool indicates whether or not the status was actually inserted into the timeline. This will be false // if the status is a boost and the original post or another boost of it already exists < boostReinsertionDepth back in the timeline. - IndexAndPrepareOne(statusCreatedAt time.Time, statusID string, boostOfID string, accountID string, boostOfAccountID string) (bool, error) + IndexAndPrepareOne(ctx context.Context, statusCreatedAt time.Time, statusID string, boostOfID string, accountID string, boostOfAccountID string) (bool, error) // OldestPreparedPostID returns the id of the rearmost (ie., the oldest) prepared post, or an error if something goes wrong. // If nothing goes wrong but there's no oldest post, an empty string will be returned so make sure to check for this. - OldestPreparedPostID() (string, error) + OldestPreparedPostID(ctx context.Context) (string, error) /* INFO FUNCTIONS */ // ActualPostIndexLength returns the actual length of the post index at this point in time. - PostIndexLength() int + PostIndexLength(ctx context.Context) int /* UTILITY FUNCTIONS @@ -117,11 +118,11 @@ type Timeline interface { // If a status has multiple entries in a timeline, they will all be removed. // // The returned int indicates the amount of entries that were removed. - Remove(statusID string) (int, error) + Remove(ctx context.Context, statusID string) (int, error) // RemoveAllBy removes all statuses by the given accountID, from both the index and prepared posts. // // The returned int indicates the amount of entries that were removed. - RemoveAllBy(accountID string) (int, error) + RemoveAllBy(ctx context.Context, accountID string) (int, error) } // timeline fulfils the Timeline interface @@ -138,9 +139,9 @@ type timeline struct { } // NewTimeline returns a new Timeline for the given account ID -func NewTimeline(accountID string, db db.DB, typeConverter typeutils.TypeConverter, log *logrus.Logger) (Timeline, error) { +func NewTimeline(ctx context.Context, accountID string, db db.DB, typeConverter typeutils.TypeConverter, log *logrus.Logger) (Timeline, error) { timelineOwnerAccount := >smodel.Account{} - if err := db.GetByID(accountID, timelineOwnerAccount); err != nil { + if err := db.GetByID(ctx, accountID, timelineOwnerAccount); err != nil { return nil, err } @@ -160,7 +161,7 @@ func (t *timeline) Reset() error { return nil } -func (t *timeline) PostIndexLength() int { +func (t *timeline) PostIndexLength(ctx context.Context) int { if t.postIndex == nil || t.postIndex.data == nil { return 0 } diff --git a/internal/transport/controller.go b/internal/transport/controller.go index 4eb6b5658..c2f5026e0 100644 --- a/internal/transport/controller.go +++ b/internal/transport/controller.go @@ -19,6 +19,7 @@ package transport import ( + "context" "crypto" "fmt" "sync" @@ -33,7 +34,7 @@ import ( // Controller generates transports for use in making federation requests to other servers. type Controller interface { NewTransport(pubKeyID string, privkey crypto.PrivateKey) (Transport, error) - NewTransportForUsername(username string) (Transport, error) + NewTransportForUsername(ctx context.Context, username string) (Transport, error) } type controller struct { @@ -90,7 +91,7 @@ func (c *controller) NewTransport(pubKeyID string, privkey crypto.PrivateKey) (T }, nil } -func (c *controller) NewTransportForUsername(username string) (Transport, error) { +func (c *controller) NewTransportForUsername(ctx context.Context, username string) (Transport, error) { // We need an account to use to create a transport for dereferecing something. // If a username has been given, we can fetch the account with that username and use it. // Otherwise, we can take the instance account and use those credentials to make the request. @@ -101,7 +102,7 @@ func (c *controller) NewTransportForUsername(username string) (Transport, error) u = username } - ourAccount, err := c.db.GetLocalAccountByUsername(u) + ourAccount, err := c.db.GetLocalAccountByUsername(ctx, u) if err != nil { return nil, fmt.Errorf("error getting account %s from db: %s", username, err) } diff --git a/internal/transport/deliver.go b/internal/transport/deliver.go index 844cb6bea..fd0fb576f 100644 --- a/internal/transport/deliver.go +++ b/internal/transport/deliver.go @@ -1,3 +1,21 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + package transport import ( @@ -5,12 +23,12 @@ import ( "net/url" ) -func (t *transport) BatchDeliver(c context.Context, b []byte, recipients []*url.URL) error { - return t.sigTransport.BatchDeliver(c, b, recipients) +func (t *transport) BatchDeliver(ctx context.Context, b []byte, recipients []*url.URL) error { + return t.sigTransport.BatchDeliver(ctx, b, recipients) } -func (t *transport) Deliver(c context.Context, b []byte, to *url.URL) error { +func (t *transport) Deliver(ctx context.Context, b []byte, to *url.URL) error { l := t.log.WithField("func", "Deliver") l.Debugf("performing POST to %s", to.String()) - return t.sigTransport.Deliver(c, b, to) + return t.sigTransport.Deliver(ctx, b, to) } diff --git a/internal/transport/dereference.go b/internal/transport/dereference.go index d7a28fe17..85fa370ee 100644 --- a/internal/transport/dereference.go +++ b/internal/transport/dereference.go @@ -1,3 +1,21 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + package transport import ( @@ -5,8 +23,8 @@ import ( "net/url" ) -func (t *transport) Dereference(c context.Context, iri *url.URL) ([]byte, error) { +func (t *transport) Dereference(ctx context.Context, iri *url.URL) ([]byte, error) { l := t.log.WithField("func", "Dereference") l.Debugf("performing GET to %s", iri.String()) - return t.sigTransport.Dereference(c, iri) + return t.sigTransport.Dereference(ctx, iri) } diff --git a/internal/transport/derefinstance.go b/internal/transport/derefinstance.go index a8b2ddfc7..3d72d7581 100644 --- a/internal/transport/derefinstance.go +++ b/internal/transport/derefinstance.go @@ -1,3 +1,21 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + package transport import ( @@ -16,7 +34,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/util" ) -func (t *transport) DereferenceInstance(c context.Context, iri *url.URL) (*gtsmodel.Instance, error) { +func (t *transport) DereferenceInstance(ctx context.Context, iri *url.URL) (*gtsmodel.Instance, error) { l := t.log.WithField("func", "DereferenceInstance") var i *gtsmodel.Instance @@ -27,7 +45,7 @@ func (t *transport) DereferenceInstance(c context.Context, iri *url.URL) (*gtsmo // // This will only work with Mastodon-api compatible instances: Mastodon, some Pleroma instances, GoToSocial. l.Debugf("trying to dereference instance %s by /api/v1/instance", iri.Host) - i, err = dereferenceByAPIV1Instance(c, t, iri) + i, err = dereferenceByAPIV1Instance(ctx, t, iri) if err == nil { l.Debugf("successfully dereferenced instance using /api/v1/instance") return i, nil @@ -37,7 +55,7 @@ func (t *transport) DereferenceInstance(c context.Context, iri *url.URL) (*gtsmo // If that doesn't work, try to dereference using /.well-known/nodeinfo. // This will involve two API calls and return less info overall, but should be more widely compatible. l.Debugf("trying to dereference instance %s by /.well-known/nodeinfo", iri.Host) - i, err = dereferenceByNodeInfo(c, t, iri) + i, err = dereferenceByNodeInfo(ctx, t, iri) if err == nil { l.Debugf("successfully dereferenced instance using /.well-known/nodeinfo") return i, nil @@ -58,7 +76,7 @@ func (t *transport) DereferenceInstance(c context.Context, iri *url.URL) (*gtsmo }, nil } -func dereferenceByAPIV1Instance(c context.Context, t *transport, iri *url.URL) (*gtsmodel.Instance, error) { +func dereferenceByAPIV1Instance(ctx context.Context, t *transport, iri *url.URL) (*gtsmodel.Instance, error) { l := t.log.WithField("func", "dereferenceByAPIV1Instance") cleanIRI := &url.URL{ @@ -68,11 +86,10 @@ func dereferenceByAPIV1Instance(c context.Context, t *transport, iri *url.URL) ( } l.Debugf("performing GET to %s", cleanIRI.String()) - req, err := http.NewRequest("GET", cleanIRI.String(), nil) + req, err := http.NewRequestWithContext(ctx, "GET", cleanIRI.String(), nil) if err != nil { return nil, err } - req = req.WithContext(c) req.Header.Add("Accept", "application/json") req.Header.Add("Date", t.clock.Now().UTC().Format("Mon, 02 Jan 2006 15:04:05")+" GMT") req.Header.Add("User-Agent", fmt.Sprintf("%s %s", t.appAgent, t.gofedAgent)) @@ -216,7 +233,7 @@ func dereferenceByNodeInfo(c context.Context, t *transport, iri *url.URL) (*gtsm return i, nil } -func callNodeInfoWellKnown(c context.Context, t *transport, iri *url.URL) (*url.URL, error) { +func callNodeInfoWellKnown(ctx context.Context, t *transport, iri *url.URL) (*url.URL, error) { l := t.log.WithField("func", "callNodeInfoWellKnown") cleanIRI := &url.URL{ @@ -226,11 +243,11 @@ func callNodeInfoWellKnown(c context.Context, t *transport, iri *url.URL) (*url. } l.Debugf("performing GET to %s", cleanIRI.String()) - req, err := http.NewRequest("GET", cleanIRI.String(), nil) + req, err := http.NewRequestWithContext(ctx, "GET", cleanIRI.String(), nil) if err != nil { return nil, err } - req = req.WithContext(c) + req.Header.Add("Accept", "application/json") req.Header.Add("Date", t.clock.Now().UTC().Format("Mon, 02 Jan 2006 15:04:05")+" GMT") req.Header.Add("User-Agent", fmt.Sprintf("%s %s", t.appAgent, t.gofedAgent)) @@ -281,15 +298,15 @@ func callNodeInfoWellKnown(c context.Context, t *transport, iri *url.URL) (*url. return nodeinfoHref, nil } -func callNodeInfo(c context.Context, t *transport, iri *url.URL) (*apimodel.Nodeinfo, error) { +func callNodeInfo(ctx context.Context, t *transport, iri *url.URL) (*apimodel.Nodeinfo, error) { l := t.log.WithField("func", "callNodeInfo") l.Debugf("performing GET to %s", iri.String()) - req, err := http.NewRequest("GET", iri.String(), nil) + req, err := http.NewRequestWithContext(ctx, "GET", iri.String(), nil) if err != nil { return nil, err } - req = req.WithContext(c) + req.Header.Add("Accept", "application/json") req.Header.Add("Date", t.clock.Now().UTC().Format("Mon, 02 Jan 2006 15:04:05")+" GMT") req.Header.Add("User-Agent", fmt.Sprintf("%s %s", t.appAgent, t.gofedAgent)) diff --git a/internal/transport/derefmedia.go b/internal/transport/derefmedia.go index 5fa901100..e265bfdd4 100644 --- a/internal/transport/derefmedia.go +++ b/internal/transport/derefmedia.go @@ -1,3 +1,21 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + package transport import ( @@ -8,14 +26,13 @@ import ( "net/url" ) -func (t *transport) DereferenceMedia(c context.Context, iri *url.URL, expectedContentType string) ([]byte, error) { +func (t *transport) DereferenceMedia(ctx context.Context, iri *url.URL, expectedContentType string) ([]byte, error) { l := t.log.WithField("func", "DereferenceMedia") l.Debugf("performing GET to %s", iri.String()) - req, err := http.NewRequest("GET", iri.String(), nil) + req, err := http.NewRequestWithContext(ctx, "GET", iri.String(), nil) if err != nil { return nil, err } - req = req.WithContext(c) if expectedContentType == "" { req.Header.Add("Accept", "*/*") } else { diff --git a/internal/transport/finger.go b/internal/transport/finger.go index 12cd2fb64..ce092e83f 100644 --- a/internal/transport/finger.go +++ b/internal/transport/finger.go @@ -1,3 +1,21 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + package transport import ( @@ -8,7 +26,7 @@ import ( "net/url" ) -func (t *transport) Finger(c context.Context, targetUsername string, targetDomain string) ([]byte, error) { +func (t *transport) Finger(ctx context.Context, targetUsername string, targetDomain string) ([]byte, error) { l := t.log.WithField("func", "Finger") urlString := fmt.Sprintf("https://%s/.well-known/webfinger?resource=acct:%s@%s", targetDomain, targetUsername, targetDomain) l.Debugf("performing GET to %s", urlString) @@ -20,11 +38,11 @@ func (t *transport) Finger(c context.Context, targetUsername string, targetDomai l.Debugf("performing GET to %s", iri.String()) - req, err := http.NewRequest("GET", iri.String(), nil) + req, err := http.NewRequestWithContext(ctx, "GET", iri.String(), nil) if err != nil { return nil, err } - req = req.WithContext(c) + req.Header.Add("Accept", "application/json") req.Header.Add("Accept", "application/jrd+json") req.Header.Add("Date", t.clock.Now().UTC().Format("Mon, 02 Jan 2006 15:04:05")+" GMT") diff --git a/internal/transport/transport.go b/internal/transport/transport.go index 04c72de5c..8d8262834 100644 --- a/internal/transport/transport.go +++ b/internal/transport/transport.go @@ -1,3 +1,21 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + package transport import ( @@ -17,11 +35,11 @@ import ( type Transport interface { pub.Transport // DereferenceMedia fetches the bytes of the given media attachment IRI, with the expectedContentType. - DereferenceMedia(c context.Context, iri *url.URL, expectedContentType string) ([]byte, error) + DereferenceMedia(ctx context.Context, iri *url.URL, expectedContentType string) ([]byte, error) // DereferenceInstance dereferences remote instance information, first by checking /api/v1/instance, and then by checking /.well-known/nodeinfo. - DereferenceInstance(c context.Context, iri *url.URL) (*gtsmodel.Instance, error) + DereferenceInstance(ctx context.Context, iri *url.URL) (*gtsmodel.Instance, error) // Finger performs a webfinger request with the given username and domain, and returns the bytes from the response body. - Finger(c context.Context, targetUsername string, targetDomains string) ([]byte, error) + Finger(ctx context.Context, targetUsername string, targetDomains string) ([]byte, error) } // transport implements the Transport interface diff --git a/internal/typeutils/astointernal.go b/internal/typeutils/astointernal.go index 887716a69..46132233b 100644 --- a/internal/typeutils/astointernal.go +++ b/internal/typeutils/astointernal.go @@ -19,6 +19,7 @@ package typeutils import ( + "context" "errors" "fmt" "net/url" @@ -29,7 +30,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) -func (c *converter) ASRepresentationToAccount(accountable ap.Accountable, update bool) (*gtsmodel.Account, error) { +func (c *converter) ASRepresentationToAccount(ctx context.Context, accountable ap.Accountable, update bool) (*gtsmodel.Account, error) { // first check if we actually already know this account uriProp := accountable.GetJSONLDId() if uriProp == nil || !uriProp.IsIRI() { @@ -38,7 +39,7 @@ func (c *converter) ASRepresentationToAccount(accountable ap.Accountable, update uri := uriProp.GetIRI() if !update { - acct, err := c.db.GetAccountByURI(uri.String()) + acct, err := c.db.GetAccountByURI(ctx, uri.String()) if err == nil { // we already know this account so we can skip generating it return acct, nil @@ -170,7 +171,7 @@ func (c *converter) ASRepresentationToAccount(accountable ap.Accountable, update return acct, nil } -func (c *converter) ASStatusToStatus(statusable ap.Statusable) (*gtsmodel.Status, error) { +func (c *converter) ASStatusToStatus(ctx context.Context, statusable ap.Statusable) (*gtsmodel.Status, error) { status := >smodel.Status{} // uri at which this status is reachable @@ -219,6 +220,7 @@ func (c *converter) ASStatusToStatus(statusable ap.Statusable) (*gtsmodel.Status published, err := ap.ExtractPublished(statusable) if err == nil { status.CreatedAt = published + status.UpdatedAt = published } // which account posted this status? @@ -229,7 +231,7 @@ func (c *converter) ASStatusToStatus(statusable ap.Statusable) (*gtsmodel.Status } status.AccountURI = attributedTo.String() - statusOwner, err := c.db.GetAccountByURI(attributedTo.String()) + statusOwner, err := c.db.GetAccountByURI(ctx, attributedTo.String()) if err != nil { return nil, fmt.Errorf("couldn't get status owner from db: %s", err) } @@ -245,14 +247,14 @@ func (c *converter) ASStatusToStatus(statusable ap.Statusable) (*gtsmodel.Status status.InReplyToURI = inReplyToURI.String() // now we can check if we have the replied-to status in our db already - if inReplyToStatus, err := c.db.GetStatusByURI(inReplyToURI.String()); err == nil { + if inReplyToStatus, err := c.db.GetStatusByURI(ctx, inReplyToURI.String()); err == nil { // we have the status in our database already // so we can set these fields here and now... status.InReplyToID = inReplyToStatus.ID status.InReplyToAccountID = inReplyToStatus.AccountID status.InReplyTo = inReplyToStatus if status.InReplyToAccount == nil { - if inReplyToAccount, err := c.db.GetAccountByID(inReplyToStatus.AccountID); err == nil { + if inReplyToAccount, err := c.db.GetAccountByID(ctx, inReplyToStatus.AccountID); err == nil { status.InReplyToAccount = inReplyToAccount } } @@ -318,7 +320,7 @@ func (c *converter) ASStatusToStatus(statusable ap.Statusable) (*gtsmodel.Status return status, nil } -func (c *converter) ASFollowToFollowRequest(followable ap.Followable) (*gtsmodel.FollowRequest, error) { +func (c *converter) ASFollowToFollowRequest(ctx context.Context, followable ap.Followable) (*gtsmodel.FollowRequest, error) { idProp := followable.GetJSONLDId() if idProp == nil || !idProp.IsIRI() { @@ -330,7 +332,7 @@ func (c *converter) ASFollowToFollowRequest(followable ap.Followable) (*gtsmodel if err != nil { return nil, errors.New("error extracting actor property from follow") } - originAccount, err := c.db.GetAccountByURI(origin.String()) + originAccount, err := c.db.GetAccountByURI(ctx, origin.String()) if err != nil { return nil, fmt.Errorf("error extracting account with uri %s from the database: %s", origin.String(), err) } @@ -339,7 +341,7 @@ func (c *converter) ASFollowToFollowRequest(followable ap.Followable) (*gtsmodel if err != nil { return nil, errors.New("error extracting object property from follow") } - targetAccount, err := c.db.GetAccountByURI(target.String()) + targetAccount, err := c.db.GetAccountByURI(ctx, target.String()) if err != nil { return nil, fmt.Errorf("error extracting account with uri %s from the database: %s", origin.String(), err) } @@ -353,7 +355,7 @@ func (c *converter) ASFollowToFollowRequest(followable ap.Followable) (*gtsmodel return followRequest, nil } -func (c *converter) ASFollowToFollow(followable ap.Followable) (*gtsmodel.Follow, error) { +func (c *converter) ASFollowToFollow(ctx context.Context, followable ap.Followable) (*gtsmodel.Follow, error) { idProp := followable.GetJSONLDId() if idProp == nil || !idProp.IsIRI() { return nil, errors.New("no id property set on follow, or was not an iri") @@ -364,7 +366,7 @@ func (c *converter) ASFollowToFollow(followable ap.Followable) (*gtsmodel.Follow if err != nil { return nil, errors.New("error extracting actor property from follow") } - originAccount, err := c.db.GetAccountByURI(origin.String()) + originAccount, err := c.db.GetAccountByURI(ctx, origin.String()) if err != nil { return nil, fmt.Errorf("error extracting account with uri %s from the database: %s", origin.String(), err) } @@ -373,7 +375,7 @@ func (c *converter) ASFollowToFollow(followable ap.Followable) (*gtsmodel.Follow if err != nil { return nil, errors.New("error extracting object property from follow") } - targetAccount, err := c.db.GetAccountByURI(target.String()) + targetAccount, err := c.db.GetAccountByURI(ctx, target.String()) if err != nil { return nil, fmt.Errorf("error extracting account with uri %s from the database: %s", origin.String(), err) } @@ -387,7 +389,7 @@ func (c *converter) ASFollowToFollow(followable ap.Followable) (*gtsmodel.Follow return follow, nil } -func (c *converter) ASLikeToFave(likeable ap.Likeable) (*gtsmodel.StatusFave, error) { +func (c *converter) ASLikeToFave(ctx context.Context, likeable ap.Likeable) (*gtsmodel.StatusFave, error) { idProp := likeable.GetJSONLDId() if idProp == nil || !idProp.IsIRI() { return nil, errors.New("no id property set on like, or was not an iri") @@ -398,7 +400,7 @@ func (c *converter) ASLikeToFave(likeable ap.Likeable) (*gtsmodel.StatusFave, er if err != nil { return nil, errors.New("error extracting actor property from like") } - originAccount, err := c.db.GetAccountByURI(origin.String()) + originAccount, err := c.db.GetAccountByURI(ctx, origin.String()) if err != nil { return nil, fmt.Errorf("error extracting account with uri %s from the database: %s", origin.String(), err) } @@ -408,7 +410,7 @@ func (c *converter) ASLikeToFave(likeable ap.Likeable) (*gtsmodel.StatusFave, er return nil, errors.New("error extracting object property from like") } - targetStatus, err := c.db.GetStatusByURI(target.String()) + targetStatus, err := c.db.GetStatusByURI(ctx, target.String()) if err != nil { return nil, fmt.Errorf("error extracting status with uri %s from the database: %s", target.String(), err) } @@ -417,7 +419,7 @@ func (c *converter) ASLikeToFave(likeable ap.Likeable) (*gtsmodel.StatusFave, er if targetStatus.Account != nil { targetAccount = targetStatus.Account } else { - a, err := c.db.GetAccountByID(targetStatus.AccountID) + a, err := c.db.GetAccountByID(ctx, targetStatus.AccountID) if err != nil { return nil, fmt.Errorf("error extracting account with id %s from the database: %s", targetStatus.AccountID, err) } @@ -435,7 +437,7 @@ func (c *converter) ASLikeToFave(likeable ap.Likeable) (*gtsmodel.StatusFave, er }, nil } -func (c *converter) ASBlockToBlock(blockable ap.Blockable) (*gtsmodel.Block, error) { +func (c *converter) ASBlockToBlock(ctx context.Context, blockable ap.Blockable) (*gtsmodel.Block, error) { idProp := blockable.GetJSONLDId() if idProp == nil || !idProp.IsIRI() { return nil, errors.New("ASBlockToBlock: no id property set on block, or was not an iri") @@ -446,7 +448,7 @@ func (c *converter) ASBlockToBlock(blockable ap.Blockable) (*gtsmodel.Block, err if err != nil { return nil, errors.New("ASBlockToBlock: error extracting actor property from block") } - originAccount, err := c.db.GetAccountByURI(origin.String()) + originAccount, err := c.db.GetAccountByURI(ctx, origin.String()) if err != nil { return nil, fmt.Errorf("error extracting account with uri %s from the database: %s", origin.String(), err) } @@ -456,7 +458,7 @@ func (c *converter) ASBlockToBlock(blockable ap.Blockable) (*gtsmodel.Block, err return nil, errors.New("ASBlockToBlock: error extracting object property from block") } - targetAccount, err := c.db.GetAccountByURI(target.String()) + targetAccount, err := c.db.GetAccountByURI(ctx, target.String()) if err != nil { return nil, fmt.Errorf("error extracting account with uri %s from the database: %s", origin.String(), err) } @@ -470,7 +472,7 @@ func (c *converter) ASBlockToBlock(blockable ap.Blockable) (*gtsmodel.Block, err }, nil } -func (c *converter) ASAnnounceToStatus(announceable ap.Announceable) (*gtsmodel.Status, bool, error) { +func (c *converter) ASAnnounceToStatus(ctx context.Context, announceable ap.Announceable) (*gtsmodel.Status, bool, error) { status := >smodel.Status{} isNew := true @@ -481,7 +483,7 @@ func (c *converter) ASAnnounceToStatus(announceable ap.Announceable) (*gtsmodel. } uri := idProp.GetIRI().String() - if status, err := c.db.GetStatusByURI(uri); err == nil { + if status, err := c.db.GetStatusByURI(ctx, uri); err == nil { // we already have it, great, just return it as-is :) isNew = false return status, isNew, nil @@ -515,7 +517,7 @@ func (c *converter) ASAnnounceToStatus(announceable ap.Announceable) (*gtsmodel. // get the boosting account based on the URI // this should have been dereferenced already before we hit this point so we can confidently error out if we don't have it - boostingAccount, err := c.db.GetAccountByURI(actor.String()) + boostingAccount, err := c.db.GetAccountByURI(ctx, actor.String()) if err != nil { return nil, isNew, fmt.Errorf("ASAnnounceToStatus: error in db fetching account with uri %s: %s", actor.String(), err) } diff --git a/internal/typeutils/astointernal_test.go b/internal/typeutils/astointernal_test.go index 1d02dec5a..a01e79202 100644 --- a/internal/typeutils/astointernal_test.go +++ b/internal/typeutils/astointernal_test.go @@ -348,7 +348,7 @@ func (suite *ASToInternalTestSuite) SetupTest() { func (suite *ASToInternalTestSuite) TestParsePerson() { testPerson := suite.people["new_person_1"] - acct, err := suite.typeconverter.ASRepresentationToAccount(testPerson, false) + acct, err := suite.typeconverter.ASRepresentationToAccount(context.Background(), testPerson, false) assert.NoError(suite.T(), err) suite.Equal("https://unknown-instance.com/users/brand_new_person", acct.URI) @@ -379,7 +379,7 @@ func (suite *ASToInternalTestSuite) TestParseGargron() { rep, ok := t.(ap.Accountable) assert.True(suite.T(), ok) - acct, err := suite.typeconverter.ASRepresentationToAccount(rep, false) + acct, err := suite.typeconverter.ASRepresentationToAccount(context.Background(), rep, false) assert.NoError(suite.T(), err) fmt.Printf("%+v", acct) diff --git a/internal/typeutils/converter.go b/internal/typeutils/converter.go index e477a6135..4af9767bc 100644 --- a/internal/typeutils/converter.go +++ b/internal/typeutils/converter.go @@ -19,6 +19,7 @@ package typeutils import ( + "context" "net/url" "github.com/go-fed/activity/streams/vocab" @@ -47,45 +48,45 @@ type TypeConverter interface { // AccountToMastoSensitive takes a db model account as a param, and returns a populated mastotype account, or an error // if something goes wrong. The returned account should be ready to serialize on an API level, and may have sensitive fields, // so serve it only to an authorized user who should have permission to see it. - AccountToMastoSensitive(account *gtsmodel.Account) (*model.Account, error) + AccountToMastoSensitive(ctx context.Context, account *gtsmodel.Account) (*model.Account, error) // AccountToMastoPublic takes a db model account as a param, and returns a populated mastotype account, or an error // if something goes wrong. The returned account should be ready to serialize on an API level, and may NOT have sensitive fields. // In other words, this is the public record that the server has of an account. - AccountToMastoPublic(account *gtsmodel.Account) (*model.Account, error) + AccountToMastoPublic(ctx context.Context, account *gtsmodel.Account) (*model.Account, error) // AccountToMastoBlocked takes a db model account as a param, and returns a mastotype account, or an error if // something goes wrong. The returned account will be a bare minimum representation of the account. This function should be used // when someone wants to view an account they've blocked. - AccountToMastoBlocked(account *gtsmodel.Account) (*model.Account, error) + AccountToMastoBlocked(ctx context.Context, account *gtsmodel.Account) (*model.Account, error) // AppToMastoSensitive takes a db model application as a param, and returns a populated mastotype application, or an error // if something goes wrong. The returned application should be ready to serialize on an API level, and may have sensitive fields // (such as client id and client secret), so serve it only to an authorized user who should have permission to see it. - AppToMastoSensitive(application *gtsmodel.Application) (*model.Application, error) + AppToMastoSensitive(ctx context.Context, application *gtsmodel.Application) (*model.Application, error) // AppToMastoPublic takes a db model application as a param, and returns a populated mastotype application, or an error // if something goes wrong. The returned application should be ready to serialize on an API level, and has sensitive // fields sanitized so that it can be served to non-authorized accounts without revealing any private information. - AppToMastoPublic(application *gtsmodel.Application) (*model.Application, error) + AppToMastoPublic(ctx context.Context, application *gtsmodel.Application) (*model.Application, error) // AttachmentToMasto converts a gts model media attacahment into its mastodon representation for serialization on the API. - AttachmentToMasto(attachment *gtsmodel.MediaAttachment) (model.Attachment, error) + AttachmentToMasto(ctx context.Context, attachment *gtsmodel.MediaAttachment) (model.Attachment, error) // MentionToMasto converts a gts model mention into its mastodon (frontend) representation for serialization on the API. - MentionToMasto(m *gtsmodel.Mention) (model.Mention, error) + MentionToMasto(ctx context.Context, m *gtsmodel.Mention) (model.Mention, error) // EmojiToMasto converts a gts model emoji into its mastodon (frontend) representation for serialization on the API. - EmojiToMasto(e *gtsmodel.Emoji) (model.Emoji, error) + EmojiToMasto(ctx context.Context, e *gtsmodel.Emoji) (model.Emoji, error) // TagToMasto converts a gts model tag into its mastodon (frontend) representation for serialization on the API. - TagToMasto(t *gtsmodel.Tag) (model.Tag, error) + TagToMasto(ctx context.Context, t *gtsmodel.Tag) (model.Tag, error) // StatusToMasto converts a gts model status into its mastodon (frontend) representation for serialization on the API. // // Requesting account can be nil. - StatusToMasto(s *gtsmodel.Status, requestingAccount *gtsmodel.Account) (*model.Status, error) + StatusToMasto(ctx context.Context, s *gtsmodel.Status, requestingAccount *gtsmodel.Account) (*model.Status, error) // VisToMasto converts a gts visibility into its mastodon equivalent - VisToMasto(m gtsmodel.Visibility) model.Visibility + VisToMasto(ctx context.Context, m gtsmodel.Visibility) model.Visibility // InstanceToMasto converts a gts instance into its mastodon equivalent for serving at /api/v1/instance - InstanceToMasto(i *gtsmodel.Instance) (*model.Instance, error) + InstanceToMasto(ctx context.Context, i *gtsmodel.Instance) (*model.Instance, error) // RelationshipToMasto converts a gts relationship into its mastodon equivalent for serving in various places - RelationshipToMasto(r *gtsmodel.Relationship) (*model.Relationship, error) + RelationshipToMasto(ctx context.Context, r *gtsmodel.Relationship) (*model.Relationship, error) // NotificationToMasto converts a gts notification into a mastodon notification - NotificationToMasto(n *gtsmodel.Notification) (*model.Notification, error) + NotificationToMasto(ctx context.Context, n *gtsmodel.Notification) (*model.Notification, error) // DomainBlockTomasto converts a gts model domin block into a mastodon domain block, for serving at /api/v1/admin/domain_blocks - DomainBlockToMasto(b *gtsmodel.DomainBlock, export bool) (*model.DomainBlock, error) + DomainBlockToMasto(ctx context.Context, b *gtsmodel.DomainBlock, export bool) (*model.DomainBlock, error) /* FRONTEND (mastodon) MODEL TO INTERNAL (gts) MODEL @@ -103,17 +104,17 @@ type TypeConverter interface { // If update is false, and the account is already known in the database, then the existing account entry will be returned. // If update is true, then even if the account is already known, all fields in the accountable will be parsed and a new *gtsmodel.Account // will be generated. This is useful when one needs to force refresh of an account, eg., during an Update of a Profile. - ASRepresentationToAccount(accountable ap.Accountable, update bool) (*gtsmodel.Account, error) + ASRepresentationToAccount(ctx context.Context, accountable ap.Accountable, update bool) (*gtsmodel.Account, error) // ASStatus converts a remote activitystreams 'status' representation into a gts model status. - ASStatusToStatus(statusable ap.Statusable) (*gtsmodel.Status, error) + ASStatusToStatus(ctx context.Context, statusable ap.Statusable) (*gtsmodel.Status, error) // ASFollowToFollowRequest converts a remote activitystreams `follow` representation into gts model follow request. - ASFollowToFollowRequest(followable ap.Followable) (*gtsmodel.FollowRequest, error) + ASFollowToFollowRequest(ctx context.Context, followable ap.Followable) (*gtsmodel.FollowRequest, error) // ASFollowToFollowRequest converts a remote activitystreams `follow` representation into gts model follow. - ASFollowToFollow(followable ap.Followable) (*gtsmodel.Follow, error) + ASFollowToFollow(ctx context.Context, followable ap.Followable) (*gtsmodel.Follow, error) // ASLikeToFave converts a remote activitystreams 'like' representation into a gts model status fave. - ASLikeToFave(likeable ap.Likeable) (*gtsmodel.StatusFave, error) + ASLikeToFave(ctx context.Context, likeable ap.Likeable) (*gtsmodel.StatusFave, error) // ASBlockToBlock converts a remote activity streams 'block' representation into a gts model block. - ASBlockToBlock(blockable ap.Blockable) (*gtsmodel.Block, error) + ASBlockToBlock(ctx context.Context, blockable ap.Blockable) (*gtsmodel.Block, error) // ASAnnounceToStatus converts an activitystreams 'announce' into a status. // // The returned bool indicates whether this status is new (true) or not new (false). @@ -126,46 +127,46 @@ type TypeConverter interface { // This is useful when multiple users on an instance might receive the same boost, and we only want to process the boost once. // // NOTE -- this is different from one status being boosted multiple times! In this case, new boosts should indeed be created. - ASAnnounceToStatus(announceable ap.Announceable) (status *gtsmodel.Status, new bool, err error) + ASAnnounceToStatus(ctx context.Context, announceable ap.Announceable) (status *gtsmodel.Status, new bool, err error) /* INTERNAL (gts) MODEL TO ACTIVITYSTREAMS MODEL */ // AccountToAS converts a gts model account into an activity streams person, suitable for federation - AccountToAS(a *gtsmodel.Account) (vocab.ActivityStreamsPerson, error) + AccountToAS(ctx context.Context, a *gtsmodel.Account) (vocab.ActivityStreamsPerson, error) // AccountToASMinimal converts a gts model account into an activity streams person, suitable for federation. // // The returned account will just have the Type, Username, PublicKey, and ID properties set. This is // suitable for serving to requesters to whom we want to give as little information as possible because // we don't trust them (yet). - AccountToASMinimal(a *gtsmodel.Account) (vocab.ActivityStreamsPerson, error) + AccountToASMinimal(ctx context.Context, a *gtsmodel.Account) (vocab.ActivityStreamsPerson, error) // StatusToAS converts a gts model status into an activity streams note, suitable for federation - StatusToAS(s *gtsmodel.Status) (vocab.ActivityStreamsNote, error) + StatusToAS(ctx context.Context, s *gtsmodel.Status) (vocab.ActivityStreamsNote, error) // FollowToASFollow converts a gts model Follow into an activity streams Follow, suitable for federation - FollowToAS(f *gtsmodel.Follow, originAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (vocab.ActivityStreamsFollow, error) + FollowToAS(ctx context.Context, f *gtsmodel.Follow, originAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (vocab.ActivityStreamsFollow, error) // MentionToAS converts a gts model mention into an activity streams Mention, suitable for federation - MentionToAS(m *gtsmodel.Mention) (vocab.ActivityStreamsMention, error) + MentionToAS(ctx context.Context, m *gtsmodel.Mention) (vocab.ActivityStreamsMention, error) // AttachmentToAS converts a gts model media attachment into an activity streams Attachment, suitable for federation - AttachmentToAS(a *gtsmodel.MediaAttachment) (vocab.ActivityStreamsDocument, error) + AttachmentToAS(ctx context.Context, a *gtsmodel.MediaAttachment) (vocab.ActivityStreamsDocument, error) // FaveToAS converts a gts model status fave into an activityStreams LIKE, suitable for federation. - FaveToAS(f *gtsmodel.StatusFave) (vocab.ActivityStreamsLike, error) + FaveToAS(ctx context.Context, f *gtsmodel.StatusFave) (vocab.ActivityStreamsLike, error) // BoostToAS converts a gts model boost into an activityStreams ANNOUNCE, suitable for federation - BoostToAS(boostWrapperStatus *gtsmodel.Status, boostingAccount *gtsmodel.Account, boostedAccount *gtsmodel.Account) (vocab.ActivityStreamsAnnounce, error) + BoostToAS(ctx context.Context, boostWrapperStatus *gtsmodel.Status, boostingAccount *gtsmodel.Account, boostedAccount *gtsmodel.Account) (vocab.ActivityStreamsAnnounce, error) // BlockToAS converts a gts model block into an activityStreams BLOCK, suitable for federation. - BlockToAS(block *gtsmodel.Block) (vocab.ActivityStreamsBlock, error) + BlockToAS(ctx context.Context, block *gtsmodel.Block) (vocab.ActivityStreamsBlock, error) // StatusToASRepliesCollection converts a gts model status into an activityStreams REPLIES collection. - StatusToASRepliesCollection(status *gtsmodel.Status, onlyOtherAccounts bool) (vocab.ActivityStreamsCollection, error) + StatusToASRepliesCollection(ctx context.Context, status *gtsmodel.Status, onlyOtherAccounts bool) (vocab.ActivityStreamsCollection, error) // StatusURIsToASRepliesPage returns a collection page with appropriate next/part of pagination. - StatusURIsToASRepliesPage(status *gtsmodel.Status, onlyOtherAccounts bool, minID string, replies map[string]*url.URL) (vocab.ActivityStreamsCollectionPage, error) + StatusURIsToASRepliesPage(ctx context.Context, status *gtsmodel.Status, onlyOtherAccounts bool, minID string, replies map[string]*url.URL) (vocab.ActivityStreamsCollectionPage, error) /* INTERNAL (gts) MODEL TO INTERNAL MODEL */ // FollowRequestToFollow just converts a follow request into a follow, that's it! No bells and whistles. - FollowRequestToFollow(f *gtsmodel.FollowRequest) *gtsmodel.Follow + FollowRequestToFollow(ctx context.Context, f *gtsmodel.FollowRequest) *gtsmodel.Follow // StatusToBoost wraps the given status into a boosting status. - StatusToBoost(s *gtsmodel.Status, boostingAccount *gtsmodel.Account) (*gtsmodel.Status, error) + StatusToBoost(ctx context.Context, s *gtsmodel.Status, boostingAccount *gtsmodel.Account) (*gtsmodel.Status, error) /* WRAPPER CONVENIENCE FUNCTIONS diff --git a/internal/typeutils/internal.go b/internal/typeutils/internal.go index ad15ecbee..23839b9a8 100644 --- a/internal/typeutils/internal.go +++ b/internal/typeutils/internal.go @@ -1,6 +1,7 @@ package typeutils import ( + "context" "fmt" "time" @@ -9,7 +10,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/util" ) -func (c *converter) FollowRequestToFollow(f *gtsmodel.FollowRequest) *gtsmodel.Follow { +func (c *converter) FollowRequestToFollow(ctx context.Context, f *gtsmodel.FollowRequest) *gtsmodel.Follow { return >smodel.Follow{ ID: f.ID, CreatedAt: f.CreatedAt, @@ -22,7 +23,7 @@ func (c *converter) FollowRequestToFollow(f *gtsmodel.FollowRequest) *gtsmodel.F } } -func (c *converter) StatusToBoost(s *gtsmodel.Status, boostingAccount *gtsmodel.Account) (*gtsmodel.Status, error) { +func (c *converter) StatusToBoost(ctx context.Context, s *gtsmodel.Status, boostingAccount *gtsmodel.Account) (*gtsmodel.Status, error) { // the wrapper won't use the same ID as the boosted status so we generate some new UUIDs uris := util.GenerateURIsForAccount(boostingAccount.Username, c.config.Protocol, c.config.Host) boostWrapperStatusID, err := id.NewULID() diff --git a/internal/typeutils/internaltoas.go b/internal/typeutils/internaltoas.go index 178567dc6..14ed094c5 100644 --- a/internal/typeutils/internaltoas.go +++ b/internal/typeutils/internaltoas.go @@ -19,6 +19,7 @@ package typeutils import ( + "context" "crypto/x509" "encoding/pem" "fmt" @@ -33,7 +34,7 @@ import ( // Converts a gts model account into an Activity Streams person type, following // the spec laid out for mastodon here: https://docs.joinmastodon.org/spec/activitypub/ -func (c *converter) AccountToAS(a *gtsmodel.Account) (vocab.ActivityStreamsPerson, error) { +func (c *converter) AccountToAS(ctx context.Context, a *gtsmodel.Account) (vocab.ActivityStreamsPerson, error) { // first check if we have this person in our asCache already if personI, err := c.asCache.Fetch(a.ID); err == nil { if person, ok := personI.(vocab.ActivityStreamsPerson); ok { @@ -213,9 +214,12 @@ func (c *converter) AccountToAS(a *gtsmodel.Account) (vocab.ActivityStreamsPerso // icon // Used as profile avatar. if a.AvatarMediaAttachmentID != "" { - avatar := >smodel.MediaAttachment{} - if err := c.db.GetByID(a.AvatarMediaAttachmentID, avatar); err != nil { - return nil, err + if a.AvatarMediaAttachment == nil { + avatar := >smodel.MediaAttachment{} + if err := c.db.GetByID(ctx, a.AvatarMediaAttachmentID, avatar); err != nil { + return nil, err + } + a.AvatarMediaAttachment = avatar } iconProperty := streams.NewActivityStreamsIconProperty() @@ -223,11 +227,11 @@ func (c *converter) AccountToAS(a *gtsmodel.Account) (vocab.ActivityStreamsPerso iconImage := streams.NewActivityStreamsImage() mediaType := streams.NewActivityStreamsMediaTypeProperty() - mediaType.Set(avatar.File.ContentType) + mediaType.Set(a.AvatarMediaAttachment.File.ContentType) iconImage.SetActivityStreamsMediaType(mediaType) avatarURLProperty := streams.NewActivityStreamsUrlProperty() - avatarURL, err := url.Parse(avatar.URL) + avatarURL, err := url.Parse(a.AvatarMediaAttachment.URL) if err != nil { return nil, err } @@ -241,9 +245,12 @@ func (c *converter) AccountToAS(a *gtsmodel.Account) (vocab.ActivityStreamsPerso // image // Used as profile header. if a.HeaderMediaAttachmentID != "" { - header := >smodel.MediaAttachment{} - if err := c.db.GetByID(a.HeaderMediaAttachmentID, header); err != nil { - return nil, err + if a.HeaderMediaAttachment == nil { + header := >smodel.MediaAttachment{} + if err := c.db.GetByID(ctx, a.HeaderMediaAttachmentID, header); err != nil { + return nil, err + } + a.HeaderMediaAttachment = header } headerProperty := streams.NewActivityStreamsImageProperty() @@ -251,11 +258,11 @@ func (c *converter) AccountToAS(a *gtsmodel.Account) (vocab.ActivityStreamsPerso headerImage := streams.NewActivityStreamsImage() mediaType := streams.NewActivityStreamsMediaTypeProperty() - mediaType.Set(header.File.ContentType) + mediaType.Set(a.HeaderMediaAttachment.File.ContentType) headerImage.SetActivityStreamsMediaType(mediaType) headerURLProperty := streams.NewActivityStreamsUrlProperty() - headerURL, err := url.Parse(header.URL) + headerURL, err := url.Parse(a.HeaderMediaAttachment.URL) if err != nil { return nil, err } @@ -278,7 +285,7 @@ func (c *converter) AccountToAS(a *gtsmodel.Account) (vocab.ActivityStreamsPerso // the spec laid out for mastodon here: https://docs.joinmastodon.org/spec/activitypub/ // // The returned account will just have the Type, Username, PublicKey, and ID properties set. -func (c *converter) AccountToASMinimal(a *gtsmodel.Account) (vocab.ActivityStreamsPerson, error) { +func (c *converter) AccountToASMinimal(ctx context.Context, a *gtsmodel.Account) (vocab.ActivityStreamsPerson, error) { person := streams.NewActivityStreamsPerson() // id should be the activitypub URI of this user @@ -340,7 +347,7 @@ func (c *converter) AccountToASMinimal(a *gtsmodel.Account) (vocab.ActivityStrea return person, nil } -func (c *converter) StatusToAS(s *gtsmodel.Status) (vocab.ActivityStreamsNote, error) { +func (c *converter) StatusToAS(ctx context.Context, s *gtsmodel.Status) (vocab.ActivityStreamsNote, error) { // first check if we have this note in our asCache already if noteI, err := c.asCache.Fetch(s.ID); err == nil { if note, ok := noteI.(vocab.ActivityStreamsNote); ok { @@ -354,7 +361,7 @@ func (c *converter) StatusToAS(s *gtsmodel.Status) (vocab.ActivityStreamsNote, e // check if author account is already attached to status and attach it if not // if we can't retrieve this, bail here already because we can't attribute the status to anyone if s.Account == nil { - a, err := c.db.GetAccountByID(s.AccountID) + a, err := c.db.GetAccountByID(ctx, s.AccountID) if err != nil { return nil, fmt.Errorf("StatusToAS: error retrieving author account from db: %s", err) } @@ -386,7 +393,7 @@ func (c *converter) StatusToAS(s *gtsmodel.Status) (vocab.ActivityStreamsNote, e // fetch the replied status if we don't have it on hand already if s.InReplyTo == nil { rs := >smodel.Status{} - if err := c.db.GetByID(s.InReplyToID, rs); err != nil { + if err := c.db.GetByID(ctx, s.InReplyToID, rs); err != nil { return nil, fmt.Errorf("StatusToAS: error retrieving replied-to status from db: %s", err) } s.InReplyTo = rs @@ -432,7 +439,7 @@ func (c *converter) StatusToAS(s *gtsmodel.Status) (vocab.ActivityStreamsNote, e // tag -- mentions for _, m := range s.Mentions { - asMention, err := c.MentionToAS(m) + asMention, err := c.MentionToAS(ctx, m) if err != nil { return nil, fmt.Errorf("StatusToAS: error converting mention to AS mention: %s", err) } @@ -520,7 +527,7 @@ func (c *converter) StatusToAS(s *gtsmodel.Status) (vocab.ActivityStreamsNote, e // attachment attachmentProp := streams.NewActivityStreamsAttachmentProperty() for _, a := range s.Attachments { - doc, err := c.AttachmentToAS(a) + doc, err := c.AttachmentToAS(ctx, a) if err != nil { return nil, fmt.Errorf("StatusToAS: error converting attachment: %s", err) } @@ -529,7 +536,7 @@ func (c *converter) StatusToAS(s *gtsmodel.Status) (vocab.ActivityStreamsNote, e status.SetActivityStreamsAttachment(attachmentProp) // replies - repliesCollection, err := c.StatusToASRepliesCollection(s, false) + repliesCollection, err := c.StatusToASRepliesCollection(ctx, s, false) if err != nil { return nil, fmt.Errorf("error creating repliesCollection: %s", err) } @@ -546,7 +553,7 @@ func (c *converter) StatusToAS(s *gtsmodel.Status) (vocab.ActivityStreamsNote, e return status, nil } -func (c *converter) FollowToAS(f *gtsmodel.Follow, originAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (vocab.ActivityStreamsFollow, error) { +func (c *converter) FollowToAS(ctx context.Context, f *gtsmodel.Follow, originAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (vocab.ActivityStreamsFollow, error) { // parse out the various URIs we need for this // origin account (who's doing the follow) originAccountURI, err := url.Parse(originAccount.URI) @@ -592,10 +599,10 @@ func (c *converter) FollowToAS(f *gtsmodel.Follow, originAccount *gtsmodel.Accou return follow, nil } -func (c *converter) MentionToAS(m *gtsmodel.Mention) (vocab.ActivityStreamsMention, error) { +func (c *converter) MentionToAS(ctx context.Context, m *gtsmodel.Mention) (vocab.ActivityStreamsMention, error) { if m.OriginAccount == nil { a := >smodel.Account{} - if err := c.db.GetWhere([]db.Where{{Key: "target_account_id", Value: m.TargetAccountID}}, a); err != nil { + if err := c.db.GetWhere(ctx, []db.Where{{Key: "target_account_id", Value: m.TargetAccountID}}, a); err != nil { return nil, fmt.Errorf("MentionToAS: error getting target account from db: %s", err) } m.OriginAccount = a @@ -629,7 +636,7 @@ func (c *converter) MentionToAS(m *gtsmodel.Mention) (vocab.ActivityStreamsMenti return mention, nil } -func (c *converter) AttachmentToAS(a *gtsmodel.MediaAttachment) (vocab.ActivityStreamsDocument, error) { +func (c *converter) AttachmentToAS(ctx context.Context, a *gtsmodel.MediaAttachment) (vocab.ActivityStreamsDocument, error) { // type -- Document doc := streams.NewActivityStreamsDocument() @@ -674,11 +681,11 @@ func (c *converter) AttachmentToAS(a *gtsmodel.MediaAttachment) (vocab.ActivityS "type": "Like" } */ -func (c *converter) FaveToAS(f *gtsmodel.StatusFave) (vocab.ActivityStreamsLike, error) { +func (c *converter) FaveToAS(ctx context.Context, f *gtsmodel.StatusFave) (vocab.ActivityStreamsLike, error) { // check if targetStatus is already pinned to this fave, and fetch it if not if f.Status == nil { s := >smodel.Status{} - if err := c.db.GetByID(f.StatusID, s); err != nil { + if err := c.db.GetByID(ctx, f.StatusID, s); err != nil { return nil, fmt.Errorf("FaveToAS: error fetching target status from database: %s", err) } f.Status = s @@ -687,7 +694,7 @@ func (c *converter) FaveToAS(f *gtsmodel.StatusFave) (vocab.ActivityStreamsLike, // check if the targetAccount is already pinned to this fave, and fetch it if not if f.TargetAccount == nil { a := >smodel.Account{} - if err := c.db.GetByID(f.TargetAccountID, a); err != nil { + if err := c.db.GetByID(ctx, f.TargetAccountID, a); err != nil { return nil, fmt.Errorf("FaveToAS: error fetching target account from database: %s", err) } f.TargetAccount = a @@ -696,7 +703,7 @@ func (c *converter) FaveToAS(f *gtsmodel.StatusFave) (vocab.ActivityStreamsLike, // check if the faving account is already pinned to this fave, and fetch it if not if f.Account == nil { a := >smodel.Account{} - if err := c.db.GetByID(f.AccountID, a); err != nil { + if err := c.db.GetByID(ctx, f.AccountID, a); err != nil { return nil, fmt.Errorf("FaveToAS: error fetching faving account from database: %s", err) } f.Account = a @@ -744,11 +751,11 @@ func (c *converter) FaveToAS(f *gtsmodel.StatusFave) (vocab.ActivityStreamsLike, return like, nil } -func (c *converter) BoostToAS(boostWrapperStatus *gtsmodel.Status, boostingAccount *gtsmodel.Account, boostedAccount *gtsmodel.Account) (vocab.ActivityStreamsAnnounce, error) { +func (c *converter) BoostToAS(ctx context.Context, boostWrapperStatus *gtsmodel.Status, boostingAccount *gtsmodel.Account, boostedAccount *gtsmodel.Account) (vocab.ActivityStreamsAnnounce, error) { // the boosted status is probably pinned to the boostWrapperStatus but double check to make sure if boostWrapperStatus.BoostOf == nil { b := >smodel.Status{} - if err := c.db.GetByID(boostWrapperStatus.BoostOfID, b); err != nil { + if err := c.db.GetByID(ctx, boostWrapperStatus.BoostOfID, b); err != nil { return nil, fmt.Errorf("BoostToAS: error getting status with ID %s from the db: %s", boostWrapperStatus.BoostOfID, err) } boostWrapperStatus.BoostOf = b @@ -828,10 +835,10 @@ func (c *converter) BoostToAS(boostWrapperStatus *gtsmodel.Status, boostingAccou "type":"Block" } */ -func (c *converter) BlockToAS(b *gtsmodel.Block) (vocab.ActivityStreamsBlock, error) { +func (c *converter) BlockToAS(ctx context.Context, b *gtsmodel.Block) (vocab.ActivityStreamsBlock, error) { if b.Account == nil { a := >smodel.Account{} - if err := c.db.GetByID(b.AccountID, a); err != nil { + if err := c.db.GetByID(ctx, b.AccountID, a); err != nil { return nil, fmt.Errorf("BlockToAS: error getting block account from database: %s", err) } b.Account = a @@ -839,7 +846,7 @@ func (c *converter) BlockToAS(b *gtsmodel.Block) (vocab.ActivityStreamsBlock, er if b.TargetAccount == nil { a := >smodel.Account{} - if err := c.db.GetByID(b.TargetAccountID, a); err != nil { + if err := c.db.GetByID(ctx, b.TargetAccountID, a); err != nil { return nil, fmt.Errorf("BlockToAS: error getting block target account from database: %s", err) } b.TargetAccount = a @@ -903,7 +910,7 @@ func (c *converter) BlockToAS(b *gtsmodel.Block) (vocab.ActivityStreamsBlock, er } } */ -func (c *converter) StatusToASRepliesCollection(status *gtsmodel.Status, onlyOtherAccounts bool) (vocab.ActivityStreamsCollection, error) { +func (c *converter) StatusToASRepliesCollection(ctx context.Context, status *gtsmodel.Status, onlyOtherAccounts bool) (vocab.ActivityStreamsCollection, error) { collectionID := fmt.Sprintf("%s/replies", status.URI) collectionIDURI, err := url.Parse(collectionID) if err != nil { @@ -966,7 +973,7 @@ func (c *converter) StatusToASRepliesCollection(status *gtsmodel.Status, onlyOth ] } */ -func (c *converter) StatusURIsToASRepliesPage(status *gtsmodel.Status, onlyOtherAccounts bool, minID string, replies map[string]*url.URL) (vocab.ActivityStreamsCollectionPage, error) { +func (c *converter) StatusURIsToASRepliesPage(ctx context.Context, status *gtsmodel.Status, onlyOtherAccounts bool, minID string, replies map[string]*url.URL) (vocab.ActivityStreamsCollectionPage, error) { collectionID := fmt.Sprintf("%s/replies", status.URI) page := streams.NewActivityStreamsCollectionPage() diff --git a/internal/typeutils/internaltoas_test.go b/internal/typeutils/internaltoas_test.go index caa56ce0d..46f04df2f 100644 --- a/internal/typeutils/internaltoas_test.go +++ b/internal/typeutils/internaltoas_test.go @@ -19,6 +19,7 @@ package typeutils_test import ( + "context" "encoding/json" "fmt" "testing" @@ -58,7 +59,7 @@ func (suite *InternalToASTestSuite) TearDownTest() { func (suite *InternalToASTestSuite) TestAccountToAS() { testAccount := suite.accounts["local_account_1"] // take zork for this test - asPerson, err := suite.typeconverter.AccountToAS(testAccount) + asPerson, err := suite.typeconverter.AccountToAS(context.Background(), testAccount) assert.NoError(suite.T(), err) ser, err := streams.Serialize(asPerson) diff --git a/internal/typeutils/internaltofrontend.go b/internal/typeutils/internaltofrontend.go index caa14e211..89da9eb01 100644 --- a/internal/typeutils/internaltofrontend.go +++ b/internal/typeutils/internaltofrontend.go @@ -19,6 +19,7 @@ package typeutils import ( + "context" "fmt" "strings" "time" @@ -28,9 +29,9 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) -func (c *converter) AccountToMastoSensitive(a *gtsmodel.Account) (*model.Account, error) { +func (c *converter) AccountToMastoSensitive(ctx context.Context, a *gtsmodel.Account) (*model.Account, error) { // we can build this sensitive account easily by first getting the public account.... - mastoAccount, err := c.AccountToMastoPublic(a) + mastoAccount, err := c.AccountToMastoPublic(ctx, a) if err != nil { return nil, err } @@ -38,7 +39,7 @@ func (c *converter) AccountToMastoSensitive(a *gtsmodel.Account) (*model.Account // then adding the Source object to it... // check pending follow requests aimed at this account - frs, err := c.db.GetAccountFollowRequests(a.ID) + frs, err := c.db.GetAccountFollowRequests(ctx, a.ID) if err != nil { if err != db.ErrNoEntries { return nil, fmt.Errorf("error getting follow requests: %s", err) @@ -50,7 +51,7 @@ func (c *converter) AccountToMastoSensitive(a *gtsmodel.Account) (*model.Account } mastoAccount.Source = &model.Source{ - Privacy: c.VisToMasto(a.Privacy), + Privacy: c.VisToMasto(ctx, a.Privacy), Sensitive: a.Sensitive, Language: a.Language, Note: a.Note, @@ -61,7 +62,11 @@ func (c *converter) AccountToMastoSensitive(a *gtsmodel.Account) (*model.Account return mastoAccount, nil } -func (c *converter) AccountToMastoPublic(a *gtsmodel.Account) (*model.Account, error) { +func (c *converter) AccountToMastoPublic(ctx context.Context, a *gtsmodel.Account) (*model.Account, error) { + if a == nil { + return nil, fmt.Errorf("given account was nil") + } + // first check if we have this account in our frontEnd cache if accountI, err := c.frontendCache.Fetch(a.ID); err == nil { if account, ok := accountI.(*model.Account); ok { @@ -71,26 +76,26 @@ func (c *converter) AccountToMastoPublic(a *gtsmodel.Account) (*model.Account, e } // count followers - followersCount, err := c.db.CountAccountFollowedBy(a.ID, false) + followersCount, err := c.db.CountAccountFollowedBy(ctx, a.ID, false) if err != nil { return nil, fmt.Errorf("error counting followers: %s", err) } // count following - followingCount, err := c.db.CountAccountFollows(a.ID, false) + followingCount, err := c.db.CountAccountFollows(ctx, a.ID, false) if err != nil { return nil, fmt.Errorf("error counting following: %s", err) } // count statuses - statusesCount, err := c.db.CountAccountStatuses(a.ID) + statusesCount, err := c.db.CountAccountStatuses(ctx, a.ID) if err != nil { return nil, fmt.Errorf("error counting statuses: %s", err) } // check when the last status was var lastStatusAt string - lastPosted, err := c.db.GetAccountLastPosted(a.ID) + lastPosted, err := c.db.GetAccountLastPosted(ctx, a.ID) if err == nil && !lastPosted.IsZero() { lastStatusAt = lastPosted.Format(time.RFC3339) } @@ -101,7 +106,7 @@ func (c *converter) AccountToMastoPublic(a *gtsmodel.Account) (*model.Account, e if a.AvatarMediaAttachmentID != "" { // make sure avi is pinned to this account if a.AvatarMediaAttachment == nil { - avi, err := c.db.GetAttachmentByID(a.AvatarMediaAttachmentID) + avi, err := c.db.GetAttachmentByID(ctx, a.AvatarMediaAttachmentID) if err != nil { return nil, fmt.Errorf("error retrieving avatar: %s", err) } @@ -116,7 +121,7 @@ func (c *converter) AccountToMastoPublic(a *gtsmodel.Account) (*model.Account, e if a.HeaderMediaAttachmentID != "" { // make sure header is pinned to this account if a.HeaderMediaAttachment == nil { - avi, err := c.db.GetAttachmentByID(a.HeaderMediaAttachmentID) + avi, err := c.db.GetAttachmentByID(ctx, a.HeaderMediaAttachmentID) if err != nil { return nil, fmt.Errorf("error retrieving avatar: %s", err) } @@ -187,7 +192,7 @@ func (c *converter) AccountToMastoPublic(a *gtsmodel.Account) (*model.Account, e return accountFrontend, nil } -func (c *converter) AccountToMastoBlocked(a *gtsmodel.Account) (*model.Account, error) { +func (c *converter) AccountToMastoBlocked(ctx context.Context, a *gtsmodel.Account) (*model.Account, error) { var acct string if a.Domain != "" { // this is a remote user @@ -214,7 +219,7 @@ func (c *converter) AccountToMastoBlocked(a *gtsmodel.Account) (*model.Account, }, nil } -func (c *converter) AppToMastoSensitive(a *gtsmodel.Application) (*model.Application, error) { +func (c *converter) AppToMastoSensitive(ctx context.Context, a *gtsmodel.Application) (*model.Application, error) { return &model.Application{ ID: a.ID, Name: a.Name, @@ -226,14 +231,14 @@ func (c *converter) AppToMastoSensitive(a *gtsmodel.Application) (*model.Applica }, nil } -func (c *converter) AppToMastoPublic(a *gtsmodel.Application) (*model.Application, error) { +func (c *converter) AppToMastoPublic(ctx context.Context, a *gtsmodel.Application) (*model.Application, error) { return &model.Application{ Name: a.Name, Website: a.Website, }, nil } -func (c *converter) AttachmentToMasto(a *gtsmodel.MediaAttachment) (model.Attachment, error) { +func (c *converter) AttachmentToMasto(ctx context.Context, a *gtsmodel.MediaAttachment) (model.Attachment, error) { return model.Attachment{ ID: a.ID, Type: strings.ToLower(string(a.Type)), @@ -264,33 +269,36 @@ func (c *converter) AttachmentToMasto(a *gtsmodel.MediaAttachment) (model.Attach }, nil } -func (c *converter) MentionToMasto(m *gtsmodel.Mention) (model.Mention, error) { - target := >smodel.Account{} - if err := c.db.GetByID(m.TargetAccountID, target); err != nil { - return model.Mention{}, err +func (c *converter) MentionToMasto(ctx context.Context, m *gtsmodel.Mention) (model.Mention, error) { + if m.TargetAccount == nil { + targetAccount, err := c.db.GetAccountByID(ctx, m.TargetAccountID) + if err != nil { + return model.Mention{}, err + } + m.TargetAccount = targetAccount } var local bool - if target.Domain == "" { + if m.TargetAccount.Domain == "" { local = true } var acct string if local { - acct = target.Username + acct = m.TargetAccount.Username } else { - acct = fmt.Sprintf("%s@%s", target.Username, target.Domain) + acct = fmt.Sprintf("%s@%s", m.TargetAccount.Username, m.TargetAccount.Domain) } return model.Mention{ - ID: target.ID, - Username: target.Username, - URL: target.URL, + ID: m.TargetAccount.ID, + Username: m.TargetAccount.Username, + URL: m.TargetAccount.URL, Acct: acct, }, nil } -func (c *converter) EmojiToMasto(e *gtsmodel.Emoji) (model.Emoji, error) { +func (c *converter) EmojiToMasto(ctx context.Context, e *gtsmodel.Emoji) (model.Emoji, error) { return model.Emoji{ Shortcode: e.Shortcode, URL: e.ImageURL, @@ -300,27 +308,25 @@ func (c *converter) EmojiToMasto(e *gtsmodel.Emoji) (model.Emoji, error) { }, nil } -func (c *converter) TagToMasto(t *gtsmodel.Tag) (model.Tag, error) { - tagURL := fmt.Sprintf("%s://%s/tags/%s", c.config.Protocol, c.config.Host, t.Name) - +func (c *converter) TagToMasto(ctx context.Context, t *gtsmodel.Tag) (model.Tag, error) { return model.Tag{ Name: t.Name, - URL: tagURL, // we don't serve URLs with collections of tagged statuses (FOR NOW) so this is purely for mastodon compatibility ¯\_(ツ)_/¯ + URL: t.URL, }, nil } -func (c *converter) StatusToMasto(s *gtsmodel.Status, requestingAccount *gtsmodel.Account) (*model.Status, error) { - repliesCount, err := c.db.CountStatusReplies(s) +func (c *converter) StatusToMasto(ctx context.Context, s *gtsmodel.Status, requestingAccount *gtsmodel.Account) (*model.Status, error) { + repliesCount, err := c.db.CountStatusReplies(ctx, s) if err != nil { return nil, fmt.Errorf("error counting replies: %s", err) } - reblogsCount, err := c.db.CountStatusReblogs(s) + reblogsCount, err := c.db.CountStatusReblogs(ctx, s) if err != nil { return nil, fmt.Errorf("error counting reblogs: %s", err) } - favesCount, err := c.db.CountStatusFaves(s) + favesCount, err := c.db.CountStatusFaves(ctx, s) if err != nil { return nil, fmt.Errorf("error counting faves: %s", err) } @@ -330,8 +336,8 @@ func (c *converter) StatusToMasto(s *gtsmodel.Status, requestingAccount *gtsmode // the boosted status might have been set on this struct already so check first before doing db calls if s.BoostOf == nil { // it's not set so fetch it from the db - bs := >smodel.Status{} - if err := c.db.GetByID(s.BoostOfID, bs); err != nil { + bs, err := c.db.GetStatusByID(ctx, s.BoostOfID) + if err != nil { return nil, fmt.Errorf("error getting boosted status with id %s: %s", s.BoostOfID, err) } s.BoostOf = bs @@ -340,15 +346,15 @@ func (c *converter) StatusToMasto(s *gtsmodel.Status, requestingAccount *gtsmode // the boosted account might have been set on this struct already or passed as a param so check first before doing db calls if s.BoostOfAccount == nil { // it's not set so fetch it from the db - ba := >smodel.Account{} - if err := c.db.GetByID(s.BoostOf.AccountID, ba); err != nil { + ba, err := c.db.GetAccountByID(ctx, s.BoostOf.AccountID) + if err != nil { return nil, fmt.Errorf("error getting boosted account %s from status with id %s: %s", s.BoostOf.AccountID, s.BoostOfID, err) } s.BoostOfAccount = ba s.BoostOf.Account = ba } - mastoRebloggedStatus, err = c.StatusToMasto(s.BoostOf, requestingAccount) + mastoRebloggedStatus, err = c.StatusToMasto(ctx, s.BoostOf, requestingAccount) if err != nil { return nil, fmt.Errorf("error converting boosted status to mastotype: %s", err) } @@ -357,24 +363,24 @@ func (c *converter) StatusToMasto(s *gtsmodel.Status, requestingAccount *gtsmode var mastoApplication *model.Application if s.CreatedWithApplicationID != "" { gtsApplication := >smodel.Application{} - if err := c.db.GetByID(s.CreatedWithApplicationID, gtsApplication); err != nil { + if err := c.db.GetByID(ctx, s.CreatedWithApplicationID, gtsApplication); err != nil { return nil, fmt.Errorf("error fetching application used to create status: %s", err) } - mastoApplication, err = c.AppToMastoPublic(gtsApplication) + mastoApplication, err = c.AppToMastoPublic(ctx, gtsApplication) if err != nil { return nil, fmt.Errorf("error parsing application used to create status: %s", err) } } if s.Account == nil { - a := >smodel.Account{} - if err := c.db.GetByID(s.AccountID, a); err != nil { + a, err := c.db.GetAccountByID(ctx, s.AccountID) + if err != nil { return nil, fmt.Errorf("error getting status author: %s", err) } s.Account = a } - mastoAuthorAccount, err := c.AccountToMastoPublic(s.Account) + mastoAuthorAccount, err := c.AccountToMastoPublic(ctx, s.Account) if err != nil { return nil, fmt.Errorf("error parsing account of status author: %s", err) } @@ -384,7 +390,7 @@ func (c *converter) StatusToMasto(s *gtsmodel.Status, requestingAccount *gtsmode // if so, we can directly convert the gts attachments into masto ones if s.Attachments != nil { for _, gtsAttachment := range s.Attachments { - mastoAttachment, err := c.AttachmentToMasto(gtsAttachment) + mastoAttachment, err := c.AttachmentToMasto(ctx, gtsAttachment) if err != nil { return nil, fmt.Errorf("error converting attachment with id %s: %s", gtsAttachment.ID, err) } @@ -393,14 +399,14 @@ func (c *converter) StatusToMasto(s *gtsmodel.Status, requestingAccount *gtsmode // the status doesn't have gts attachments on it, but it does have attachment IDs // in this case, we need to pull the gts attachments from the db to convert them into masto ones } else { - for _, a := range s.AttachmentIDs { - gtsAttachment := >smodel.MediaAttachment{} - if err := c.db.GetByID(a, gtsAttachment); err != nil { - return nil, fmt.Errorf("error getting attachment with id %s: %s", a, err) + for _, aID := range s.AttachmentIDs { + gtsAttachment, err := c.db.GetAttachmentByID(ctx, aID) + if err != nil { + return nil, fmt.Errorf("error getting attachment with id %s: %s", aID, err) } - mastoAttachment, err := c.AttachmentToMasto(gtsAttachment) + mastoAttachment, err := c.AttachmentToMasto(ctx, gtsAttachment) if err != nil { - return nil, fmt.Errorf("error converting attachment with id %s: %s", a, err) + return nil, fmt.Errorf("error converting attachment with id %s: %s", aID, err) } mastoAttachments = append(mastoAttachments, mastoAttachment) } @@ -411,7 +417,7 @@ func (c *converter) StatusToMasto(s *gtsmodel.Status, requestingAccount *gtsmode // if so, we can directly convert the gts mentions into masto ones if s.Mentions != nil { for _, gtsMention := range s.Mentions { - mastoMention, err := c.MentionToMasto(gtsMention) + mastoMention, err := c.MentionToMasto(ctx, gtsMention) if err != nil { return nil, fmt.Errorf("error converting mention with id %s: %s", gtsMention.ID, err) } @@ -420,12 +426,12 @@ func (c *converter) StatusToMasto(s *gtsmodel.Status, requestingAccount *gtsmode // the status doesn't have gts mentions on it, but it does have mention IDs // in this case, we need to pull the gts mentions from the db to convert them into masto ones } else { - for _, m := range s.MentionIDs { - gtsMention := >smodel.Mention{} - if err := c.db.GetByID(m, gtsMention); err != nil { - return nil, fmt.Errorf("error getting mention with id %s: %s", m, err) + for _, mID := range s.MentionIDs { + gtsMention, err := c.db.GetMention(ctx, mID) + if err != nil { + return nil, fmt.Errorf("error getting mention with id %s: %s", mID, err) } - mastoMention, err := c.MentionToMasto(gtsMention) + mastoMention, err := c.MentionToMasto(ctx, gtsMention) if err != nil { return nil, fmt.Errorf("error converting mention with id %s: %s", gtsMention.ID, err) } @@ -438,7 +444,7 @@ func (c *converter) StatusToMasto(s *gtsmodel.Status, requestingAccount *gtsmode // if so, we can directly convert the gts tags into masto ones if s.Tags != nil { for _, gtsTag := range s.Tags { - mastoTag, err := c.TagToMasto(gtsTag) + mastoTag, err := c.TagToMasto(ctx, gtsTag) if err != nil { return nil, fmt.Errorf("error converting tag with id %s: %s", gtsTag.ID, err) } @@ -449,10 +455,10 @@ func (c *converter) StatusToMasto(s *gtsmodel.Status, requestingAccount *gtsmode } else { for _, t := range s.TagIDs { gtsTag := >smodel.Tag{} - if err := c.db.GetByID(t, gtsTag); err != nil { + if err := c.db.GetByID(ctx, t, gtsTag); err != nil { return nil, fmt.Errorf("error getting tag with id %s: %s", t, err) } - mastoTag, err := c.TagToMasto(gtsTag) + mastoTag, err := c.TagToMasto(ctx, gtsTag) if err != nil { return nil, fmt.Errorf("error converting tag with id %s: %s", gtsTag.ID, err) } @@ -465,7 +471,7 @@ func (c *converter) StatusToMasto(s *gtsmodel.Status, requestingAccount *gtsmode // if so, we can directly convert the gts emojis into masto ones if s.Emojis != nil { for _, gtsEmoji := range s.Emojis { - mastoEmoji, err := c.EmojiToMasto(gtsEmoji) + mastoEmoji, err := c.EmojiToMasto(ctx, gtsEmoji) if err != nil { return nil, fmt.Errorf("error converting emoji with id %s: %s", gtsEmoji.ID, err) } @@ -476,10 +482,10 @@ func (c *converter) StatusToMasto(s *gtsmodel.Status, requestingAccount *gtsmode } else { for _, e := range s.EmojiIDs { gtsEmoji := >smodel.Emoji{} - if err := c.db.GetByID(e, gtsEmoji); err != nil { + if err := c.db.GetByID(ctx, e, gtsEmoji); err != nil { return nil, fmt.Errorf("error getting emoji with id %s: %s", e, err) } - mastoEmoji, err := c.EmojiToMasto(gtsEmoji) + mastoEmoji, err := c.EmojiToMasto(ctx, gtsEmoji) if err != nil { return nil, fmt.Errorf("error converting emoji with id %s: %s", gtsEmoji.ID, err) } @@ -491,7 +497,7 @@ func (c *converter) StatusToMasto(s *gtsmodel.Status, requestingAccount *gtsmode var mastoPoll *model.Poll statusInteractions := &statusInteractions{} - si, err := c.interactionsWithStatusForAccount(s, requestingAccount) + si, err := c.interactionsWithStatusForAccount(ctx, s, requestingAccount) if err == nil { statusInteractions = si } @@ -503,7 +509,7 @@ func (c *converter) StatusToMasto(s *gtsmodel.Status, requestingAccount *gtsmode InReplyToAccountID: s.InReplyToAccountID, Sensitive: s.Sensitive, SpoilerText: s.ContentWarning, - Visibility: c.VisToMasto(s.Visibility), + Visibility: c.VisToMasto(ctx, s.Visibility), Language: s.Language, URI: s.URI, URL: s.URL, @@ -535,7 +541,7 @@ func (c *converter) StatusToMasto(s *gtsmodel.Status, requestingAccount *gtsmode } // VisToMasto converts a gts visibility into its mastodon equivalent -func (c *converter) VisToMasto(m gtsmodel.Visibility) model.Visibility { +func (c *converter) VisToMasto(ctx context.Context, m gtsmodel.Visibility) model.Visibility { switch m { case gtsmodel.VisibilityPublic: return model.VisibilityPublic @@ -549,7 +555,7 @@ func (c *converter) VisToMasto(m gtsmodel.Visibility) model.Visibility { return "" } -func (c *converter) InstanceToMasto(i *gtsmodel.Instance) (*model.Instance, error) { +func (c *converter) InstanceToMasto(ctx context.Context, i *gtsmodel.Instance) (*model.Instance, error) { mi := &model.Instance{ URI: i.URI, Title: i.Title, @@ -567,17 +573,17 @@ func (c *converter) InstanceToMasto(i *gtsmodel.Instance) (*model.Instance, erro statusCountKey := "status_count" domainCountKey := "domain_count" - userCount, err := c.db.CountInstanceUsers(c.config.Host) + userCount, err := c.db.CountInstanceUsers(ctx, c.config.Host) if err == nil { mi.Stats[userCountKey] = userCount } - statusCount, err := c.db.CountInstanceStatuses(c.config.Host) + statusCount, err := c.db.CountInstanceStatuses(ctx, c.config.Host) if err == nil { mi.Stats[statusCountKey] = statusCount } - domainCount, err := c.db.CountInstanceDomains(c.config.Host) + domainCount, err := c.db.CountInstanceDomains(ctx, c.config.Host) if err == nil { mi.Stats[domainCountKey] = domainCount } @@ -593,7 +599,7 @@ func (c *converter) InstanceToMasto(i *gtsmodel.Instance) (*model.Instance, erro } // get the instance account if it exists and just skip if it doesn't - ia, err := c.db.GetInstanceAccount("") + ia, err := c.db.GetInstanceAccount(ctx, "") if err == nil { if ia.HeaderMediaAttachment != nil { mi.Thumbnail = ia.HeaderMediaAttachment.URL @@ -602,19 +608,22 @@ func (c *converter) InstanceToMasto(i *gtsmodel.Instance) (*model.Instance, erro // contact account is optional but let's try to get it if i.ContactAccountID != "" { - ia := >smodel.Account{} - if err := c.db.GetByID(i.ContactAccountID, ia); err == nil { - ma, err := c.AccountToMastoPublic(ia) + if i.ContactAccount == nil { + contactAccount, err := c.db.GetAccountByID(ctx, i.ContactAccountID) if err == nil { - mi.ContactAccount = ma + i.ContactAccount = contactAccount } } + ma, err := c.AccountToMastoPublic(ctx, i.ContactAccount) + if err == nil { + mi.ContactAccount = ma + } } return mi, nil } -func (c *converter) RelationshipToMasto(r *gtsmodel.Relationship) (*model.Relationship, error) { +func (c *converter) RelationshipToMasto(ctx context.Context, r *gtsmodel.Relationship) (*model.Relationship, error) { return &model.Relationship{ ID: r.ID, Following: r.Following, @@ -632,9 +641,9 @@ func (c *converter) RelationshipToMasto(r *gtsmodel.Relationship) (*model.Relati }, nil } -func (c *converter) NotificationToMasto(n *gtsmodel.Notification) (*model.Notification, error) { +func (c *converter) NotificationToMasto(ctx context.Context, n *gtsmodel.Notification) (*model.Notification, error) { if n.TargetAccount == nil { - tAccount, err := c.db.GetAccountByID(n.TargetAccountID) + tAccount, err := c.db.GetAccountByID(ctx, n.TargetAccountID) if err != nil { return nil, fmt.Errorf("NotificationToMasto: error getting target account with id %s from the db: %s", n.TargetAccountID, err) } @@ -642,14 +651,14 @@ func (c *converter) NotificationToMasto(n *gtsmodel.Notification) (*model.Notifi } if n.OriginAccount == nil { - ogAccount, err := c.db.GetAccountByID(n.OriginAccountID) + ogAccount, err := c.db.GetAccountByID(ctx, n.OriginAccountID) if err != nil { return nil, fmt.Errorf("NotificationToMasto: error getting origin account with id %s from the db: %s", n.OriginAccountID, err) } n.OriginAccount = ogAccount } - mastoAccount, err := c.AccountToMastoPublic(n.OriginAccount) + mastoAccount, err := c.AccountToMastoPublic(ctx, n.OriginAccount) if err != nil { return nil, fmt.Errorf("NotificationToMasto: error converting account to masto: %s", err) } @@ -657,7 +666,7 @@ func (c *converter) NotificationToMasto(n *gtsmodel.Notification) (*model.Notifi var mastoStatus *model.Status if n.StatusID != "" { if n.Status == nil { - status, err := c.db.GetStatusByID(n.StatusID) + status, err := c.db.GetStatusByID(ctx, n.StatusID) if err != nil { return nil, fmt.Errorf("NotificationToMasto: error getting status with id %s from the db: %s", n.StatusID, err) } @@ -673,7 +682,7 @@ func (c *converter) NotificationToMasto(n *gtsmodel.Notification) (*model.Notifi } var err error - mastoStatus, err = c.StatusToMasto(n.Status, nil) + mastoStatus, err = c.StatusToMasto(ctx, n.Status, nil) if err != nil { return nil, fmt.Errorf("NotificationToMasto: error converting status to masto: %s", err) } @@ -688,7 +697,7 @@ func (c *converter) NotificationToMasto(n *gtsmodel.Notification) (*model.Notifi }, nil } -func (c *converter) DomainBlockToMasto(b *gtsmodel.DomainBlock, export bool) (*model.DomainBlock, error) { +func (c *converter) DomainBlockToMasto(ctx context.Context, b *gtsmodel.DomainBlock, export bool) (*model.DomainBlock, error) { domainBlock := &model.DomainBlock{ Domain: b.Domain, diff --git a/internal/typeutils/util.go b/internal/typeutils/util.go index 5751fbc84..1d1903afc 100644 --- a/internal/typeutils/util.go +++ b/internal/typeutils/util.go @@ -1,34 +1,35 @@ package typeutils import ( + "context" "fmt" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) -func (c *converter) interactionsWithStatusForAccount(s *gtsmodel.Status, requestingAccount *gtsmodel.Account) (*statusInteractions, error) { +func (c *converter) interactionsWithStatusForAccount(ctx context.Context, s *gtsmodel.Status, requestingAccount *gtsmodel.Account) (*statusInteractions, error) { si := &statusInteractions{} if requestingAccount != nil { - faved, err := c.db.IsStatusFavedBy(s, requestingAccount.ID) + faved, err := c.db.IsStatusFavedBy(ctx, s, requestingAccount.ID) if err != nil { return nil, fmt.Errorf("error checking if requesting account has faved status: %s", err) } si.Faved = faved - reblogged, err := c.db.IsStatusRebloggedBy(s, requestingAccount.ID) + reblogged, err := c.db.IsStatusRebloggedBy(ctx, s, requestingAccount.ID) if err != nil { return nil, fmt.Errorf("error checking if requesting account has reblogged status: %s", err) } si.Reblogged = reblogged - muted, err := c.db.IsStatusMutedBy(s, requestingAccount.ID) + muted, err := c.db.IsStatusMutedBy(ctx, s, requestingAccount.ID) if err != nil { return nil, fmt.Errorf("error checking if requesting account has muted status: %s", err) } si.Muted = muted - bookmarked, err := c.db.IsStatusBookmarkedBy(s, requestingAccount.ID) + bookmarked, err := c.db.IsStatusBookmarkedBy(ctx, s, requestingAccount.ID) if err != nil { return nil, fmt.Errorf("error checking if requesting account has bookmarked status: %s", err) } diff --git a/internal/visibility/filter.go b/internal/visibility/filter.go index 2c43fa4ee..644e85b35 100644 --- a/internal/visibility/filter.go +++ b/internal/visibility/filter.go @@ -19,6 +19,8 @@ package visibility import ( + "context" + "github.com/sirupsen/logrus" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" @@ -29,17 +31,17 @@ type Filter interface { // StatusVisible returns true if targetStatus is visible to requestingAccount, based on the // privacy settings of the status, and any blocks/mutes that might exist between the two accounts // or account domains, and other relevant accounts mentioned in or replied to by the status. - StatusVisible(targetStatus *gtsmodel.Status, requestingAccount *gtsmodel.Account) (bool, error) + StatusVisible(ctx context.Context, targetStatus *gtsmodel.Status, requestingAccount *gtsmodel.Account) (bool, error) // StatusHometimelineable returns true if targetStatus should be in the home timeline of the requesting account. // // This function will call StatusVisible internally, so it's not necessary to call it beforehand. - StatusHometimelineable(targetStatus *gtsmodel.Status, requestingAccount *gtsmodel.Account) (bool, error) + StatusHometimelineable(ctx context.Context, targetStatus *gtsmodel.Status, requestingAccount *gtsmodel.Account) (bool, error) // StatusPublictimelineable returns true if targetStatus should be in the public timeline of the requesting account. // // This function will call StatusVisible internally, so it's not necessary to call it beforehand. - StatusPublictimelineable(targetStatus *gtsmodel.Status, timelineOwnerAccount *gtsmodel.Account) (bool, error) + StatusPublictimelineable(ctx context.Context, targetStatus *gtsmodel.Status, timelineOwnerAccount *gtsmodel.Account) (bool, error) } type filter struct { diff --git a/internal/visibility/relevantaccounts.go b/internal/visibility/relevantaccounts.go index 5957d3111..d19d26ff4 100644 --- a/internal/visibility/relevantaccounts.go +++ b/internal/visibility/relevantaccounts.go @@ -19,6 +19,7 @@ package visibility import ( + "context" "errors" "fmt" @@ -41,7 +42,7 @@ type relevantAccounts struct { BoostedMentionedAccounts []*gtsmodel.Account } -func (f *filter) relevantAccounts(status *gtsmodel.Status, getBoosted bool) (*relevantAccounts, error) { +func (f *filter) relevantAccounts(ctx context.Context, status *gtsmodel.Status, getBoosted bool) (*relevantAccounts, error) { relAccts := &relevantAccounts{ MentionedAccounts: []*gtsmodel.Account{}, BoostedMentionedAccounts: []*gtsmodel.Account{}, @@ -77,7 +78,7 @@ func (f *filter) relevantAccounts(status *gtsmodel.Status, getBoosted bool) (*re relAccts.Account = status.Account } else { // it wasn't set, so get it from the db - account, err := f.db.GetAccountByID(status.AccountID) + account, err := f.db.GetAccountByID(ctx, status.AccountID) if err != nil { return nil, fmt.Errorf("relevantAccounts: error getting account with id %s: %s", status.AccountID, err) } @@ -96,7 +97,7 @@ func (f *filter) relevantAccounts(status *gtsmodel.Status, getBoosted bool) (*re relAccts.InReplyToAccount = status.InReplyToAccount } else { // it wasn't set, so get it from the db - inReplyToAccount, err := f.db.GetAccountByID(status.InReplyToAccountID) + inReplyToAccount, err := f.db.GetAccountByID(ctx, status.InReplyToAccountID) if err != nil { return nil, fmt.Errorf("relevantAccounts: error getting inReplyToAccount with id %s: %s", status.InReplyToAccountID, err) } @@ -115,7 +116,7 @@ func (f *filter) relevantAccounts(status *gtsmodel.Status, getBoosted bool) (*re } if !idIn(mID, status.Mentions) { // mention with ID isn't in status.Mentions - mention, err := f.db.GetMention(mID) + mention, err := f.db.GetMention(ctx, mID) if err != nil { return nil, fmt.Errorf("relevantAccounts: error getting mention with id %s: %s", mID, err) } @@ -146,7 +147,7 @@ func (f *filter) relevantAccounts(status *gtsmodel.Status, getBoosted bool) (*re // 4, 5, 6. Boosted status items // get the boosted status if it's not set on the status already if status.BoostOfID != "" && status.BoostOf == nil { - boostedStatus, err := f.db.GetStatusByID(status.BoostOfID) + boostedStatus, err := f.db.GetStatusByID(ctx, status.BoostOfID) if err != nil { return nil, fmt.Errorf("relevantAccounts: error getting boosted status with id %s: %s", status.BoostOfID, err) } @@ -155,7 +156,7 @@ func (f *filter) relevantAccounts(status *gtsmodel.Status, getBoosted bool) (*re if status.BoostOf != nil { // return relevant accounts for the boosted status - boostedRelAccts, err := f.relevantAccounts(status.BoostOf, false) // false because we don't want to recurse + boostedRelAccts, err := f.relevantAccounts(ctx, status.BoostOf, false) // false because we don't want to recurse if err != nil { return nil, fmt.Errorf("relevantAccounts: error getting relevant accounts of boosted status %s: %s", status.BoostOf.ID, err) } @@ -170,7 +171,7 @@ func (f *filter) relevantAccounts(status *gtsmodel.Status, getBoosted bool) (*re // domainBlockedRelevant checks through all relevant accounts attached to a status // to make sure none of them are domain blocked by this instance. -func (f *filter) domainBlockedRelevant(r *relevantAccounts) (bool, error) { +func (f *filter) domainBlockedRelevant(ctx context.Context, r *relevantAccounts) (bool, error) { domains := []string{} if r.Account != nil { @@ -201,7 +202,7 @@ func (f *filter) domainBlockedRelevant(r *relevantAccounts) (bool, error) { } } - return f.db.AreDomainsBlocked(domains) + return f.db.AreDomainsBlocked(ctx, domains) } func idIn(id string, mentions []*gtsmodel.Mention) bool { diff --git a/internal/visibility/statushometimelineable.go b/internal/visibility/statushometimelineable.go index a3ca62fb3..dd0ca079b 100644 --- a/internal/visibility/statushometimelineable.go +++ b/internal/visibility/statushometimelineable.go @@ -19,13 +19,14 @@ package visibility import ( + "context" "fmt" "github.com/sirupsen/logrus" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) -func (f *filter) StatusHometimelineable(targetStatus *gtsmodel.Status, timelineOwnerAccount *gtsmodel.Account) (bool, error) { +func (f *filter) StatusHometimelineable(ctx context.Context, targetStatus *gtsmodel.Status, timelineOwnerAccount *gtsmodel.Account) (bool, error) { l := f.log.WithFields(logrus.Fields{ "func": "StatusHometimelineable", "statusID": targetStatus.ID, @@ -36,7 +37,7 @@ func (f *filter) StatusHometimelineable(targetStatus *gtsmodel.Status, timelineO return true, nil } - v, err := f.StatusVisible(targetStatus, timelineOwnerAccount) + v, err := f.StatusVisible(ctx, targetStatus, timelineOwnerAccount) if err != nil { return false, fmt.Errorf("StatusHometimelineable: error checking visibility of status with id %s: %s", targetStatus.ID, err) } @@ -63,7 +64,7 @@ func (f *filter) StatusHometimelineable(targetStatus *gtsmodel.Status, timelineO if targetStatus.InReplyToID != "" { // pin the reply to status on to this status if it hasn't been done already if targetStatus.InReplyTo == nil { - rs, err := f.db.GetStatusByID(targetStatus.InReplyToID) + rs, err := f.db.GetStatusByID(ctx, targetStatus.InReplyToID) if err != nil { return false, fmt.Errorf("StatusHometimelineable: error getting replied to status with id %s: %s", targetStatus.InReplyToID, err) } @@ -72,7 +73,7 @@ func (f *filter) StatusHometimelineable(targetStatus *gtsmodel.Status, timelineO // pin the reply to account on to this status if it hasn't been done already if targetStatus.InReplyToAccount == nil { - ra, err := f.db.GetAccountByID(targetStatus.InReplyToAccountID) + ra, err := f.db.GetAccountByID(ctx, targetStatus.InReplyToAccountID) if err != nil { return false, fmt.Errorf("StatusHometimelineable: error getting replied to account with id %s: %s", targetStatus.InReplyToAccountID, err) } @@ -85,7 +86,7 @@ func (f *filter) StatusHometimelineable(targetStatus *gtsmodel.Status, timelineO } // the replied-to account != timelineOwnerAccount, so make sure the timelineOwnerAccount follows the replied-to account - follows, err := f.db.IsFollowing(timelineOwnerAccount, targetStatus.InReplyToAccount) + follows, err := f.db.IsFollowing(ctx, timelineOwnerAccount, targetStatus.InReplyToAccount) if err != nil { return false, fmt.Errorf("StatusHometimelineable: error checking follow from account %s to account %s: %s", timelineOwnerAccount.ID, targetStatus.InReplyToAccountID, err) } diff --git a/internal/visibility/statuspublictimelineable.go b/internal/visibility/statuspublictimelineable.go index f07e06aae..8d0a7aa28 100644 --- a/internal/visibility/statuspublictimelineable.go +++ b/internal/visibility/statuspublictimelineable.go @@ -19,13 +19,14 @@ package visibility import ( + "context" "fmt" "github.com/sirupsen/logrus" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) -func (f *filter) StatusPublictimelineable(targetStatus *gtsmodel.Status, timelineOwnerAccount *gtsmodel.Account) (bool, error) { +func (f *filter) StatusPublictimelineable(ctx context.Context, targetStatus *gtsmodel.Status, timelineOwnerAccount *gtsmodel.Account) (bool, error) { l := f.log.WithFields(logrus.Fields{ "func": "StatusPublictimelineable", "statusID": targetStatus.ID, @@ -41,7 +42,7 @@ func (f *filter) StatusPublictimelineable(targetStatus *gtsmodel.Status, timelin return true, nil } - v, err := f.StatusVisible(targetStatus, timelineOwnerAccount) + v, err := f.StatusVisible(ctx, targetStatus, timelineOwnerAccount) if err != nil { return false, fmt.Errorf("StatusPublictimelineable: error checking visibility of status with id %s: %s", targetStatus.ID, err) } diff --git a/internal/visibility/statusvisible.go b/internal/visibility/statusvisible.go index 15e545881..5b6fe0c1e 100644 --- a/internal/visibility/statusvisible.go +++ b/internal/visibility/statusvisible.go @@ -19,6 +19,7 @@ package visibility import ( + "context" "errors" "fmt" @@ -28,20 +29,20 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) -func (f *filter) StatusVisible(targetStatus *gtsmodel.Status, requestingAccount *gtsmodel.Account) (bool, error) { +func (f *filter) StatusVisible(ctx context.Context, targetStatus *gtsmodel.Status, requestingAccount *gtsmodel.Account) (bool, error) { l := f.log.WithFields(logrus.Fields{ "func": "StatusVisible", "statusID": targetStatus.ID, }) getBoosted := true - relevantAccounts, err := f.relevantAccounts(targetStatus, getBoosted) + relevantAccounts, err := f.relevantAccounts(ctx, targetStatus, getBoosted) if err != nil { l.Debugf("error pulling relevant accounts for status %s: %s", targetStatus.ID, err) return false, fmt.Errorf("StatusVisible: error pulling relevant accounts for status %s: %s", targetStatus.ID, err) } - domainBlocked, err := f.domainBlockedRelevant(relevantAccounts) + domainBlocked, err := f.domainBlockedRelevant(ctx, relevantAccounts) if err != nil { l.Debugf("error checking domain block: %s", err) return false, fmt.Errorf("error checking domain block: %s", err) @@ -67,7 +68,7 @@ func (f *filter) StatusVisible(targetStatus *gtsmodel.Status, requestingAccount // note: we only do this for local users if targetAccount.Domain == "" { targetUser := >smodel.User{} - if err := f.db.GetWhere([]db.Where{{Key: "account_id", Value: targetAccount.ID}}, targetUser); err != nil { + if err := f.db.GetWhere(ctx, []db.Where{{Key: "account_id", Value: targetAccount.ID}}, targetUser); err != nil { l.Debug("target user could not be selected") if err == db.ErrNoEntries { return false, nil @@ -97,7 +98,7 @@ func (f *filter) StatusVisible(targetStatus *gtsmodel.Status, requestingAccount // note: we only do this for local users if requestingAccount.Domain == "" { requestingUser := >smodel.User{} - if err := f.db.GetWhere([]db.Where{{Key: "account_id", Value: requestingAccount.ID}}, requestingUser); err != nil { + if err := f.db.GetWhere(ctx, []db.Where{{Key: "account_id", Value: requestingAccount.ID}}, requestingUser); err != nil { // if the requesting account is local but doesn't have a corresponding user in the db this is a problem l.Debug("requesting user could not be selected") if err == db.ErrNoEntries { @@ -126,7 +127,7 @@ func (f *filter) StatusVisible(targetStatus *gtsmodel.Status, requestingAccount // At this point we have a populated targetAccount, targetStatus, and requestingAccount, so we can check for blocks and whathaveyou // First check if a block exists directly between the target account (which authored the status) and the requesting account. - if blocked, err := f.db.IsBlocked(targetAccount.ID, requestingAccount.ID, true); err != nil { + if blocked, err := f.db.IsBlocked(ctx, targetAccount.ID, requestingAccount.ID, true); err != nil { l.Debugf("something went wrong figuring out if the accounts have a block: %s", err) return false, err } else if blocked { @@ -137,7 +138,7 @@ func (f *filter) StatusVisible(targetStatus *gtsmodel.Status, requestingAccount // status replies to account id if relevantAccounts.InReplyToAccount != nil && relevantAccounts.InReplyToAccount.ID != requestingAccount.ID { - if blocked, err := f.db.IsBlocked(relevantAccounts.InReplyToAccount.ID, requestingAccount.ID, true); err != nil { + if blocked, err := f.db.IsBlocked(ctx, relevantAccounts.InReplyToAccount.ID, requestingAccount.ID, true); err != nil { return false, err } else if blocked { l.Trace("a block exists between requesting account and reply to account") @@ -146,7 +147,7 @@ func (f *filter) StatusVisible(targetStatus *gtsmodel.Status, requestingAccount // check reply to ID if targetStatus.InReplyToID != "" && (targetStatus.Visibility == gtsmodel.VisibilityFollowersOnly || targetStatus.Visibility == gtsmodel.VisibilityDirect) { - followsRepliedAccount, err := f.db.IsFollowing(requestingAccount, relevantAccounts.InReplyToAccount) + followsRepliedAccount, err := f.db.IsFollowing(ctx, requestingAccount, relevantAccounts.InReplyToAccount) if err != nil { return false, err } @@ -159,7 +160,7 @@ func (f *filter) StatusVisible(targetStatus *gtsmodel.Status, requestingAccount // status boosts accounts id if relevantAccounts.BoostedAccount != nil { - if blocked, err := f.db.IsBlocked(relevantAccounts.BoostedAccount.ID, requestingAccount.ID, true); err != nil { + if blocked, err := f.db.IsBlocked(ctx, relevantAccounts.BoostedAccount.ID, requestingAccount.ID, true); err != nil { return false, err } else if blocked { l.Trace("a block exists between requesting account and boosted account") @@ -169,7 +170,7 @@ func (f *filter) StatusVisible(targetStatus *gtsmodel.Status, requestingAccount // status boosts a reply to account id if relevantAccounts.BoostedInReplyToAccount != nil { - if blocked, err := f.db.IsBlocked(relevantAccounts.BoostedInReplyToAccount.ID, requestingAccount.ID, true); err != nil { + if blocked, err := f.db.IsBlocked(ctx, relevantAccounts.BoostedInReplyToAccount.ID, requestingAccount.ID, true); err != nil { return false, err } else if blocked { l.Trace("a block exists between requesting account and boosted reply to account") @@ -182,7 +183,7 @@ func (f *filter) StatusVisible(targetStatus *gtsmodel.Status, requestingAccount if a == nil { continue } - if blocked, err := f.db.IsBlocked(a.ID, requestingAccount.ID, true); err != nil { + if blocked, err := f.db.IsBlocked(ctx, a.ID, requestingAccount.ID, true); err != nil { return false, err } else if blocked { l.Trace("a block exists between requesting account and a mentioned account") @@ -195,7 +196,7 @@ func (f *filter) StatusVisible(targetStatus *gtsmodel.Status, requestingAccount if a == nil { continue } - if blocked, err := f.db.IsBlocked(a.ID, requestingAccount.ID, true); err != nil { + if blocked, err := f.db.IsBlocked(ctx, a.ID, requestingAccount.ID, true); err != nil { return false, err } else if blocked { l.Trace("a block exists between requesting account and a boosted mentioned account") @@ -221,7 +222,7 @@ func (f *filter) StatusVisible(targetStatus *gtsmodel.Status, requestingAccount return true, nil case gtsmodel.VisibilityFollowersOnly: // check one-way follow - follows, err := f.db.IsFollowing(requestingAccount, targetAccount) + follows, err := f.db.IsFollowing(ctx, requestingAccount, targetAccount) if err != nil { return false, err } @@ -232,7 +233,7 @@ func (f *filter) StatusVisible(targetStatus *gtsmodel.Status, requestingAccount return true, nil case gtsmodel.VisibilityMutualsOnly: // check mutual follow - mutuals, err := f.db.IsMutualFollowing(requestingAccount, targetAccount) + mutuals, err := f.db.IsMutualFollowing(ctx, requestingAccount, targetAccount) if err != nil { return false, err } diff --git a/internal/web/base.go b/internal/web/base.go index c0b85b613..eabde676c 100644 --- a/internal/web/base.go +++ b/internal/web/base.go @@ -50,7 +50,7 @@ func (m *Module) baseHandler(c *gin.Context) { l := m.log.WithField("func", "BaseGETHandler") l.Trace("serving index html") - instance, err := m.processor.InstanceGet(m.config.Host) + instance, err := m.processor.InstanceGet(c.Request.Context(), m.config.Host) if err != nil { l.Debugf("error getting instance from processor: %s", err) c.JSON(http.StatusInternalServerError, gin.H{"error": "internal server error"}) @@ -71,7 +71,7 @@ func (m *Module) NotFoundHandler(c *gin.Context) { l := m.log.WithField("func", "404") l.Trace("serving 404 html") - instance, err := m.processor.InstanceGet(m.config.Host) + instance, err := m.processor.InstanceGet(c.Request.Context(), m.config.Host) if err != nil { l.Debugf("error getting instance from processor: %s", err) c.JSON(http.StatusInternalServerError, gin.H{"error": "internal server error"}) |