summaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
authorLibravatar tobi <31960611+tsmethurst@users.noreply.github.com>2021-08-25 15:34:33 +0200
committerLibravatar GitHub <noreply@github.com>2021-08-25 15:34:33 +0200
commit2dc9fc1626507bb54417fc4a1920b847cafb27a2 (patch)
tree4ddeac479b923db38090aac8bd9209f3646851c1 /internal
parentManually approves followers (#146) (diff)
downloadgotosocial-2dc9fc1626507bb54417fc4a1920b847cafb27a2.tar.xz
Pg to bun (#148)
* start moving to bun * changing more stuff * more * and yet more * tests passing * seems stable now * more big changes * small fix * little fixes
Diffstat (limited to 'internal')
-rw-r--r--internal/api/client/account/accountcreate.go2
-rw-r--r--internal/api/client/account/accountget.go2
-rw-r--r--internal/api/client/account/accountupdate.go2
-rw-r--r--internal/api/client/account/accountupdate_test.go2
-rw-r--r--internal/api/client/account/accountverify.go2
-rw-r--r--internal/api/client/account/block.go2
-rw-r--r--internal/api/client/account/follow.go2
-rw-r--r--internal/api/client/account/followers.go2
-rw-r--r--internal/api/client/account/following.go2
-rw-r--r--internal/api/client/account/relationships.go2
-rw-r--r--internal/api/client/account/statuses.go2
-rw-r--r--internal/api/client/account/unblock.go2
-rw-r--r--internal/api/client/account/unfollow.go2
-rw-r--r--internal/api/client/admin/domainblockcreate.go4
-rw-r--r--internal/api/client/admin/domainblockdelete.go2
-rw-r--r--internal/api/client/admin/domainblockget.go2
-rw-r--r--internal/api/client/admin/domainblocksget.go2
-rw-r--r--internal/api/client/admin/emojicreate.go2
-rw-r--r--internal/api/client/app/appcreate.go2
-rw-r--r--internal/api/client/auth/auth_test.go16
-rw-r--r--internal/api/client/auth/authorize.go19
-rw-r--r--internal/api/client/auth/callback.go30
-rw-r--r--internal/api/client/auth/middleware.go8
-rw-r--r--internal/api/client/auth/signin.go7
-rw-r--r--internal/api/client/blocks/blocksget.go2
-rw-r--r--internal/api/client/favourites/favouritesget.go2
-rw-r--r--internal/api/client/fileserver/fileserver.go16
-rw-r--r--internal/api/client/fileserver/servefile.go2
-rw-r--r--internal/api/client/followrequest/accept.go2
-rw-r--r--internal/api/client/followrequest/get.go2
-rw-r--r--internal/api/client/instance/instanceget.go2
-rw-r--r--internal/api/client/instance/instancepatch.go2
-rw-r--r--internal/api/client/media/media.go17
-rw-r--r--internal/api/client/media/mediacreate.go2
-rw-r--r--internal/api/client/media/mediacreate_test.go2
-rw-r--r--internal/api/client/media/mediaget.go2
-rw-r--r--internal/api/client/media/mediaupdate.go2
-rw-r--r--internal/api/client/notification/notificationsget.go2
-rw-r--r--internal/api/client/search/searchget.go2
-rw-r--r--internal/api/client/status/statusboost.go2
-rw-r--r--internal/api/client/status/statusboost_test.go6
-rw-r--r--internal/api/client/status/statusboostedby.go2
-rw-r--r--internal/api/client/status/statuscontext.go2
-rw-r--r--internal/api/client/status/statuscreate.go2
-rw-r--r--internal/api/client/status/statuscreate_test.go20
-rw-r--r--internal/api/client/status/statusdelete.go2
-rw-r--r--internal/api/client/status/statusfave.go2
-rw-r--r--internal/api/client/status/statusfave_test.go4
-rw-r--r--internal/api/client/status/statusfavedby.go2
-rw-r--r--internal/api/client/status/statusfavedby_test.go2
-rw-r--r--internal/api/client/status/statusget.go2
-rw-r--r--internal/api/client/status/statusunboost.go2
-rw-r--r--internal/api/client/status/statusunfave.go2
-rw-r--r--internal/api/client/status/statusunfave_test.go4
-rw-r--r--internal/api/client/streaming/stream.go4
-rw-r--r--internal/api/client/timeline/home.go2
-rw-r--r--internal/api/client/timeline/public.go2
-rw-r--r--internal/api/s2s/nodeinfo/nodeinfoget.go2
-rw-r--r--internal/api/s2s/nodeinfo/wellknownget.go2
-rw-r--r--internal/api/s2s/user/userget_test.go2
-rw-r--r--internal/api/security/signaturecheck.go2
-rw-r--r--internal/cache/cache.go2
-rw-r--r--internal/cliactions/admin/account/account.go46
-rw-r--r--internal/cliactions/server/server.go14
-rw-r--r--internal/cliactions/testrig/testrig.go2
-rw-r--r--internal/db/account.go26
-rw-r--r--internal/db/admin.go11
-rw-r--r--internal/db/basic.go31
-rw-r--r--internal/db/bundb/account.go (renamed from internal/db/pg/account.go)155
-rw-r--r--internal/db/bundb/account_test.go (renamed from internal/db/pg/account_test.go)22
-rw-r--r--internal/db/bundb/admin.go (renamed from internal/db/pg/admin.go)149
-rw-r--r--internal/db/bundb/basic.go179
-rw-r--r--internal/db/bundb/basic_test.go68
-rw-r--r--internal/db/bundb/bundb.go (renamed from internal/db/pg/pg.go)150
-rw-r--r--internal/db/bundb/bundb_test.go (renamed from internal/db/pg/pg_test.go)4
-rw-r--r--internal/db/bundb/domain.go (renamed from internal/db/pg/domain.go)30
-rw-r--r--internal/db/bundb/instance.go (renamed from internal/db/pg/instance.go)66
-rw-r--r--internal/db/bundb/media.go (renamed from internal/db/pg/media.go)18
-rw-r--r--internal/db/bundb/mention.go (renamed from internal/db/pg/mention.go)22
-rw-r--r--internal/db/bundb/notification.go (renamed from internal/db/pg/notification.go)25
-rw-r--r--internal/db/bundb/relationship.go (renamed from internal/db/pg/relationship.go)172
-rw-r--r--internal/db/bundb/session.go85
-rw-r--r--internal/db/bundb/status.go375
-rw-r--r--internal/db/bundb/status_test.go (renamed from internal/db/pg/status_test.go)22
-rw-r--r--internal/db/bundb/timeline.go (renamed from internal/db/pg/timeline.go)89
-rw-r--r--internal/db/bundb/util.go78
-rw-r--r--internal/db/db.go9
-rw-r--r--internal/db/domain.go13
-rw-r--r--internal/db/instance.go14
-rw-r--r--internal/db/media.go8
-rw-r--r--internal/db/mention.go10
-rw-r--r--internal/db/notification.go10
-rw-r--r--internal/db/pg/basic.go205
-rw-r--r--internal/db/pg/status.go318
-rw-r--r--internal/db/pg/util.go25
-rw-r--r--internal/db/relationship.go30
-rw-r--r--internal/db/session.go (renamed from internal/federation/dereferencing/blocked.go)24
-rw-r--r--internal/db/status.go36
-rw-r--r--internal/db/timeline.go12
-rw-r--r--internal/federation/authenticate.go6
-rw-r--r--internal/federation/dereference.go29
-rw-r--r--internal/federation/dereferencing/account.go67
-rw-r--r--internal/federation/dereferencing/announce.go9
-rw-r--r--internal/federation/dereferencing/collectionpage.go6
-rw-r--r--internal/federation/dereferencing/dereferencer.go17
-rw-r--r--internal/federation/dereferencing/handshake.go7
-rw-r--r--internal/federation/dereferencing/instance.go6
-rw-r--r--internal/federation/dereferencing/status.go52
-rw-r--r--internal/federation/dereferencing/thread.go34
-rw-r--r--internal/federation/federatingdb/accept.go8
-rw-r--r--internal/federation/federatingdb/announce.go2
-rw-r--r--internal/federation/federatingdb/create.go16
-rw-r--r--internal/federation/federatingdb/delete.go10
-rw-r--r--internal/federation/federatingdb/followers.go22
-rw-r--r--internal/federation/federatingdb/following.go22
-rw-r--r--internal/federation/federatingdb/get.go14
-rw-r--r--internal/federation/federatingdb/outbox.go8
-rw-r--r--internal/federation/federatingdb/owns.go18
-rw-r--r--internal/federation/federatingdb/undo.go10
-rw-r--r--internal/federation/federatingdb/update.go5
-rw-r--r--internal/federation/federatingdb/util.go12
-rw-r--r--internal/federation/federatingprotocol.go16
-rw-r--r--internal/federation/federator.go18
-rw-r--r--internal/federation/finger.go6
-rw-r--r--internal/federation/handshake.go9
-rw-r--r--internal/federation/transport.go2
-rw-r--r--internal/gtsmodel/account.go60
-rw-r--r--internal/gtsmodel/application.go4
-rw-r--r--internal/gtsmodel/block.go16
-rw-r--r--internal/gtsmodel/domainblock.go14
-rw-r--r--internal/gtsmodel/emaildomainblock.go12
-rw-r--r--internal/gtsmodel/emoji.go33
-rw-r--r--internal/gtsmodel/follow.go18
-rw-r--r--internal/gtsmodel/followrequest.go18
-rw-r--r--internal/gtsmodel/instance.go22
-rw-r--r--internal/gtsmodel/mediaattachment.go20
-rw-r--r--internal/gtsmodel/mention.go26
-rw-r--r--internal/gtsmodel/messages.go32
-rw-r--r--internal/gtsmodel/notification.go18
-rw-r--r--internal/gtsmodel/routersession.go6
-rw-r--r--internal/gtsmodel/status.go72
-rw-r--r--internal/gtsmodel/statusbookmark.go14
-rw-r--r--internal/gtsmodel/statusfave.go18
-rw-r--r--internal/gtsmodel/statusmute.go16
-rw-r--r--internal/gtsmodel/tag.go16
-rw-r--r--internal/gtsmodel/user.go40
-rw-r--r--internal/media/handler.go28
-rw-r--r--internal/oauth/clientstore.go12
-rw-r--r--internal/oauth/tokenstore.go136
-rw-r--r--internal/processing/account.go46
-rw-r--r--internal/processing/account/account.go29
-rw-r--r--internal/processing/account/create.go17
-rw-r--r--internal/processing/account/createblock.go25
-rw-r--r--internal/processing/account/createfollow.go23
-rw-r--r--internal/processing/account/delete.go70
-rw-r--r--internal/processing/account/get.go18
-rw-r--r--internal/processing/account/getfollowers.go13
-rw-r--r--internal/processing/account/getfollowing.go13
-rw-r--r--internal/processing/account/getrelationship.go7
-rw-r--r--internal/processing/account/getstatuses.go11
-rw-r--r--internal/processing/account/removeblock.go11
-rw-r--r--internal/processing/account/removefollow.go19
-rw-r--r--internal/processing/account/update.go37
-rw-r--r--internal/processing/admin.go26
-rw-r--r--internal/processing/admin/admin.go13
-rw-r--r--internal/processing/admin/createdomainblock.go21
-rw-r--r--internal/processing/admin/deletedomainblock.go17
-rw-r--r--internal/processing/admin/emoji.go9
-rw-r--r--internal/processing/admin/getdomainblock.go7
-rw-r--r--internal/processing/admin/getdomainblocks.go8
-rw-r--r--internal/processing/admin/importdomainblocks.go5
-rw-r--r--internal/processing/app.go10
-rw-r--r--internal/processing/blocks.go7
-rw-r--r--internal/processing/federation.go64
-rw-r--r--internal/processing/followrequest.go50
-rw-r--r--internal/processing/fromclientapi.go168
-rw-r--r--internal/processing/fromcommon.go69
-rw-r--r--internal/processing/fromfederator.go37
-rw-r--r--internal/processing/instance.go25
-rw-r--r--internal/processing/media.go18
-rw-r--r--internal/processing/media/create.go9
-rw-r--r--internal/processing/media/delete.go22
-rw-r--r--internal/processing/media/getfile.go15
-rw-r--r--internal/processing/media/getmedia.go9
-rw-r--r--internal/processing/media/media.go12
-rw-r--r--internal/processing/media/update.go13
-rw-r--r--internal/processing/notification.go8
-rw-r--r--internal/processing/processor.go100
-rw-r--r--internal/processing/search.go45
-rw-r--r--internal/processing/status.go42
-rw-r--r--internal/processing/status/boost.go31
-rw-r--r--internal/processing/status/boostedby.go31
-rw-r--r--internal/processing/status/context.go37
-rw-r--r--internal/processing/status/create.go41
-rw-r--r--internal/processing/status/delete.go27
-rw-r--r--internal/processing/status/fave.go31
-rw-r--r--internal/processing/status/favedby.go31
-rw-r--r--internal/processing/status/get.go27
-rw-r--r--internal/processing/status/status.go56
-rw-r--r--internal/processing/status/unboost.go31
-rw-r--r--internal/processing/status/unfave.go31
-rw-r--r--internal/processing/status/util.go57
-rw-r--r--internal/processing/status/util_test.go43
-rw-r--r--internal/processing/streaming.go10
-rw-r--r--internal/processing/streaming/authorize.go8
-rw-r--r--internal/processing/streaming/openstream.go3
-rw-r--r--internal/processing/streaming/streaming.go5
-rw-r--r--internal/processing/timeline.go45
-rw-r--r--internal/router/router.go4
-rw-r--r--internal/router/session.go60
-rw-r--r--internal/text/common.go9
-rw-r--r--internal/text/common_test.go7
-rw-r--r--internal/text/formatter.go12
-rw-r--r--internal/text/link.go3
-rw-r--r--internal/text/link_test.go13
-rw-r--r--internal/text/markdown.go8
-rw-r--r--internal/text/markdown_test.go7
-rw-r--r--internal/text/plain.go9
-rw-r--r--internal/text/plain_test.go7
-rw-r--r--internal/timeline/get.go41
-rw-r--r--internal/timeline/get_test.go37
-rw-r--r--internal/timeline/index.go31
-rw-r--r--internal/timeline/index_test.go47
-rw-r--r--internal/timeline/manager.go79
-rw-r--r--internal/timeline/manager_test.go37
-rw-r--r--internal/timeline/prepare.go35
-rw-r--r--internal/timeline/remove.go5
-rw-r--r--internal/timeline/timeline.go41
-rw-r--r--internal/transport/controller.go7
-rw-r--r--internal/transport/deliver.go26
-rw-r--r--internal/transport/dereference.go22
-rw-r--r--internal/transport/derefinstance.go41
-rw-r--r--internal/transport/derefmedia.go23
-rw-r--r--internal/transport/finger.go24
-rw-r--r--internal/transport/transport.go24
-rw-r--r--internal/typeutils/astointernal.go46
-rw-r--r--internal/typeutils/astointernal_test.go4
-rw-r--r--internal/typeutils/converter.go71
-rw-r--r--internal/typeutils/internal.go5
-rw-r--r--internal/typeutils/internaltoas.go73
-rw-r--r--internal/typeutils/internaltoas_test.go3
-rw-r--r--internal/typeutils/internaltofrontend.go175
-rw-r--r--internal/typeutils/util.go11
-rw-r--r--internal/visibility/filter.go8
-rw-r--r--internal/visibility/relevantaccounts.go17
-rw-r--r--internal/visibility/statushometimelineable.go11
-rw-r--r--internal/visibility/statuspublictimelineable.go5
-rw-r--r--internal/visibility/statusvisible.go29
-rw-r--r--internal/web/base.go4
249 files changed, 3778 insertions, 2944 deletions
diff --git a/internal/api/client/account/accountcreate.go b/internal/api/client/account/accountcreate.go
index 50e72655e..a9d672f80 100644
--- a/internal/api/client/account/accountcreate.go
+++ b/internal/api/client/account/accountcreate.go
@@ -101,7 +101,7 @@ func (m *Module) AccountCreatePOSTHandler(c *gin.Context) {
form.IP = signUpIP
- ti, err := m.processor.AccountCreate(authed, form)
+ ti, err := m.processor.AccountCreate(c.Request.Context(), authed, form)
if err != nil {
l.Errorf("internal server error while creating new account: %s", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
diff --git a/internal/api/client/account/accountget.go b/internal/api/client/account/accountget.go
index a7f9d8c70..8bac1360b 100644
--- a/internal/api/client/account/accountget.go
+++ b/internal/api/client/account/accountget.go
@@ -70,7 +70,7 @@ func (m *Module) AccountGETHandler(c *gin.Context) {
return
}
- acctInfo, err := m.processor.AccountGet(authed, targetAcctID)
+ acctInfo, err := m.processor.AccountGet(c.Request.Context(), authed, targetAcctID)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "not found"})
return
diff --git a/internal/api/client/account/accountupdate.go b/internal/api/client/account/accountupdate.go
index f55f45f59..282d172ed 100644
--- a/internal/api/client/account/accountupdate.go
+++ b/internal/api/client/account/accountupdate.go
@@ -122,7 +122,7 @@ func (m *Module) AccountUpdateCredentialsPATCHHandler(c *gin.Context) {
return
}
- acctSensitive, err := m.processor.AccountUpdate(authed, form)
+ acctSensitive, err := m.processor.AccountUpdate(c.Request.Context(), authed, form)
if err != nil {
l.Debugf("could not update account: %s", err)
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
diff --git a/internal/api/client/account/accountupdate_test.go b/internal/api/client/account/accountupdate_test.go
index 349429625..8fc31171b 100644
--- a/internal/api/client/account/accountupdate_test.go
+++ b/internal/api/client/account/accountupdate_test.go
@@ -79,7 +79,7 @@ func (suite *AccountUpdateTestSuite) TestAccountUpdateCredentialsPATCHHandler()
recorder := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(recorder)
ctx.Set(oauth.SessionAuthorizedAccount, suite.testAccounts["local_account_1"])
- ctx.Set(oauth.SessionAuthorizedToken, oauth.TokenToOauthToken(suite.testTokens["local_account_1"]))
+ ctx.Set(oauth.SessionAuthorizedToken, oauth.DBTokenToToken(suite.testTokens["local_account_1"]))
ctx.Set(oauth.SessionAuthorizedApplication, suite.testApplications["application_1"])
ctx.Set(oauth.SessionAuthorizedUser, suite.testUsers["local_account_1"])
ctx.Request = httptest.NewRequest(http.MethodPatch, fmt.Sprintf("http://localhost:8080/%s", account.UpdateCredentialsPath), bytes.NewReader(requestBody.Bytes())) // the endpoint we're hitting
diff --git a/internal/api/client/account/accountverify.go b/internal/api/client/account/accountverify.go
index 4c77f3fa6..c5c40a03d 100644
--- a/internal/api/client/account/accountverify.go
+++ b/internal/api/client/account/accountverify.go
@@ -59,7 +59,7 @@ func (m *Module) AccountVerifyGETHandler(c *gin.Context) {
return
}
- acctSensitive, err := m.processor.AccountGet(authed, authed.Account.ID)
+ acctSensitive, err := m.processor.AccountGet(c.Request.Context(), authed, authed.Account.ID)
if err != nil {
l.Debugf("error getting account from processor: %s", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "internal server error"})
diff --git a/internal/api/client/account/block.go b/internal/api/client/account/block.go
index 0d9d6c51b..243f90c5e 100644
--- a/internal/api/client/account/block.go
+++ b/internal/api/client/account/block.go
@@ -72,7 +72,7 @@ func (m *Module) AccountBlockPOSTHandler(c *gin.Context) {
return
}
- relationship, errWithCode := m.processor.AccountBlockCreate(authed, targetAcctID)
+ relationship, errWithCode := m.processor.AccountBlockCreate(c.Request.Context(), authed, targetAcctID)
if errWithCode != nil {
c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()})
return
diff --git a/internal/api/client/account/follow.go b/internal/api/client/account/follow.go
index 985a5f821..8f3f7ad0a 100644
--- a/internal/api/client/account/follow.go
+++ b/internal/api/client/account/follow.go
@@ -99,7 +99,7 @@ func (m *Module) AccountFollowPOSTHandler(c *gin.Context) {
}
form.ID = targetAcctID
- relationship, errWithCode := m.processor.AccountFollowCreate(authed, form)
+ relationship, errWithCode := m.processor.AccountFollowCreate(c.Request.Context(), authed, form)
if errWithCode != nil {
c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()})
return
diff --git a/internal/api/client/account/followers.go b/internal/api/client/account/followers.go
index 7e93544b8..4f30e9939 100644
--- a/internal/api/client/account/followers.go
+++ b/internal/api/client/account/followers.go
@@ -74,7 +74,7 @@ func (m *Module) AccountFollowersGETHandler(c *gin.Context) {
return
}
- followers, errWithCode := m.processor.AccountFollowersGet(authed, targetAcctID)
+ followers, errWithCode := m.processor.AccountFollowersGet(c.Request.Context(), authed, targetAcctID)
if errWithCode != nil {
c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()})
return
diff --git a/internal/api/client/account/following.go b/internal/api/client/account/following.go
index e70265eb5..baac2c9d3 100644
--- a/internal/api/client/account/following.go
+++ b/internal/api/client/account/following.go
@@ -74,7 +74,7 @@ func (m *Module) AccountFollowingGETHandler(c *gin.Context) {
return
}
- following, errWithCode := m.processor.AccountFollowingGet(authed, targetAcctID)
+ following, errWithCode := m.processor.AccountFollowingGet(c.Request.Context(), authed, targetAcctID)
if errWithCode != nil {
c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()})
return
diff --git a/internal/api/client/account/relationships.go b/internal/api/client/account/relationships.go
index 9dbc8c4bb..d350209af 100644
--- a/internal/api/client/account/relationships.go
+++ b/internal/api/client/account/relationships.go
@@ -71,7 +71,7 @@ func (m *Module) AccountRelationshipsGETHandler(c *gin.Context) {
relationships := []model.Relationship{}
for _, targetAccountID := range targetAccountIDs {
- r, errWithCode := m.processor.AccountRelationshipGet(authed, targetAccountID)
+ r, errWithCode := m.processor.AccountRelationshipGet(c.Request.Context(), authed, targetAccountID)
if err != nil {
c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()})
return
diff --git a/internal/api/client/account/statuses.go b/internal/api/client/account/statuses.go
index 097ccc3cc..4841d86df 100644
--- a/internal/api/client/account/statuses.go
+++ b/internal/api/client/account/statuses.go
@@ -166,7 +166,7 @@ func (m *Module) AccountStatusesGETHandler(c *gin.Context) {
mediaOnly = i
}
- statuses, errWithCode := m.processor.AccountStatusesGet(authed, targetAcctID, limit, excludeReplies, maxID, pinnedOnly, mediaOnly)
+ statuses, errWithCode := m.processor.AccountStatusesGet(c.Request.Context(), authed, targetAcctID, limit, excludeReplies, maxID, pinnedOnly, mediaOnly)
if errWithCode != nil {
l.Debugf("error from processor account statuses get: %s", errWithCode)
c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()})
diff --git a/internal/api/client/account/unblock.go b/internal/api/client/account/unblock.go
index d9a2f2881..7b16ac887 100644
--- a/internal/api/client/account/unblock.go
+++ b/internal/api/client/account/unblock.go
@@ -72,7 +72,7 @@ func (m *Module) AccountUnblockPOSTHandler(c *gin.Context) {
return
}
- relationship, errWithCode := m.processor.AccountBlockRemove(authed, targetAcctID)
+ relationship, errWithCode := m.processor.AccountBlockRemove(c.Request.Context(), authed, targetAcctID)
if errWithCode != nil {
c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()})
return
diff --git a/internal/api/client/account/unfollow.go b/internal/api/client/account/unfollow.go
index 84a558c65..7ac9697d5 100644
--- a/internal/api/client/account/unfollow.go
+++ b/internal/api/client/account/unfollow.go
@@ -75,7 +75,7 @@ func (m *Module) AccountUnfollowPOSTHandler(c *gin.Context) {
return
}
- relationship, errWithCode := m.processor.AccountFollowRemove(authed, targetAcctID)
+ relationship, errWithCode := m.processor.AccountFollowRemove(c.Request.Context(), authed, targetAcctID)
if errWithCode != nil {
l.Debug(errWithCode.Error())
c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()})
diff --git a/internal/api/client/admin/domainblockcreate.go b/internal/api/client/admin/domainblockcreate.go
index d48c70408..9ef4c6f92 100644
--- a/internal/api/client/admin/domainblockcreate.go
+++ b/internal/api/client/admin/domainblockcreate.go
@@ -141,7 +141,7 @@ func (m *Module) DomainBlocksPOSTHandler(c *gin.Context) {
if imp {
// we're importing multiple blocks
- domainBlocks, err := m.processor.AdminDomainBlocksImport(authed, form)
+ domainBlocks, err := m.processor.AdminDomainBlocksImport(c.Request.Context(), authed, form)
if err != nil {
l.Debugf("error importing domain blocks: %s", err)
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
@@ -150,7 +150,7 @@ func (m *Module) DomainBlocksPOSTHandler(c *gin.Context) {
c.JSON(http.StatusOK, domainBlocks)
} else {
// we're just creating one block
- domainBlock, err := m.processor.AdminDomainBlockCreate(authed, form)
+ domainBlock, err := m.processor.AdminDomainBlockCreate(c.Request.Context(), authed, form)
if err != nil {
l.Debugf("error creating domain block: %s", err)
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
diff --git a/internal/api/client/admin/domainblockdelete.go b/internal/api/client/admin/domainblockdelete.go
index 9cd2fd711..64e4ef6de 100644
--- a/internal/api/client/admin/domainblockdelete.go
+++ b/internal/api/client/admin/domainblockdelete.go
@@ -68,7 +68,7 @@ func (m *Module) DomainBlockDELETEHandler(c *gin.Context) {
return
}
- domainBlock, errWithCode := m.processor.AdminDomainBlockDelete(authed, domainBlockID)
+ domainBlock, errWithCode := m.processor.AdminDomainBlockDelete(c.Request.Context(), authed, domainBlockID)
if errWithCode != nil {
l.Debugf("error deleting domain block: %s", errWithCode.Error())
c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()})
diff --git a/internal/api/client/admin/domainblockget.go b/internal/api/client/admin/domainblockget.go
index 86923d705..d23f99a8c 100644
--- a/internal/api/client/admin/domainblockget.go
+++ b/internal/api/client/admin/domainblockget.go
@@ -81,7 +81,7 @@ func (m *Module) DomainBlockGETHandler(c *gin.Context) {
export = i
}
- domainBlock, err := m.processor.AdminDomainBlockGet(authed, domainBlockID, export)
+ domainBlock, err := m.processor.AdminDomainBlockGet(c.Request.Context(), authed, domainBlockID, export)
if err != nil {
l.Debugf("error getting domain block: %s", err)
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
diff --git a/internal/api/client/admin/domainblocksget.go b/internal/api/client/admin/domainblocksget.go
index 77a287387..dad8250e0 100644
--- a/internal/api/client/admin/domainblocksget.go
+++ b/internal/api/client/admin/domainblocksget.go
@@ -81,7 +81,7 @@ func (m *Module) DomainBlocksGETHandler(c *gin.Context) {
export = i
}
- domainBlocks, err := m.processor.AdminDomainBlocksGet(authed, export)
+ domainBlocks, err := m.processor.AdminDomainBlocksGet(c.Request.Context(), authed, export)
if err != nil {
l.Debugf("error getting domain blocks: %s", err)
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
diff --git a/internal/api/client/admin/emojicreate.go b/internal/api/client/admin/emojicreate.go
index bfdf28249..859933b16 100644
--- a/internal/api/client/admin/emojicreate.go
+++ b/internal/api/client/admin/emojicreate.go
@@ -111,7 +111,7 @@ func (m *Module) emojiCreatePOSTHandler(c *gin.Context) {
return
}
- mastoEmoji, err := m.processor.AdminEmojiCreate(authed, form)
+ mastoEmoji, err := m.processor.AdminEmojiCreate(c.Request.Context(), authed, form)
if err != nil {
l.Debugf("error creating emoji: %s", err)
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
diff --git a/internal/api/client/app/appcreate.go b/internal/api/client/app/appcreate.go
index 31072f9c8..d233841b0 100644
--- a/internal/api/client/app/appcreate.go
+++ b/internal/api/client/app/appcreate.go
@@ -101,7 +101,7 @@ func (m *Module) AppsPOSTHandler(c *gin.Context) {
return
}
- mastoApp, err := m.processor.AppCreate(authed, form)
+ mastoApp, err := m.processor.AppCreate(c.Request.Context(), authed, form)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
diff --git a/internal/api/client/auth/auth_test.go b/internal/api/client/auth/auth_test.go
index 48d2a2508..3d5170f31 100644
--- a/internal/api/client/auth/auth_test.go
+++ b/internal/api/client/auth/auth_test.go
@@ -28,7 +28,7 @@ import (
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
- "github.com/superseriousbusiness/gotosocial/internal/db/pg"
+ "github.com/superseriousbusiness/gotosocial/internal/db/bundb"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"golang.org/x/crypto/bcrypt"
@@ -104,7 +104,7 @@ func (suite *AuthTestSuite) SetupTest() {
log := logrus.New()
log.SetLevel(logrus.TraceLevel)
- db, err := pg.NewPostgresService(context.Background(), suite.config, log)
+ db, err := bundb.NewBunDBService(context.Background(), suite.config, log)
if err != nil {
logrus.Panicf("error creating database connection: %s", err)
}
@@ -120,23 +120,23 @@ func (suite *AuthTestSuite) SetupTest() {
}
for _, m := range models {
- if err := suite.db.CreateTable(m); err != nil {
+ if err := suite.db.CreateTable(context.Background(), m); err != nil {
logrus.Panicf("db connection error: %s", err)
}
}
suite.oauthServer = oauth.New(suite.db, log)
- if err := suite.db.Put(suite.testAccount); err != nil {
+ if err := suite.db.Put(context.Background(), suite.testAccount); err != nil {
logrus.Panicf("could not insert test account into db: %s", err)
}
- if err := suite.db.Put(suite.testUser); err != nil {
+ if err := suite.db.Put(context.Background(), suite.testUser); err != nil {
logrus.Panicf("could not insert test user into db: %s", err)
}
- if err := suite.db.Put(suite.testClient); err != nil {
+ if err := suite.db.Put(context.Background(), suite.testClient); err != nil {
logrus.Panicf("could not insert test client into db: %s", err)
}
- if err := suite.db.Put(suite.testApplication); err != nil {
+ if err := suite.db.Put(context.Background(), suite.testApplication); err != nil {
logrus.Panicf("could not insert test application into db: %s", err)
}
@@ -152,7 +152,7 @@ func (suite *AuthTestSuite) TearDownTest() {
&gtsmodel.Application{},
}
for _, m := range models {
- if err := suite.db.DropTable(m); err != nil {
+ if err := suite.db.DropTable(context.Background(), m); err != nil {
logrus.Panicf("error dropping table: %s", err)
}
}
diff --git a/internal/api/client/auth/authorize.go b/internal/api/client/auth/authorize.go
index a10408723..0328f3b21 100644
--- a/internal/api/client/auth/authorize.go
+++ b/internal/api/client/auth/authorize.go
@@ -70,30 +70,23 @@ func (m *Module) AuthorizeGETHandler(c *gin.Context) {
c.JSON(http.StatusInternalServerError, gin.H{"error": "no client_id found in session"})
return
}
- app := &gtsmodel.Application{
- ClientID: clientID,
- }
- if err := m.db.GetWhere([]db.Where{{Key: sessionClientID, Value: app.ClientID}}, app); err != nil {
+ app := &gtsmodel.Application{}
+ if err := m.db.GetWhere(c.Request.Context(), []db.Where{{Key: sessionClientID, Value: app.ClientID}}, app); err != nil {
m.clearSession(s)
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("no application found for client id %s", clientID)})
return
}
// we can also use the userid of the user to fetch their username from the db to greet them nicely <3
- user := &gtsmodel.User{
- ID: userID,
- }
- if err := m.db.GetByID(user.ID, user); err != nil {
+ user := &gtsmodel.User{}
+ if err := m.db.GetByID(c.Request.Context(), user.ID, user); err != nil {
m.clearSession(s)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
- acct := &gtsmodel.Account{
- ID: user.AccountID,
- }
-
- if err := m.db.GetByID(acct.ID, acct); err != nil {
+ acct, err := m.db.GetAccountByID(c.Request.Context(), user.AccountID)
+ if err != nil {
m.clearSession(s)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
diff --git a/internal/api/client/auth/callback.go b/internal/api/client/auth/callback.go
index a26838aa3..cbb429352 100644
--- a/internal/api/client/auth/callback.go
+++ b/internal/api/client/auth/callback.go
@@ -19,6 +19,7 @@
package auth
import (
+ "context"
"errors"
"fmt"
"net"
@@ -80,13 +81,13 @@ func (m *Module) CallbackGETHandler(c *gin.Context) {
app := &gtsmodel.Application{
ClientID: clientID,
}
- if err := m.db.GetWhere([]db.Where{{Key: sessionClientID, Value: app.ClientID}}, app); err != nil {
+ if err := m.db.GetWhere(c.Request.Context(), []db.Where{{Key: sessionClientID, Value: app.ClientID}}, app); err != nil {
m.clearSession(s)
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("no application found for client id %s", clientID)})
return
}
- user, err := m.parseUserFromClaims(claims, net.IP(c.ClientIP()), app.ID)
+ user, err := m.parseUserFromClaims(c.Request.Context(), claims, net.IP(c.ClientIP()), app.ID)
if err != nil {
m.clearSession(s)
c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
@@ -103,14 +104,14 @@ func (m *Module) CallbackGETHandler(c *gin.Context) {
c.Redirect(http.StatusFound, OauthAuthorizePath)
}
-func (m *Module) parseUserFromClaims(claims *oidc.Claims, ip net.IP, appID string) (*gtsmodel.User, error) {
+func (m *Module) parseUserFromClaims(ctx context.Context, claims *oidc.Claims, ip net.IP, appID string) (*gtsmodel.User, error) {
if claims.Email == "" {
return nil, errors.New("no email returned in claims")
}
// see if we already have a user for this email address
user := &gtsmodel.User{}
- err := m.db.GetWhere([]db.Where{{Key: "email", Value: claims.Email}}, user)
+ err := m.db.GetWhere(ctx, []db.Where{{Key: "email", Value: claims.Email}}, user)
if err == nil {
// we do! so we can just return it
return user, nil
@@ -122,7 +123,7 @@ func (m *Module) parseUserFromClaims(claims *oidc.Claims, ip net.IP, appID strin
}
// maybe we have an unconfirmed user
- err = m.db.GetWhere([]db.Where{{Key: "unconfirmed_email", Value: claims.Email}}, user)
+ err = m.db.GetWhere(ctx, []db.Where{{Key: "unconfirmed_email", Value: claims.Email}}, user)
if err == nil {
// user is unconfirmed so return an error
return nil, fmt.Errorf("user with email address %s is unconfirmed", claims.Email)
@@ -137,9 +138,13 @@ func (m *Module) parseUserFromClaims(claims *oidc.Claims, ip net.IP, appID strin
// however, because we trust the OIDC provider, we should now create a user + account with the provided claims
// check if the email address is available for use; if it's not there's nothing we can so
- if err := m.db.IsEmailAvailable(claims.Email); err != nil {
+ emailAvailable, err := m.db.IsEmailAvailable(ctx, claims.Email)
+ if err != nil {
return nil, fmt.Errorf("email %s not available: %s", claims.Email, err)
}
+ if !emailAvailable {
+ return nil, fmt.Errorf("email %s in use", claims.Email)
+ }
// now we need a username
var username string
@@ -180,12 +185,11 @@ func (m *Module) parseUserFromClaims(claims *oidc.Claims, ip net.IP, appID strin
// note that for the first iteration, iString is still "" when the check is made, so our first choice
// is still the raw username with no integer stuck on the end
for i := 1; !found; i = i + 1 {
- if err := m.db.IsUsernameAvailable(username + iString); err != nil {
- if strings.Contains(err.Error(), "db error") {
- // if there's an actual db error we should return
- return nil, fmt.Errorf("error checking username availability: %s", err)
- }
- } else {
+ usernameAvailable, err := m.db.IsUsernameAvailable(ctx, username+iString)
+ if err != nil {
+ return nil, err
+ }
+ if usernameAvailable {
// no error so we've found a username that works
found = true
username = username + iString
@@ -209,7 +213,7 @@ func (m *Module) parseUserFromClaims(claims *oidc.Claims, ip net.IP, appID strin
password := uuid.NewString() + uuid.NewString()
// create the user! this will also create an account and store it in the database so we don't need to do that here
- user, err = m.db.NewSignup(username, "", m.config.AccountsConfig.RequireApproval, claims.Email, password, ip, "", appID, claims.EmailVerified, admin)
+ user, err = m.db.NewSignup(ctx, username, "", m.config.AccountsConfig.RequireApproval, claims.Email, password, ip, "", appID, claims.EmailVerified, admin)
if err != nil {
return nil, fmt.Errorf("error creating user: %s", err)
}
diff --git a/internal/api/client/auth/middleware.go b/internal/api/client/auth/middleware.go
index a734b2ceb..3599c7048 100644
--- a/internal/api/client/auth/middleware.go
+++ b/internal/api/client/auth/middleware.go
@@ -49,15 +49,15 @@ func (m *Module) OauthTokenMiddleware(c *gin.Context) {
// fetch user's and account for this user id
user := &gtsmodel.User{}
- if err := m.db.GetByID(uid, user); err != nil || user == nil {
+ if err := m.db.GetByID(c.Request.Context(), uid, user); err != nil || user == nil {
l.Warnf("no user found for validated uid %s", uid)
return
}
c.Set(oauth.SessionAuthorizedUser, user)
l.Tracef("set gin context %s to %+v", oauth.SessionAuthorizedUser, user)
- acct := &gtsmodel.Account{}
- if err := m.db.GetByID(user.AccountID, acct); err != nil || acct == nil {
+ acct, err := m.db.GetAccountByID(c.Request.Context(), user.AccountID)
+ if err != nil || acct == nil {
l.Warnf("no account found for validated user %s", uid)
return
}
@@ -69,7 +69,7 @@ func (m *Module) OauthTokenMiddleware(c *gin.Context) {
if cid := ti.GetClientID(); cid != "" {
l.Tracef("authenticated client %s with bearer token, scope is %s", cid, ti.GetScope())
app := &gtsmodel.Application{}
- if err := m.db.GetWhere([]db.Where{{Key: "client_id", Value: cid}}, app); err != nil {
+ if err := m.db.GetWhere(c.Request.Context(), []db.Where{{Key: "client_id", Value: cid}}, app); err != nil {
l.Tracef("no app found for client %s", cid)
}
c.Set(oauth.SessionAuthorizedApplication, app)
diff --git a/internal/api/client/auth/signin.go b/internal/api/client/auth/signin.go
index 543505cbd..6b8bb93db 100644
--- a/internal/api/client/auth/signin.go
+++ b/internal/api/client/auth/signin.go
@@ -19,6 +19,7 @@
package auth
import (
+ "context"
"errors"
"net/http"
@@ -74,7 +75,7 @@ func (m *Module) SignInPOSTHandler(c *gin.Context) {
}
l.Tracef("parsed form: %+v", form)
- userid, err := m.ValidatePassword(form.Email, form.Password)
+ userid, err := m.ValidatePassword(c.Request.Context(), form.Email, form.Password)
if err != nil {
c.String(http.StatusForbidden, err.Error())
m.clearSession(s)
@@ -96,7 +97,7 @@ func (m *Module) SignInPOSTHandler(c *gin.Context) {
// The goal is to authenticate the password against the one for that email
// address stored in the database. If OK, we return the userid (a ulid) for that user,
// so that it can be used in further Oauth flows to generate a token/retreieve an oauth client from the db.
-func (m *Module) ValidatePassword(email string, password string) (userid string, err error) {
+func (m *Module) ValidatePassword(ctx context.Context, email string, password string) (userid string, err error) {
l := m.log.WithField("func", "ValidatePassword")
// make sure an email/password was provided and bail if not
@@ -108,7 +109,7 @@ func (m *Module) ValidatePassword(email string, password string) (userid string,
// first we select the user from the database based on email address, bail if no user found for that email
gtsUser := &gtsmodel.User{}
- if err := m.db.GetWhere([]db.Where{{Key: "email", Value: email}}, gtsUser); err != nil {
+ if err := m.db.GetWhere(ctx, []db.Where{{Key: "email", Value: email}}, gtsUser); err != nil {
l.Debugf("user %s was not retrievable from db during oauth authorization attempt: %s", email, err)
return incorrectPassword()
}
diff --git a/internal/api/client/blocks/blocksget.go b/internal/api/client/blocks/blocksget.go
index 65c11ea1a..b6c9c39e1 100644
--- a/internal/api/client/blocks/blocksget.go
+++ b/internal/api/client/blocks/blocksget.go
@@ -117,7 +117,7 @@ func (m *Module) BlocksGETHandler(c *gin.Context) {
limit = int(i)
}
- resp, errWithCode := m.processor.BlocksGet(authed, maxID, sinceID, limit)
+ resp, errWithCode := m.processor.BlocksGet(c.Request.Context(), authed, maxID, sinceID, limit)
if errWithCode != nil {
l.Debugf("error from processor BlocksGet: %s", errWithCode)
c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()})
diff --git a/internal/api/client/favourites/favouritesget.go b/internal/api/client/favourites/favouritesget.go
index 76eb921e0..6984ea754 100644
--- a/internal/api/client/favourites/favouritesget.go
+++ b/internal/api/client/favourites/favouritesget.go
@@ -43,7 +43,7 @@ func (m *Module) FavouritesGETHandler(c *gin.Context) {
limit = int(i)
}
- resp, errWithCode := m.processor.FavedTimelineGet(authed, maxID, minID, limit)
+ resp, errWithCode := m.processor.FavedTimelineGet(c.Request.Context(), authed, maxID, minID, limit)
if errWithCode != nil {
l.Debugf("error from processor FavedTimelineGet: %s", errWithCode)
c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()})
diff --git a/internal/api/client/fileserver/fileserver.go b/internal/api/client/fileserver/fileserver.go
index 08e6abb62..61286c17a 100644
--- a/internal/api/client/fileserver/fileserver.go
+++ b/internal/api/client/fileserver/fileserver.go
@@ -25,8 +25,6 @@ import (
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/api"
"github.com/superseriousbusiness/gotosocial/internal/config"
- "github.com/superseriousbusiness/gotosocial/internal/db"
- "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/router"
)
@@ -66,17 +64,3 @@ func (m *FileServer) Route(s router.Router) error {
s.AttachHandler(http.MethodGet, fmt.Sprintf("%s/:%s/:%s/:%s/:%s", m.storageBase, AccountIDKey, MediaTypeKey, MediaSizeKey, FileNameKey), m.ServeFile)
return nil
}
-
-// CreateTables populates necessary tables in the given DB
-func (m *FileServer) CreateTables(db db.DB) error {
- models := []interface{}{
- &gtsmodel.MediaAttachment{},
- }
-
- for _, m := range models {
- if err := db.CreateTable(m); err != nil {
- return fmt.Errorf("error creating table: %s", err)
- }
- }
- return nil
-}
diff --git a/internal/api/client/fileserver/servefile.go b/internal/api/client/fileserver/servefile.go
index 1339fbac3..130a16c4f 100644
--- a/internal/api/client/fileserver/servefile.go
+++ b/internal/api/client/fileserver/servefile.go
@@ -78,7 +78,7 @@ func (m *FileServer) ServeFile(c *gin.Context) {
return
}
- content, err := m.processor.FileGet(authed, &model.GetContentRequestForm{
+ content, err := m.processor.FileGet(c.Request.Context(), authed, &model.GetContentRequestForm{
AccountID: accountID,
MediaType: mediaType,
MediaSize: mediaSize,
diff --git a/internal/api/client/followrequest/accept.go b/internal/api/client/followrequest/accept.go
index bb2910c8f..3dba7673f 100644
--- a/internal/api/client/followrequest/accept.go
+++ b/internal/api/client/followrequest/accept.go
@@ -48,7 +48,7 @@ func (m *Module) FollowRequestAcceptPOSTHandler(c *gin.Context) {
return
}
- r, errWithCode := m.processor.FollowRequestAccept(authed, originAccountID)
+ r, errWithCode := m.processor.FollowRequestAccept(c.Request.Context(), authed, originAccountID)
if errWithCode != nil {
l.Debug(errWithCode.Error())
c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()})
diff --git a/internal/api/client/followrequest/get.go b/internal/api/client/followrequest/get.go
index 3f02ee02a..47e1d74ba 100644
--- a/internal/api/client/followrequest/get.go
+++ b/internal/api/client/followrequest/get.go
@@ -41,7 +41,7 @@ func (m *Module) FollowRequestGETHandler(c *gin.Context) {
return
}
- accts, errWithCode := m.processor.FollowRequestsGet(authed)
+ accts, errWithCode := m.processor.FollowRequestsGet(c.Request.Context(), authed)
if errWithCode != nil {
c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()})
return
diff --git a/internal/api/client/instance/instanceget.go b/internal/api/client/instance/instanceget.go
index 0d53edadb..0a6d17153 100644
--- a/internal/api/client/instance/instanceget.go
+++ b/internal/api/client/instance/instanceget.go
@@ -31,7 +31,7 @@ import (
func (m *Module) InstanceInformationGETHandler(c *gin.Context) {
l := m.log.WithField("func", "InstanceInformationGETHandler")
- instance, err := m.processor.InstanceGet(m.config.Host)
+ instance, err := m.processor.InstanceGet(c.Request.Context(), m.config.Host)
if err != nil {
l.Debugf("error getting instance from processor: %s", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "internal server error"})
diff --git a/internal/api/client/instance/instancepatch.go b/internal/api/client/instance/instancepatch.go
index 3620f6044..fa37ccd8e 100644
--- a/internal/api/client/instance/instancepatch.go
+++ b/internal/api/client/instance/instancepatch.go
@@ -116,7 +116,7 @@ func (m *Module) InstanceUpdatePATCHHandler(c *gin.Context) {
return
}
- i, errWithCode := m.processor.InstancePatch(form)
+ i, errWithCode := m.processor.InstancePatch(c.Request.Context(), form)
if errWithCode != nil {
l.Debugf("error with instance patch request: %s", errWithCode.Error())
c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()})
diff --git a/internal/api/client/media/media.go b/internal/api/client/media/media.go
index 05058e2e2..1e9e8fdaa 100644
--- a/internal/api/client/media/media.go
+++ b/internal/api/client/media/media.go
@@ -19,14 +19,11 @@
package media
import (
- "fmt"
"net/http"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/api"
"github.com/superseriousbusiness/gotosocial/internal/config"
- "github.com/superseriousbusiness/gotosocial/internal/db"
- "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/router"
)
@@ -63,17 +60,3 @@ func (m *Module) Route(s router.Router) error {
s.AttachHandler(http.MethodPut, BasePathWithID, m.MediaPUTHandler)
return nil
}
-
-// CreateTables populates necessary tables in the given DB
-func (m *Module) CreateTables(db db.DB) error {
- models := []interface{}{
- &gtsmodel.MediaAttachment{},
- }
-
- for _, m := range models {
- if err := db.CreateTable(m); err != nil {
- return fmt.Errorf("error creating table: %s", err)
- }
- }
- return nil
-}
diff --git a/internal/api/client/media/mediacreate.go b/internal/api/client/media/mediacreate.go
index f41d4568f..58d076ea6 100644
--- a/internal/api/client/media/mediacreate.go
+++ b/internal/api/client/media/mediacreate.go
@@ -108,7 +108,7 @@ func (m *Module) MediaCreatePOSTHandler(c *gin.Context) {
}
l.Debug("calling processor media create func")
- mastoAttachment, err := m.processor.MediaCreate(authed, form)
+ mastoAttachment, err := m.processor.MediaCreate(c.Request.Context(), authed, form)
if err != nil {
l.Debugf("error creating attachment: %s", err)
c.JSON(http.StatusUnprocessableEntity, gin.H{"error": err.Error()})
diff --git a/internal/api/client/media/mediacreate_test.go b/internal/api/client/media/mediacreate_test.go
index 5c48a4381..8433786e4 100644
--- a/internal/api/client/media/mediacreate_test.go
+++ b/internal/api/client/media/mediacreate_test.go
@@ -121,7 +121,7 @@ func (suite *MediaCreateTestSuite) TestStatusCreatePOSTImageHandlerSuccessful()
// set up the context for the request
t := suite.testTokens["local_account_1"]
- oauthToken := oauth.TokenToOauthToken(t)
+ oauthToken := oauth.DBTokenToToken(t)
recorder := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(recorder)
ctx.Set(oauth.SessionAuthorizedApplication, suite.testApplications["application_1"])
diff --git a/internal/api/client/media/mediaget.go b/internal/api/client/media/mediaget.go
index 17c5a090b..5fd7856e9 100644
--- a/internal/api/client/media/mediaget.go
+++ b/internal/api/client/media/mediaget.go
@@ -75,7 +75,7 @@ func (m *Module) MediaGETHandler(c *gin.Context) {
return
}
- attachment, errWithCode := m.processor.MediaGet(authed, attachmentID)
+ attachment, errWithCode := m.processor.MediaGet(c.Request.Context(), authed, attachmentID)
if errWithCode != nil {
c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()})
return
diff --git a/internal/api/client/media/mediaupdate.go b/internal/api/client/media/mediaupdate.go
index 0ceb01f82..3af19297f 100644
--- a/internal/api/client/media/mediaupdate.go
+++ b/internal/api/client/media/mediaupdate.go
@@ -122,7 +122,7 @@ func (m *Module) MediaPUTHandler(c *gin.Context) {
return
}
- attachment, errWithCode := m.processor.MediaUpdate(authed, attachmentID, &form)
+ attachment, errWithCode := m.processor.MediaUpdate(c.Request.Context(), authed, attachmentID, &form)
if errWithCode != nil {
c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()})
return
diff --git a/internal/api/client/notification/notificationsget.go b/internal/api/client/notification/notificationsget.go
index a30674750..81e8a6890 100644
--- a/internal/api/client/notification/notificationsget.go
+++ b/internal/api/client/notification/notificationsget.go
@@ -68,7 +68,7 @@ func (m *Module) NotificationsGETHandler(c *gin.Context) {
sinceID = sinceIDString
}
- notifs, errWithCode := m.processor.NotificationsGet(authed, limit, maxID, sinceID)
+ notifs, errWithCode := m.processor.NotificationsGet(c.Request.Context(), authed, limit, maxID, sinceID)
if errWithCode != nil {
l.Debugf("error processing notifications get: %s", errWithCode.Error())
c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()})
diff --git a/internal/api/client/search/searchget.go b/internal/api/client/search/searchget.go
index faa227719..848915274 100644
--- a/internal/api/client/search/searchget.go
+++ b/internal/api/client/search/searchget.go
@@ -164,7 +164,7 @@ func (m *Module) SearchGETHandler(c *gin.Context) {
Following: following,
}
- results, errWithCode := m.processor.SearchGet(authed, searchQuery)
+ results, errWithCode := m.processor.SearchGet(c.Request.Context(), authed, searchQuery)
if errWithCode != nil {
l.Debugf("error searching: %s", errWithCode.Error())
c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()})
diff --git a/internal/api/client/status/statusboost.go b/internal/api/client/status/statusboost.go
index 5aa7989bc..094e56ac0 100644
--- a/internal/api/client/status/statusboost.go
+++ b/internal/api/client/status/statusboost.go
@@ -87,7 +87,7 @@ func (m *Module) StatusBoostPOSTHandler(c *gin.Context) {
return
}
- mastoStatus, errWithCode := m.processor.StatusBoost(authed, targetStatusID)
+ mastoStatus, errWithCode := m.processor.StatusBoost(c.Request.Context(), authed, targetStatusID)
if errWithCode != nil {
l.Debugf("error processing status boost: %s", errWithCode.Error())
c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()})
diff --git a/internal/api/client/status/statusboost_test.go b/internal/api/client/status/statusboost_test.go
index fbe267fac..4157bde38 100644
--- a/internal/api/client/status/statusboost_test.go
+++ b/internal/api/client/status/statusboost_test.go
@@ -67,7 +67,7 @@ func (suite *StatusBoostTestSuite) TearDownTest() {
func (suite *StatusBoostTestSuite) TestPostBoost() {
t := suite.testTokens["local_account_1"]
- oauthToken := oauth.TokenToOauthToken(t)
+ oauthToken := oauth.DBTokenToToken(t)
targetStatus := suite.testStatuses["admin_account_status_1"]
@@ -133,7 +133,7 @@ func (suite *StatusBoostTestSuite) TestPostBoost() {
func (suite *StatusBoostTestSuite) TestPostUnboostable() {
t := suite.testTokens["local_account_1"]
- oauthToken := oauth.TokenToOauthToken(t)
+ oauthToken := oauth.DBTokenToToken(t)
targetStatus := suite.testStatuses["local_account_2_status_4"]
@@ -171,7 +171,7 @@ func (suite *StatusBoostTestSuite) TestPostUnboostable() {
func (suite *StatusBoostTestSuite) TestPostNotVisible() {
t := suite.testTokens["local_account_2"]
- oauthToken := oauth.TokenToOauthToken(t)
+ oauthToken := oauth.DBTokenToToken(t)
targetStatus := suite.testStatuses["local_account_1_status_3"] // this is a mutual only status and these accounts aren't mutuals
diff --git a/internal/api/client/status/statusboostedby.go b/internal/api/client/status/statusboostedby.go
index 260e21642..908c3ff10 100644
--- a/internal/api/client/status/statusboostedby.go
+++ b/internal/api/client/status/statusboostedby.go
@@ -84,7 +84,7 @@ func (m *Module) StatusBoostedByGETHandler(c *gin.Context) {
return
}
- mastoAccounts, err := m.processor.StatusBoostedBy(authed, targetStatusID)
+ mastoAccounts, err := m.processor.StatusBoostedBy(c.Request.Context(), authed, targetStatusID)
if err != nil {
l.Debugf("error processing status boosted by request: %s", err)
c.JSON(http.StatusBadRequest, gin.H{"error": "bad request"})
diff --git a/internal/api/client/status/statuscontext.go b/internal/api/client/status/statuscontext.go
index 6e28b004e..90fcb9608 100644
--- a/internal/api/client/status/statuscontext.go
+++ b/internal/api/client/status/statuscontext.go
@@ -86,7 +86,7 @@ func (m *Module) StatusContextGETHandler(c *gin.Context) {
return
}
- statusContext, errWithCode := m.processor.StatusGetContext(authed, targetStatusID)
+ statusContext, errWithCode := m.processor.StatusGetContext(c.Request.Context(), authed, targetStatusID)
if errWithCode != nil {
l.Debugf("error getting status context: %s", errWithCode.Error())
c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()})
diff --git a/internal/api/client/status/statuscreate.go b/internal/api/client/status/statuscreate.go
index 2007ba80f..09fc47b5b 100644
--- a/internal/api/client/status/statuscreate.go
+++ b/internal/api/client/status/statuscreate.go
@@ -101,7 +101,7 @@ func (m *Module) StatusCreatePOSTHandler(c *gin.Context) {
return
}
- mastoStatus, err := m.processor.StatusCreate(authed, form)
+ mastoStatus, err := m.processor.StatusCreate(c.Request.Context(), authed, form)
if err != nil {
l.Debugf("error processing status create: %s", err)
c.JSON(http.StatusBadRequest, gin.H{"error": "bad request"})
diff --git a/internal/api/client/status/statuscreate_test.go b/internal/api/client/status/statuscreate_test.go
index 33912397e..060d96dad 100644
--- a/internal/api/client/status/statuscreate_test.go
+++ b/internal/api/client/status/statuscreate_test.go
@@ -19,6 +19,7 @@
package status_test
import (
+ "context"
"encoding/json"
"fmt"
"io/ioutil"
@@ -82,7 +83,7 @@ https://docs.gotosocial.org/en/latest/user_guide/posts/#links
func (suite *StatusCreateTestSuite) TestPostNewStatus() {
t := suite.testTokens["local_account_1"]
- oauthToken := oauth.TokenToOauthToken(t)
+ oauthToken := oauth.DBTokenToToken(t)
// setup
recorder := httptest.NewRecorder()
@@ -128,7 +129,7 @@ func (suite *StatusCreateTestSuite) TestPostNewStatus() {
}, statusReply.Tags[0])
gtsTag := &gtsmodel.Tag{}
- err = suite.db.GetWhere([]db.Where{{Key: "name", Value: "helloworld"}}, gtsTag)
+ err = suite.db.GetWhere(context.Background(), []db.Where{{Key: "name", Value: "helloworld"}}, gtsTag)
assert.NoError(suite.T(), err)
assert.Equal(suite.T(), statusReply.Account.ID, gtsTag.FirstSeenFromAccountID)
}
@@ -136,7 +137,7 @@ func (suite *StatusCreateTestSuite) TestPostNewStatus() {
func (suite *StatusCreateTestSuite) TestPostAnotherNewStatus() {
t := suite.testTokens["local_account_1"]
- oauthToken := oauth.TokenToOauthToken(t)
+ oauthToken := oauth.DBTokenToToken(t)
// setup
recorder := httptest.NewRecorder()
@@ -171,7 +172,7 @@ func (suite *StatusCreateTestSuite) TestPostAnotherNewStatus() {
func (suite *StatusCreateTestSuite) TestPostNewStatusWithEmoji() {
t := suite.testTokens["local_account_1"]
- oauthToken := oauth.TokenToOauthToken(t)
+ oauthToken := oauth.DBTokenToToken(t)
// setup
recorder := httptest.NewRecorder()
@@ -212,7 +213,7 @@ func (suite *StatusCreateTestSuite) TestPostNewStatusWithEmoji() {
// Try to reply to a status that doesn't exist
func (suite *StatusCreateTestSuite) TestReplyToNonexistentStatus() {
t := suite.testTokens["local_account_1"]
- oauthToken := oauth.TokenToOauthToken(t)
+ oauthToken := oauth.DBTokenToToken(t)
// setup
recorder := httptest.NewRecorder()
@@ -243,7 +244,7 @@ func (suite *StatusCreateTestSuite) TestReplyToNonexistentStatus() {
// Post a reply to the status of a local user that allows replies.
func (suite *StatusCreateTestSuite) TestReplyToLocalStatus() {
t := suite.testTokens["local_account_1"]
- oauthToken := oauth.TokenToOauthToken(t)
+ oauthToken := oauth.DBTokenToToken(t)
// setup
recorder := httptest.NewRecorder()
@@ -283,7 +284,7 @@ func (suite *StatusCreateTestSuite) TestReplyToLocalStatus() {
// Take a media file which is currently not associated with a status, and attach it to a new status.
func (suite *StatusCreateTestSuite) TestAttachNewMediaSuccess() {
t := suite.testTokens["local_account_1"]
- oauthToken := oauth.TokenToOauthToken(t)
+ oauthToken := oauth.DBTokenToToken(t)
attachment := suite.testAttachments["local_account_1_unattached_1"]
@@ -322,12 +323,11 @@ func (suite *StatusCreateTestSuite) TestAttachNewMediaSuccess() {
assert.Len(suite.T(), statusResponse.MediaAttachments, 1)
// get the updated media attachment from the database
- gtsAttachment := &gtsmodel.MediaAttachment{}
- err = suite.db.GetByID(statusResponse.MediaAttachments[0].ID, gtsAttachment)
+ gtsAttachment, err := suite.db.GetAttachmentByID(context.Background(), statusResponse.MediaAttachments[0].ID)
assert.NoError(suite.T(), err)
// convert it to a masto attachment
- gtsAttachmentAsMasto, err := suite.tc.AttachmentToMasto(gtsAttachment)
+ gtsAttachmentAsMasto, err := suite.tc.AttachmentToMasto(context.Background(), gtsAttachment)
assert.NoError(suite.T(), err)
// compare it with what we have now
diff --git a/internal/api/client/status/statusdelete.go b/internal/api/client/status/statusdelete.go
index 257280ce0..9a67c45ba 100644
--- a/internal/api/client/status/statusdelete.go
+++ b/internal/api/client/status/statusdelete.go
@@ -86,7 +86,7 @@ func (m *Module) StatusDELETEHandler(c *gin.Context) {
return
}
- mastoStatus, err := m.processor.StatusDelete(authed, targetStatusID)
+ mastoStatus, err := m.processor.StatusDelete(c.Request.Context(), authed, targetStatusID)
if err != nil {
l.Debugf("error processing status delete: %s", err)
c.JSON(http.StatusBadRequest, gin.H{"error": "bad request"})
diff --git a/internal/api/client/status/statusfave.go b/internal/api/client/status/statusfave.go
index a76acf3d9..94a8f9380 100644
--- a/internal/api/client/status/statusfave.go
+++ b/internal/api/client/status/statusfave.go
@@ -83,7 +83,7 @@ func (m *Module) StatusFavePOSTHandler(c *gin.Context) {
return
}
- mastoStatus, err := m.processor.StatusFave(authed, targetStatusID)
+ mastoStatus, err := m.processor.StatusFave(c.Request.Context(), authed, targetStatusID)
if err != nil {
l.Debugf("error processing status fave: %s", err)
c.JSON(http.StatusBadRequest, gin.H{"error": "bad request"})
diff --git a/internal/api/client/status/statusfave_test.go b/internal/api/client/status/statusfave_test.go
index 0f44b5e90..2f7a2c596 100644
--- a/internal/api/client/status/statusfave_test.go
+++ b/internal/api/client/status/statusfave_test.go
@@ -71,7 +71,7 @@ func (suite *StatusFaveTestSuite) TearDownTest() {
func (suite *StatusFaveTestSuite) TestPostFave() {
t := suite.testTokens["local_account_1"]
- oauthToken := oauth.TokenToOauthToken(t)
+ oauthToken := oauth.DBTokenToToken(t)
targetStatus := suite.testStatuses["admin_account_status_2"]
@@ -119,7 +119,7 @@ func (suite *StatusFaveTestSuite) TestPostFave() {
func (suite *StatusFaveTestSuite) TestPostUnfaveable() {
t := suite.testTokens["local_account_1"]
- oauthToken := oauth.TokenToOauthToken(t)
+ oauthToken := oauth.DBTokenToToken(t)
targetStatus := suite.testStatuses["local_account_2_status_3"] // this one is unlikeable and unreplyable
diff --git a/internal/api/client/status/statusfavedby.go b/internal/api/client/status/statusfavedby.go
index a5d6e7c58..7b8e19e20 100644
--- a/internal/api/client/status/statusfavedby.go
+++ b/internal/api/client/status/statusfavedby.go
@@ -84,7 +84,7 @@ func (m *Module) StatusFavedByGETHandler(c *gin.Context) {
return
}
- mastoAccounts, err := m.processor.StatusFavedBy(authed, targetStatusID)
+ mastoAccounts, err := m.processor.StatusFavedBy(c.Request.Context(), authed, targetStatusID)
if err != nil {
l.Debugf("error processing status faved by request: %s", err)
c.JSON(http.StatusBadRequest, gin.H{"error": "bad request"})
diff --git a/internal/api/client/status/statusfavedby_test.go b/internal/api/client/status/statusfavedby_test.go
index 22a549b30..7475f1e69 100644
--- a/internal/api/client/status/statusfavedby_test.go
+++ b/internal/api/client/status/statusfavedby_test.go
@@ -69,7 +69,7 @@ func (suite *StatusFavedByTestSuite) TearDownTest() {
func (suite *StatusFavedByTestSuite) TestGetFavedBy() {
t := suite.testTokens["local_account_2"]
- oauthToken := oauth.TokenToOauthToken(t)
+ oauthToken := oauth.DBTokenToToken(t)
targetStatus := suite.testStatuses["admin_account_status_1"] // this status is faved by local_account_1
diff --git a/internal/api/client/status/statusget.go b/internal/api/client/status/statusget.go
index bcca010f5..39668288f 100644
--- a/internal/api/client/status/statusget.go
+++ b/internal/api/client/status/statusget.go
@@ -83,7 +83,7 @@ func (m *Module) StatusGETHandler(c *gin.Context) {
return
}
- mastoStatus, err := m.processor.StatusGet(authed, targetStatusID)
+ mastoStatus, err := m.processor.StatusGet(c.Request.Context(), authed, targetStatusID)
if err != nil {
l.Debugf("error processing status get: %s", err)
c.JSON(http.StatusBadRequest, gin.H{"error": "bad request"})
diff --git a/internal/api/client/status/statusunboost.go b/internal/api/client/status/statusunboost.go
index dc42e3b62..8bb28fa50 100644
--- a/internal/api/client/status/statusunboost.go
+++ b/internal/api/client/status/statusunboost.go
@@ -84,7 +84,7 @@ func (m *Module) StatusUnboostPOSTHandler(c *gin.Context) {
return
}
- mastoStatus, errWithCode := m.processor.StatusUnboost(authed, targetStatusID)
+ mastoStatus, errWithCode := m.processor.StatusUnboost(c.Request.Context(), authed, targetStatusID)
if errWithCode != nil {
l.Debugf("error processing status unboost: %s", errWithCode.Error())
c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()})
diff --git a/internal/api/client/status/statusunfave.go b/internal/api/client/status/statusunfave.go
index 80eb87acf..5a1074ca2 100644
--- a/internal/api/client/status/statusunfave.go
+++ b/internal/api/client/status/statusunfave.go
@@ -83,7 +83,7 @@ func (m *Module) StatusUnfavePOSTHandler(c *gin.Context) {
return
}
- mastoStatus, err := m.processor.StatusUnfave(authed, targetStatusID)
+ mastoStatus, err := m.processor.StatusUnfave(c.Request.Context(), authed, targetStatusID)
if err != nil {
l.Debugf("error processing status unfave: %s", err)
c.JSON(http.StatusBadRequest, gin.H{"error": "bad request"})
diff --git a/internal/api/client/status/statusunfave_test.go b/internal/api/client/status/statusunfave_test.go
index a5f267f4c..9e7ea8f82 100644
--- a/internal/api/client/status/statusunfave_test.go
+++ b/internal/api/client/status/statusunfave_test.go
@@ -71,7 +71,7 @@ func (suite *StatusUnfaveTestSuite) TearDownTest() {
func (suite *StatusUnfaveTestSuite) TestPostUnfave() {
t := suite.testTokens["local_account_1"]
- oauthToken := oauth.TokenToOauthToken(t)
+ oauthToken := oauth.DBTokenToToken(t)
// this is the status we wanna unfave: in the testrig it's already faved by this account
targetStatus := suite.testStatuses["admin_account_status_1"]
@@ -120,7 +120,7 @@ func (suite *StatusUnfaveTestSuite) TestPostUnfave() {
func (suite *StatusUnfaveTestSuite) TestPostAlreadyNotFaved() {
t := suite.testTokens["local_account_1"]
- oauthToken := oauth.TokenToOauthToken(t)
+ oauthToken := oauth.DBTokenToToken(t)
// this is the status we wanna unfave: in the testrig it's not faved by this account
targetStatus := suite.testStatuses["admin_account_status_2"]
diff --git a/internal/api/client/streaming/stream.go b/internal/api/client/streaming/stream.go
index 626f1ff41..fa210e8d8 100644
--- a/internal/api/client/streaming/stream.go
+++ b/internal/api/client/streaming/stream.go
@@ -122,7 +122,7 @@ func (m *Module) StreamGETHandler(c *gin.Context) {
}
// make sure a valid token has been provided and obtain the associated account
- account, err := m.processor.AuthorizeStreamingRequest(accessToken)
+ account, err := m.processor.AuthorizeStreamingRequest(c.Request.Context(), accessToken)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "could not authorize with given token"})
return
@@ -147,7 +147,7 @@ func (m *Module) StreamGETHandler(c *gin.Context) {
defer conn.Close() // whatever happens, when we leave this function we want to close the websocket connection
// inform the processor that we have a new connection and want a stream for it
- stream, errWithCode := m.processor.OpenStreamForAccount(account, streamType)
+ stream, errWithCode := m.processor.OpenStreamForAccount(c.Request.Context(), account, streamType)
if errWithCode != nil {
c.JSON(errWithCode.Code(), errWithCode.Safe())
return
diff --git a/internal/api/client/timeline/home.go b/internal/api/client/timeline/home.go
index a6e64f384..6df4b29d0 100644
--- a/internal/api/client/timeline/home.go
+++ b/internal/api/client/timeline/home.go
@@ -153,7 +153,7 @@ func (m *Module) HomeTimelineGETHandler(c *gin.Context) {
local = i
}
- resp, errWithCode := m.processor.HomeTimelineGet(authed, maxID, sinceID, minID, limit, local)
+ resp, errWithCode := m.processor.HomeTimelineGet(c.Request.Context(), authed, maxID, sinceID, minID, limit, local)
if errWithCode != nil {
l.Debugf("error from processor HomeTimelineGet: %s", errWithCode)
c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()})
diff --git a/internal/api/client/timeline/public.go b/internal/api/client/timeline/public.go
index 178fcd7f1..8c8c9f120 100644
--- a/internal/api/client/timeline/public.go
+++ b/internal/api/client/timeline/public.go
@@ -153,7 +153,7 @@ func (m *Module) PublicTimelineGETHandler(c *gin.Context) {
local = i
}
- resp, errWithCode := m.processor.PublicTimelineGet(authed, maxID, sinceID, minID, limit, local)
+ resp, errWithCode := m.processor.PublicTimelineGet(c.Request.Context(), authed, maxID, sinceID, minID, limit, local)
if errWithCode != nil {
l.Debugf("error from processor PublicTimelineGet: %s", errWithCode)
c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()})
diff --git a/internal/api/s2s/nodeinfo/nodeinfoget.go b/internal/api/s2s/nodeinfo/nodeinfoget.go
index a54c8b190..c362e1d2e 100644
--- a/internal/api/s2s/nodeinfo/nodeinfoget.go
+++ b/internal/api/s2s/nodeinfo/nodeinfoget.go
@@ -33,7 +33,7 @@ func (m *Module) NodeInfoGETHandler(c *gin.Context) {
"user-agent": c.Request.UserAgent(),
})
- ni, err := m.processor.GetNodeInfo(c.Request)
+ ni, err := m.processor.GetNodeInfo(c.Request.Context(), c.Request)
if err != nil {
l.Debugf("error with get node info request: %s", err)
c.JSON(err.Code(), err.Safe())
diff --git a/internal/api/s2s/nodeinfo/wellknownget.go b/internal/api/s2s/nodeinfo/wellknownget.go
index 614d2a9c6..fd2c84408 100644
--- a/internal/api/s2s/nodeinfo/wellknownget.go
+++ b/internal/api/s2s/nodeinfo/wellknownget.go
@@ -33,7 +33,7 @@ func (m *Module) NodeInfoWellKnownGETHandler(c *gin.Context) {
"user-agent": c.Request.UserAgent(),
})
- niRel, err := m.processor.GetNodeInfoRel(c.Request)
+ niRel, err := m.processor.GetNodeInfoRel(c.Request.Context(), c.Request)
if err != nil {
l.Debugf("error with get node info rel request: %s", err)
c.JSON(err.Code(), err.Safe())
diff --git a/internal/api/s2s/user/userget_test.go b/internal/api/s2s/user/userget_test.go
index ab0015c57..29cc0e0d8 100644
--- a/internal/api/s2s/user/userget_test.go
+++ b/internal/api/s2s/user/userget_test.go
@@ -105,7 +105,7 @@ func (suite *UserGetTestSuite) TestGetUser() {
// convert person to account
// since this account is already known, we should get a pretty full model of it from the conversion
- a, err := suite.tc.ASRepresentationToAccount(person, false)
+ a, err := suite.tc.ASRepresentationToAccount(context.Background(), person, false)
assert.NoError(suite.T(), err)
assert.EqualValues(suite.T(), targetAccount.Username, a.Username)
}
diff --git a/internal/api/security/signaturecheck.go b/internal/api/security/signaturecheck.go
index 88b0b4dff..71e539e96 100644
--- a/internal/api/security/signaturecheck.go
+++ b/internal/api/security/signaturecheck.go
@@ -31,7 +31,7 @@ func (m *Module) SignatureCheck(c *gin.Context) {
// we managed to parse the url!
// if the domain is blocked we want to bail as early as possible
- blocked, err := m.db.IsURIBlocked(requestingPublicKeyID)
+ blocked, err := m.db.IsURIBlocked(c.Request.Context(), requestingPublicKeyID)
if err != nil {
l.Errorf("could not tell if domain %s was blocked or not: %s", requestingPublicKeyID.Host, err)
c.AbortWithStatus(http.StatusInternalServerError)
diff --git a/internal/cache/cache.go b/internal/cache/cache.go
index eb3744cfe..ce4aad04d 100644
--- a/internal/cache/cache.go
+++ b/internal/cache/cache.go
@@ -37,7 +37,7 @@ type cache struct {
// New returns a new in-memory cache.
func New() Cache {
c := ttlcache.NewCache()
- c.SetTTL(30 * time.Second)
+ c.SetTTL(5 * time.Minute)
cache := &cache{
c: c,
}
diff --git a/internal/cliactions/admin/account/account.go b/internal/cliactions/admin/account/account.go
index 0ae7f32de..46998ec6a 100644
--- a/internal/cliactions/admin/account/account.go
+++ b/internal/cliactions/admin/account/account.go
@@ -28,7 +28,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/cliactions"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
- "github.com/superseriousbusiness/gotosocial/internal/db/pg"
+ "github.com/superseriousbusiness/gotosocial/internal/db/bundb"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/util"
"golang.org/x/crypto/bcrypt"
@@ -36,7 +36,7 @@ import (
// Create creates a new account in the database using the provided flags.
var Create cliactions.GTSAction = func(ctx context.Context, c *config.Config, log *logrus.Logger) error {
- dbConn, err := pg.NewPostgresService(ctx, c, log)
+ dbConn, err := bundb.NewBunDBService(ctx, c, log)
if err != nil {
return fmt.Errorf("error creating dbservice: %s", err)
}
@@ -65,7 +65,7 @@ var Create cliactions.GTSAction = func(ctx context.Context, c *config.Config, lo
return err
}
- _, err = dbConn.NewSignup(username, "", false, email, password, nil, "", "", false, false)
+ _, err = dbConn.NewSignup(ctx, username, "", false, email, password, nil, "", "", false, false)
if err != nil {
return err
}
@@ -75,7 +75,7 @@ var Create cliactions.GTSAction = func(ctx context.Context, c *config.Config, lo
// Confirm sets a user to Approved, sets Email to the current UnconfirmedEmail value, and sets ConfirmedAt to now.
var Confirm cliactions.GTSAction = func(ctx context.Context, c *config.Config, log *logrus.Logger) error {
- dbConn, err := pg.NewPostgresService(ctx, c, log)
+ dbConn, err := bundb.NewBunDBService(ctx, c, log)
if err != nil {
return fmt.Errorf("error creating dbservice: %s", err)
}
@@ -88,20 +88,20 @@ var Confirm cliactions.GTSAction = func(ctx context.Context, c *config.Config, l
return err
}
- a, err := dbConn.GetLocalAccountByUsername(username)
+ a, err := dbConn.GetLocalAccountByUsername(ctx, username)
if err != nil {
return err
}
u := &gtsmodel.User{}
- if err := dbConn.GetWhere([]db.Where{{Key: "account_id", Value: a.ID}}, u); err != nil {
+ if err := dbConn.GetWhere(ctx, []db.Where{{Key: "account_id", Value: a.ID}}, u); err != nil {
return err
}
u.Approved = true
u.Email = u.UnconfirmedEmail
u.ConfirmedAt = time.Now()
- if err := dbConn.UpdateByID(u.ID, u); err != nil {
+ if err := dbConn.UpdateByID(ctx, u.ID, u); err != nil {
return err
}
@@ -110,7 +110,7 @@ var Confirm cliactions.GTSAction = func(ctx context.Context, c *config.Config, l
// Promote sets a user to admin.
var Promote cliactions.GTSAction = func(ctx context.Context, c *config.Config, log *logrus.Logger) error {
- dbConn, err := pg.NewPostgresService(ctx, c, log)
+ dbConn, err := bundb.NewBunDBService(ctx, c, log)
if err != nil {
return fmt.Errorf("error creating dbservice: %s", err)
}
@@ -123,17 +123,17 @@ var Promote cliactions.GTSAction = func(ctx context.Context, c *config.Config, l
return err
}
- a, err := dbConn.GetLocalAccountByUsername(username)
+ a, err := dbConn.GetLocalAccountByUsername(ctx, username)
if err != nil {
return err
}
u := &gtsmodel.User{}
- if err := dbConn.GetWhere([]db.Where{{Key: "account_id", Value: a.ID}}, u); err != nil {
+ if err := dbConn.GetWhere(ctx, []db.Where{{Key: "account_id", Value: a.ID}}, u); err != nil {
return err
}
u.Admin = true
- if err := dbConn.UpdateByID(u.ID, u); err != nil {
+ if err := dbConn.UpdateByID(ctx, u.ID, u); err != nil {
return err
}
@@ -142,7 +142,7 @@ var Promote cliactions.GTSAction = func(ctx context.Context, c *config.Config, l
// Demote sets admin on a user to false.
var Demote cliactions.GTSAction = func(ctx context.Context, c *config.Config, log *logrus.Logger) error {
- dbConn, err := pg.NewPostgresService(ctx, c, log)
+ dbConn, err := bundb.NewBunDBService(ctx, c, log)
if err != nil {
return fmt.Errorf("error creating dbservice: %s", err)
}
@@ -155,17 +155,17 @@ var Demote cliactions.GTSAction = func(ctx context.Context, c *config.Config, lo
return err
}
- a, err := dbConn.GetLocalAccountByUsername(username)
+ a, err := dbConn.GetLocalAccountByUsername(ctx, username)
if err != nil {
return err
}
u := &gtsmodel.User{}
- if err := dbConn.GetWhere([]db.Where{{Key: "account_id", Value: a.ID}}, u); err != nil {
+ if err := dbConn.GetWhere(ctx, []db.Where{{Key: "account_id", Value: a.ID}}, u); err != nil {
return err
}
u.Admin = false
- if err := dbConn.UpdateByID(u.ID, u); err != nil {
+ if err := dbConn.UpdateByID(ctx, u.ID, u); err != nil {
return err
}
@@ -174,7 +174,7 @@ var Demote cliactions.GTSAction = func(ctx context.Context, c *config.Config, lo
// Disable sets Disabled to true on a user.
var Disable cliactions.GTSAction = func(ctx context.Context, c *config.Config, log *logrus.Logger) error {
- dbConn, err := pg.NewPostgresService(ctx, c, log)
+ dbConn, err := bundb.NewBunDBService(ctx, c, log)
if err != nil {
return fmt.Errorf("error creating dbservice: %s", err)
}
@@ -187,17 +187,17 @@ var Disable cliactions.GTSAction = func(ctx context.Context, c *config.Config, l
return err
}
- a, err := dbConn.GetLocalAccountByUsername(username)
+ a, err := dbConn.GetLocalAccountByUsername(ctx, username)
if err != nil {
return err
}
u := &gtsmodel.User{}
- if err := dbConn.GetWhere([]db.Where{{Key: "account_id", Value: a.ID}}, u); err != nil {
+ if err := dbConn.GetWhere(ctx, []db.Where{{Key: "account_id", Value: a.ID}}, u); err != nil {
return err
}
u.Disabled = true
- if err := dbConn.UpdateByID(u.ID, u); err != nil {
+ if err := dbConn.UpdateByID(ctx, u.ID, u); err != nil {
return err
}
@@ -212,7 +212,7 @@ var Suspend cliactions.GTSAction = func(ctx context.Context, c *config.Config, l
// Password sets the password of target account.
var Password cliactions.GTSAction = func(ctx context.Context, c *config.Config, log *logrus.Logger) error {
- dbConn, err := pg.NewPostgresService(ctx, c, log)
+ dbConn, err := bundb.NewBunDBService(ctx, c, log)
if err != nil {
return fmt.Errorf("error creating dbservice: %s", err)
}
@@ -233,13 +233,13 @@ var Password cliactions.GTSAction = func(ctx context.Context, c *config.Config,
return err
}
- a, err := dbConn.GetLocalAccountByUsername(username)
+ a, err := dbConn.GetLocalAccountByUsername(ctx, username)
if err != nil {
return err
}
u := &gtsmodel.User{}
- if err := dbConn.GetWhere([]db.Where{{Key: "account_id", Value: a.ID}}, u); err != nil {
+ if err := dbConn.GetWhere(ctx, []db.Where{{Key: "account_id", Value: a.ID}}, u); err != nil {
return err
}
@@ -250,7 +250,7 @@ var Password cliactions.GTSAction = func(ctx context.Context, c *config.Config,
u.EncryptedPassword = string(pw)
- if err := dbConn.UpdateByID(u.ID, u); err != nil {
+ if err := dbConn.UpdateByID(ctx, u.ID, u); err != nil {
return err
}
diff --git a/internal/cliactions/server/server.go b/internal/cliactions/server/server.go
index 72c6cfadf..877a9d397 100644
--- a/internal/cliactions/server/server.go
+++ b/internal/cliactions/server/server.go
@@ -35,7 +35,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/blob"
"github.com/superseriousbusiness/gotosocial/internal/cliactions"
"github.com/superseriousbusiness/gotosocial/internal/config"
- "github.com/superseriousbusiness/gotosocial/internal/db/pg"
+ "github.com/superseriousbusiness/gotosocial/internal/db/bundb"
"github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/federation/federatingdb"
"github.com/superseriousbusiness/gotosocial/internal/gotosocial"
@@ -79,28 +79,28 @@ var models []interface{} = []interface{}{
// Start creates and starts a gotosocial server
var Start cliactions.GTSAction = func(ctx context.Context, c *config.Config, log *logrus.Logger) error {
- dbService, err := pg.NewPostgresService(ctx, c, log)
+ dbService, err := bundb.NewBunDBService(ctx, c, log)
if err != nil {
return fmt.Errorf("error creating dbservice: %s", err)
}
for _, m := range models {
- if err := dbService.CreateTable(m); err != nil {
+ if err := dbService.CreateTable(ctx, m); err != nil {
return fmt.Errorf("table creation error: %s", err)
}
}
- if err := dbService.CreateInstanceAccount(); err != nil {
+ if err := dbService.CreateInstanceAccount(ctx); err != nil {
return fmt.Errorf("error creating instance account: %s", err)
}
- if err := dbService.CreateInstanceInstance(); err != nil {
+ if err := dbService.CreateInstanceInstance(ctx); err != nil {
return fmt.Errorf("error creating instance instance: %s", err)
}
federatingDB := federatingdb.New(dbService, c, log)
- router, err := router.New(c, dbService, log)
+ router, err := router.New(ctx, c, dbService, log)
if err != nil {
return fmt.Errorf("error creating router: %s", err)
}
@@ -120,7 +120,7 @@ var Start cliactions.GTSAction = func(ctx context.Context, c *config.Config, log
transportController := transport.NewController(c, dbService, &federation.Clock{}, http.DefaultClient, log)
federator := federation.NewFederator(dbService, federatingDB, transportController, c, log, typeConverter, mediaHandler)
processor := processing.NewProcessor(c, typeConverter, federator, oauthServer, mediaHandler, storageBackend, timelineManager, dbService, log)
- if err := processor.Start(); err != nil {
+ if err := processor.Start(ctx); err != nil {
return fmt.Errorf("error starting processor: %s", err)
}
diff --git a/internal/cliactions/testrig/testrig.go b/internal/cliactions/testrig/testrig.go
index a7032825c..7badca556 100644
--- a/internal/cliactions/testrig/testrig.go
+++ b/internal/cliactions/testrig/testrig.go
@@ -63,7 +63,7 @@ var Start cliactions.GTSAction = func(ctx context.Context, _ *config.Config, log
federator := testrig.NewTestFederator(dbService, transportController, storageBackend)
processor := testrig.NewTestProcessor(dbService, storageBackend, federator)
- if err := processor.Start(); err != nil {
+ if err := processor.Start(ctx); err != nil {
return fmt.Errorf("error starting processor: %s", err)
}
diff --git a/internal/db/account.go b/internal/db/account.go
index 0e1575f9b..058a89859 100644
--- a/internal/db/account.go
+++ b/internal/db/account.go
@@ -19,6 +19,7 @@
package db
import (
+ "context"
"time"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
@@ -27,40 +28,43 @@ import (
// Account contains functions related to account getting/setting/creation.
type Account interface {
// GetAccountByID returns one account with the given ID, or an error if something goes wrong.
- GetAccountByID(id string) (*gtsmodel.Account, Error)
+ GetAccountByID(ctx context.Context, id string) (*gtsmodel.Account, Error)
// GetAccountByURI returns one account with the given URI, or an error if something goes wrong.
- GetAccountByURI(uri string) (*gtsmodel.Account, Error)
+ GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.Account, Error)
// GetAccountByURL returns one account with the given URL, or an error if something goes wrong.
- GetAccountByURL(uri string) (*gtsmodel.Account, Error)
+ GetAccountByURL(ctx context.Context, uri string) (*gtsmodel.Account, Error)
+
+ // UpdateAccount updates one account by ID.
+ UpdateAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, Error)
// GetLocalAccountByUsername returns an account on this instance by its username.
- GetLocalAccountByUsername(username string) (*gtsmodel.Account, Error)
+ GetLocalAccountByUsername(ctx context.Context, username string) (*gtsmodel.Account, Error)
// GetAccountFaves fetches faves/likes created by the target accountID.
- GetAccountFaves(accountID string) ([]*gtsmodel.StatusFave, Error)
+ GetAccountFaves(ctx context.Context, accountID string) ([]*gtsmodel.StatusFave, Error)
// GetAccountStatusesCount is a shortcut for the common action of counting statuses produced by accountID.
- CountAccountStatuses(accountID string) (int, Error)
+ CountAccountStatuses(ctx context.Context, accountID string) (int, Error)
// GetAccountStatuses is a shortcut for getting the most recent statuses. accountID is optional, if not provided
// then all statuses will be returned. If limit is set to 0, the size of the returned slice will not be limited. This can
// be very memory intensive so you probably shouldn't do this!
// In case of no entries, a 'no entries' error will be returned
- GetAccountStatuses(accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, Error)
+ GetAccountStatuses(ctx context.Context, accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, Error)
- GetAccountBlocks(accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, Error)
+ GetAccountBlocks(ctx context.Context, accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, Error)
// GetAccountLastPosted simply gets the timestamp of the most recent post by the account.
//
// The returned time will be zero if account has never posted anything.
- GetAccountLastPosted(accountID string) (time.Time, Error)
+ GetAccountLastPosted(ctx context.Context, accountID string) (time.Time, Error)
// SetAccountHeaderOrAvatar sets the header or avatar for the given accountID to the given media attachment.
- SetAccountHeaderOrAvatar(mediaAttachment *gtsmodel.MediaAttachment, accountID string) Error
+ SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachment *gtsmodel.MediaAttachment, accountID string) Error
// GetInstanceAccount returns the instance account for the given domain.
// If domain is empty, this instance account will be returned.
- GetInstanceAccount(domain string) (*gtsmodel.Account, Error)
+ GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, Error)
}
diff --git a/internal/db/admin.go b/internal/db/admin.go
index aa2b22f47..24d628e84 100644
--- a/internal/db/admin.go
+++ b/internal/db/admin.go
@@ -19,6 +19,7 @@
package db
import (
+ "context"
"net"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
@@ -28,26 +29,26 @@ import (
type Admin interface {
// IsUsernameAvailable checks whether a given username is available on our domain.
// Returns an error if the username is already taken, or something went wrong in the db.
- IsUsernameAvailable(username string) Error
+ IsUsernameAvailable(ctx context.Context, username string) (bool, Error)
// IsEmailAvailable checks whether a given email address for a new account is available to be used on our domain.
// Return an error if:
// A) the email is already associated with an account
// B) we block signups from this email domain
// C) something went wrong in the db
- IsEmailAvailable(email string) Error
+ IsEmailAvailable(ctx context.Context, email string) (bool, Error)
// NewSignup creates a new user in the database with the given parameters.
// By the time this function is called, it should be assumed that all the parameters have passed validation!
- NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, Error)
+ NewSignup(ctx context.Context, username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, Error)
// CreateInstanceAccount creates an account in the database with the same username as the instance host value.
// Ie., if the instance is hosted at 'example.org' the instance user will have a username of 'example.org'.
// This is needed for things like serving files that belong to the instance and not an individual user/account.
- CreateInstanceAccount() Error
+ CreateInstanceAccount(ctx context.Context) Error
// CreateInstanceInstance creates an instance in the database with the same domain as the instance host value.
// Ie., if the instance is hosted at 'example.org' the instance will have a domain of 'example.org'.
// This is needed for things like serving instance information through /api/v1/instance
- CreateInstanceInstance() Error
+ CreateInstanceInstance(ctx context.Context) Error
}
diff --git a/internal/db/basic.go b/internal/db/basic.go
index 729920bba..cf65ddc09 100644
--- a/internal/db/basic.go
+++ b/internal/db/basic.go
@@ -24,15 +24,11 @@ import "context"
type Basic interface {
// CreateTable creates a table for the given interface.
// For implementations that don't use tables, this can just return nil.
- CreateTable(i interface{}) Error
+ CreateTable(ctx context.Context, i interface{}) Error
// DropTable drops the table for the given interface.
// For implementations that don't use tables, this can just return nil.
- DropTable(i interface{}) Error
-
- // RegisterTable registers a table for use in many2many relations.
- // For implementations that don't use tables, or many2many relations, this can just return nil.
- RegisterTable(i interface{}) Error
+ DropTable(ctx context.Context, i interface{}) Error
// Stop should stop and close the database connection cleanly, returning an error if this is not possible.
// If the database implementation doesn't need to be stopped, this can just return nil.
@@ -45,43 +41,38 @@ type Basic interface {
// for other implementations (for example, in-memory) it might just be the key of a map.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
// In case of no entries, a 'no entries' error will be returned
- GetByID(id string, i interface{}) Error
+ GetByID(ctx context.Context, id string, i interface{}) Error
// GetWhere gets one entry where key = value. This is similar to GetByID but allows the caller to specify the
// name of the key to select from.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
// In case of no entries, a 'no entries' error will be returned
- GetWhere(where []Where, i interface{}) Error
+ GetWhere(ctx context.Context, where []Where, i interface{}) Error
// GetAll will try to get all entries of type i.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
// In case of no entries, a 'no entries' error will be returned
- GetAll(i interface{}) Error
+ GetAll(ctx context.Context, i interface{}) Error
// Put simply stores i. It is up to the implementation to figure out how to store it, and using what key.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
- Put(i interface{}) Error
-
- // Upsert stores or updates i based on the given conflict column, as in https://www.postgresqltutorial.com/postgresql-upsert/
- // It is up to the implementation to figure out how to store it, and using what key.
- // The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
- Upsert(i interface{}, conflictColumn string) Error
+ Put(ctx context.Context, i interface{}) Error
// UpdateByID updates i with id id.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
- UpdateByID(id string, i interface{}) Error
+ UpdateByID(ctx context.Context, id string, i interface{}) Error
// UpdateOneByID updates interface i with database the given database id. It will update one field of key key and value value.
- UpdateOneByID(id string, key string, value interface{}, i interface{}) Error
+ UpdateOneByID(ctx context.Context, id string, key string, value interface{}, i interface{}) Error
// UpdateWhere updates column key of interface i with the given value, where the given parameters apply.
- UpdateWhere(where []Where, key string, value interface{}, i interface{}) Error
+ UpdateWhere(ctx context.Context, where []Where, key string, value interface{}, i interface{}) Error
// DeleteByID removes i with id id.
// If i didn't exist anyway, then no error should be returned.
- DeleteByID(id string, i interface{}) Error
+ DeleteByID(ctx context.Context, id string, i interface{}) Error
// DeleteWhere deletes i where key = value
// If i didn't exist anyway, then no error should be returned.
- DeleteWhere(where []Where, i interface{}) Error
+ DeleteWhere(ctx context.Context, where []Where, i interface{}) Error
}
diff --git a/internal/db/pg/account.go b/internal/db/bundb/account.go
index 3889c6601..7ebb79a15 100644
--- a/internal/db/pg/account.go
+++ b/internal/db/bundb/account.go
@@ -16,70 +16,90 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
-package pg
+package bundb
import (
"context"
"errors"
"fmt"
+ "strings"
"time"
- "github.com/go-pg/pg/v10"
- "github.com/go-pg/pg/v10/orm"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/uptrace/bun"
)
type accountDB struct {
config *config.Config
- conn *pg.DB
+ conn *bun.DB
log *logrus.Logger
- cancel context.CancelFunc
}
-func (a *accountDB) newAccountQ(account *gtsmodel.Account) *orm.Query {
- return a.conn.Model(account).
+func (a *accountDB) newAccountQ(account *gtsmodel.Account) *bun.SelectQuery {
+ return a.conn.
+ NewSelect().
+ Model(account).
Relation("AvatarMediaAttachment").
Relation("HeaderMediaAttachment")
}
-func (a *accountDB) GetAccountByID(id string) (*gtsmodel.Account, db.Error) {
- account := &gtsmodel.Account{}
+func (a *accountDB) GetAccountByID(ctx context.Context, id string) (*gtsmodel.Account, db.Error) {
+ account := new(gtsmodel.Account)
q := a.newAccountQ(account).
Where("account.id = ?", id)
- err := processErrorResponse(q.Select())
+ err := processErrorResponse(q.Scan(ctx))
return account, err
}
-func (a *accountDB) GetAccountByURI(uri string) (*gtsmodel.Account, db.Error) {
- account := &gtsmodel.Account{}
+func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) {
+ account := new(gtsmodel.Account)
q := a.newAccountQ(account).
Where("account.uri = ?", uri)
- err := processErrorResponse(q.Select())
+ err := processErrorResponse(q.Scan(ctx))
return account, err
}
-func (a *accountDB) GetAccountByURL(uri string) (*gtsmodel.Account, db.Error) {
- account := &gtsmodel.Account{}
+func (a *accountDB) GetAccountByURL(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) {
+ account := new(gtsmodel.Account)
q := a.newAccountQ(account).
Where("account.url = ?", uri)
- err := processErrorResponse(q.Select())
+ err := processErrorResponse(q.Scan(ctx))
return account, err
}
-func (a *accountDB) GetInstanceAccount(domain string) (*gtsmodel.Account, db.Error) {
- account := &gtsmodel.Account{}
+func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, db.Error) {
+ if strings.TrimSpace(account.ID) == "" {
+ return nil, errors.New("account had no ID")
+ }
+
+ account.UpdatedAt = time.Now()
+
+ q := a.conn.
+ NewUpdate().
+ Model(account).
+ WherePK()
+
+ _, err := q.Exec(ctx)
+
+ err = processErrorResponse(err)
+
+ return account, err
+}
+
+func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, db.Error) {
+ account := new(gtsmodel.Account)
q := a.newAccountQ(account)
@@ -90,29 +110,31 @@ func (a *accountDB) GetInstanceAccount(domain string) (*gtsmodel.Account, db.Err
} else {
q = q.
Where("account.username = ?", domain).
- Where("? IS NULL", pg.Ident("domain"))
+ Where("? IS NULL", bun.Ident("domain"))
}
- err := processErrorResponse(q.Select())
+ err := processErrorResponse(q.Scan(ctx))
return account, err
}
-func (a *accountDB) GetAccountLastPosted(accountID string) (time.Time, db.Error) {
- status := &gtsmodel.Status{}
+func (a *accountDB) GetAccountLastPosted(ctx context.Context, accountID string) (time.Time, db.Error) {
+ status := new(gtsmodel.Status)
- q := a.conn.Model(status).
+ q := a.conn.
+ NewSelect().
+ Model(status).
Order("id DESC").
Limit(1).
Where("account_id = ?", accountID).
Column("created_at")
- err := processErrorResponse(q.Select())
+ err := processErrorResponse(q.Scan(ctx))
return status.CreatedAt, err
}
-func (a *accountDB) SetAccountHeaderOrAvatar(mediaAttachment *gtsmodel.MediaAttachment, accountID string) db.Error {
+func (a *accountDB) SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachment *gtsmodel.MediaAttachment, accountID string) db.Error {
if mediaAttachment.Avatar && mediaAttachment.Header {
return errors.New("one media attachment cannot be both header and avatar")
}
@@ -127,51 +149,66 @@ func (a *accountDB) SetAccountHeaderOrAvatar(mediaAttachment *gtsmodel.MediaAtta
}
// TODO: there are probably more side effects here that need to be handled
- if _, err := a.conn.Model(mediaAttachment).OnConflict("(id) DO UPDATE").Insert(); err != nil {
+ if _, err := a.conn.
+ NewInsert().
+ Model(mediaAttachment).
+ Exec(ctx); err != nil {
return err
}
- if _, err := a.conn.Model(&gtsmodel.Account{}).Set(fmt.Sprintf("%s_media_attachment_id = ?", headerOrAVI), mediaAttachment.ID).Where("id = ?", accountID).Update(); err != nil {
+ if _, err := a.conn.
+ NewUpdate().
+ Model(&gtsmodel.Account{}).
+ Set(fmt.Sprintf("%s_media_attachment_id = ?", headerOrAVI), mediaAttachment.ID).
+ Where("id = ?", accountID).
+ Exec(ctx); err != nil {
return err
}
return nil
}
-func (a *accountDB) GetLocalAccountByUsername(username string) (*gtsmodel.Account, db.Error) {
- account := &gtsmodel.Account{}
+func (a *accountDB) GetLocalAccountByUsername(ctx context.Context, username string) (*gtsmodel.Account, db.Error) {
+ account := new(gtsmodel.Account)
q := a.newAccountQ(account).
Where("username = ?", username).
- Where("? IS NULL", pg.Ident("domain"))
+ Where("? IS NULL", bun.Ident("domain"))
- err := processErrorResponse(q.Select())
+ err := processErrorResponse(q.Scan(ctx))
return account, err
}
-func (a *accountDB) GetAccountFaves(accountID string) ([]*gtsmodel.StatusFave, db.Error) {
- faves := []*gtsmodel.StatusFave{}
+func (a *accountDB) GetAccountFaves(ctx context.Context, accountID string) ([]*gtsmodel.StatusFave, db.Error) {
+ faves := new([]*gtsmodel.StatusFave)
- if err := a.conn.Model(&faves).
+ if err := a.conn.
+ NewSelect().
+ Model(faves).
Where("account_id = ?", accountID).
- Select(); err != nil {
- if err == pg.ErrNoRows {
- return faves, nil
- }
+ Scan(ctx); err != nil {
return nil, err
}
- return faves, nil
+ return *faves, nil
}
-func (a *accountDB) CountAccountStatuses(accountID string) (int, db.Error) {
- return a.conn.Model(&gtsmodel.Status{}).Where("account_id = ?", accountID).Count()
+func (a *accountDB) CountAccountStatuses(ctx context.Context, accountID string) (int, db.Error) {
+ return a.conn.
+ NewSelect().
+ Model(&gtsmodel.Status{}).
+ Where("account_id = ?", accountID).
+ Count(ctx)
}
-func (a *accountDB) GetAccountStatuses(accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, db.Error) {
+func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, db.Error) {
a.log.Debugf("getting statuses for account %s", accountID)
statuses := []*gtsmodel.Status{}
- q := a.conn.Model(&statuses).Order("id DESC")
+ q := a.conn.
+ NewSelect().
+ Model(&statuses).
+ Order("id DESC")
+
if accountID != "" {
q = q.Where("account_id = ?", accountID)
}
@@ -181,27 +218,26 @@ func (a *accountDB) GetAccountStatuses(accountID string, limit int, excludeRepli
}
if excludeReplies {
- q = q.Where("? IS NULL", pg.Ident("in_reply_to_id"))
+ q = q.Where("? IS NULL", bun.Ident("in_reply_to_id"))
}
if pinnedOnly {
q = q.Where("pinned = ?", true)
}
- if mediaOnly {
- q = q.WhereGroup(func(q *pg.Query) (*pg.Query, error) {
- return q.Where("? IS NOT NULL", pg.Ident("attachments")).Where("attachments != '{}'"), nil
- })
- }
-
if maxID != "" {
q = q.Where("id < ?", maxID)
}
- if err := q.Select(); err != nil {
- if err == pg.ErrNoRows {
- return nil, db.ErrNoEntries
- }
+ if mediaOnly {
+ q = q.WhereGroup(" AND ", func(q *bun.SelectQuery) *bun.SelectQuery {
+ return q.
+ WhereOr("? IS NOT NULL", bun.Ident("attachments")).
+ WhereOr("attachments != '{}'")
+ })
+ }
+
+ if err := q.Scan(ctx); err != nil {
return nil, err
}
@@ -213,10 +249,12 @@ func (a *accountDB) GetAccountStatuses(accountID string, limit int, excludeRepli
return statuses, nil
}
-func (a *accountDB) GetAccountBlocks(accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, db.Error) {
+func (a *accountDB) GetAccountBlocks(ctx context.Context, accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, db.Error) {
blocks := []*gtsmodel.Block{}
- fq := a.conn.Model(&blocks).
+ fq := a.conn.
+ NewSelect().
+ Model(&blocks).
Where("block.account_id = ?", accountID).
Relation("TargetAccount").
Order("block.id DESC")
@@ -233,11 +271,8 @@ func (a *accountDB) GetAccountBlocks(accountID string, maxID string, sinceID str
fq = fq.Limit(limit)
}
- err := fq.Select()
+ err := fq.Scan(ctx)
if err != nil {
- if err == pg.ErrNoRows {
- return nil, "", "", db.ErrNoEntries
- }
return nil, "", "", err
}
diff --git a/internal/db/pg/account_test.go b/internal/db/bundb/account_test.go
index 7ea5ff39a..7174b781d 100644
--- a/internal/db/pg/account_test.go
+++ b/internal/db/bundb/account_test.go
@@ -16,17 +16,19 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
-package pg_test
+package bundb_test
import (
+ "context"
"testing"
+ "time"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/testrig"
)
type AccountTestSuite struct {
- PGStandardTestSuite
+ BunDBStandardTestSuite
}
func (suite *AccountTestSuite) SetupSuite() {
@@ -54,7 +56,7 @@ func (suite *AccountTestSuite) TearDownTest() {
}
func (suite *AccountTestSuite) TestGetAccountByIDWithExtras() {
- account, err := suite.db.GetAccountByID(suite.testAccounts["local_account_1"].ID)
+ account, err := suite.db.GetAccountByID(context.Background(), suite.testAccounts["local_account_1"].ID)
if err != nil {
suite.FailNow(err.Error())
}
@@ -65,6 +67,20 @@ func (suite *AccountTestSuite) TestGetAccountByIDWithExtras() {
suite.NotEmpty(account.HeaderMediaAttachment.URL)
}
+func (suite *AccountTestSuite) TestUpdateAccount() {
+ testAccount := suite.testAccounts["local_account_1"]
+
+ testAccount.DisplayName = "new display name!"
+
+ _, err := suite.db.UpdateAccount(context.Background(), testAccount)
+ suite.NoError(err)
+
+ updated, err := suite.db.GetAccountByID(context.Background(), testAccount.ID)
+ suite.NoError(err)
+ suite.Equal("new display name!", updated.DisplayName)
+ suite.WithinDuration(time.Now(), updated.UpdatedAt, 5*time.Second)
+}
+
func TestAccountTestSuite(t *testing.T) {
suite.Run(t, new(AccountTestSuite))
}
diff --git a/internal/db/pg/admin.go b/internal/db/bundb/admin.go
index 854f56ef0..67a1e8a0d 100644
--- a/internal/db/pg/admin.go
+++ b/internal/db/bundb/admin.go
@@ -16,76 +16,76 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
-package pg
+package bundb
import (
"context"
"crypto/rand"
"crypto/rsa"
+ "database/sql"
"fmt"
"net"
"net/mail"
"strings"
"time"
- "github.com/go-pg/pg/v10"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/superseriousbusiness/gotosocial/internal/util"
+ "github.com/uptrace/bun"
"golang.org/x/crypto/bcrypt"
)
type adminDB struct {
config *config.Config
- conn *pg.DB
+ conn *bun.DB
log *logrus.Logger
- cancel context.CancelFunc
}
-func (a *adminDB) IsUsernameAvailable(username string) db.Error {
- // if no error we fail because it means we found something
- // if error but it's not pg.ErrNoRows then we fail
- // if err is pg.ErrNoRows we're good, we found nothing so continue
- if err := a.conn.Model(&gtsmodel.Account{}).Where("username = ?", username).Where("domain = ?", nil).Select(); err == nil {
- return fmt.Errorf("username %s already in use", username)
- } else if err != pg.ErrNoRows {
- return fmt.Errorf("db error: %s", err)
- }
- return nil
+func (a *adminDB) IsUsernameAvailable(ctx context.Context, username string) (bool, db.Error) {
+ q := a.conn.
+ NewSelect().
+ Model(&gtsmodel.Account{}).
+ Where("username = ?", username).
+ Where("domain = ?", nil)
+
+ return notExists(ctx, q)
}
-func (a *adminDB) IsEmailAvailable(email string) db.Error {
+func (a *adminDB) IsEmailAvailable(ctx context.Context, email string) (bool, db.Error) {
// parse the domain from the email
m, err := mail.ParseAddress(email)
if err != nil {
- return fmt.Errorf("error parsing email address %s: %s", email, err)
+ return false, fmt.Errorf("error parsing email address %s: %s", email, err)
}
domain := strings.Split(m.Address, "@")[1] // domain will always be the second part after @
// check if the email domain is blocked
- if err := a.conn.Model(&gtsmodel.EmailDomainBlock{}).Where("domain = ?", domain).Select(); err == nil {
+ if err := a.conn.
+ NewSelect().
+ Model(&gtsmodel.EmailDomainBlock{}).
+ Where("domain = ?", domain).
+ Scan(ctx); err == nil {
// fail because we found something
- return fmt.Errorf("email domain %s is blocked", domain)
- } else if err != pg.ErrNoRows {
- // fail because we got an unexpected error
- return fmt.Errorf("db error: %s", err)
+ return false, fmt.Errorf("email domain %s is blocked", domain)
+ } else if err != sql.ErrNoRows {
+ return false, processErrorResponse(err)
}
// check if this email is associated with a user already
- if err := a.conn.Model(&gtsmodel.User{}).Where("email = ?", email).WhereOr("unconfirmed_email = ?", email).Select(); err == nil {
- // fail because we found something
- return fmt.Errorf("email %s already in use", email)
- } else if err != pg.ErrNoRows {
- // fail because we got an unexpected error
- return fmt.Errorf("db error: %s", err)
- }
- return nil
+ q := a.conn.
+ NewSelect().
+ Model(&gtsmodel.User{}).
+ Where("email = ?", email).
+ WhereOr("unconfirmed_email = ?", email)
+
+ return notExists(ctx, q)
}
-func (a *adminDB) NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, db.Error) {
+func (a *adminDB) NewSignup(ctx context.Context, username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, db.Error) {
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
a.log.Errorf("error creating new rsa key: %s", err)
@@ -94,13 +94,12 @@ func (a *adminDB) NewSignup(username string, reason string, requireApproval bool
// if something went wrong while creating a user, we might already have an account, so check here first...
acct := &gtsmodel.Account{}
- err = a.conn.Model(acct).Where("username = ?", username).Where("? IS NULL", pg.Ident("domain")).Select()
+ err = a.conn.NewSelect().
+ Model(acct).
+ Where("username = ?", username).
+ Where("? IS NULL", bun.Ident("domain")).
+ Scan(ctx)
if err != nil {
- // there's been an actual error
- if err != pg.ErrNoRows {
- return nil, fmt.Errorf("db error checking existence of account: %s", err)
- }
-
// we just don't have an account yet create one
newAccountURIs := util.GenerateURIsForAccount(username, a.config.Protocol, a.config.Host)
newAccountID, err := id.NewRandomULID()
@@ -125,7 +124,10 @@ func (a *adminDB) NewSignup(username string, reason string, requireApproval bool
FollowingURI: newAccountURIs.FollowingURI,
FeaturedCollectionURI: newAccountURIs.CollectionURI,
}
- if _, err = a.conn.Model(acct).Insert(); err != nil {
+ if _, err = a.conn.
+ NewInsert().
+ Model(acct).
+ Exec(ctx); err != nil {
return nil, err
}
}
@@ -161,15 +163,33 @@ func (a *adminDB) NewSignup(username string, reason string, requireApproval bool
u.Moderator = true
}
- if _, err = a.conn.Model(u).Insert(); err != nil {
+ if _, err = a.conn.
+ NewInsert().
+ Model(u).
+ Exec(ctx); err != nil {
return nil, err
}
return u, nil
}
-func (a *adminDB) CreateInstanceAccount() db.Error {
+func (a *adminDB) CreateInstanceAccount(ctx context.Context) db.Error {
username := a.config.Host
+
+ // check if instance account already exists
+ existsQ := a.conn.
+ NewSelect().
+ Model(&gtsmodel.Account{}).
+ Where("username = ?", username).
+ Where("? IS NULL", bun.Ident("domain"))
+ count, err := existsQ.Count(ctx)
+ if err != nil && count == 1 {
+ a.log.Infof("instance account %s already exists", username)
+ return nil
+ } else if err != sql.ErrNoRows {
+ return processErrorResponse(err)
+ }
+
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
a.log.Errorf("error creating new rsa key: %s", err)
@@ -198,19 +218,36 @@ func (a *adminDB) CreateInstanceAccount() db.Error {
FollowingURI: newAccountURIs.FollowingURI,
FeaturedCollectionURI: newAccountURIs.CollectionURI,
}
- inserted, err := a.conn.Model(acct).Where("username = ?", username).SelectOrInsert()
- if err != nil {
+
+ insertQ := a.conn.
+ NewInsert().
+ Model(acct)
+
+ if _, err := insertQ.Exec(ctx); err != nil {
return err
}
- if inserted {
- a.log.Infof("created instance account %s with id %s", username, acct.ID)
- } else {
- a.log.Infof("instance account %s already exists with id %s", username, acct.ID)
- }
+
+ a.log.Infof("instance account %s CREATED with id %s", username, acct.ID)
return nil
}
-func (a *adminDB) CreateInstanceInstance() db.Error {
+func (a *adminDB) CreateInstanceInstance(ctx context.Context) db.Error {
+ domain := a.config.Host
+
+ // check if instance entry already exists
+ existsQ := a.conn.
+ NewSelect().
+ Model(&gtsmodel.Instance{}).
+ Where("domain = ?", domain)
+
+ count, err := existsQ.Count(ctx)
+ if err != nil && count == 1 {
+ a.log.Infof("instance instance %s already exists", domain)
+ return nil
+ } else if err != sql.ErrNoRows {
+ return processErrorResponse(err)
+ }
+
iID, err := id.NewRandomULID()
if err != nil {
return err
@@ -218,18 +255,18 @@ func (a *adminDB) CreateInstanceInstance() db.Error {
i := &gtsmodel.Instance{
ID: iID,
- Domain: a.config.Host,
- Title: a.config.Host,
+ Domain: domain,
+ Title: domain,
URI: fmt.Sprintf("%s://%s", a.config.Protocol, a.config.Host),
}
- inserted, err := a.conn.Model(i).Where("domain = ?", a.config.Host).SelectOrInsert()
- if err != nil {
+
+ insertQ := a.conn.
+ NewInsert().
+ Model(i)
+
+ if _, err := insertQ.Exec(ctx); err != nil {
return err
}
- if inserted {
- a.log.Infof("created instance instance %s with id %s", a.config.Host, i.ID)
- } else {
- a.log.Infof("instance instance %s already exists with id %s", a.config.Host, i.ID)
- }
+ a.log.Infof("created instance instance %s with id %s", domain, i.ID)
return nil
}
diff --git a/internal/db/bundb/basic.go b/internal/db/bundb/basic.go
new file mode 100644
index 000000000..983b6b810
--- /dev/null
+++ b/internal/db/bundb/basic.go
@@ -0,0 +1,179 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package bundb
+
+import (
+ "context"
+ "errors"
+ "strings"
+
+ "github.com/sirupsen/logrus"
+ "github.com/superseriousbusiness/gotosocial/internal/config"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/uptrace/bun"
+)
+
+type basicDB struct {
+ config *config.Config
+ conn *bun.DB
+ log *logrus.Logger
+}
+
+func (b *basicDB) Put(ctx context.Context, i interface{}) db.Error {
+ _, err := b.conn.NewInsert().Model(i).Exec(ctx)
+ if err != nil && strings.Contains(err.Error(), "duplicate key value violates unique constraint") {
+ return db.ErrAlreadyExists
+ }
+ return err
+}
+
+func (b *basicDB) GetByID(ctx context.Context, id string, i interface{}) db.Error {
+ q := b.conn.
+ NewSelect().
+ Model(i).
+ Where("id = ?", id)
+
+ return processErrorResponse(q.Scan(ctx))
+}
+
+func (b *basicDB) GetWhere(ctx context.Context, where []db.Where, i interface{}) db.Error {
+ if len(where) == 0 {
+ return errors.New("no queries provided")
+ }
+
+ q := b.conn.NewSelect().Model(i)
+ for _, w := range where {
+
+ if w.Value == nil {
+ q = q.Where("? IS NULL", bun.Ident(w.Key))
+ } else {
+ if w.CaseInsensitive {
+ q = q.Where("LOWER(?) = LOWER(?)", bun.Safe(w.Key), w.Value)
+ } else {
+ q = q.Where("? = ?", bun.Safe(w.Key), w.Value)
+ }
+ }
+ }
+
+ return processErrorResponse(q.Scan(ctx))
+}
+
+func (b *basicDB) GetAll(ctx context.Context, i interface{}) db.Error {
+ q := b.conn.
+ NewSelect().
+ Model(i)
+
+ return processErrorResponse(q.Scan(ctx))
+}
+
+func (b *basicDB) DeleteByID(ctx context.Context, id string, i interface{}) db.Error {
+ q := b.conn.
+ NewDelete().
+ Model(i).
+ Where("id = ?", id)
+
+ _, err := q.Exec(ctx)
+
+ return processErrorResponse(err)
+}
+
+func (b *basicDB) DeleteWhere(ctx context.Context, where []db.Where, i interface{}) db.Error {
+ if len(where) == 0 {
+ return errors.New("no queries provided")
+ }
+
+ q := b.conn.
+ NewDelete().
+ Model(i)
+
+ for _, w := range where {
+ q = q.Where("? = ?", bun.Safe(w.Key), w.Value)
+ }
+
+ _, err := q.Exec(ctx)
+
+ return processErrorResponse(err)
+}
+
+func (b *basicDB) UpdateByID(ctx context.Context, id string, i interface{}) db.Error {
+ q := b.conn.
+ NewUpdate().
+ Model(i).
+ WherePK()
+
+ _, err := q.Exec(ctx)
+
+ return processErrorResponse(err)
+}
+
+func (b *basicDB) UpdateOneByID(ctx context.Context, id string, key string, value interface{}, i interface{}) db.Error {
+ q := b.conn.NewUpdate().
+ Model(i).
+ Set("? = ?", bun.Safe(key), value).
+ WherePK()
+
+ _, err := q.Exec(ctx)
+
+ return processErrorResponse(err)
+}
+
+func (b *basicDB) UpdateWhere(ctx context.Context, where []db.Where, key string, value interface{}, i interface{}) db.Error {
+ q := b.conn.NewUpdate().Model(i)
+
+ for _, w := range where {
+ if w.Value == nil {
+ q = q.Where("? IS NULL", bun.Ident(w.Key))
+ } else {
+ if w.CaseInsensitive {
+ q = q.Where("LOWER(?) = LOWER(?)", bun.Safe(w.Key), w.Value)
+ } else {
+ q = q.Where("? = ?", bun.Safe(w.Key), w.Value)
+ }
+ }
+ }
+
+ q = q.Set("? = ?", bun.Safe(key), value)
+
+ _, err := q.Exec(ctx)
+
+ return processErrorResponse(err)
+}
+
+func (b *basicDB) CreateTable(ctx context.Context, i interface{}) db.Error {
+ _, err := b.conn.NewCreateTable().Model(i).IfNotExists().Exec(ctx)
+ return err
+}
+
+func (b *basicDB) DropTable(ctx context.Context, i interface{}) db.Error {
+ _, err := b.conn.NewDropTable().Model(i).IfExists().Exec(ctx)
+ return processErrorResponse(err)
+}
+
+func (b *basicDB) IsHealthy(ctx context.Context) db.Error {
+ return b.conn.Ping()
+}
+
+func (b *basicDB) Stop(ctx context.Context) db.Error {
+ b.log.Info("closing db connection")
+ if err := b.conn.Close(); err != nil {
+ // only cancel if there's a problem closing the db
+ return err
+ }
+ return nil
+}
diff --git a/internal/db/bundb/basic_test.go b/internal/db/bundb/basic_test.go
new file mode 100644
index 000000000..9189618c9
--- /dev/null
+++ b/internal/db/bundb/basic_test.go
@@ -0,0 +1,68 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package bundb_test
+
+import (
+ "context"
+ "testing"
+
+ "github.com/stretchr/testify/suite"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/testrig"
+)
+
+type BasicTestSuite struct {
+ BunDBStandardTestSuite
+}
+
+func (suite *BasicTestSuite) SetupSuite() {
+ suite.testTokens = testrig.NewTestTokens()
+ suite.testClients = testrig.NewTestClients()
+ suite.testApplications = testrig.NewTestApplications()
+ suite.testUsers = testrig.NewTestUsers()
+ suite.testAccounts = testrig.NewTestAccounts()
+ suite.testAttachments = testrig.NewTestAttachments()
+ suite.testStatuses = testrig.NewTestStatuses()
+ suite.testTags = testrig.NewTestTags()
+ suite.testMentions = testrig.NewTestMentions()
+}
+
+func (suite *BasicTestSuite) SetupTest() {
+ suite.config = testrig.NewTestConfig()
+ suite.db = testrig.NewTestDB()
+ suite.log = testrig.NewTestLog()
+
+ testrig.StandardDBSetup(suite.db, suite.testAccounts)
+}
+
+func (suite *BasicTestSuite) TearDownTest() {
+ testrig.StandardDBTeardown(suite.db)
+}
+
+func (suite *BasicTestSuite) TestGetAccountByID() {
+ testAccount := suite.testAccounts["local_account_1"]
+
+ a := &gtsmodel.Account{}
+ err := suite.db.GetByID(context.Background(), testAccount.ID, a)
+ suite.NoError(err)
+}
+
+func TestBasicTestSuite(t *testing.T) {
+ suite.Run(t, new(BasicTestSuite))
+}
diff --git a/internal/db/pg/pg.go b/internal/db/bundb/bundb.go
index 0437baf02..49ed09cbd 100644
--- a/internal/db/pg/pg.go
+++ b/internal/db/bundb/bundb.go
@@ -16,12 +16,13 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
-package pg
+package bundb
import (
"context"
"crypto/tls"
"crypto/x509"
+ "database/sql"
"encoding/pem"
"errors"
"fmt"
@@ -29,14 +30,20 @@ import (
"strings"
"time"
- "github.com/go-pg/pg/extra/pgdebug"
- "github.com/go-pg/pg/v10"
- "github.com/go-pg/pg/v10/orm"
+ "github.com/jackc/pgx/v4"
+ "github.com/jackc/pgx/v4/stdlib"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id"
+ "github.com/uptrace/bun"
+ "github.com/uptrace/bun/dialect/pgdialect"
+)
+
+const (
+ dbTypePostgres = "postgres"
+ dbTypeSqlite = "sqlite"
)
var registerTables []interface{} = []interface{}{
@@ -44,8 +51,8 @@ var registerTables []interface{} = []interface{}{
&gtsmodel.StatusToTag{},
}
-// postgresService satisfies the DB interface
-type postgresService struct {
+// bunDBService satisfies the DB interface
+type bunDBService struct {
db.Account
db.Admin
db.Basic
@@ -55,130 +62,115 @@ type postgresService struct {
db.Mention
db.Notification
db.Relationship
+ db.Session
db.Status
db.Timeline
config *config.Config
- conn *pg.DB
+ conn *bun.DB
log *logrus.Logger
- cancel context.CancelFunc
}
-// NewPostgresService returns a postgresService derived from the provided config, which implements the go-fed DB interface.
-// Under the hood, it uses https://github.com/go-pg/pg to create and maintain a database connection.
-func NewPostgresService(ctx context.Context, c *config.Config, log *logrus.Logger) (db.DB, error) {
- for _, t := range registerTables {
- // https://pg.uptrace.dev/orm/many-to-many-relation/
- orm.RegisterTable(t)
- }
-
- opts, err := derivePGOptions(c)
- if err != nil {
- return nil, fmt.Errorf("could not create postgres service: %s", err)
- }
- log.Debugf("using pg options: %+v", opts)
-
- // create a connection
- pgCtx, cancel := context.WithCancel(ctx)
- conn := pg.Connect(opts).WithContext(pgCtx)
-
- // this will break the logfmt format we normally log in,
- // since we can't choose where pg outputs to and it defaults to
- // stdout. So use this option with care!
- if log.GetLevel() >= logrus.TraceLevel {
- conn.AddQueryHook(pgdebug.DebugHook{
- // Print all queries.
- Verbose: true,
- })
+// NewBunDBService returns a bunDB derived from the provided config, which implements the go-fed DB interface.
+// Under the hood, it uses https://github.com/uptrace/bun to create and maintain a database connection.
+func NewBunDBService(ctx context.Context, c *config.Config, log *logrus.Logger) (db.DB, error) {
+ var sqldb *sql.DB
+ var conn *bun.DB
+
+ // depending on the database type we're trying to create, we need to use a different driver...
+ switch strings.ToLower(c.DBConfig.Type) {
+ case dbTypePostgres:
+ // POSTGRES
+ opts, err := deriveBunDBPGOptions(c)
+ if err != nil {
+ return nil, fmt.Errorf("could not create bundb postgres options: %s", err)
+ }
+ sqldb = stdlib.OpenDB(*opts)
+ conn = bun.NewDB(sqldb, pgdialect.New())
+ case dbTypeSqlite:
+ // SQLITE
+ // TODO: https://bun.uptrace.dev/guide/drivers.html#sqlite
+ default:
+ return nil, fmt.Errorf("database type %s not supported for bundb", strings.ToLower(c.DBConfig.Type))
}
// actually *begin* the connection so that we can tell if the db is there and listening
- if err := conn.Ping(ctx); err != nil {
- cancel()
+ if err := conn.Ping(); err != nil {
return nil, fmt.Errorf("db connection error: %s", err)
}
+ log.Info("connected to database")
- // print out discovered postgres version
- var version string
- if _, err = conn.QueryOneContext(ctx, pg.Scan(&version), "SELECT version()"); err != nil {
- cancel()
- return nil, fmt.Errorf("db connection error: %s", err)
+ for _, t := range registerTables {
+ // https://bun.uptrace.dev/orm/many-to-many-relation/
+ conn.RegisterModel(t)
}
- log.Infof("connected to postgres version: %s", version)
- ps := &postgresService{
+ ps := &bunDBService{
Account: &accountDB{
config: c,
conn: conn,
log: log,
- cancel: cancel,
},
Admin: &adminDB{
config: c,
conn: conn,
log: log,
- cancel: cancel,
},
Basic: &basicDB{
config: c,
conn: conn,
log: log,
- cancel: cancel,
},
Domain: &domainDB{
config: c,
conn: conn,
log: log,
- cancel: cancel,
},
Instance: &instanceDB{
config: c,
conn: conn,
log: log,
- cancel: cancel,
},
Media: &mediaDB{
config: c,
conn: conn,
log: log,
- cancel: cancel,
},
Mention: &mentionDB{
config: c,
conn: conn,
log: log,
- cancel: cancel,
},
Notification: &notificationDB{
config: c,
conn: conn,
log: log,
- cancel: cancel,
},
Relationship: &relationshipDB{
config: c,
conn: conn,
log: log,
- cancel: cancel,
+ },
+ Session: &sessionDB{
+ config: c,
+ conn: conn,
+ log: log,
},
Status: &statusDB{
config: c,
conn: conn,
log: log,
- cancel: cancel,
},
Timeline: &timelineDB{
config: c,
conn: conn,
log: log,
- cancel: cancel,
},
config: c,
conn: conn,
log: log,
- cancel: cancel,
}
- // we can confidently return this useable postgres service now
+ // we can confidently return this useable service now
return ps, nil
}
@@ -186,9 +178,9 @@ func NewPostgresService(ctx context.Context, c *config.Config, log *logrus.Logge
HANDY STUFF
*/
-// derivePGOptions takes an application config and returns either a ready-to-use *pg.Options
+// deriveBunDBPGOptions takes an application config and returns either a ready-to-use set of options
// with sensible defaults, or an error if it's not satisfied by the provided config.
-func derivePGOptions(c *config.Config) (*pg.Options, error) {
+func deriveBunDBPGOptions(c *config.Config) (*pgx.ConnConfig, error) {
if strings.ToUpper(c.DBConfig.Type) != db.DBTypePostgres {
return nil, fmt.Errorf("expected db type of %s but got %s", db.DBTypePostgres, c.DBConfig.Type)
}
@@ -266,18 +258,16 @@ func derivePGOptions(c *config.Config) (*pg.Options, error) {
tlsConfig.RootCAs = certPool
}
- // We can rely on the pg library we're using to set
- // sensible defaults for everything we don't set here.
- options := &pg.Options{
- Addr: fmt.Sprintf("%s:%d", c.DBConfig.Address, c.DBConfig.Port),
- User: c.DBConfig.User,
- Password: c.DBConfig.Password,
- Database: c.DBConfig.Database,
- ApplicationName: c.ApplicationName,
- TLSConfig: tlsConfig,
- }
+ cfg, _ := pgx.ParseConfig("")
+ cfg.Host = c.DBConfig.Address
+ cfg.Port = uint16(c.DBConfig.Port)
+ cfg.User = c.DBConfig.User
+ cfg.Password = c.DBConfig.Password
+ cfg.TLSConfig = tlsConfig
+ cfg.Database = c.DBConfig.Database
+ cfg.PreferSimpleProtocol = true
- return options, nil
+ return cfg, nil
}
/*
@@ -286,9 +276,9 @@ func derivePGOptions(c *config.Config) (*pg.Options, error) {
// TODO: move these to the type converter, it's bananas that they're here and not there
-func (ps *postgresService) MentionStringsToMentions(targetAccounts []string, originAccountID string, statusID string) ([]*gtsmodel.Mention, error) {
+func (ps *bunDBService) MentionStringsToMentions(ctx context.Context, targetAccounts []string, originAccountID string, statusID string) ([]*gtsmodel.Mention, error) {
ogAccount := &gtsmodel.Account{}
- if err := ps.conn.Model(ogAccount).Where("id = ?", originAccountID).Select(); err != nil {
+ if err := ps.conn.NewSelect().Model(ogAccount).Where("id = ?", originAccountID).Scan(ctx); err != nil {
return nil, err
}
@@ -333,14 +323,14 @@ func (ps *postgresService) MentionStringsToMentions(targetAccounts []string, ori
// match username + account, case insensitive
if local {
// local user -- should have a null domain
- err = ps.conn.Model(mentionedAccount).Where("LOWER(?) = LOWER(?)", pg.Ident("username"), username).Where("? IS NULL", pg.Ident("domain")).Select()
+ err = ps.conn.NewSelect().Model(mentionedAccount).Where("LOWER(?) = LOWER(?)", bun.Ident("username"), username).Where("? IS NULL", bun.Ident("domain")).Scan(ctx)
} else {
// remote user -- should have domain defined
- err = ps.conn.Model(mentionedAccount).Where("LOWER(?) = LOWER(?)", pg.Ident("username"), username).Where("LOWER(?) = LOWER(?)", pg.Ident("domain"), domain).Select()
+ err = ps.conn.NewSelect().Model(mentionedAccount).Where("LOWER(?) = LOWER(?)", bun.Ident("username"), username).Where("LOWER(?) = LOWER(?)", bun.Ident("domain"), domain).Scan(ctx)
}
if err != nil {
- if err == pg.ErrNoRows {
+ if err == sql.ErrNoRows {
// no result found for this username/domain so just don't include it as a mencho and carry on about our business
ps.log.Debugf("no account found with username '%s' and domain '%s', skipping it", username, domain)
continue
@@ -364,14 +354,14 @@ func (ps *postgresService) MentionStringsToMentions(targetAccounts []string, ori
return menchies, nil
}
-func (ps *postgresService) TagStringsToTags(tags []string, originAccountID string, statusID string) ([]*gtsmodel.Tag, error) {
+func (ps *bunDBService) TagStringsToTags(ctx context.Context, tags []string, originAccountID string, statusID string) ([]*gtsmodel.Tag, error) {
newTags := []*gtsmodel.Tag{}
for _, t := range tags {
tag := &gtsmodel.Tag{}
// we can use selectorinsert here to create the new tag if it doesn't exist already
// inserted will be true if this is a new tag we just created
- if err := ps.conn.Model(tag).Where("LOWER(?) = LOWER(?)", pg.Ident("name"), t).Select(); err != nil {
- if err == pg.ErrNoRows {
+ if err := ps.conn.NewSelect().Model(tag).Where("LOWER(?) = LOWER(?)", bun.Ident("name"), t).Scan(ctx); err != nil {
+ if err == sql.ErrNoRows {
// tag doesn't exist yet so populate it
newID, err := id.NewRandomULID()
if err != nil {
@@ -400,13 +390,13 @@ func (ps *postgresService) TagStringsToTags(tags []string, originAccountID strin
return newTags, nil
}
-func (ps *postgresService) EmojiStringsToEmojis(emojis []string, originAccountID string, statusID string) ([]*gtsmodel.Emoji, error) {
+func (ps *bunDBService) EmojiStringsToEmojis(ctx context.Context, emojis []string, originAccountID string, statusID string) ([]*gtsmodel.Emoji, error) {
newEmojis := []*gtsmodel.Emoji{}
for _, e := range emojis {
emoji := &gtsmodel.Emoji{}
- err := ps.conn.Model(emoji).Where("shortcode = ?", e).Where("visible_in_picker = true").Where("disabled = false").Select()
+ err := ps.conn.NewSelect().Model(emoji).Where("shortcode = ?", e).Where("visible_in_picker = true").Where("disabled = false").Scan(ctx)
if err != nil {
- if err == pg.ErrNoRows {
+ if err == sql.ErrNoRows {
// no result found for this username/domain so just don't include it as an emoji and carry on about our business
ps.log.Debugf("no emoji found with shortcode %s, skipping it", e)
continue
diff --git a/internal/db/pg/pg_test.go b/internal/db/bundb/bundb_test.go
index c1e10abdf..b789375af 100644
--- a/internal/db/pg/pg_test.go
+++ b/internal/db/bundb/bundb_test.go
@@ -16,7 +16,7 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
-package pg_test
+package bundb_test
import (
"github.com/sirupsen/logrus"
@@ -27,7 +27,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/oauth"
)
-type PGStandardTestSuite struct {
+type BunDBStandardTestSuite struct {
// standard suite interfaces
suite.Suite
config *config.Config
diff --git a/internal/db/pg/domain.go b/internal/db/bundb/domain.go
index 4e9b2ab48..6aa2b8ffe 100644
--- a/internal/db/pg/domain.go
+++ b/internal/db/bundb/domain.go
@@ -16,48 +16,46 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
-package pg
+package bundb
import (
"context"
"net/url"
- "github.com/go-pg/pg/v10"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/util"
+ "github.com/uptrace/bun"
)
type domainDB struct {
config *config.Config
- conn *pg.DB
+ conn *bun.DB
log *logrus.Logger
- cancel context.CancelFunc
}
-func (d *domainDB) IsDomainBlocked(domain string) (bool, db.Error) {
+func (d *domainDB) IsDomainBlocked(ctx context.Context, domain string) (bool, db.Error) {
if domain == "" {
return false, nil
}
- blocked, err := d.conn.
+ q := d.conn.
+ NewSelect().
Model(&gtsmodel.DomainBlock{}).
Where("LOWER(domain) = LOWER(?)", domain).
- Exists()
+ Limit(1)
- err = processErrorResponse(err)
-
- return blocked, err
+ return exists(ctx, q)
}
-func (d *domainDB) AreDomainsBlocked(domains []string) (bool, db.Error) {
+func (d *domainDB) AreDomainsBlocked(ctx context.Context, domains []string) (bool, db.Error) {
// filter out any doubles
uniqueDomains := util.UniqueStrings(domains)
for _, domain := range uniqueDomains {
- if blocked, err := d.IsDomainBlocked(domain); err != nil {
+ if blocked, err := d.IsDomainBlocked(ctx, domain); err != nil {
return false, err
} else if blocked {
return blocked, nil
@@ -68,16 +66,16 @@ func (d *domainDB) AreDomainsBlocked(domains []string) (bool, db.Error) {
return false, nil
}
-func (d *domainDB) IsURIBlocked(uri *url.URL) (bool, db.Error) {
+func (d *domainDB) IsURIBlocked(ctx context.Context, uri *url.URL) (bool, db.Error) {
domain := uri.Hostname()
- return d.IsDomainBlocked(domain)
+ return d.IsDomainBlocked(ctx, domain)
}
-func (d *domainDB) AreURIsBlocked(uris []*url.URL) (bool, db.Error) {
+func (d *domainDB) AreURIsBlocked(ctx context.Context, uris []*url.URL) (bool, db.Error) {
domains := []string{}
for _, uri := range uris {
domains = append(domains, uri.Hostname())
}
- return d.AreDomainsBlocked(domains)
+ return d.AreDomainsBlocked(ctx, domains)
}
diff --git a/internal/db/pg/instance.go b/internal/db/bundb/instance.go
index 968832ca5..f9364346e 100644
--- a/internal/db/pg/instance.go
+++ b/internal/db/bundb/instance.go
@@ -16,43 +16,50 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
-package pg
+package bundb
import (
"context"
- "github.com/go-pg/pg/v10"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/uptrace/bun"
)
type instanceDB struct {
config *config.Config
- conn *pg.DB
+ conn *bun.DB
log *logrus.Logger
- cancel context.CancelFunc
}
-func (i *instanceDB) CountInstanceUsers(domain string) (int, db.Error) {
- q := i.conn.Model(&[]*gtsmodel.Account{})
+func (i *instanceDB) CountInstanceUsers(ctx context.Context, domain string) (int, db.Error) {
+ q := i.conn.
+ NewSelect().
+ Model(&[]*gtsmodel.Account{})
if domain == i.config.Host {
// if the domain is *this* domain, just count where the domain field is null
- q = q.Where("? IS NULL", pg.Ident("domain"))
+ q = q.Where("? IS NULL", bun.Ident("domain"))
} else {
q = q.Where("domain = ?", domain)
}
// don't count the instance account or suspended users
- q = q.Where("username != ?", domain).Where("? IS NULL", pg.Ident("suspended_at"))
+ q = q.
+ Where("username != ?", domain).
+ Where("? IS NULL", bun.Ident("suspended_at"))
- return q.Count()
+ count, err := q.Count(ctx)
+
+ return count, processErrorResponse(err)
}
-func (i *instanceDB) CountInstanceStatuses(domain string) (int, db.Error) {
- q := i.conn.Model(&[]*gtsmodel.Status{})
+func (i *instanceDB) CountInstanceStatuses(ctx context.Context, domain string) (int, db.Error) {
+ q := i.conn.
+ NewSelect().
+ Model(&[]*gtsmodel.Status{})
if domain == i.config.Host {
// if the domain is *this* domain, just count where local is true
@@ -63,30 +70,39 @@ func (i *instanceDB) CountInstanceStatuses(domain string) (int, db.Error) {
Where("account.domain = ?", domain)
}
- return q.Count()
+ count, err := q.Count(ctx)
+
+ return count, processErrorResponse(err)
}
-func (i *instanceDB) CountInstanceDomains(domain string) (int, db.Error) {
- q := i.conn.Model(&[]*gtsmodel.Instance{})
+func (i *instanceDB) CountInstanceDomains(ctx context.Context, domain string) (int, db.Error) {
+ q := i.conn.
+ NewSelect().
+ Model(&[]*gtsmodel.Instance{})
if domain == i.config.Host {
// if the domain is *this* domain, just count other instances it knows about
// exclude domains that are blocked
- q = q.Where("domain != ?", domain).Where("? IS NULL", pg.Ident("suspended_at"))
+ q = q.Where("domain != ?", domain).Where("? IS NULL", bun.Ident("suspended_at"))
} else {
// TODO: implement federated domain counting properly for remote domains
return 0, nil
}
- return q.Count()
+ count, err := q.Count(ctx)
+
+ return count, processErrorResponse(err)
}
-func (i *instanceDB) GetInstanceAccounts(domain string, maxID string, limit int) ([]*gtsmodel.Account, db.Error) {
+func (i *instanceDB) GetInstanceAccounts(ctx context.Context, domain string, maxID string, limit int) ([]*gtsmodel.Account, db.Error) {
i.log.Debug("GetAccountsForInstance")
accounts := []*gtsmodel.Account{}
- q := i.conn.Model(&accounts).Where("domain = ?", domain).Order("id DESC")
+ q := i.conn.NewSelect().
+ Model(&accounts).
+ Where("domain = ?", domain).
+ Order("id DESC")
if maxID != "" {
q = q.Where("id < ?", maxID)
@@ -96,17 +112,7 @@ func (i *instanceDB) GetInstanceAccounts(domain string, maxID string, limit int)
q = q.Limit(limit)
}
- err := q.Select()
- if err != nil {
- if err == pg.ErrNoRows {
- return nil, db.ErrNoEntries
- }
- return nil, err
- }
-
- if len(accounts) == 0 {
- return nil, db.ErrNoEntries
- }
+ err := processErrorResponse(q.Scan(ctx))
- return accounts, nil
+ return accounts, err
}
diff --git a/internal/db/pg/media.go b/internal/db/bundb/media.go
index 618030af3..04e55ca62 100644
--- a/internal/db/pg/media.go
+++ b/internal/db/bundb/media.go
@@ -16,38 +16,38 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
-package pg
+package bundb
import (
"context"
- "github.com/go-pg/pg/v10"
- "github.com/go-pg/pg/v10/orm"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/uptrace/bun"
)
type mediaDB struct {
config *config.Config
- conn *pg.DB
+ conn *bun.DB
log *logrus.Logger
- cancel context.CancelFunc
}
-func (m *mediaDB) newMediaQ(i interface{}) *orm.Query {
- return m.conn.Model(i).
+func (m *mediaDB) newMediaQ(i interface{}) *bun.SelectQuery {
+ return m.conn.
+ NewSelect().
+ Model(i).
Relation("Account")
}
-func (m *mediaDB) GetAttachmentByID(id string) (*gtsmodel.MediaAttachment, db.Error) {
+func (m *mediaDB) GetAttachmentByID(ctx context.Context, id string) (*gtsmodel.MediaAttachment, db.Error) {
attachment := &gtsmodel.MediaAttachment{}
q := m.newMediaQ(attachment).
Where("media_attachment.id = ?", id)
- err := processErrorResponse(q.Select())
+ err := processErrorResponse(q.Scan(ctx))
return attachment, err
}
diff --git a/internal/db/pg/mention.go b/internal/db/bundb/mention.go
index b31f07b67..a444f9b5f 100644
--- a/internal/db/pg/mention.go
+++ b/internal/db/bundb/mention.go
@@ -16,25 +16,23 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
-package pg
+package bundb
import (
"context"
- "github.com/go-pg/pg/v10"
- "github.com/go-pg/pg/v10/orm"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/cache"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/uptrace/bun"
)
type mentionDB struct {
config *config.Config
- conn *pg.DB
+ conn *bun.DB
log *logrus.Logger
- cancel context.CancelFunc
cache cache.Cache
}
@@ -67,14 +65,16 @@ func (m *mentionDB) mentionCached(id string) (*gtsmodel.Mention, bool) {
return mention, true
}
-func (m *mentionDB) newMentionQ(i interface{}) *orm.Query {
- return m.conn.Model(i).
+func (m *mentionDB) newMentionQ(i interface{}) *bun.SelectQuery {
+ return m.conn.
+ NewSelect().
+ Model(i).
Relation("Status").
Relation("OriginAccount").
Relation("TargetAccount")
}
-func (m *mentionDB) GetMention(id string) (*gtsmodel.Mention, db.Error) {
+func (m *mentionDB) GetMention(ctx context.Context, id string) (*gtsmodel.Mention, db.Error) {
if mention, cached := m.mentionCached(id); cached {
return mention, nil
}
@@ -84,7 +84,7 @@ func (m *mentionDB) GetMention(id string) (*gtsmodel.Mention, db.Error) {
q := m.newMentionQ(mention).
Where("mention.id = ?", id)
- err := processErrorResponse(q.Select())
+ err := processErrorResponse(q.Scan(ctx))
if err == nil && mention != nil {
m.cacheMention(id, mention)
@@ -93,11 +93,11 @@ func (m *mentionDB) GetMention(id string) (*gtsmodel.Mention, db.Error) {
return mention, err
}
-func (m *mentionDB) GetMentions(ids []string) ([]*gtsmodel.Mention, db.Error) {
+func (m *mentionDB) GetMentions(ctx context.Context, ids []string) ([]*gtsmodel.Mention, db.Error) {
mentions := []*gtsmodel.Mention{}
for _, i := range ids {
- mention, err := m.GetMention(i)
+ mention, err := m.GetMention(ctx, i)
if err != nil {
return nil, processErrorResponse(err)
}
diff --git a/internal/db/pg/notification.go b/internal/db/bundb/notification.go
index 281a76d85..1c30837ec 100644
--- a/internal/db/pg/notification.go
+++ b/internal/db/bundb/notification.go
@@ -16,25 +16,23 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
-package pg
+package bundb
import (
"context"
- "github.com/go-pg/pg/v10"
- "github.com/go-pg/pg/v10/orm"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/cache"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/uptrace/bun"
)
type notificationDB struct {
config *config.Config
- conn *pg.DB
+ conn *bun.DB
log *logrus.Logger
- cancel context.CancelFunc
cache cache.Cache
}
@@ -67,14 +65,16 @@ func (n *notificationDB) notificationCached(id string) (*gtsmodel.Notification,
return notification, true
}
-func (n *notificationDB) newNotificationQ(i interface{}) *orm.Query {
- return n.conn.Model(i).
+func (n *notificationDB) newNotificationQ(i interface{}) *bun.SelectQuery {
+ return n.conn.
+ NewSelect().
+ Model(i).
Relation("OriginAccount").
Relation("TargetAccount").
Relation("Status")
}
-func (n *notificationDB) GetNotification(id string) (*gtsmodel.Notification, db.Error) {
+func (n *notificationDB) GetNotification(ctx context.Context, id string) (*gtsmodel.Notification, db.Error) {
if notification, cached := n.notificationCached(id); cached {
return notification, nil
}
@@ -84,7 +84,7 @@ func (n *notificationDB) GetNotification(id string) (*gtsmodel.Notification, db.
q := n.newNotificationQ(notification).
Where("notification.id = ?", id)
- err := processErrorResponse(q.Select())
+ err := processErrorResponse(q.Scan(ctx))
if err == nil && notification != nil {
n.cacheNotification(id, notification)
@@ -93,10 +93,11 @@ func (n *notificationDB) GetNotification(id string) (*gtsmodel.Notification, db.
return notification, err
}
-func (n *notificationDB) GetNotifications(accountID string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, db.Error) {
+func (n *notificationDB) GetNotifications(ctx context.Context, accountID string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, db.Error) {
// begin by selecting just the IDs
notifIDs := []*gtsmodel.Notification{}
q := n.conn.
+ NewSelect().
Model(&notifIDs).
Column("id").
Where("target_account_id = ?", accountID).
@@ -114,7 +115,7 @@ func (n *notificationDB) GetNotifications(accountID string, limit int, maxID str
q = q.Limit(limit)
}
- err := processErrorResponse(q.Select())
+ err := processErrorResponse(q.Scan(ctx))
if err != nil {
return nil, err
}
@@ -123,7 +124,7 @@ func (n *notificationDB) GetNotifications(accountID string, limit int, maxID str
// reason for this is that for each notif, we can instead get it from our cache if it's cached
notifications := []*gtsmodel.Notification{}
for _, notifID := range notifIDs {
- notif, err := n.GetNotification(notifID.ID)
+ notif, err := n.GetNotification(ctx, notifID.ID)
errP := processErrorResponse(err)
if errP != nil {
return nil, errP
diff --git a/internal/db/pg/relationship.go b/internal/db/bundb/relationship.go
index 76bd50c76..ccc604baf 100644
--- a/internal/db/pg/relationship.go
+++ b/internal/db/bundb/relationship.go
@@ -16,44 +16,49 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
-package pg
+package bundb
import (
"context"
+ "database/sql"
"fmt"
- "github.com/go-pg/pg/v10"
- "github.com/go-pg/pg/v10/orm"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/uptrace/bun"
)
type relationshipDB struct {
config *config.Config
- conn *pg.DB
+ conn *bun.DB
log *logrus.Logger
- cancel context.CancelFunc
}
-func (r *relationshipDB) newBlockQ(block *gtsmodel.Block) *orm.Query {
- return r.conn.Model(block).
+func (r *relationshipDB) newBlockQ(block *gtsmodel.Block) *bun.SelectQuery {
+ return r.conn.
+ NewSelect().
+ Model(block).
Relation("Account").
Relation("TargetAccount")
}
-func (r *relationshipDB) newFollowQ(follow interface{}) *orm.Query {
- return r.conn.Model(follow).
+func (r *relationshipDB) newFollowQ(follow interface{}) *bun.SelectQuery {
+ return r.conn.
+ NewSelect().
+ Model(follow).
Relation("Account").
Relation("TargetAccount")
}
-func (r *relationshipDB) IsBlocked(account1 string, account2 string, eitherDirection bool) (bool, db.Error) {
+func (r *relationshipDB) IsBlocked(ctx context.Context, account1 string, account2 string, eitherDirection bool) (bool, db.Error) {
q := r.conn.
+ NewSelect().
Model(&gtsmodel.Block{}).
Where("account_id = ?", account1).
- Where("target_account_id = ?", account2)
+ Where("target_account_id = ?", account2).
+ Limit(1)
if eitherDirection {
q = q.
@@ -61,30 +66,36 @@ func (r *relationshipDB) IsBlocked(account1 string, account2 string, eitherDirec
Where("account_id = ?", account2)
}
- return q.Exists()
+ return exists(ctx, q)
}
-func (r *relationshipDB) GetBlock(account1 string, account2 string) (*gtsmodel.Block, db.Error) {
+func (r *relationshipDB) GetBlock(ctx context.Context, account1 string, account2 string) (*gtsmodel.Block, db.Error) {
block := &gtsmodel.Block{}
q := r.newBlockQ(block).
Where("block.account_id = ?", account1).
Where("block.target_account_id = ?", account2)
- err := processErrorResponse(q.Select())
+ err := processErrorResponse(q.Scan(ctx))
return block, err
}
-func (r *relationshipDB) GetRelationship(requestingAccount string, targetAccount string) (*gtsmodel.Relationship, db.Error) {
+func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount string, targetAccount string) (*gtsmodel.Relationship, db.Error) {
rel := &gtsmodel.Relationship{
ID: targetAccount,
}
// check if the requesting account follows the target account
follow := &gtsmodel.Follow{}
- if err := r.conn.Model(follow).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Select(); err != nil {
- if err != pg.ErrNoRows {
+ if err := r.conn.
+ NewSelect().
+ Model(follow).
+ Where("account_id = ?", requestingAccount).
+ Where("target_account_id = ?", targetAccount).
+ Limit(1).
+ Scan(ctx); err != nil {
+ if err != sql.ErrNoRows {
// a proper error
return nil, fmt.Errorf("getrelationship: error checking follow existence: %s", err)
}
@@ -100,75 +111,101 @@ func (r *relationshipDB) GetRelationship(requestingAccount string, targetAccount
}
// check if the target account follows the requesting account
- followedBy, err := r.conn.Model(&gtsmodel.Follow{}).Where("account_id = ?", targetAccount).Where("target_account_id = ?", requestingAccount).Exists()
+ count, err := r.conn.
+ NewSelect().
+ Model(&gtsmodel.Follow{}).
+ Where("account_id = ?", targetAccount).
+ Where("target_account_id = ?", requestingAccount).
+ Limit(1).
+ Count(ctx)
if err != nil {
return nil, fmt.Errorf("getrelationship: error checking followed_by existence: %s", err)
}
- rel.FollowedBy = followedBy
+ rel.FollowedBy = count > 0
// check if the requesting account blocks the target account
- blocking, err := r.conn.Model(&gtsmodel.Block{}).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Exists()
+ count, err = r.conn.NewSelect().
+ Model(&gtsmodel.Block{}).
+ Where("account_id = ?", requestingAccount).
+ Where("target_account_id = ?", targetAccount).
+ Limit(1).
+ Count(ctx)
if err != nil {
return nil, fmt.Errorf("getrelationship: error checking blocking existence: %s", err)
}
- rel.Blocking = blocking
+ rel.Blocking = count > 0
// check if the target account blocks the requesting account
- blockedBy, err := r.conn.Model(&gtsmodel.Block{}).Where("account_id = ?", targetAccount).Where("target_account_id = ?", requestingAccount).Exists()
+ count, err = r.conn.
+ NewSelect().
+ Model(&gtsmodel.Block{}).
+ Where("account_id = ?", targetAccount).
+ Where("target_account_id = ?", requestingAccount).
+ Limit(1).
+ Count(ctx)
if err != nil {
return nil, fmt.Errorf("getrelationship: error checking blocked existence: %s", err)
}
- rel.BlockedBy = blockedBy
+ rel.BlockedBy = count > 0
// check if there's a pending following request from requesting account to target account
- requested, err := r.conn.Model(&gtsmodel.FollowRequest{}).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Exists()
+ count, err = r.conn.
+ NewSelect().
+ Model(&gtsmodel.FollowRequest{}).
+ Where("account_id = ?", requestingAccount).
+ Where("target_account_id = ?", targetAccount).
+ Limit(1).
+ Count(ctx)
if err != nil {
return nil, fmt.Errorf("getrelationship: error checking blocked existence: %s", err)
}
- rel.Requested = requested
+ rel.Requested = count > 0
return rel, nil
}
-func (r *relationshipDB) IsFollowing(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.Error) {
+func (r *relationshipDB) IsFollowing(ctx context.Context, sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.Error) {
if sourceAccount == nil || targetAccount == nil {
return false, nil
}
q := r.conn.
+ NewSelect().
Model(&gtsmodel.Follow{}).
Where("account_id = ?", sourceAccount.ID).
- Where("target_account_id = ?", targetAccount.ID)
+ Where("target_account_id = ?", targetAccount.ID).
+ Limit(1)
- return q.Exists()
+ return exists(ctx, q)
}
-func (r *relationshipDB) IsFollowRequested(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.Error) {
+func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.Error) {
if sourceAccount == nil || targetAccount == nil {
return false, nil
}
q := r.conn.
+ NewSelect().
Model(&gtsmodel.FollowRequest{}).
Where("account_id = ?", sourceAccount.ID).
Where("target_account_id = ?", targetAccount.ID)
- return q.Exists()
+ return exists(ctx, q)
}
-func (r *relationshipDB) IsMutualFollowing(account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, db.Error) {
+func (r *relationshipDB) IsMutualFollowing(ctx context.Context, account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, db.Error) {
if account1 == nil || account2 == nil {
return false, nil
}
// make sure account 1 follows account 2
- f1, err := r.IsFollowing(account1, account2)
+ f1, err := r.IsFollowing(ctx, account1, account2)
if err != nil {
return false, processErrorResponse(err)
}
// make sure account 2 follows account 1
- f2, err := r.IsFollowing(account2, account1)
+ f2, err := r.IsFollowing(ctx, account2, account1)
if err != nil {
return false, processErrorResponse(err)
}
@@ -176,14 +213,16 @@ func (r *relationshipDB) IsMutualFollowing(account1 *gtsmodel.Account, account2
return f1 && f2, nil
}
-func (r *relationshipDB) AcceptFollowRequest(originAccountID string, targetAccountID string) (*gtsmodel.Follow, db.Error) {
+func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.Follow, db.Error) {
// make sure the original follow request exists
fr := &gtsmodel.FollowRequest{}
- if err := r.conn.Model(fr).Where("account_id = ?", originAccountID).Where("target_account_id = ?", targetAccountID).Select(); err != nil {
- if err == pg.ErrMultiRows {
- return nil, db.ErrNoEntries
- }
- return nil, err
+ if err := r.conn.
+ NewSelect().
+ Model(fr).
+ Where("account_id = ?", originAccountID).
+ Where("target_account_id = ?", targetAccountID).
+ Scan(ctx); err != nil {
+ return nil, processErrorResponse(err)
}
// create a new follow to 'replace' the request with
@@ -195,82 +234,95 @@ func (r *relationshipDB) AcceptFollowRequest(originAccountID string, targetAccou
}
// if the follow already exists, just update the URI -- we don't need to do anything else
- if _, err := r.conn.Model(follow).OnConflict("ON CONSTRAINT follows_account_id_target_account_id_key DO UPDATE set uri = ?", follow.URI).Insert(); err != nil {
- return nil, err
+ if _, err := r.conn.
+ NewInsert().
+ Model(follow).
+ On("CONFLICT CONSTRAINT follows_account_id_target_account_id_key DO UPDATE set uri = ?", follow.URI).
+ Exec(ctx); err != nil {
+ return nil, processErrorResponse(err)
}
// now remove the follow request
- if _, err := r.conn.Model(&gtsmodel.FollowRequest{}).Where("account_id = ?", originAccountID).Where("target_account_id = ?", targetAccountID).Delete(); err != nil {
- return nil, err
+ if _, err := r.conn.
+ NewDelete().
+ Model(&gtsmodel.FollowRequest{}).
+ Where("account_id = ?", originAccountID).
+ Where("target_account_id = ?", targetAccountID).
+ Exec(ctx); err != nil {
+ return nil, processErrorResponse(err)
}
return follow, nil
}
-func (r *relationshipDB) GetAccountFollowRequests(accountID string) ([]*gtsmodel.FollowRequest, db.Error) {
+func (r *relationshipDB) GetAccountFollowRequests(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, db.Error) {
followRequests := []*gtsmodel.FollowRequest{}
q := r.newFollowQ(&followRequests).
Where("target_account_id = ?", accountID)
- err := processErrorResponse(q.Select())
+ err := processErrorResponse(q.Scan(ctx))
return followRequests, err
}
-func (r *relationshipDB) GetAccountFollows(accountID string) ([]*gtsmodel.Follow, db.Error) {
+func (r *relationshipDB) GetAccountFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, db.Error) {
follows := []*gtsmodel.Follow{}
q := r.newFollowQ(&follows).
Where("account_id = ?", accountID)
- err := processErrorResponse(q.Select())
+ err := processErrorResponse(q.Scan(ctx))
return follows, err
}
-func (r *relationshipDB) CountAccountFollows(accountID string, localOnly bool) (int, db.Error) {
+func (r *relationshipDB) CountAccountFollows(ctx context.Context, accountID string, localOnly bool) (int, db.Error) {
return r.conn.
+ NewSelect().
Model(&[]*gtsmodel.Follow{}).
Where("account_id = ?", accountID).
- Count()
+ Count(ctx)
}
-func (r *relationshipDB) GetAccountFollowedBy(accountID string, localOnly bool) ([]*gtsmodel.Follow, db.Error) {
+func (r *relationshipDB) GetAccountFollowedBy(ctx context.Context, accountID string, localOnly bool) ([]*gtsmodel.Follow, db.Error) {
follows := []*gtsmodel.Follow{}
- q := r.conn.Model(&follows)
+ q := r.conn.
+ NewSelect().
+ Model(&follows)
if localOnly {
// for local accounts let's get where domain is null OR where domain is an empty string, just to be safe
- whereGroup := func(q *pg.Query) (*pg.Query, error) {
+ whereGroup := func(q *bun.SelectQuery) *bun.SelectQuery {
q = q.
- WhereOr("? IS NULL", pg.Ident("a.domain")).
+ WhereOr("? IS NULL", bun.Ident("a.domain")).
WhereOr("a.domain = ?", "")
- return q, nil
+ return q
}
q = q.ColumnExpr("follow.*").
Join("JOIN accounts AS a ON follow.account_id = TEXT(a.id)").
Where("follow.target_account_id = ?", accountID).
- WhereGroup(whereGroup)
+ WhereGroup(" AND ", whereGroup)
} else {
q = q.Where("target_account_id = ?", accountID)
}
- if err := q.Select(); err != nil {
- if err == pg.ErrNoRows {
+ if err := q.Scan(ctx); err != nil {
+ if err == sql.ErrNoRows {
return follows, nil
}
- return nil, err
+ return nil, processErrorResponse(err)
}
return follows, nil
}
-func (r *relationshipDB) CountAccountFollowedBy(accountID string, localOnly bool) (int, db.Error) {
+func (r *relationshipDB) CountAccountFollowedBy(ctx context.Context, accountID string, localOnly bool) (int, db.Error) {
return r.conn.
+ NewSelect().
Model(&[]*gtsmodel.Follow{}).
Where("target_account_id = ?", accountID).
- Count()
+ Count(ctx)
}
diff --git a/internal/db/bundb/session.go b/internal/db/bundb/session.go
new file mode 100644
index 000000000..87e20673d
--- /dev/null
+++ b/internal/db/bundb/session.go
@@ -0,0 +1,85 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package bundb
+
+import (
+ "context"
+ "crypto/rand"
+
+ "github.com/sirupsen/logrus"
+ "github.com/superseriousbusiness/gotosocial/internal/config"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/id"
+ "github.com/uptrace/bun"
+)
+
+type sessionDB struct {
+ config *config.Config
+ conn *bun.DB
+ log *logrus.Logger
+}
+
+func (s *sessionDB) GetSession(ctx context.Context) (*gtsmodel.RouterSession, db.Error) {
+ rs := new(gtsmodel.RouterSession)
+
+ q := s.conn.
+ NewSelect().
+ Model(rs).
+ Limit(1)
+
+ _, err := q.Exec(ctx)
+
+ err = processErrorResponse(err)
+
+ return rs, err
+}
+
+func (s *sessionDB) CreateSession(ctx context.Context) (*gtsmodel.RouterSession, db.Error) {
+ auth := make([]byte, 32)
+ crypt := make([]byte, 32)
+
+ if _, err := rand.Read(auth); err != nil {
+ return nil, err
+ }
+ if _, err := rand.Read(crypt); err != nil {
+ return nil, err
+ }
+
+ rid, err := id.NewULID()
+ if err != nil {
+ return nil, err
+ }
+
+ rs := &gtsmodel.RouterSession{
+ ID: rid,
+ Auth: auth,
+ Crypt: crypt,
+ }
+
+ q := s.conn.
+ NewInsert().
+ Model(rs)
+
+ _, err = q.Exec(ctx)
+
+ err = processErrorResponse(err)
+
+ return rs, err
+}
diff --git a/internal/db/bundb/status.go b/internal/db/bundb/status.go
new file mode 100644
index 000000000..da8d8ca41
--- /dev/null
+++ b/internal/db/bundb/status.go
@@ -0,0 +1,375 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package bundb
+
+import (
+ "container/list"
+ "context"
+ "errors"
+ "time"
+
+ "github.com/sirupsen/logrus"
+ "github.com/superseriousbusiness/gotosocial/internal/cache"
+ "github.com/superseriousbusiness/gotosocial/internal/config"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/uptrace/bun"
+)
+
+type statusDB struct {
+ config *config.Config
+ conn *bun.DB
+ log *logrus.Logger
+ cache cache.Cache
+}
+
+func (s *statusDB) cacheStatus(id string, status *gtsmodel.Status) {
+ if s.cache == nil {
+ s.cache = cache.New()
+ }
+
+ if err := s.cache.Store(id, status); err != nil {
+ s.log.Panicf("statusDB: error storing in cache: %s", err)
+ }
+}
+
+func (s *statusDB) statusCached(id string) (*gtsmodel.Status, bool) {
+ if s.cache == nil {
+ s.cache = cache.New()
+ return nil, false
+ }
+
+ sI, err := s.cache.Fetch(id)
+ if err != nil || sI == nil {
+ return nil, false
+ }
+
+ status, ok := sI.(*gtsmodel.Status)
+ if !ok {
+ s.log.Panicf("statusDB: cached interface with key %s was not a status", id)
+ }
+
+ return status, true
+}
+
+func (s *statusDB) newStatusQ(status interface{}) *bun.SelectQuery {
+ return s.conn.
+ NewSelect().
+ Model(status).
+ Relation("Attachments").
+ Relation("Tags").
+ Relation("Mentions").
+ Relation("Emojis").
+ Relation("Account").
+ Relation("InReplyToAccount").
+ Relation("BoostOfAccount").
+ Relation("CreatedWithApplication")
+}
+
+func (s *statusDB) getAttachedStatuses(ctx context.Context, status *gtsmodel.Status) *gtsmodel.Status {
+ if status.InReplyToID != "" && status.InReplyTo == nil {
+ if inReplyTo, cached := s.statusCached(status.InReplyToID); cached {
+ status.InReplyTo = inReplyTo
+ } else if inReplyTo, err := s.GetStatusByID(ctx, status.InReplyToID); err == nil {
+ status.InReplyTo = inReplyTo
+ }
+ }
+
+ if status.BoostOfID != "" && status.BoostOf == nil {
+ if boostOf, cached := s.statusCached(status.BoostOfID); cached {
+ status.BoostOf = boostOf
+ } else if boostOf, err := s.GetStatusByID(ctx, status.BoostOfID); err == nil {
+ status.BoostOf = boostOf
+ }
+ }
+
+ return status
+}
+
+func (s *statusDB) newFaveQ(faves interface{}) *bun.SelectQuery {
+ return s.conn.
+ NewSelect().
+ Model(faves).
+ Relation("Account").
+ Relation("TargetAccount").
+ Relation("Status")
+}
+
+func (s *statusDB) GetStatusByID(ctx context.Context, id string) (*gtsmodel.Status, db.Error) {
+ if status, cached := s.statusCached(id); cached {
+ return status, nil
+ }
+
+ status := new(gtsmodel.Status)
+
+ q := s.newStatusQ(status).
+ Where("status.id = ?", id)
+
+ err := processErrorResponse(q.Scan(ctx))
+
+ if err != nil {
+ return nil, err
+ }
+
+ if status != nil {
+ s.cacheStatus(id, status)
+ }
+
+ return s.getAttachedStatuses(ctx, status), err
+}
+
+func (s *statusDB) GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.Status, db.Error) {
+ if status, cached := s.statusCached(uri); cached {
+ return status, nil
+ }
+
+ status := &gtsmodel.Status{}
+
+ q := s.newStatusQ(status).
+ Where("LOWER(status.uri) = LOWER(?)", uri)
+
+ err := processErrorResponse(q.Scan(ctx))
+
+ if err != nil {
+ return nil, err
+ }
+
+ if status != nil {
+ s.cacheStatus(uri, status)
+ }
+
+ return s.getAttachedStatuses(ctx, status), err
+}
+
+func (s *statusDB) GetStatusByURL(ctx context.Context, uri string) (*gtsmodel.Status, db.Error) {
+ if status, cached := s.statusCached(uri); cached {
+ return status, nil
+ }
+
+ status := &gtsmodel.Status{}
+
+ q := s.newStatusQ(status).
+ Where("LOWER(status.url) = LOWER(?)", uri)
+
+ err := processErrorResponse(q.Scan(ctx))
+
+ if err != nil {
+ return nil, err
+ }
+
+ if status != nil {
+ s.cacheStatus(uri, status)
+ }
+
+ return s.getAttachedStatuses(ctx, status), err
+}
+
+func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Error {
+ transaction := func(ctx context.Context, tx bun.Tx) error {
+ // create links between this status and any emojis it uses
+ for _, i := range status.EmojiIDs {
+ if _, err := tx.NewInsert().Model(&gtsmodel.StatusToEmoji{
+ StatusID: status.ID,
+ EmojiID: i,
+ }).Exec(ctx); err != nil {
+ return err
+ }
+ }
+
+ // create links between this status and any tags it uses
+ for _, i := range status.TagIDs {
+ if _, err := tx.NewInsert().Model(&gtsmodel.StatusToTag{
+ StatusID: status.ID,
+ TagID: i,
+ }).Exec(ctx); err != nil {
+ return err
+ }
+ }
+
+ // change the status ID of the media attachments to the new status
+ for _, a := range status.Attachments {
+ a.StatusID = status.ID
+ a.UpdatedAt = time.Now()
+ if _, err := s.conn.NewUpdate().Model(a).
+ Where("id = ?", a.ID).
+ Exec(ctx); err != nil {
+ return err
+ }
+ }
+
+ _, err := tx.NewInsert().Model(status).Exec(ctx)
+ return err
+ }
+
+ return processErrorResponse(s.conn.RunInTx(ctx, nil, transaction))
+}
+
+func (s *statusDB) GetStatusParents(ctx context.Context, status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, db.Error) {
+ parents := []*gtsmodel.Status{}
+ s.statusParent(ctx, status, &parents, onlyDirect)
+
+ return parents, nil
+}
+
+func (s *statusDB) statusParent(ctx context.Context, status *gtsmodel.Status, foundStatuses *[]*gtsmodel.Status, onlyDirect bool) {
+ if status.InReplyToID == "" {
+ return
+ }
+
+ parentStatus, err := s.GetStatusByID(ctx, status.InReplyToID)
+ if err == nil {
+ *foundStatuses = append(*foundStatuses, parentStatus)
+ }
+
+ if onlyDirect {
+ return
+ }
+
+ s.statusParent(ctx, parentStatus, foundStatuses, false)
+}
+
+func (s *statusDB) GetStatusChildren(ctx context.Context, status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, db.Error) {
+ foundStatuses := &list.List{}
+ foundStatuses.PushFront(status)
+ s.statusChildren(ctx, status, foundStatuses, onlyDirect, minID)
+
+ children := []*gtsmodel.Status{}
+ for e := foundStatuses.Front(); e != nil; e = e.Next() {
+ entry, ok := e.Value.(*gtsmodel.Status)
+ if !ok {
+ panic(errors.New("entry in foundStatuses was not a *gtsmodel.Status"))
+ }
+
+ // only append children, not the overall parent status
+ if entry.ID != status.ID {
+ children = append(children, entry)
+ }
+ }
+
+ return children, nil
+}
+
+func (s *statusDB) statusChildren(ctx context.Context, status *gtsmodel.Status, foundStatuses *list.List, onlyDirect bool, minID string) {
+ immediateChildren := []*gtsmodel.Status{}
+
+ q := s.conn.
+ NewSelect().
+ Model(&immediateChildren).
+ Where("in_reply_to_id = ?", status.ID)
+ if minID != "" {
+ q = q.Where("status.id > ?", minID)
+ }
+
+ if err := q.Scan(ctx); err != nil {
+ return
+ }
+
+ for _, child := range immediateChildren {
+ insertLoop:
+ for e := foundStatuses.Front(); e != nil; e = e.Next() {
+ entry, ok := e.Value.(*gtsmodel.Status)
+ if !ok {
+ panic(errors.New("entry in foundStatuses was not a *gtsmodel.Status"))
+ }
+
+ if child.InReplyToAccountID != "" && entry.ID == child.InReplyToID {
+ foundStatuses.InsertAfter(child, e)
+ break insertLoop
+ }
+ }
+
+ // only do one loop if we only want direct children
+ if onlyDirect {
+ return
+ }
+ s.statusChildren(ctx, child, foundStatuses, false, minID)
+ }
+}
+
+func (s *statusDB) CountStatusReplies(ctx context.Context, status *gtsmodel.Status) (int, db.Error) {
+ return s.conn.NewSelect().Model(&gtsmodel.Status{}).Where("in_reply_to_id = ?", status.ID).Count(ctx)
+}
+
+func (s *statusDB) CountStatusReblogs(ctx context.Context, status *gtsmodel.Status) (int, db.Error) {
+ return s.conn.NewSelect().Model(&gtsmodel.Status{}).Where("boost_of_id = ?", status.ID).Count(ctx)
+}
+
+func (s *statusDB) CountStatusFaves(ctx context.Context, status *gtsmodel.Status) (int, db.Error) {
+ return s.conn.NewSelect().Model(&gtsmodel.StatusFave{}).Where("status_id = ?", status.ID).Count(ctx)
+}
+
+func (s *statusDB) IsStatusFavedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) {
+ q := s.conn.
+ NewSelect().
+ Model(&gtsmodel.StatusFave{}).
+ Where("status_id = ?", status.ID).
+ Where("account_id = ?", accountID)
+
+ return exists(ctx, q)
+}
+
+func (s *statusDB) IsStatusRebloggedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) {
+ q := s.conn.
+ NewSelect().
+ Model(&gtsmodel.Status{}).
+ Where("boost_of_id = ?", status.ID).
+ Where("account_id = ?", accountID)
+
+ return exists(ctx, q)
+}
+
+func (s *statusDB) IsStatusMutedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) {
+ q := s.conn.
+ NewSelect().
+ Model(&gtsmodel.StatusMute{}).
+ Where("status_id = ?", status.ID).
+ Where("account_id = ?", accountID)
+
+ return exists(ctx, q)
+}
+
+func (s *statusDB) IsStatusBookmarkedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) {
+ q := s.conn.
+ NewSelect().
+ Model(&gtsmodel.StatusBookmark{}).
+ Where("status_id = ?", status.ID).
+ Where("account_id = ?", accountID)
+
+ return exists(ctx, q)
+}
+
+func (s *statusDB) GetStatusFaves(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.StatusFave, db.Error) {
+ faves := []*gtsmodel.StatusFave{}
+
+ q := s.newFaveQ(&faves).
+ Where("status_id = ?", status.ID)
+
+ err := processErrorResponse(q.Scan(ctx))
+ return faves, err
+}
+
+func (s *statusDB) GetStatusReblogs(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.Status, db.Error) {
+ reblogs := []*gtsmodel.Status{}
+
+ q := s.newStatusQ(&reblogs).
+ Where("boost_of_id = ?", status.ID)
+
+ err := processErrorResponse(q.Scan(ctx))
+ return reblogs, err
+}
diff --git a/internal/db/pg/status_test.go b/internal/db/bundb/status_test.go
index 8a185757c..513000577 100644
--- a/internal/db/pg/status_test.go
+++ b/internal/db/bundb/status_test.go
@@ -16,9 +16,10 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
-package pg_test
+package bundb_test
import (
+ "context"
"fmt"
"testing"
"time"
@@ -28,7 +29,7 @@ import (
)
type StatusTestSuite struct {
- PGStandardTestSuite
+ BunDBStandardTestSuite
}
func (suite *StatusTestSuite) SetupSuite() {
@@ -56,8 +57,9 @@ func (suite *StatusTestSuite) TearDownTest() {
}
func (suite *StatusTestSuite) TestGetStatusByID() {
- status, err := suite.db.GetStatusByID(suite.testStatuses["local_account_1_status_1"].ID)
+ status, err := suite.db.GetStatusByID(context.Background(), suite.testStatuses["local_account_1_status_1"].ID)
if err != nil {
+ fmt.Println(err.Error())
suite.FailNow(err.Error())
}
suite.NotNil(status)
@@ -70,7 +72,7 @@ func (suite *StatusTestSuite) TestGetStatusByID() {
}
func (suite *StatusTestSuite) TestGetStatusByURI() {
- status, err := suite.db.GetStatusByURI(suite.testStatuses["local_account_1_status_1"].URI)
+ status, err := suite.db.GetStatusByURI(context.Background(), suite.testStatuses["local_account_1_status_1"].URI)
if err != nil {
suite.FailNow(err.Error())
}
@@ -84,7 +86,7 @@ func (suite *StatusTestSuite) TestGetStatusByURI() {
}
func (suite *StatusTestSuite) TestGetStatusWithExtras() {
- status, err := suite.db.GetStatusByID(suite.testStatuses["admin_account_status_1"].ID)
+ status, err := suite.db.GetStatusByID(context.Background(), suite.testStatuses["admin_account_status_1"].ID)
if err != nil {
suite.FailNow(err.Error())
}
@@ -97,7 +99,7 @@ func (suite *StatusTestSuite) TestGetStatusWithExtras() {
}
func (suite *StatusTestSuite) TestGetStatusWithMention() {
- status, err := suite.db.GetStatusByID(suite.testStatuses["local_account_2_status_5"].ID)
+ status, err := suite.db.GetStatusByID(context.Background(), suite.testStatuses["local_account_2_status_5"].ID)
if err != nil {
suite.FailNow(err.Error())
}
@@ -112,18 +114,18 @@ func (suite *StatusTestSuite) TestGetStatusWithMention() {
func (suite *StatusTestSuite) TestGetStatusTwice() {
before1 := time.Now()
- _, err := suite.db.GetStatusByURI(suite.testStatuses["local_account_1_status_1"].URI)
+ _, err := suite.db.GetStatusByURI(context.Background(), suite.testStatuses["local_account_1_status_1"].URI)
suite.NoError(err)
after1 := time.Now()
duration1 := after1.Sub(before1)
- fmt.Println(duration1.Nanoseconds())
+ fmt.Println(duration1.Milliseconds())
before2 := time.Now()
- _, err = suite.db.GetStatusByURI(suite.testStatuses["local_account_1_status_1"].URI)
+ _, err = suite.db.GetStatusByURI(context.Background(), suite.testStatuses["local_account_1_status_1"].URI)
suite.NoError(err)
after2 := time.Now()
duration2 := after2.Sub(before2)
- fmt.Println(duration2.Nanoseconds())
+ fmt.Println(duration2.Milliseconds())
// second retrieval should be several orders faster since it will be cached now
suite.Less(duration2, duration1)
diff --git a/internal/db/pg/timeline.go b/internal/db/bundb/timeline.go
index fa8b07aab..b62ad4c50 100644
--- a/internal/db/pg/timeline.go
+++ b/internal/db/bundb/timeline.go
@@ -16,43 +16,35 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
-package pg
+package bundb
import (
"context"
+ "database/sql"
"sort"
- "github.com/go-pg/pg/v10"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/uptrace/bun"
)
type timelineDB struct {
config *config.Config
- conn *pg.DB
+ conn *bun.DB
log *logrus.Logger
- cancel context.CancelFunc
}
-func (t *timelineDB) GetHomeTimeline(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) {
+func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) {
statuses := []*gtsmodel.Status{}
- q := t.conn.Model(&statuses)
+ q := t.conn.
+ NewSelect().
+ Model(&statuses)
q = q.ColumnExpr("status.*").
// Find out who accountID follows.
Join("LEFT JOIN follows AS f ON f.target_account_id = status.account_id").
- // Use a WhereGroup here to specify that we want EITHER statuses posted by accounts that accountID follows,
- // OR statuses posted by accountID itself (since a user should be able to see their own statuses).
- //
- // This is equivalent to something like WHERE ... AND (... OR ...)
- // See: https://pg.uptrace.dev/queries/#select
- WhereGroup(func(q *pg.Query) (*pg.Query, error) {
- q = q.WhereOr("f.account_id = ?", accountID).
- WhereOr("status.account_id = ?", accountID)
- return q, nil
- }).
// Sort by highest ID (newest) to lowest ID (oldest)
Order("status.id DESC")
@@ -81,29 +73,32 @@ func (t *timelineDB) GetHomeTimeline(accountID string, maxID string, sinceID str
q = q.Limit(limit)
}
- err := q.Select()
- if err != nil {
- if err == pg.ErrNoRows {
- return nil, db.ErrNoEntries
- }
- return nil, err
+ // Use a WhereGroup here to specify that we want EITHER statuses posted by accounts that accountID follows,
+ // OR statuses posted by accountID itself (since a user should be able to see their own statuses).
+ //
+ // This is equivalent to something like WHERE ... AND (... OR ...)
+ // See: https://bun.uptrace.dev/guide/queries.html#select
+ whereGroup := func(*bun.SelectQuery) *bun.SelectQuery {
+ return q.
+ WhereOr("f.account_id = ?", accountID).
+ WhereOr("status.account_id = ?", accountID)
}
- if len(statuses) == 0 {
- return nil, db.ErrNoEntries
- }
+ q = q.WhereGroup(" AND ", whereGroup)
- return statuses, nil
+ return statuses, processErrorResponse(q.Scan(ctx))
}
-func (t *timelineDB) GetPublicTimeline(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) {
+func (t *timelineDB) GetPublicTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) {
statuses := []*gtsmodel.Status{}
- q := t.conn.Model(&statuses).
+ q := t.conn.
+ NewSelect().
+ Model(&statuses).
Where("visibility = ?", gtsmodel.VisibilityPublic).
- Where("? IS NULL", pg.Ident("in_reply_to_id")).
- Where("? IS NULL", pg.Ident("in_reply_to_uri")).
- Where("? IS NULL", pg.Ident("boost_of_id")).
+ Where("? IS NULL", bun.Ident("in_reply_to_id")).
+ Where("? IS NULL", bun.Ident("in_reply_to_uri")).
+ Where("? IS NULL", bun.Ident("boost_of_id")).
Order("status.id DESC")
if maxID != "" {
@@ -126,28 +121,18 @@ func (t *timelineDB) GetPublicTimeline(accountID string, maxID string, sinceID s
q = q.Limit(limit)
}
- err := q.Select()
- if err != nil {
- if err == pg.ErrNoRows {
- return nil, db.ErrNoEntries
- }
- return nil, err
- }
-
- if len(statuses) == 0 {
- return nil, db.ErrNoEntries
- }
-
- return statuses, nil
+ return statuses, processErrorResponse(q.Scan(ctx))
}
// TODO optimize this query and the logic here, because it's slow as balls -- it takes like a literal second to return with a limit of 20!
// It might be worth serving it through a timeline instead of raw DB queries, like we do for Home feeds.
-func (t *timelineDB) GetFavedTimeline(accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, db.Error) {
+func (t *timelineDB) GetFavedTimeline(ctx context.Context, accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, db.Error) {
faves := []*gtsmodel.StatusFave{}
- fq := t.conn.Model(&faves).
+ fq := t.conn.
+ NewSelect().
+ Model(&faves).
Where("account_id = ?", accountID).
Order("id DESC")
@@ -163,9 +148,9 @@ func (t *timelineDB) GetFavedTimeline(accountID string, maxID string, minID stri
fq = fq.Limit(limit)
}
- err := fq.Select()
+ err := fq.Scan(ctx)
if err != nil {
- if err == pg.ErrNoRows {
+ if err == sql.ErrNoRows {
return nil, "", "", db.ErrNoEntries
}
return nil, "", "", err
@@ -185,9 +170,13 @@ func (t *timelineDB) GetFavedTimeline(accountID string, maxID string, minID stri
}
statuses := []*gtsmodel.Status{}
- err = t.conn.Model(&statuses).Where("id IN (?)", pg.In(in)).Select()
+ err = t.conn.
+ NewSelect().
+ Model(&statuses).
+ Where("id IN (?)", bun.In(in)).
+ Scan(ctx)
if err != nil {
- if err == pg.ErrNoRows {
+ if err == sql.ErrNoRows {
return nil, "", "", db.ErrNoEntries
}
return nil, "", "", err
diff --git a/internal/db/bundb/util.go b/internal/db/bundb/util.go
new file mode 100644
index 000000000..115d18de2
--- /dev/null
+++ b/internal/db/bundb/util.go
@@ -0,0 +1,78 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package bundb
+
+import (
+ "context"
+ "strings"
+
+ "database/sql"
+
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/uptrace/bun"
+)
+
+// processErrorResponse parses the given error and returns an appropriate DBError.
+func processErrorResponse(err error) db.Error {
+ switch err {
+ case nil:
+ return nil
+ case sql.ErrNoRows:
+ return db.ErrNoEntries
+ default:
+ if strings.Contains(err.Error(), "duplicate key value violates unique constraint") {
+ return db.ErrAlreadyExists
+ }
+ return err
+ }
+}
+
+func exists(ctx context.Context, q *bun.SelectQuery) (bool, db.Error) {
+ count, err := q.Count(ctx)
+
+ exists := count != 0
+
+ err = processErrorResponse(err)
+
+ if err != nil {
+ if err == db.ErrNoEntries {
+ return false, nil
+ }
+ return false, err
+ }
+
+ return exists, nil
+}
+
+func notExists(ctx context.Context, q *bun.SelectQuery) (bool, db.Error) {
+ count, err := q.Count(ctx)
+
+ notExists := count == 0
+
+ err = processErrorResponse(err)
+
+ if err != nil {
+ if err == db.ErrNoEntries {
+ return true, nil
+ }
+ return false, err
+ }
+
+ return notExists, nil
+}
diff --git a/internal/db/db.go b/internal/db/db.go
index d6ac883e4..ec94fcfe7 100644
--- a/internal/db/db.go
+++ b/internal/db/db.go
@@ -19,6 +19,8 @@
package db
import (
+ "context"
+
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
@@ -38,6 +40,7 @@ type DB interface {
Mention
Notification
Relationship
+ Session
Status
Timeline
@@ -52,7 +55,7 @@ type DB interface {
//
// Note: this func doesn't/shouldn't do any manipulation of the accounts in the DB, it's just for checking
// if they exist in the db and conveniently returning them if they do.
- MentionStringsToMentions(targetAccounts []string, originAccountID string, statusID string) ([]*gtsmodel.Mention, error)
+ MentionStringsToMentions(ctx context.Context, targetAccounts []string, originAccountID string, statusID string) ([]*gtsmodel.Mention, error)
// TagStringsToTags takes a slice of deduplicated, lowercase tags in the form "somehashtag", which have been
// used in a status. It takes the id of the account that wrote the status, and the id of the status itself, and then
@@ -61,7 +64,7 @@ type DB interface {
//
// Note: this func doesn't/shouldn't do any manipulation of the tags in the DB, it's just for checking
// if they exist in the db already, and conveniently returning them, or creating new tag structs.
- TagStringsToTags(tags []string, originAccountID string, statusID string) ([]*gtsmodel.Tag, error)
+ TagStringsToTags(ctx context.Context, tags []string, originAccountID string, statusID string) ([]*gtsmodel.Tag, error)
// EmojiStringsToEmojis takes a slice of deduplicated, lowercase emojis in the form ":emojiname:", which have been
// used in a status. It takes the id of the account that wrote the status, and the id of the status itself, and then
@@ -69,5 +72,5 @@ type DB interface {
//
// Note: this func doesn't/shouldn't do any manipulation of the emoji in the DB, it's just for checking
// if they exist in the db and conveniently returning them if they do.
- EmojiStringsToEmojis(emojis []string, originAccountID string, statusID string) ([]*gtsmodel.Emoji, error)
+ EmojiStringsToEmojis(ctx context.Context, emojis []string, originAccountID string, statusID string) ([]*gtsmodel.Emoji, error)
}
diff --git a/internal/db/domain.go b/internal/db/domain.go
index a6583c80c..df50a6770 100644
--- a/internal/db/domain.go
+++ b/internal/db/domain.go
@@ -18,19 +18,22 @@
package db
-import "net/url"
+import (
+ "context"
+ "net/url"
+)
// Domain contains DB functions related to domains and domain blocks.
type Domain interface {
// IsDomainBlocked checks if an instance-level domain block exists for the given domain string (eg., `example.org`).
- IsDomainBlocked(domain string) (bool, Error)
+ IsDomainBlocked(ctx context.Context, domain string) (bool, Error)
// AreDomainsBlocked checks if an instance-level domain block exists for any of the given domains strings, and returns true if even one is found.
- AreDomainsBlocked(domains []string) (bool, Error)
+ AreDomainsBlocked(ctx context.Context, domains []string) (bool, Error)
// IsURIBlocked checks if an instance-level domain block exists for the `host` in the given URI (eg., `https://example.org/users/whatever`).
- IsURIBlocked(uri *url.URL) (bool, Error)
+ IsURIBlocked(ctx context.Context, uri *url.URL) (bool, Error)
// AreURIsBlocked checks if an instance-level domain block exists for any `host` in the given URI slice, and returns true if even one is found.
- AreURIsBlocked(uris []*url.URL) (bool, Error)
+ AreURIsBlocked(ctx context.Context, uris []*url.URL) (bool, Error)
}
diff --git a/internal/db/instance.go b/internal/db/instance.go
index 1f7c83e4f..dcd978a81 100644
--- a/internal/db/instance.go
+++ b/internal/db/instance.go
@@ -18,19 +18,23 @@
package db
-import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+import (
+ "context"
+
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+)
// Instance contains functions for instance-level actions (counting instance users etc.).
type Instance interface {
// CountInstanceUsers returns the number of known accounts registered with the given domain.
- CountInstanceUsers(domain string) (int, Error)
+ CountInstanceUsers(ctx context.Context, domain string) (int, Error)
// CountInstanceStatuses returns the number of known statuses posted from the given domain.
- CountInstanceStatuses(domain string) (int, Error)
+ CountInstanceStatuses(ctx context.Context, domain string) (int, Error)
// CountInstanceDomains returns the number of known instances known that the given domain federates with.
- CountInstanceDomains(domain string) (int, Error)
+ CountInstanceDomains(ctx context.Context, domain string) (int, Error)
// GetInstanceAccounts returns a slice of accounts from the given instance, arranged by ID.
- GetInstanceAccounts(domain string, maxID string, limit int) ([]*gtsmodel.Account, Error)
+ GetInstanceAccounts(ctx context.Context, domain string, maxID string, limit int) ([]*gtsmodel.Account, Error)
}
diff --git a/internal/db/media.go b/internal/db/media.go
index db4db3411..b779dd276 100644
--- a/internal/db/media.go
+++ b/internal/db/media.go
@@ -18,10 +18,14 @@
package db
-import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+import (
+ "context"
+
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+)
// Media contains functions related to creating/getting/removing media attachments.
type Media interface {
// GetAttachmentByID gets a single attachment by its ID
- GetAttachmentByID(id string) (*gtsmodel.MediaAttachment, Error)
+ GetAttachmentByID(ctx context.Context, id string) (*gtsmodel.MediaAttachment, Error)
}
diff --git a/internal/db/mention.go b/internal/db/mention.go
index cb1c56dc1..b9b45546a 100644
--- a/internal/db/mention.go
+++ b/internal/db/mention.go
@@ -18,13 +18,17 @@
package db
-import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+import (
+ "context"
+
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+)
// Mention contains functions for getting/creating mentions in the database.
type Mention interface {
// GetMention gets a single mention by ID
- GetMention(id string) (*gtsmodel.Mention, Error)
+ GetMention(ctx context.Context, id string) (*gtsmodel.Mention, Error)
// GetMentions gets multiple mentions.
- GetMentions(ids []string) ([]*gtsmodel.Mention, Error)
+ GetMentions(ctx context.Context, ids []string) ([]*gtsmodel.Mention, Error)
}
diff --git a/internal/db/notification.go b/internal/db/notification.go
index 326f0f149..09c17f031 100644
--- a/internal/db/notification.go
+++ b/internal/db/notification.go
@@ -18,14 +18,18 @@
package db
-import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+import (
+ "context"
+
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+)
// Notification contains functions for creating and getting notifications.
type Notification interface {
// GetNotifications returns a slice of notifications that pertain to the given accountID.
//
// Returned notifications will be ordered ID descending (ie., highest/newest to lowest/oldest).
- GetNotifications(accountID string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, Error)
+ GetNotifications(ctx context.Context, accountID string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, Error)
// GetNotification returns one notification according to its id.
- GetNotification(id string) (*gtsmodel.Notification, Error)
+ GetNotification(ctx context.Context, id string) (*gtsmodel.Notification, Error)
}
diff --git a/internal/db/pg/basic.go b/internal/db/pg/basic.go
deleted file mode 100644
index 6e76b4450..000000000
--- a/internal/db/pg/basic.go
+++ /dev/null
@@ -1,205 +0,0 @@
-/*
- GoToSocial
- Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
-
- This program is free software: you can redistribute it and/or modify
- it under the terms of the GNU Affero General Public License as published by
- the Free Software Foundation, either version 3 of the License, or
- (at your option) any later version.
-
- This program is distributed in the hope that it will be useful,
- but WITHOUT ANY WARRANTY; without even the implied warranty of
- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- GNU Affero General Public License for more details.
-
- You should have received a copy of the GNU Affero General Public License
- along with this program. If not, see <http://www.gnu.org/licenses/>.
-*/
-
-package pg
-
-import (
- "context"
- "errors"
- "fmt"
- "strings"
-
- "github.com/go-pg/pg/v10"
- "github.com/go-pg/pg/v10/orm"
- "github.com/sirupsen/logrus"
- "github.com/superseriousbusiness/gotosocial/internal/config"
- "github.com/superseriousbusiness/gotosocial/internal/db"
-)
-
-type basicDB struct {
- config *config.Config
- conn *pg.DB
- log *logrus.Logger
- cancel context.CancelFunc
-}
-
-func (b *basicDB) Put(i interface{}) db.Error {
- _, err := b.conn.Model(i).Insert(i)
- if err != nil && strings.Contains(err.Error(), "duplicate key value violates unique constraint") {
- return db.ErrAlreadyExists
- }
- return err
-}
-
-func (b *basicDB) GetByID(id string, i interface{}) db.Error {
- if err := b.conn.Model(i).Where("id = ?", id).Select(); err != nil {
- if err == pg.ErrNoRows {
- return db.ErrNoEntries
- }
- return err
-
- }
- return nil
-}
-
-func (b *basicDB) GetWhere(where []db.Where, i interface{}) db.Error {
- if len(where) == 0 {
- return errors.New("no queries provided")
- }
-
- q := b.conn.Model(i)
- for _, w := range where {
-
- if w.Value == nil {
- q = q.Where("? IS NULL", pg.Ident(w.Key))
- } else {
- if w.CaseInsensitive {
- q = q.Where("LOWER(?) = LOWER(?)", pg.Safe(w.Key), w.Value)
- } else {
- q = q.Where("? = ?", pg.Safe(w.Key), w.Value)
- }
- }
- }
-
- if err := q.Select(); err != nil {
- if err == pg.ErrNoRows {
- return db.ErrNoEntries
- }
- return err
- }
- return nil
-}
-
-func (b *basicDB) GetAll(i interface{}) db.Error {
- if err := b.conn.Model(i).Select(); err != nil {
- if err == pg.ErrNoRows {
- return db.ErrNoEntries
- }
- return err
- }
- return nil
-}
-
-func (b *basicDB) DeleteByID(id string, i interface{}) db.Error {
- if _, err := b.conn.Model(i).Where("id = ?", id).Delete(); err != nil {
- // if there are no rows *anyway* then that's fine
- // just return err if there's an actual error
- if err != pg.ErrNoRows {
- return err
- }
- }
- return nil
-}
-
-func (b *basicDB) DeleteWhere(where []db.Where, i interface{}) db.Error {
- if len(where) == 0 {
- return errors.New("no queries provided")
- }
-
- q := b.conn.Model(i)
- for _, w := range where {
- q = q.Where("? = ?", pg.Safe(w.Key), w.Value)
- }
-
- if _, err := q.Delete(); err != nil {
- // if there are no rows *anyway* then that's fine
- // just return err if there's an actual error
- if err != pg.ErrNoRows {
- return err
- }
- }
- return nil
-}
-
-func (b *basicDB) Upsert(i interface{}, conflictColumn string) db.Error {
- if _, err := b.conn.Model(i).OnConflict(fmt.Sprintf("(%s) DO UPDATE", conflictColumn)).Insert(); err != nil {
- if err == pg.ErrNoRows {
- return db.ErrNoEntries
- }
- return err
- }
- return nil
-}
-
-func (b *basicDB) UpdateByID(id string, i interface{}) db.Error {
- if _, err := b.conn.Model(i).Where("id = ?", id).OnConflict("(id) DO UPDATE").Insert(); err != nil {
- if err == pg.ErrNoRows {
- return db.ErrNoEntries
- }
- return err
- }
- return nil
-}
-
-func (b *basicDB) UpdateOneByID(id string, key string, value interface{}, i interface{}) db.Error {
- _, err := b.conn.Model(i).Set("? = ?", pg.Safe(key), value).Where("id = ?", id).Update()
- return err
-}
-
-func (b *basicDB) UpdateWhere(where []db.Where, key string, value interface{}, i interface{}) db.Error {
- q := b.conn.Model(i)
-
- for _, w := range where {
- if w.Value == nil {
- q = q.Where("? IS NULL", pg.Ident(w.Key))
- } else {
- if w.CaseInsensitive {
- q = q.Where("LOWER(?) = LOWER(?)", pg.Safe(w.Key), w.Value)
- } else {
- q = q.Where("? = ?", pg.Safe(w.Key), w.Value)
- }
- }
- }
-
- q = q.Set("? = ?", pg.Safe(key), value)
-
- _, err := q.Update()
-
- return err
-}
-
-func (b *basicDB) CreateTable(i interface{}) db.Error {
- return b.conn.Model(i).CreateTable(&orm.CreateTableOptions{
- IfNotExists: true,
- })
-}
-
-func (b *basicDB) DropTable(i interface{}) db.Error {
- return b.conn.Model(i).DropTable(&orm.DropTableOptions{
- IfExists: true,
- })
-}
-
-func (b *basicDB) RegisterTable(i interface{}) db.Error {
- orm.RegisterTable(i)
- return nil
-}
-
-func (b *basicDB) IsHealthy(ctx context.Context) db.Error {
- return b.conn.Ping(ctx)
-}
-
-func (b *basicDB) Stop(ctx context.Context) db.Error {
- b.log.Info("closing db connection")
- if err := b.conn.Close(); err != nil {
- // only cancel if there's a problem closing the db
- b.cancel()
- return err
- }
- return nil
-}
diff --git a/internal/db/pg/status.go b/internal/db/pg/status.go
deleted file mode 100644
index 99790428e..000000000
--- a/internal/db/pg/status.go
+++ /dev/null
@@ -1,318 +0,0 @@
-/*
- GoToSocial
- Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
-
- This program is free software: you can redistribute it and/or modify
- it under the terms of the GNU Affero General Public License as published by
- the Free Software Foundation, either version 3 of the License, or
- (at your option) any later version.
-
- This program is distributed in the hope that it will be useful,
- but WITHOUT ANY WARRANTY; without even the implied warranty of
- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- GNU Affero General Public License for more details.
-
- You should have received a copy of the GNU Affero General Public License
- along with this program. If not, see <http://www.gnu.org/licenses/>.
-*/
-
-package pg
-
-import (
- "container/list"
- "context"
- "errors"
- "time"
-
- "github.com/go-pg/pg/v10"
- "github.com/go-pg/pg/v10/orm"
- "github.com/sirupsen/logrus"
- "github.com/superseriousbusiness/gotosocial/internal/cache"
- "github.com/superseriousbusiness/gotosocial/internal/config"
- "github.com/superseriousbusiness/gotosocial/internal/db"
- "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
-)
-
-type statusDB struct {
- config *config.Config
- conn *pg.DB
- log *logrus.Logger
- cancel context.CancelFunc
- cache cache.Cache
-}
-
-func (s *statusDB) cacheStatus(id string, status *gtsmodel.Status) {
- if s.cache == nil {
- s.cache = cache.New()
- }
-
- if err := s.cache.Store(id, status); err != nil {
- s.log.Panicf("statusDB: error storing in cache: %s", err)
- }
-}
-
-func (s *statusDB) statusCached(id string) (*gtsmodel.Status, bool) {
- if s.cache == nil {
- s.cache = cache.New()
- return nil, false
- }
-
- sI, err := s.cache.Fetch(id)
- if err != nil || sI == nil {
- return nil, false
- }
-
- status, ok := sI.(*gtsmodel.Status)
- if !ok {
- s.log.Panicf("statusDB: cached interface with key %s was not a status", id)
- }
-
- return status, true
-}
-
-func (s *statusDB) newStatusQ(status interface{}) *orm.Query {
- return s.conn.Model(status).
- Relation("Attachments").
- Relation("Tags").
- Relation("Mentions").
- Relation("Emojis").
- Relation("Account").
- Relation("InReplyTo").
- Relation("InReplyToAccount").
- Relation("BoostOf").
- Relation("BoostOfAccount").
- Relation("CreatedWithApplication")
-}
-
-func (s *statusDB) newFaveQ(faves interface{}) *orm.Query {
- return s.conn.Model(faves).
- Relation("Account").
- Relation("TargetAccount").
- Relation("Status")
-}
-
-func (s *statusDB) GetStatusByID(id string) (*gtsmodel.Status, db.Error) {
- if status, cached := s.statusCached(id); cached {
- return status, nil
- }
-
- status := &gtsmodel.Status{}
-
- q := s.newStatusQ(status).
- Where("status.id = ?", id)
-
- err := processErrorResponse(q.Select())
-
- if err == nil && status != nil {
- s.cacheStatus(id, status)
- }
-
- return status, err
-}
-
-func (s *statusDB) GetStatusByURI(uri string) (*gtsmodel.Status, db.Error) {
- if status, cached := s.statusCached(uri); cached {
- return status, nil
- }
-
- status := &gtsmodel.Status{}
-
- q := s.newStatusQ(status).
- Where("LOWER(status.uri) = LOWER(?)", uri)
-
- err := processErrorResponse(q.Select())
-
- if err == nil && status != nil {
- s.cacheStatus(uri, status)
- }
-
- return status, err
-}
-
-func (s *statusDB) GetStatusByURL(uri string) (*gtsmodel.Status, db.Error) {
- if status, cached := s.statusCached(uri); cached {
- return status, nil
- }
-
- status := &gtsmodel.Status{}
-
- q := s.newStatusQ(status).
- Where("LOWER(status.url) = LOWER(?)", uri)
-
- err := processErrorResponse(q.Select())
-
- if err == nil && status != nil {
- s.cacheStatus(uri, status)
- }
-
- return status, err
-}
-
-func (s *statusDB) PutStatus(status *gtsmodel.Status) db.Error {
- transaction := func(tx *pg.Tx) error {
- // create links between this status and any emojis it uses
- for _, i := range status.EmojiIDs {
- if _, err := tx.Model(&gtsmodel.StatusToEmoji{
- StatusID: status.ID,
- EmojiID: i,
- }).Insert(); err != nil {
- return err
- }
- }
-
- // create links between this status and any tags it uses
- for _, i := range status.TagIDs {
- if _, err := tx.Model(&gtsmodel.StatusToTag{
- StatusID: status.ID,
- TagID: i,
- }).Insert(); err != nil {
- return err
- }
- }
-
- // change the status ID of the media attachments to the new status
- for _, a := range status.Attachments {
- a.StatusID = status.ID
- a.UpdatedAt = time.Now()
- if _, err := s.conn.Model(a).
- Where("id = ?", a.ID).
- Update(); err != nil {
- return err
- }
- }
-
- _, err := tx.Model(status).Insert()
- return err
- }
-
- return processErrorResponse(s.conn.RunInTransaction(context.Background(), transaction))
-}
-
-func (s *statusDB) GetStatusParents(status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, db.Error) {
- parents := []*gtsmodel.Status{}
- s.statusParent(status, &parents, onlyDirect)
-
- return parents, nil
-}
-
-func (s *statusDB) statusParent(status *gtsmodel.Status, foundStatuses *[]*gtsmodel.Status, onlyDirect bool) {
- if status.InReplyToID == "" {
- return
- }
-
- parentStatus, err := s.GetStatusByID(status.InReplyToID)
- if err == nil {
- *foundStatuses = append(*foundStatuses, parentStatus)
- }
-
- if onlyDirect {
- return
- }
-
- s.statusParent(parentStatus, foundStatuses, false)
-}
-
-func (s *statusDB) GetStatusChildren(status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, db.Error) {
- foundStatuses := &list.List{}
- foundStatuses.PushFront(status)
- s.statusChildren(status, foundStatuses, onlyDirect, minID)
-
- children := []*gtsmodel.Status{}
- for e := foundStatuses.Front(); e != nil; e = e.Next() {
- entry, ok := e.Value.(*gtsmodel.Status)
- if !ok {
- panic(errors.New("entry in foundStatuses was not a *gtsmodel.Status"))
- }
-
- // only append children, not the overall parent status
- if entry.ID != status.ID {
- children = append(children, entry)
- }
- }
-
- return children, nil
-}
-
-func (s *statusDB) statusChildren(status *gtsmodel.Status, foundStatuses *list.List, onlyDirect bool, minID string) {
- immediateChildren := []*gtsmodel.Status{}
-
- q := s.conn.Model(&immediateChildren).Where("in_reply_to_id = ?", status.ID)
- if minID != "" {
- q = q.Where("status.id > ?", minID)
- }
-
- if err := q.Select(); err != nil {
- return
- }
-
- for _, child := range immediateChildren {
- insertLoop:
- for e := foundStatuses.Front(); e != nil; e = e.Next() {
- entry, ok := e.Value.(*gtsmodel.Status)
- if !ok {
- panic(errors.New("entry in foundStatuses was not a *gtsmodel.Status"))
- }
-
- if child.InReplyToAccountID != "" && entry.ID == child.InReplyToID {
- foundStatuses.InsertAfter(child, e)
- break insertLoop
- }
- }
-
- // only do one loop if we only want direct children
- if onlyDirect {
- return
- }
- s.statusChildren(child, foundStatuses, false, minID)
- }
-}
-
-func (s *statusDB) CountStatusReplies(status *gtsmodel.Status) (int, db.Error) {
- return s.conn.Model(&gtsmodel.Status{}).Where("in_reply_to_id = ?", status.ID).Count()
-}
-
-func (s *statusDB) CountStatusReblogs(status *gtsmodel.Status) (int, db.Error) {
- return s.conn.Model(&gtsmodel.Status{}).Where("boost_of_id = ?", status.ID).Count()
-}
-
-func (s *statusDB) CountStatusFaves(status *gtsmodel.Status) (int, db.Error) {
- return s.conn.Model(&gtsmodel.StatusFave{}).Where("status_id = ?", status.ID).Count()
-}
-
-func (s *statusDB) IsStatusFavedBy(status *gtsmodel.Status, accountID string) (bool, db.Error) {
- return s.conn.Model(&gtsmodel.StatusFave{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
-}
-
-func (s *statusDB) IsStatusRebloggedBy(status *gtsmodel.Status, accountID string) (bool, db.Error) {
- return s.conn.Model(&gtsmodel.Status{}).Where("boost_of_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
-}
-
-func (s *statusDB) IsStatusMutedBy(status *gtsmodel.Status, accountID string) (bool, db.Error) {
- return s.conn.Model(&gtsmodel.StatusMute{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
-}
-
-func (s *statusDB) IsStatusBookmarkedBy(status *gtsmodel.Status, accountID string) (bool, db.Error) {
- return s.conn.Model(&gtsmodel.StatusBookmark{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
-}
-
-func (s *statusDB) GetStatusFaves(status *gtsmodel.Status) ([]*gtsmodel.StatusFave, db.Error) {
- faves := []*gtsmodel.StatusFave{}
-
- q := s.newFaveQ(&faves).
- Where("status_id = ?", status.ID)
-
- err := processErrorResponse(q.Select())
-
- return faves, err
-}
-
-func (s *statusDB) GetStatusReblogs(status *gtsmodel.Status) ([]*gtsmodel.Status, db.Error) {
- reblogs := []*gtsmodel.Status{}
-
- q := s.newStatusQ(&reblogs).
- Where("boost_of_id = ?", status.ID)
-
- err := processErrorResponse(q.Select())
-
- return reblogs, err
-}
diff --git a/internal/db/pg/util.go b/internal/db/pg/util.go
deleted file mode 100644
index 17c09b720..000000000
--- a/internal/db/pg/util.go
+++ /dev/null
@@ -1,25 +0,0 @@
-package pg
-
-import (
- "strings"
-
- "github.com/go-pg/pg/v10"
- "github.com/superseriousbusiness/gotosocial/internal/db"
-)
-
-// processErrorResponse parses the given error and returns an appropriate DBError.
-func processErrorResponse(err error) db.Error {
- switch err {
- case nil:
- return nil
- case pg.ErrNoRows:
- return db.ErrNoEntries
- case pg.ErrMultiRows:
- return db.ErrMultipleEntries
- default:
- if strings.Contains(err.Error(), "duplicate key value violates unique constraint") {
- return db.ErrAlreadyExists
- }
- return err
- }
-}
diff --git a/internal/db/relationship.go b/internal/db/relationship.go
index 85f64d72b..804526425 100644
--- a/internal/db/relationship.go
+++ b/internal/db/relationship.go
@@ -18,54 +18,58 @@
package db
-import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+import (
+ "context"
+
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+)
// Relationship contains functions for getting or modifying the relationship between two accounts.
type Relationship interface {
// IsBlocked checks whether account 1 has a block in place against block2.
// If eitherDirection is true, then the function returns true if account1 blocks account2, OR if account2 blocks account1.
- IsBlocked(account1 string, account2 string, eitherDirection bool) (bool, Error)
+ IsBlocked(ctx context.Context, account1 string, account2 string, eitherDirection bool) (bool, Error)
// GetBlock returns the block from account1 targeting account2, if it exists, or an error if it doesn't.
//
// Because this is slower than Blocked, only use it if you need the actual Block struct for some reason,
// not if you're just checking for the existence of a block.
- GetBlock(account1 string, account2 string) (*gtsmodel.Block, Error)
+ GetBlock(ctx context.Context, account1 string, account2 string) (*gtsmodel.Block, Error)
// GetRelationship retrieves the relationship of the targetAccount to the requestingAccount.
- GetRelationship(requestingAccount string, targetAccount string) (*gtsmodel.Relationship, Error)
+ GetRelationship(ctx context.Context, requestingAccount string, targetAccount string) (*gtsmodel.Relationship, Error)
// IsFollowing returns true if sourceAccount follows target account, or an error if something goes wrong while finding out.
- IsFollowing(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, Error)
+ IsFollowing(ctx context.Context, sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, Error)
// IsFollowRequested returns true if sourceAccount has requested to follow target account, or an error if something goes wrong while finding out.
- IsFollowRequested(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, Error)
+ IsFollowRequested(ctx context.Context, sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, Error)
// IsMutualFollowing returns true if account1 and account2 both follow each other, or an error if something goes wrong while finding out.
- IsMutualFollowing(account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, Error)
+ IsMutualFollowing(ctx context.Context, account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, Error)
// AcceptFollowRequest moves a follow request in the database from the follow_requests table to the follows table.
// In other words, it should create the follow, and delete the existing follow request.
//
// It will return the newly created follow for further processing.
- AcceptFollowRequest(originAccountID string, targetAccountID string) (*gtsmodel.Follow, Error)
+ AcceptFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.Follow, Error)
// GetAccountFollowRequests returns all follow requests targeting the given account.
- GetAccountFollowRequests(accountID string) ([]*gtsmodel.FollowRequest, Error)
+ GetAccountFollowRequests(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, Error)
// GetAccountFollows returns a slice of follows owned by the given accountID.
- GetAccountFollows(accountID string) ([]*gtsmodel.Follow, Error)
+ GetAccountFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, Error)
// CountAccountFollows returns the amount of accounts that the given accountID is following.
//
// If localOnly is set to true, then only follows from *this instance* will be returned.
- CountAccountFollows(accountID string, localOnly bool) (int, Error)
+ CountAccountFollows(ctx context.Context, accountID string, localOnly bool) (int, Error)
// GetAccountFollowedBy fetches follows that target given accountID.
//
// If localOnly is set to true, then only follows from *this instance* will be returned.
- GetAccountFollowedBy(accountID string, localOnly bool) ([]*gtsmodel.Follow, Error)
+ GetAccountFollowedBy(ctx context.Context, accountID string, localOnly bool) ([]*gtsmodel.Follow, Error)
// CountAccountFollowedBy returns the amounts that the given ID is followed by.
- CountAccountFollowedBy(accountID string, localOnly bool) (int, Error)
+ CountAccountFollowedBy(ctx context.Context, accountID string, localOnly bool) (int, Error)
}
diff --git a/internal/federation/dereferencing/blocked.go b/internal/db/session.go
index c8a4c6ade..ae13dccce 100644
--- a/internal/federation/dereferencing/blocked.go
+++ b/internal/db/session.go
@@ -16,26 +16,16 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
-package dereferencing
+package db
import (
- "github.com/superseriousbusiness/gotosocial/internal/db"
+ "context"
+
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
-func (d *deref) blockedDomain(host string) (bool, error) {
- b := &gtsmodel.DomainBlock{}
- err := d.db.GetWhere([]db.Where{{Key: "domain", Value: host, CaseInsensitive: true}}, b)
- if err == nil {
- // block exists
- return true, nil
- }
-
- if err == db.ErrNoEntries {
- // there are no entries so there's no block
- return false, nil
- }
-
- // there's an actual error
- return false, err
+// Session handles getting/creation of router sessions.
+type Session interface {
+ GetSession(ctx context.Context) (*gtsmodel.RouterSession, Error)
+ CreateSession(ctx context.Context) (*gtsmodel.RouterSession, Error)
}
diff --git a/internal/db/status.go b/internal/db/status.go
index 9d206c198..7430433c4 100644
--- a/internal/db/status.go
+++ b/internal/db/status.go
@@ -18,58 +18,62 @@
package db
-import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+import (
+ "context"
+
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+)
// Status contains functions for getting statuses, creating statuses, and checking various other fields on statuses.
type Status interface {
// GetStatusByID returns one status from the database, with all rel fields populated (if possible).
- GetStatusByID(id string) (*gtsmodel.Status, Error)
+ GetStatusByID(ctx context.Context, id string) (*gtsmodel.Status, Error)
// GetStatusByURI returns one status from the database, with all rel fields populated (if possible).
- GetStatusByURI(uri string) (*gtsmodel.Status, Error)
+ GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.Status, Error)
// GetStatusByURL returns one status from the database, with all rel fields populated (if possible).
- GetStatusByURL(uri string) (*gtsmodel.Status, Error)
+ GetStatusByURL(ctx context.Context, uri string) (*gtsmodel.Status, Error)
// PutStatus stores one status in the database.
- PutStatus(status *gtsmodel.Status) Error
+ PutStatus(ctx context.Context, status *gtsmodel.Status) Error
// CountStatusReplies returns the amount of replies recorded for a status, or an error if something goes wrong
- CountStatusReplies(status *gtsmodel.Status) (int, Error)
+ CountStatusReplies(ctx context.Context, status *gtsmodel.Status) (int, Error)
// CountStatusReblogs returns the amount of reblogs/boosts recorded for a status, or an error if something goes wrong
- CountStatusReblogs(status *gtsmodel.Status) (int, Error)
+ CountStatusReblogs(ctx context.Context, status *gtsmodel.Status) (int, Error)
// CountStatusFaves returns the amount of faves/likes recorded for a status, or an error if something goes wrong
- CountStatusFaves(status *gtsmodel.Status) (int, Error)
+ CountStatusFaves(ctx context.Context, status *gtsmodel.Status) (int, Error)
// GetStatusParents gets the parent statuses of a given status.
//
// If onlyDirect is true, only the immediate parent will be returned.
- GetStatusParents(status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, Error)
+ GetStatusParents(ctx context.Context, status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, Error)
// GetStatusChildren gets the child statuses of a given status.
//
// If onlyDirect is true, only the immediate children will be returned.
- GetStatusChildren(status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, Error)
+ GetStatusChildren(ctx context.Context, status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, Error)
// IsStatusFavedBy checks if a given status has been faved by a given account ID
- IsStatusFavedBy(status *gtsmodel.Status, accountID string) (bool, Error)
+ IsStatusFavedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, Error)
// IsStatusRebloggedBy checks if a given status has been reblogged/boosted by a given account ID
- IsStatusRebloggedBy(status *gtsmodel.Status, accountID string) (bool, Error)
+ IsStatusRebloggedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, Error)
// IsStatusMutedBy checks if a given status has been muted by a given account ID
- IsStatusMutedBy(status *gtsmodel.Status, accountID string) (bool, Error)
+ IsStatusMutedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, Error)
// IsStatusBookmarkedBy checks if a given status has been bookmarked by a given account ID
- IsStatusBookmarkedBy(status *gtsmodel.Status, accountID string) (bool, Error)
+ IsStatusBookmarkedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, Error)
// GetStatusFaves returns a slice of faves/likes of the given status.
// This slice will be unfiltered, not taking account of blocks and whatnot, so filter it before serving it back to a user.
- GetStatusFaves(status *gtsmodel.Status) ([]*gtsmodel.StatusFave, Error)
+ GetStatusFaves(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.StatusFave, Error)
// GetStatusReblogs returns a slice of statuses that are a boost/reblog of the given status.
// This slice will be unfiltered, not taking account of blocks and whatnot, so filter it before serving it back to a user.
- GetStatusReblogs(status *gtsmodel.Status) ([]*gtsmodel.Status, Error)
+ GetStatusReblogs(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.Status, Error)
}
diff --git a/internal/db/timeline.go b/internal/db/timeline.go
index 74aa5c781..83fb3a959 100644
--- a/internal/db/timeline.go
+++ b/internal/db/timeline.go
@@ -18,20 +18,24 @@
package db
-import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+import (
+ "context"
+
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+)
// Timeline contains functionality for retrieving home/public/faved etc timelines for an account.
type Timeline interface {
// GetHomeTimeline returns a slice of statuses from accounts that are followed by the given account id.
//
// Statuses should be returned in descending order of when they were created (newest first).
- GetHomeTimeline(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, Error)
+ GetHomeTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, Error)
// GetPublicTimeline fetches the account's PUBLIC timeline -- ie., posts and replies that are public.
// It will use the given filters and try to return as many statuses as possible up to the limit.
//
// Statuses should be returned in descending order of when they were created (newest first).
- GetPublicTimeline(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, Error)
+ GetPublicTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, Error)
// GetFavedTimeline fetches the account's FAVED timeline -- ie., posts and replies that the requesting account has faved.
// It will use the given filters and try to return as many statuses as possible up to the limit.
@@ -40,5 +44,5 @@ type Timeline interface {
// In other words, they'll be returned in descending order of when they were faved by the requesting user, not when they were created.
//
// Also note the extra return values, which correspond to the nextMaxID and prevMinID for building Link headers.
- GetFavedTimeline(accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, Error)
+ GetFavedTimeline(ctx context.Context, accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, Error)
}
diff --git a/internal/federation/authenticate.go b/internal/federation/authenticate.go
index 699691ca6..81ac84544 100644
--- a/internal/federation/authenticate.go
+++ b/internal/federation/authenticate.go
@@ -148,7 +148,7 @@ func (f *federator) AuthenticateFederatedRequest(ctx context.Context, requestedU
// LOCAL ACCOUNT REQUEST
// the request is coming from INSIDE THE HOUSE so skip the remote dereferencing
l.Tracef("proceeding without dereference for local public key %s", requestingPublicKeyID)
- if err := f.db.GetWhere([]db.Where{{Key: "public_key_uri", Value: requestingPublicKeyID.String()}}, requestingLocalAccount); err != nil {
+ if err := f.db.GetWhere(ctx, []db.Where{{Key: "public_key_uri", Value: requestingPublicKeyID.String()}}, requestingLocalAccount); err != nil {
return nil, false, fmt.Errorf("couldn't get local account with public key uri %s from the database: %s", requestingPublicKeyID.String(), err)
}
publicKey = requestingLocalAccount.PublicKey
@@ -156,7 +156,7 @@ func (f *federator) AuthenticateFederatedRequest(ctx context.Context, requestedU
if err != nil {
return nil, false, fmt.Errorf("error parsing url %s: %s", requestingLocalAccount.URI, err)
}
- } else if err := f.db.GetWhere([]db.Where{{Key: "public_key_uri", Value: requestingPublicKeyID.String()}}, requestingRemoteAccount); err == nil {
+ } else if err := f.db.GetWhere(ctx, []db.Where{{Key: "public_key_uri", Value: requestingPublicKeyID.String()}}, requestingRemoteAccount); err == nil {
// REMOTE ACCOUNT REQUEST WITH KEY CACHED LOCALLY
// this is a remote account and we already have the public key for it so use that
l.Tracef("proceeding without dereference for cached public key %s", requestingPublicKeyID)
@@ -170,7 +170,7 @@ func (f *federator) AuthenticateFederatedRequest(ctx context.Context, requestedU
// the request is remote and we don't have the public key yet,
// so we need to authenticate the request properly by dereferencing the remote key
l.Tracef("proceeding with dereference for uncached public key %s", requestingPublicKeyID)
- transport, err := f.transportController.NewTransportForUsername(requestedUsername)
+ transport, err := f.transportController.NewTransportForUsername(ctx, requestedUsername)
if err != nil {
return nil, false, fmt.Errorf("transport err: %s", err)
}
diff --git a/internal/federation/dereference.go b/internal/federation/dereference.go
index 96a662e32..a09f0f84b 100644
--- a/internal/federation/dereference.go
+++ b/internal/federation/dereference.go
@@ -19,36 +19,37 @@
package federation
import (
+ "context"
"net/url"
"github.com/superseriousbusiness/gotosocial/internal/ap"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
-func (f *federator) GetRemoteAccount(username string, remoteAccountID *url.URL, refresh bool) (*gtsmodel.Account, bool, error) {
- return f.dereferencer.GetRemoteAccount(username, remoteAccountID, refresh)
+func (f *federator) GetRemoteAccount(ctx context.Context, username string, remoteAccountID *url.URL, refresh bool) (*gtsmodel.Account, bool, error) {
+ return f.dereferencer.GetRemoteAccount(ctx, username, remoteAccountID, refresh)
}
-func (f *federator) EnrichRemoteAccount(username string, account *gtsmodel.Account) (*gtsmodel.Account, error) {
- return f.dereferencer.EnrichRemoteAccount(username, account)
+func (f *federator) EnrichRemoteAccount(ctx context.Context, username string, account *gtsmodel.Account) (*gtsmodel.Account, error) {
+ return f.dereferencer.EnrichRemoteAccount(ctx, username, account)
}
-func (f *federator) GetRemoteStatus(username string, remoteStatusID *url.URL, refresh bool) (*gtsmodel.Status, ap.Statusable, bool, error) {
- return f.dereferencer.GetRemoteStatus(username, remoteStatusID, refresh)
+func (f *federator) GetRemoteStatus(ctx context.Context, username string, remoteStatusID *url.URL, refresh bool) (*gtsmodel.Status, ap.Statusable, bool, error) {
+ return f.dereferencer.GetRemoteStatus(ctx, username, remoteStatusID, refresh)
}
-func (f *federator) EnrichRemoteStatus(username string, status *gtsmodel.Status) (*gtsmodel.Status, error) {
- return f.dereferencer.EnrichRemoteStatus(username, status)
+func (f *federator) EnrichRemoteStatus(ctx context.Context, username string, status *gtsmodel.Status) (*gtsmodel.Status, error) {
+ return f.dereferencer.EnrichRemoteStatus(ctx, username, status)
}
-func (f *federator) DereferenceRemoteThread(username string, statusIRI *url.URL) error {
- return f.dereferencer.DereferenceThread(username, statusIRI)
+func (f *federator) DereferenceRemoteThread(ctx context.Context, username string, statusIRI *url.URL) error {
+ return f.dereferencer.DereferenceThread(ctx, username, statusIRI)
}
-func (f *federator) GetRemoteInstance(username string, remoteInstanceURI *url.URL) (*gtsmodel.Instance, error) {
- return f.dereferencer.GetRemoteInstance(username, remoteInstanceURI)
+func (f *federator) GetRemoteInstance(ctx context.Context, username string, remoteInstanceURI *url.URL) (*gtsmodel.Instance, error) {
+ return f.dereferencer.GetRemoteInstance(ctx, username, remoteInstanceURI)
}
-func (f *federator) DereferenceAnnounce(announce *gtsmodel.Status, requestingUsername string) error {
- return f.dereferencer.DereferenceAnnounce(announce, requestingUsername)
+func (f *federator) DereferenceAnnounce(ctx context.Context, announce *gtsmodel.Status, requestingUsername string) error {
+ return f.dereferencer.DereferenceAnnounce(ctx, announce, requestingUsername)
}
diff --git a/internal/federation/dereferencing/account.go b/internal/federation/dereferencing/account.go
index ba6766061..2eee0645d 100644
--- a/internal/federation/dereferencing/account.go
+++ b/internal/federation/dereferencing/account.go
@@ -24,6 +24,7 @@ import (
"errors"
"fmt"
"net/url"
+ "strings"
"github.com/go-fed/activity/streams"
"github.com/go-fed/activity/streams/vocab"
@@ -34,18 +35,33 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/transport"
)
+func instanceAccount(account *gtsmodel.Account) bool {
+ return strings.EqualFold(account.Username, account.Domain) ||
+ account.FollowersURI == "" ||
+ account.FollowingURI == "" ||
+ (account.Username == "internal.fetch" && strings.Contains(account.Note, "internal service actor"))
+}
+
// EnrichRemoteAccount takes an account that's already been inserted into the database in a minimal form,
// and populates it with additional fields, media, etc.
//
// EnrichRemoteAccount is mostly useful for calling after an account has been initially created by
// the federatingDB's Create function, or during the federated authorization flow.
-func (d *deref) EnrichRemoteAccount(username string, account *gtsmodel.Account) (*gtsmodel.Account, error) {
- if err := d.PopulateAccountFields(account, username, false); err != nil {
+func (d *deref) EnrichRemoteAccount(ctx context.Context, username string, account *gtsmodel.Account) (*gtsmodel.Account, error) {
+
+ // if we're dealing with an instance account, we don't need to update anything
+ if instanceAccount(account) {
+ return account, nil
+ }
+
+ if err := d.PopulateAccountFields(ctx, account, username, false); err != nil {
return nil, err
}
- if err := d.db.UpdateByID(account.ID, account); err != nil {
- return nil, fmt.Errorf("EnrichRemoteAccount: error updating account: %s", err)
+ var err error
+ account, err = d.db.UpdateAccount(ctx, account)
+ if err != nil {
+ d.log.Errorf("EnrichRemoteAccount: error updating account: %s", err)
}
return account, nil
@@ -60,27 +76,27 @@ func (d *deref) EnrichRemoteAccount(username string, account *gtsmodel.Account)
// the remote instance again.
//
// SIDE EFFECTS: remote account will be stored in the database, or updated if it already exists (and refresh is true).
-func (d *deref) GetRemoteAccount(username string, remoteAccountID *url.URL, refresh bool) (*gtsmodel.Account, bool, error) {
+func (d *deref) GetRemoteAccount(ctx context.Context, username string, remoteAccountID *url.URL, refresh bool) (*gtsmodel.Account, bool, error) {
new := true
// check if we already have the account in our db
- maybeAccount, err := d.db.GetAccountByURI(remoteAccountID.String())
+ maybeAccount, err := d.db.GetAccountByURI(ctx, remoteAccountID.String())
if err == nil {
// we've seen this account before so it's not new
new = false
if !refresh {
// we're not being asked to refresh, but just in case we don't have the avatar/header cached yet....
- maybeAccount, err = d.EnrichRemoteAccount(username, maybeAccount)
+ maybeAccount, err = d.EnrichRemoteAccount(ctx, username, maybeAccount)
return maybeAccount, new, err
}
}
- accountable, err := d.dereferenceAccountable(username, remoteAccountID)
+ accountable, err := d.dereferenceAccountable(ctx, username, remoteAccountID)
if err != nil {
return nil, new, fmt.Errorf("FullyDereferenceAccount: error dereferencing accountable: %s", err)
}
- gtsAccount, err := d.typeConverter.ASRepresentationToAccount(accountable, refresh)
+ gtsAccount, err := d.typeConverter.ASRepresentationToAccount(ctx, accountable, refresh)
if err != nil {
return nil, new, fmt.Errorf("FullyDereferenceAccount: error converting accountable to account: %s", err)
}
@@ -93,23 +109,24 @@ func (d *deref) GetRemoteAccount(username string, remoteAccountID *url.URL, refr
}
gtsAccount.ID = ulid
- if err := d.PopulateAccountFields(gtsAccount, username, refresh); err != nil {
+ if err := d.PopulateAccountFields(ctx, gtsAccount, username, refresh); err != nil {
return nil, new, fmt.Errorf("FullyDereferenceAccount: error populating further account fields: %s", err)
}
- if err := d.db.Put(gtsAccount); err != nil {
+ if err := d.db.Put(ctx, gtsAccount); err != nil {
return nil, new, fmt.Errorf("FullyDereferenceAccount: error putting new account: %s", err)
}
} else {
// take the id we already have and do an update
gtsAccount.ID = maybeAccount.ID
- if err := d.PopulateAccountFields(gtsAccount, username, refresh); err != nil {
+ if err := d.PopulateAccountFields(ctx, gtsAccount, username, refresh); err != nil {
return nil, new, fmt.Errorf("FullyDereferenceAccount: error populating further account fields: %s", err)
}
- if err := d.db.UpdateByID(gtsAccount.ID, gtsAccount); err != nil {
- return nil, new, fmt.Errorf("FullyDereferenceAccount: error updating existing account: %s", err)
+ gtsAccount, err = d.db.UpdateAccount(ctx, gtsAccount)
+ if err != nil {
+ return nil, false, fmt.Errorf("EnrichRemoteAccount: error updating account: %s", err)
}
}
@@ -120,15 +137,15 @@ func (d *deref) GetRemoteAccount(username string, remoteAccountID *url.URL, refr
// it finds as something that an account model can be constructed out of.
//
// Will work for Person, Application, or Service models.
-func (d *deref) dereferenceAccountable(username string, remoteAccountID *url.URL) (ap.Accountable, error) {
+func (d *deref) dereferenceAccountable(ctx context.Context, username string, remoteAccountID *url.URL) (ap.Accountable, error) {
d.startHandshake(username, remoteAccountID)
defer d.stopHandshake(username, remoteAccountID)
- if blocked, err := d.blockedDomain(remoteAccountID.Host); blocked || err != nil {
+ if blocked, err := d.db.IsDomainBlocked(ctx, remoteAccountID.Host); blocked || err != nil {
return nil, fmt.Errorf("DereferenceAccountable: domain %s is blocked", remoteAccountID.Host)
}
- transport, err := d.transportController.NewTransportForUsername(username)
+ transport, err := d.transportController.NewTransportForUsername(ctx, username)
if err != nil {
return nil, fmt.Errorf("DereferenceAccountable: transport err: %s", err)
}
@@ -174,7 +191,7 @@ func (d *deref) dereferenceAccountable(username string, remoteAccountID *url.URL
// PopulateAccountFields populates any fields on the given account that weren't populated by the initial
// dereferencing. This includes things like header and avatar etc.
-func (d *deref) PopulateAccountFields(account *gtsmodel.Account, requestingUsername string, refresh bool) error {
+func (d *deref) PopulateAccountFields(ctx context.Context, account *gtsmodel.Account, requestingUsername string, refresh bool) error {
l := d.log.WithFields(logrus.Fields{
"func": "PopulateAccountFields",
"requestingUsername": requestingUsername,
@@ -184,17 +201,17 @@ func (d *deref) PopulateAccountFields(account *gtsmodel.Account, requestingUsern
if err != nil {
return fmt.Errorf("PopulateAccountFields: couldn't parse account URI %s: %s", account.URI, err)
}
- if blocked, err := d.blockedDomain(accountURI.Host); blocked || err != nil {
+ if blocked, err := d.db.IsDomainBlocked(ctx, accountURI.Host); blocked || err != nil {
return fmt.Errorf("PopulateAccountFields: domain %s is blocked", accountURI.Host)
}
- t, err := d.transportController.NewTransportForUsername(requestingUsername)
+ t, err := d.transportController.NewTransportForUsername(ctx, requestingUsername)
if err != nil {
return fmt.Errorf("PopulateAccountFields: error getting transport for user: %s", err)
}
// fetch the header and avatar
- if err := d.fetchHeaderAndAviForAccount(account, t, refresh); err != nil {
+ if err := d.fetchHeaderAndAviForAccount(ctx, account, t, refresh); err != nil {
// if this doesn't work, just skip it -- we can do it later
l.Debugf("error fetching header/avi for account: %s", err)
}
@@ -208,17 +225,17 @@ func (d *deref) PopulateAccountFields(account *gtsmodel.Account, requestingUsern
// targetAccount's AvatarMediaAttachmentID and HeaderMediaAttachmentID will be updated as necessary.
//
// SIDE EFFECTS: remote header and avatar will be stored in local storage.
-func (d *deref) fetchHeaderAndAviForAccount(targetAccount *gtsmodel.Account, t transport.Transport, refresh bool) error {
+func (d *deref) fetchHeaderAndAviForAccount(ctx context.Context, targetAccount *gtsmodel.Account, t transport.Transport, refresh bool) error {
accountURI, err := url.Parse(targetAccount.URI)
if err != nil {
return fmt.Errorf("fetchHeaderAndAviForAccount: couldn't parse account URI %s: %s", targetAccount.URI, err)
}
- if blocked, err := d.blockedDomain(accountURI.Host); blocked || err != nil {
+ if blocked, err := d.db.IsDomainBlocked(ctx, accountURI.Host); blocked || err != nil {
return fmt.Errorf("fetchHeaderAndAviForAccount: domain %s is blocked", accountURI.Host)
}
if targetAccount.AvatarRemoteURL != "" && (targetAccount.AvatarMediaAttachmentID == "" || refresh) {
- a, err := d.mediaHandler.ProcessRemoteHeaderOrAvatar(t, &gtsmodel.MediaAttachment{
+ a, err := d.mediaHandler.ProcessRemoteHeaderOrAvatar(ctx, t, &gtsmodel.MediaAttachment{
RemoteURL: targetAccount.AvatarRemoteURL,
Avatar: true,
}, targetAccount.ID)
@@ -229,7 +246,7 @@ func (d *deref) fetchHeaderAndAviForAccount(targetAccount *gtsmodel.Account, t t
}
if targetAccount.HeaderRemoteURL != "" && (targetAccount.HeaderMediaAttachmentID == "" || refresh) {
- a, err := d.mediaHandler.ProcessRemoteHeaderOrAvatar(t, &gtsmodel.MediaAttachment{
+ a, err := d.mediaHandler.ProcessRemoteHeaderOrAvatar(ctx, t, &gtsmodel.MediaAttachment{
RemoteURL: targetAccount.HeaderRemoteURL,
Header: true,
}, targetAccount.ID)
diff --git a/internal/federation/dereferencing/announce.go b/internal/federation/dereferencing/announce.go
index 6773db425..33af74ebe 100644
--- a/internal/federation/dereferencing/announce.go
+++ b/internal/federation/dereferencing/announce.go
@@ -19,6 +19,7 @@
package dereferencing
import (
+ "context"
"errors"
"fmt"
"net/url"
@@ -26,7 +27,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
-func (d *deref) DereferenceAnnounce(announce *gtsmodel.Status, requestingUsername string) error {
+func (d *deref) DereferenceAnnounce(ctx context.Context, announce *gtsmodel.Status, requestingUsername string) error {
if announce.BoostOf == nil || announce.BoostOf.URI == "" {
// we can't do anything unfortunately
return errors.New("DereferenceAnnounce: no URI to dereference")
@@ -36,16 +37,16 @@ func (d *deref) DereferenceAnnounce(announce *gtsmodel.Status, requestingUsernam
if err != nil {
return fmt.Errorf("DereferenceAnnounce: couldn't parse boosted status URI %s: %s", announce.BoostOf.URI, err)
}
- if blocked, err := d.blockedDomain(boostedStatusURI.Host); blocked || err != nil {
+ if blocked, err := d.db.IsDomainBlocked(ctx, boostedStatusURI.Host); blocked || err != nil {
return fmt.Errorf("DereferenceAnnounce: domain %s is blocked", boostedStatusURI.Host)
}
// dereference statuses in the thread of the boosted status
- if err := d.DereferenceThread(requestingUsername, boostedStatusURI); err != nil {
+ if err := d.DereferenceThread(ctx, requestingUsername, boostedStatusURI); err != nil {
return fmt.Errorf("DereferenceAnnounce: error dereferencing thread of boosted status: %s", err)
}
- boostedStatus, _, _, err := d.GetRemoteStatus(requestingUsername, boostedStatusURI, false)
+ boostedStatus, _, _, err := d.GetRemoteStatus(ctx, requestingUsername, boostedStatusURI, false)
if err != nil {
return fmt.Errorf("DereferenceAnnounce: error dereferencing remote status with id %s: %s", announce.BoostOf.URI, err)
}
diff --git a/internal/federation/dereferencing/collectionpage.go b/internal/federation/dereferencing/collectionpage.go
index 5feadc1ad..6f0beeaf6 100644
--- a/internal/federation/dereferencing/collectionpage.go
+++ b/internal/federation/dereferencing/collectionpage.go
@@ -32,12 +32,12 @@ import (
)
// DereferenceCollectionPage returns the activitystreams CollectionPage at the specified IRI, or an error if something goes wrong.
-func (d *deref) DereferenceCollectionPage(username string, pageIRI *url.URL) (ap.CollectionPageable, error) {
- if blocked, err := d.blockedDomain(pageIRI.Host); blocked || err != nil {
+func (d *deref) DereferenceCollectionPage(ctx context.Context, username string, pageIRI *url.URL) (ap.CollectionPageable, error) {
+ if blocked, err := d.db.IsDomainBlocked(ctx, pageIRI.Host); blocked || err != nil {
return nil, fmt.Errorf("DereferenceCollectionPage: domain %s is blocked", pageIRI.Host)
}
- transport, err := d.transportController.NewTransportForUsername(username)
+ transport, err := d.transportController.NewTransportForUsername(ctx, username)
if err != nil {
return nil, fmt.Errorf("DereferenceCollectionPage: error creating transport: %s", err)
}
diff --git a/internal/federation/dereferencing/dereferencer.go b/internal/federation/dereferencing/dereferencer.go
index 03b90569a..71625ed88 100644
--- a/internal/federation/dereferencing/dereferencer.go
+++ b/internal/federation/dereferencing/dereferencer.go
@@ -19,6 +19,7 @@
package dereferencing
import (
+ "context"
"net/url"
"sync"
@@ -34,18 +35,18 @@ import (
// Dereferencer wraps logic and functionality for doing dereferencing of remote accounts, statuses, etc, from federated instances.
type Dereferencer interface {
- GetRemoteAccount(username string, remoteAccountID *url.URL, refresh bool) (*gtsmodel.Account, bool, error)
- EnrichRemoteAccount(username string, account *gtsmodel.Account) (*gtsmodel.Account, error)
+ GetRemoteAccount(ctx context.Context, username string, remoteAccountID *url.URL, refresh bool) (*gtsmodel.Account, bool, error)
+ EnrichRemoteAccount(ctx context.Context, username string, account *gtsmodel.Account) (*gtsmodel.Account, error)
- GetRemoteStatus(username string, remoteStatusID *url.URL, refresh bool) (*gtsmodel.Status, ap.Statusable, bool, error)
- EnrichRemoteStatus(username string, status *gtsmodel.Status) (*gtsmodel.Status, error)
+ GetRemoteStatus(ctx context.Context, username string, remoteStatusID *url.URL, refresh bool) (*gtsmodel.Status, ap.Statusable, bool, error)
+ EnrichRemoteStatus(ctx context.Context, username string, status *gtsmodel.Status) (*gtsmodel.Status, error)
- GetRemoteInstance(username string, remoteInstanceURI *url.URL) (*gtsmodel.Instance, error)
+ GetRemoteInstance(ctx context.Context, username string, remoteInstanceURI *url.URL) (*gtsmodel.Instance, error)
- DereferenceAnnounce(announce *gtsmodel.Status, requestingUsername string) error
- DereferenceThread(username string, statusIRI *url.URL) error
+ DereferenceAnnounce(ctx context.Context, announce *gtsmodel.Status, requestingUsername string) error
+ DereferenceThread(ctx context.Context, username string, statusIRI *url.URL) error
- Handshaking(username string, remoteAccountID *url.URL) bool
+ Handshaking(ctx context.Context, username string, remoteAccountID *url.URL) bool
}
type deref struct {
diff --git a/internal/federation/dereferencing/handshake.go b/internal/federation/dereferencing/handshake.go
index cda8eafd0..17003be84 100644
--- a/internal/federation/dereferencing/handshake.go
+++ b/internal/federation/dereferencing/handshake.go
@@ -18,9 +18,12 @@
package dereferencing
-import "net/url"
+import (
+ "context"
+ "net/url"
+)
-func (d *deref) Handshaking(username string, remoteAccountID *url.URL) bool {
+func (d *deref) Handshaking(ctx context.Context, username string, remoteAccountID *url.URL) bool {
d.handshakeSync.Lock()
defer d.handshakeSync.Unlock()
diff --git a/internal/federation/dereferencing/instance.go b/internal/federation/dereferencing/instance.go
index 80f626662..ec3c3f13d 100644
--- a/internal/federation/dereferencing/instance.go
+++ b/internal/federation/dereferencing/instance.go
@@ -26,12 +26,12 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
-func (d *deref) GetRemoteInstance(username string, remoteInstanceURI *url.URL) (*gtsmodel.Instance, error) {
- if blocked, err := d.blockedDomain(remoteInstanceURI.Host); blocked || err != nil {
+func (d *deref) GetRemoteInstance(ctx context.Context, username string, remoteInstanceURI *url.URL) (*gtsmodel.Instance, error) {
+ if blocked, err := d.db.IsDomainBlocked(ctx, remoteInstanceURI.Host); blocked || err != nil {
return nil, fmt.Errorf("GetRemoteInstance: domain %s is blocked", remoteInstanceURI.Host)
}
- transport, err := d.transportController.NewTransportForUsername(username)
+ transport, err := d.transportController.NewTransportForUsername(ctx, username)
if err != nil {
return nil, fmt.Errorf("transport err: %s", err)
}
diff --git a/internal/federation/dereferencing/status.go b/internal/federation/dereferencing/status.go
index 68693c021..93ead6523 100644
--- a/internal/federation/dereferencing/status.go
+++ b/internal/federation/dereferencing/status.go
@@ -39,12 +39,12 @@ import (
//
// EnrichRemoteStatus is mostly useful for calling after a status has been initially created by
// the federatingDB's Create function, but additional dereferencing is needed on it.
-func (d *deref) EnrichRemoteStatus(username string, status *gtsmodel.Status) (*gtsmodel.Status, error) {
- if err := d.populateStatusFields(status, username); err != nil {
+func (d *deref) EnrichRemoteStatus(ctx context.Context, username string, status *gtsmodel.Status) (*gtsmodel.Status, error) {
+ if err := d.populateStatusFields(ctx, status, username); err != nil {
return nil, err
}
- if err := d.db.UpdateByID(status.ID, status); err != nil {
+ if err := d.db.UpdateByID(ctx, status.ID, status); err != nil {
return nil, fmt.Errorf("EnrichRemoteStatus: error updating status: %s", err)
}
@@ -62,11 +62,11 @@ func (d *deref) EnrichRemoteStatus(username string, status *gtsmodel.Status) (*g
// If a dereference was performed, then the function also returns the ap.Statusable representation for further processing.
//
// SIDE EFFECTS: remote status will be stored in the database, and the remote status owner will also be stored.
-func (d *deref) GetRemoteStatus(username string, remoteStatusID *url.URL, refresh bool) (*gtsmodel.Status, ap.Statusable, bool, error) {
+func (d *deref) GetRemoteStatus(ctx context.Context, username string, remoteStatusID *url.URL, refresh bool) (*gtsmodel.Status, ap.Statusable, bool, error) {
new := true
// check if we already have the status in our db
- maybeStatus, err := d.db.GetStatusByURI(remoteStatusID.String())
+ maybeStatus, err := d.db.GetStatusByURI(ctx, remoteStatusID.String())
if err == nil {
// we've seen this status before so it's not new
new = false
@@ -77,7 +77,7 @@ func (d *deref) GetRemoteStatus(username string, remoteStatusID *url.URL, refres
}
}
- statusable, err := d.dereferenceStatusable(username, remoteStatusID)
+ statusable, err := d.dereferenceStatusable(ctx, username, remoteStatusID)
if err != nil {
return nil, statusable, new, fmt.Errorf("GetRemoteStatus: error dereferencing statusable: %s", err)
}
@@ -88,12 +88,12 @@ func (d *deref) GetRemoteStatus(username string, remoteStatusID *url.URL, refres
}
// do this so we know we have the remote account of the status in the db
- _, _, err = d.GetRemoteAccount(username, accountURI, false)
+ _, _, err = d.GetRemoteAccount(ctx, username, accountURI, false)
if err != nil {
return nil, statusable, new, fmt.Errorf("GetRemoteStatus: couldn't derive status author: %s", err)
}
- gtsStatus, err := d.typeConverter.ASStatusToStatus(statusable)
+ gtsStatus, err := d.typeConverter.ASStatusToStatus(ctx, statusable)
if err != nil {
return nil, statusable, new, fmt.Errorf("GetRemoteStatus: error converting statusable to status: %s", err)
}
@@ -105,21 +105,21 @@ func (d *deref) GetRemoteStatus(username string, remoteStatusID *url.URL, refres
}
gtsStatus.ID = ulid
- if err := d.populateStatusFields(gtsStatus, username); err != nil {
+ if err := d.populateStatusFields(ctx, gtsStatus, username); err != nil {
return nil, statusable, new, fmt.Errorf("GetRemoteStatus: error populating status fields: %s", err)
}
- if err := d.db.PutStatus(gtsStatus); err != nil {
+ if err := d.db.PutStatus(ctx, gtsStatus); err != nil {
return nil, statusable, new, fmt.Errorf("GetRemoteStatus: error putting new status: %s", err)
}
} else {
gtsStatus.ID = maybeStatus.ID
- if err := d.populateStatusFields(gtsStatus, username); err != nil {
+ if err := d.populateStatusFields(ctx, gtsStatus, username); err != nil {
return nil, statusable, new, fmt.Errorf("GetRemoteStatus: error populating status fields: %s", err)
}
- if err := d.db.UpdateByID(gtsStatus.ID, gtsStatus); err != nil {
+ if err := d.db.UpdateByID(ctx, gtsStatus.ID, gtsStatus); err != nil {
return nil, statusable, new, fmt.Errorf("GetRemoteStatus: error updating status: %s", err)
}
}
@@ -127,12 +127,12 @@ func (d *deref) GetRemoteStatus(username string, remoteStatusID *url.URL, refres
return gtsStatus, statusable, new, nil
}
-func (d *deref) dereferenceStatusable(username string, remoteStatusID *url.URL) (ap.Statusable, error) {
- if blocked, err := d.blockedDomain(remoteStatusID.Host); blocked || err != nil {
+func (d *deref) dereferenceStatusable(ctx context.Context, username string, remoteStatusID *url.URL) (ap.Statusable, error) {
+ if blocked, err := d.db.IsDomainBlocked(ctx, remoteStatusID.Host); blocked || err != nil {
return nil, fmt.Errorf("DereferenceStatusable: domain %s is blocked", remoteStatusID.Host)
}
- transport, err := d.transportController.NewTransportForUsername(username)
+ transport, err := d.transportController.NewTransportForUsername(ctx, username)
if err != nil {
return nil, fmt.Errorf("DereferenceStatusable: transport err: %s", err)
}
@@ -236,7 +236,7 @@ func (d *deref) dereferenceStatusable(username string, remoteStatusID *url.URL)
// This function will deference all of the above, insert them in the database as necessary,
// and attach them to the status. The status itself will not be added to the database yet,
// that's up the caller to do.
-func (d *deref) populateStatusFields(status *gtsmodel.Status, requestingUsername string) error {
+func (d *deref) populateStatusFields(ctx context.Context, status *gtsmodel.Status, requestingUsername string) error {
l := d.log.WithFields(logrus.Fields{
"func": "dereferenceStatusFields",
"status": fmt.Sprintf("%+v", status),
@@ -248,12 +248,12 @@ func (d *deref) populateStatusFields(status *gtsmodel.Status, requestingUsername
if err != nil {
return fmt.Errorf("DereferenceStatusFields: couldn't parse status URI %s: %s", status.URI, err)
}
- if blocked, err := d.blockedDomain(statusURI.Host); blocked || err != nil {
+ if blocked, err := d.db.IsDomainBlocked(ctx, statusURI.Host); blocked || err != nil {
return fmt.Errorf("DereferenceStatusFields: domain %s is blocked", statusURI.Host)
}
// we can continue -- create a new transport here because we'll probably need it
- t, err := d.transportController.NewTransportForUsername(requestingUsername)
+ t, err := d.transportController.NewTransportForUsername(ctx, requestingUsername)
if err != nil {
return fmt.Errorf("error creating transport: %s", err)
}
@@ -281,7 +281,7 @@ func (d *deref) populateStatusFields(status *gtsmodel.Status, requestingUsername
// it might have been processed elsewhere so check first if it's already in the database or not
maybeAttachment := &gtsmodel.MediaAttachment{}
- err := d.db.GetWhere([]db.Where{{Key: "remote_url", Value: a.RemoteURL}}, maybeAttachment)
+ err := d.db.GetWhere(ctx, []db.Where{{Key: "remote_url", Value: a.RemoteURL}}, maybeAttachment)
if err == nil {
// we already have it in the db, dereferenced, no need to do it again
l.Tracef("attachment already exists with id %s", maybeAttachment.ID)
@@ -294,7 +294,7 @@ func (d *deref) populateStatusFields(status *gtsmodel.Status, requestingUsername
}
// it just doesn't exist yet so carry on
l.Debug("attachment doesn't exist yet, calling ProcessRemoteAttachment", a)
- deferencedAttachment, err := d.mediaHandler.ProcessRemoteAttachment(t, a, status.AccountID)
+ deferencedAttachment, err := d.mediaHandler.ProcessRemoteAttachment(ctx, t, a, status.AccountID)
if err != nil {
l.Errorf("error dereferencing status attachment: %s", err)
continue
@@ -302,7 +302,7 @@ func (d *deref) populateStatusFields(status *gtsmodel.Status, requestingUsername
l.Debugf("dereferenced attachment: %+v", deferencedAttachment)
deferencedAttachment.StatusID = status.ID
deferencedAttachment.Description = a.Description
- if err := d.db.Put(deferencedAttachment); err != nil {
+ if err := d.db.Put(ctx, deferencedAttachment); err != nil {
return fmt.Errorf("error inserting dereferenced attachment with remote url %s: %s", a.RemoteURL, err)
}
attachmentIDs = append(attachmentIDs, deferencedAttachment.ID)
@@ -338,9 +338,9 @@ func (d *deref) populateStatusFields(status *gtsmodel.Status, requestingUsername
}
var targetAccount *gtsmodel.Account
- if a, err := d.db.GetAccountByURL(targetAccountURI.String()); err == nil {
+ if a, err := d.db.GetAccountByURL(ctx, targetAccountURI.String()); err == nil {
targetAccount = a
- } else if a, _, err := d.GetRemoteAccount(requestingUsername, targetAccountURI, false); err == nil {
+ } else if a, _, err := d.GetRemoteAccount(ctx, requestingUsername, targetAccountURI, false); err == nil {
targetAccount = a
} else {
// we can't find the target account so bail
@@ -369,7 +369,7 @@ func (d *deref) populateStatusFields(status *gtsmodel.Status, requestingUsername
TargetAccountURL: targetAccount.URL,
}
- if err := d.db.Put(m); err != nil {
+ if err := d.db.Put(ctx, m); err != nil {
return fmt.Errorf("error creating mention: %s", err)
}
mentionIDs = append(mentionIDs, m.ID)
@@ -382,13 +382,13 @@ func (d *deref) populateStatusFields(status *gtsmodel.Status, requestingUsername
if err != nil {
return err
}
- if replyToStatus, err := d.db.GetStatusByURI(status.InReplyToURI); err == nil {
+ if replyToStatus, err := d.db.GetStatusByURI(ctx, status.InReplyToURI); err == nil {
// we have the status
status.InReplyToID = replyToStatus.ID
status.InReplyTo = replyToStatus
status.InReplyToAccountID = replyToStatus.AccountID
status.InReplyToAccount = replyToStatus.Account
- } else if replyToStatus, _, _, err := d.GetRemoteStatus(requestingUsername, statusURI, false); err == nil {
+ } else if replyToStatus, _, _, err := d.GetRemoteStatus(ctx, requestingUsername, statusURI, false); err == nil {
// we got the status
status.InReplyToID = replyToStatus.ID
status.InReplyTo = replyToStatus
diff --git a/internal/federation/dereferencing/thread.go b/internal/federation/dereferencing/thread.go
index 2a407f923..f9dd9aa09 100644
--- a/internal/federation/dereferencing/thread.go
+++ b/internal/federation/dereferencing/thread.go
@@ -19,12 +19,12 @@
package dereferencing
import (
+ "context"
"fmt"
"net/url"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/ap"
- "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/util"
)
@@ -34,7 +34,7 @@ import (
// This process involves working up and down the chain of replies, and parsing through the collections of IDs
// presented by remote instances as part of their replies collections, and will likely involve making several calls to
// multiple different hosts.
-func (d *deref) DereferenceThread(username string, statusIRI *url.URL) error {
+func (d *deref) DereferenceThread(ctx context.Context, username string, statusIRI *url.URL) error {
l := d.log.WithFields(logrus.Fields{
"func": "DereferenceThread",
"username": username,
@@ -49,18 +49,18 @@ func (d *deref) DereferenceThread(username string, statusIRI *url.URL) error {
}
// first make sure we have this status in our db
- _, statusable, _, err := d.GetRemoteStatus(username, statusIRI, true)
+ _, statusable, _, err := d.GetRemoteStatus(ctx, username, statusIRI, true)
if err != nil {
return fmt.Errorf("DereferenceThread: error getting status with id %s: %s", statusIRI.String(), err)
}
// first iterate up through ancestors, dereferencing if necessary as we go
- if err := d.iterateAncestors(username, *statusIRI); err != nil {
+ if err := d.iterateAncestors(ctx, username, *statusIRI); err != nil {
return fmt.Errorf("error iterating ancestors of status %s: %s", statusIRI.String(), err)
}
// now iterate down through descendants, again dereferencing as we go
- if err := d.iterateDescendants(username, *statusIRI, statusable); err != nil {
+ if err := d.iterateDescendants(ctx, username, *statusIRI, statusable); err != nil {
return fmt.Errorf("error iterating descendants of status %s: %s", statusIRI.String(), err)
}
@@ -68,7 +68,7 @@ func (d *deref) DereferenceThread(username string, statusIRI *url.URL) error {
}
// iterateAncestors has the goal of reaching the oldest ancestor of a given status, and stashing all statuses along the way.
-func (d *deref) iterateAncestors(username string, statusIRI url.URL) error {
+func (d *deref) iterateAncestors(ctx context.Context, username string, statusIRI url.URL) error {
l := d.log.WithFields(logrus.Fields{
"func": "iterateAncestors",
"username": username,
@@ -86,8 +86,8 @@ func (d *deref) iterateAncestors(username string, statusIRI url.URL) error {
return err
}
- status := &gtsmodel.Status{}
- if err := d.db.GetByID(id, status); err != nil {
+ status, err := d.db.GetStatusByID(ctx, id)
+ if err != nil {
return err
}
@@ -99,12 +99,12 @@ func (d *deref) iterateAncestors(username string, statusIRI url.URL) error {
if err != nil {
return err
}
- return d.iterateAncestors(username, *nextIRI)
+ return d.iterateAncestors(ctx, username, *nextIRI)
}
// If we reach here, we're looking at a remote status -- make sure we have it in our db by calling GetRemoteStatus
// We call it with refresh to true because we want the statusable representation to parse inReplyTo from.
- status, statusable, _, err := d.GetRemoteStatus(username, &statusIRI, true)
+ status, statusable, _, err := d.GetRemoteStatus(ctx, username, &statusIRI, true)
if err != nil {
l.Debugf("error getting remote status: %s", err)
return nil
@@ -117,22 +117,22 @@ func (d *deref) iterateAncestors(username string, statusIRI url.URL) error {
}
// get the ancestor status into our database if we don't have it yet
- if _, _, _, err := d.GetRemoteStatus(username, inReplyTo, false); err != nil {
+ if _, _, _, err := d.GetRemoteStatus(ctx, username, inReplyTo, false); err != nil {
l.Debugf("error getting remote status: %s", err)
return nil
}
// now enrich the current status, since we should have the ancestor in the db
- if _, err := d.EnrichRemoteStatus(username, status); err != nil {
+ if _, err := d.EnrichRemoteStatus(ctx, username, status); err != nil {
l.Debugf("error enriching remote status: %s", err)
return nil
}
// now move up to the next ancestor
- return d.iterateAncestors(username, *inReplyTo)
+ return d.iterateAncestors(ctx, username, *inReplyTo)
}
-func (d *deref) iterateDescendants(username string, statusIRI url.URL, statusable ap.Statusable) error {
+func (d *deref) iterateDescendants(ctx context.Context, username string, statusIRI url.URL, statusable ap.Statusable) error {
l := d.log.WithFields(logrus.Fields{
"func": "iterateDescendants",
"username": username,
@@ -182,7 +182,7 @@ func (d *deref) iterateDescendants(username string, statusIRI url.URL, statusabl
pageLoop:
for {
l.Debugf("dereferencing page %s", currentPageIRI)
- nextPage, err := d.DereferenceCollectionPage(username, currentPageIRI)
+ nextPage, err := d.DereferenceCollectionPage(ctx, username, currentPageIRI)
if err != nil {
return nil
}
@@ -226,10 +226,10 @@ pageLoop:
foundReplies = foundReplies + 1
// get the remote statusable and put it in the db
- _, statusable, new, err := d.GetRemoteStatus(username, itemURI, false)
+ _, statusable, new, err := d.GetRemoteStatus(ctx, username, itemURI, false)
if new && err == nil && statusable != nil {
// now iterate descendants of *that* status
- if err := d.iterateDescendants(username, *itemURI, statusable); err != nil {
+ if err := d.iterateDescendants(ctx, username, *itemURI, statusable); err != nil {
continue
}
}
diff --git a/internal/federation/federatingdb/accept.go b/internal/federation/federatingdb/accept.go
index 91d9df86f..0b14e8a6a 100644
--- a/internal/federation/federatingdb/accept.go
+++ b/internal/federation/federatingdb/accept.go
@@ -86,7 +86,7 @@ func (f *federatingDB) Accept(ctx context.Context, accept vocab.ActivityStreamsA
if util.IsFollowPath(acceptedObjectIRI) {
// ACCEPT FOLLOW
gtsFollowRequest := &gtsmodel.FollowRequest{}
- if err := f.db.GetWhere([]db.Where{{Key: "uri", Value: acceptedObjectIRI.String()}}, gtsFollowRequest); err != nil {
+ if err := f.db.GetWhere(ctx, []db.Where{{Key: "uri", Value: acceptedObjectIRI.String()}}, gtsFollowRequest); err != nil {
return fmt.Errorf("ACCEPT: couldn't get follow request with id %s from the database: %s", acceptedObjectIRI.String(), err)
}
@@ -94,7 +94,7 @@ func (f *federatingDB) Accept(ctx context.Context, accept vocab.ActivityStreamsA
if gtsFollowRequest.AccountID != targetAcct.ID {
return errors.New("ACCEPT: follow object account and inbox account were not the same")
}
- follow, err := f.db.AcceptFollowRequest(gtsFollowRequest.AccountID, gtsFollowRequest.TargetAccountID)
+ follow, err := f.db.AcceptFollowRequest(ctx, gtsFollowRequest.AccountID, gtsFollowRequest.TargetAccountID)
if err != nil {
return err
}
@@ -123,7 +123,7 @@ func (f *federatingDB) Accept(ctx context.Context, accept vocab.ActivityStreamsA
return errors.New("ACCEPT: couldn't parse follow into vocab.ActivityStreamsFollow")
}
// convert the follow to something we can understand
- gtsFollow, err := f.typeConverter.ASFollowToFollow(asFollow)
+ gtsFollow, err := f.typeConverter.ASFollowToFollow(ctx, asFollow)
if err != nil {
return fmt.Errorf("ACCEPT: error converting asfollow to gtsfollow: %s", err)
}
@@ -131,7 +131,7 @@ func (f *federatingDB) Accept(ctx context.Context, accept vocab.ActivityStreamsA
if gtsFollow.AccountID != targetAcct.ID {
return errors.New("ACCEPT: follow object account and inbox account were not the same")
}
- follow, err := f.db.AcceptFollowRequest(gtsFollow.AccountID, gtsFollow.TargetAccountID)
+ follow, err := f.db.AcceptFollowRequest(ctx, gtsFollow.AccountID, gtsFollow.TargetAccountID)
if err != nil {
return err
}
diff --git a/internal/federation/federatingdb/announce.go b/internal/federation/federatingdb/announce.go
index 981eaf1ef..5cd34285e 100644
--- a/internal/federation/federatingdb/announce.go
+++ b/internal/federation/federatingdb/announce.go
@@ -71,7 +71,7 @@ func (f *federatingDB) Announce(ctx context.Context, announce vocab.ActivityStre
return nil
}
- boost, isNew, err := f.typeConverter.ASAnnounceToStatus(announce)
+ boost, isNew, err := f.typeConverter.ASAnnounceToStatus(ctx, announce)
if err != nil {
return fmt.Errorf("ANNOUNCE: error converting announce to boost: %s", err)
}
diff --git a/internal/federation/federatingdb/create.go b/internal/federation/federatingdb/create.go
index fb4353cd4..8ea549c5a 100644
--- a/internal/federation/federatingdb/create.go
+++ b/internal/federation/federatingdb/create.go
@@ -100,7 +100,7 @@ func (f *federatingDB) Create(ctx context.Context, asType vocab.Type) error {
case gtsmodel.ActivityStreamsNote:
// CREATE A NOTE
note := objectIter.GetActivityStreamsNote()
- status, err := f.typeConverter.ASStatusToStatus(note)
+ status, err := f.typeConverter.ASStatusToStatus(ctx, note)
if err != nil {
return fmt.Errorf("CREATE: error converting note to status: %s", err)
}
@@ -112,7 +112,7 @@ func (f *federatingDB) Create(ctx context.Context, asType vocab.Type) error {
}
status.ID = statusID
- if err := f.db.PutStatus(status); err != nil {
+ if err := f.db.PutStatus(ctx, status); err != nil {
if err == db.ErrAlreadyExists {
// the status already exists in the database, which means we've already handled everything else,
// so we can just return nil here and be done with it.
@@ -137,7 +137,7 @@ func (f *federatingDB) Create(ctx context.Context, asType vocab.Type) error {
return errors.New("CREATE: could not convert type to follow")
}
- followRequest, err := f.typeConverter.ASFollowToFollowRequest(follow)
+ followRequest, err := f.typeConverter.ASFollowToFollowRequest(ctx, follow)
if err != nil {
return fmt.Errorf("CREATE: could not convert Follow to follow request: %s", err)
}
@@ -148,7 +148,7 @@ func (f *federatingDB) Create(ctx context.Context, asType vocab.Type) error {
}
followRequest.ID = newID
- if err := f.db.Put(followRequest); err != nil {
+ if err := f.db.Put(ctx, followRequest); err != nil {
return fmt.Errorf("CREATE: database error inserting follow request: %s", err)
}
@@ -165,7 +165,7 @@ func (f *federatingDB) Create(ctx context.Context, asType vocab.Type) error {
return errors.New("CREATE: could not convert type to like")
}
- fave, err := f.typeConverter.ASLikeToFave(like)
+ fave, err := f.typeConverter.ASLikeToFave(ctx, like)
if err != nil {
return fmt.Errorf("CREATE: could not convert Like to fave: %s", err)
}
@@ -176,7 +176,7 @@ func (f *federatingDB) Create(ctx context.Context, asType vocab.Type) error {
}
fave.ID = newID
- if err := f.db.Put(fave); err != nil {
+ if err := f.db.Put(ctx, fave); err != nil {
return fmt.Errorf("CREATE: database error inserting fave: %s", err)
}
@@ -193,7 +193,7 @@ func (f *federatingDB) Create(ctx context.Context, asType vocab.Type) error {
return errors.New("CREATE: could not convert type to block")
}
- block, err := f.typeConverter.ASBlockToBlock(blockable)
+ block, err := f.typeConverter.ASBlockToBlock(ctx, blockable)
if err != nil {
return fmt.Errorf("CREATE: could not convert Block to gts model block")
}
@@ -204,7 +204,7 @@ func (f *federatingDB) Create(ctx context.Context, asType vocab.Type) error {
}
block.ID = newID
- if err := f.db.Put(block); err != nil {
+ if err := f.db.Put(ctx, block); err != nil {
return fmt.Errorf("CREATE: database error inserting block: %s", err)
}
diff --git a/internal/federation/federatingdb/delete.go b/internal/federation/federatingdb/delete.go
index ee9310789..11b818168 100644
--- a/internal/federation/federatingdb/delete.go
+++ b/internal/federation/federatingdb/delete.go
@@ -69,11 +69,11 @@ func (f *federatingDB) Delete(ctx context.Context, id *url.URL) error {
// in a delete we only get the URI, we can't know if we have a status or a profile or something else,
// so we have to try a few different things...
- s, err := f.db.GetStatusByURI(id.String())
+ s, err := f.db.GetStatusByURI(ctx, id.String())
if err == nil {
// it's a status
l.Debugf("uri is for status with id: %s", s.ID)
- if err := f.db.DeleteByID(s.ID, &gtsmodel.Status{}); err != nil {
+ if err := f.db.DeleteByID(ctx, s.ID, &gtsmodel.Status{}); err != nil {
return fmt.Errorf("DELETE: err deleting status: %s", err)
}
fromFederatorChan <- gtsmodel.FromFederator{
@@ -84,11 +84,11 @@ func (f *federatingDB) Delete(ctx context.Context, id *url.URL) error {
}
}
- a, err := f.db.GetAccountByURI(id.String())
+ a, err := f.db.GetAccountByURI(ctx, id.String())
if err == nil {
// it's an account
- l.Debugf("uri is for an account with id: %s", s.ID)
- if err := f.db.DeleteByID(a.ID, &gtsmodel.Account{}); err != nil {
+ l.Debugf("uri is for an account with id: %s", a.ID)
+ if err := f.db.DeleteByID(ctx, a.ID, &gtsmodel.Account{}); err != nil {
return fmt.Errorf("DELETE: err deleting account: %s", err)
}
fromFederatorChan <- gtsmodel.FromFederator{
diff --git a/internal/federation/federatingdb/followers.go b/internal/federation/federatingdb/followers.go
index 241362fc1..c7f636a12 100644
--- a/internal/federation/federatingdb/followers.go
+++ b/internal/federation/federatingdb/followers.go
@@ -19,7 +19,7 @@ import (
// If modified, the library will then call Update.
//
// The library makes this call only after acquiring a lock first.
-func (f *federatingDB) Followers(c context.Context, actorIRI *url.URL) (followers vocab.ActivityStreamsCollection, err error) {
+func (f *federatingDB) Followers(ctx context.Context, actorIRI *url.URL) (followers vocab.ActivityStreamsCollection, err error) {
l := f.log.WithFields(
logrus.Fields{
"func": "Followers",
@@ -31,19 +31,19 @@ func (f *federatingDB) Followers(c context.Context, actorIRI *url.URL) (follower
acct := &gtsmodel.Account{}
if util.IsUserPath(actorIRI) {
- acct, err = f.db.GetAccountByURI(actorIRI.String())
+ acct, err = f.db.GetAccountByURI(ctx, actorIRI.String())
if err != nil {
return nil, fmt.Errorf("FOLLOWERS: db error getting account with uri %s: %s", actorIRI.String(), err)
}
} else if util.IsFollowersPath(actorIRI) {
- if err := f.db.GetWhere([]db.Where{{Key: "followers_uri", Value: actorIRI.String()}}, acct); err != nil {
+ if err := f.db.GetWhere(ctx, []db.Where{{Key: "followers_uri", Value: actorIRI.String()}}, acct); err != nil {
return nil, fmt.Errorf("FOLLOWERS: db error getting account with followers uri %s: %s", actorIRI.String(), err)
}
} else {
return nil, fmt.Errorf("FOLLOWERS: could not parse actor IRI %s as users or followers path", actorIRI.String())
}
- acctFollowers, err := f.db.GetAccountFollowedBy(acct.ID, false)
+ acctFollowers, err := f.db.GetAccountFollowedBy(ctx, acct.ID, false)
if err != nil {
return nil, fmt.Errorf("FOLLOWERS: db error getting followers for account id %s: %s", acct.ID, err)
}
@@ -51,13 +51,17 @@ func (f *federatingDB) Followers(c context.Context, actorIRI *url.URL) (follower
followers = streams.NewActivityStreamsCollection()
items := streams.NewActivityStreamsItemsProperty()
for _, follow := range acctFollowers {
- gtsFollower := &gtsmodel.Account{}
- if err := f.db.GetByID(follow.AccountID, gtsFollower); err != nil {
- return nil, fmt.Errorf("FOLLOWERS: db error getting account id %s: %s", follow.AccountID, err)
+ if follow.Account == nil {
+ followAccount, err := f.db.GetAccountByID(ctx, follow.AccountID)
+ if err != nil {
+ return nil, fmt.Errorf("FOLLOWERS: db error getting account id %s: %s", follow.AccountID, err)
+ }
+ follow.Account = followAccount
}
- uri, err := url.Parse(gtsFollower.URI)
+
+ uri, err := url.Parse(follow.Account.URI)
if err != nil {
- return nil, fmt.Errorf("FOLLOWERS: error parsing %s as url: %s", gtsFollower.URI, err)
+ return nil, fmt.Errorf("FOLLOWERS: error parsing %s as url: %s", follow.Account.URI, err)
}
items.AppendIRI(uri)
}
diff --git a/internal/federation/federatingdb/following.go b/internal/federation/federatingdb/following.go
index 45785c671..9d5c0693c 100644
--- a/internal/federation/federatingdb/following.go
+++ b/internal/federation/federatingdb/following.go
@@ -18,7 +18,7 @@ import (
// If modified, the library will then call Update.
//
// The library makes this call only after acquiring a lock first.
-func (f *federatingDB) Following(c context.Context, actorIRI *url.URL) (following vocab.ActivityStreamsCollection, err error) {
+func (f *federatingDB) Following(ctx context.Context, actorIRI *url.URL) (following vocab.ActivityStreamsCollection, err error) {
l := f.log.WithFields(
logrus.Fields{
"func": "Following",
@@ -34,7 +34,7 @@ func (f *federatingDB) Following(c context.Context, actorIRI *url.URL) (followin
return nil, fmt.Errorf("FOLLOWING: error parsing user path: %s", err)
}
- a, err := f.db.GetLocalAccountByUsername(username)
+ a, err := f.db.GetLocalAccountByUsername(ctx, username)
if err != nil {
return nil, fmt.Errorf("FOLLOWING: db error getting account with uri %s: %s", actorIRI.String(), err)
}
@@ -46,7 +46,7 @@ func (f *federatingDB) Following(c context.Context, actorIRI *url.URL) (followin
return nil, fmt.Errorf("FOLLOWING: error parsing following path: %s", err)
}
- a, err := f.db.GetLocalAccountByUsername(username)
+ a, err := f.db.GetLocalAccountByUsername(ctx, username)
if err != nil {
return nil, fmt.Errorf("FOLLOWING: db error getting account with following uri %s: %s", actorIRI.String(), err)
}
@@ -56,7 +56,7 @@ func (f *federatingDB) Following(c context.Context, actorIRI *url.URL) (followin
return nil, fmt.Errorf("FOLLOWING: could not parse actor IRI %s as users or following path", actorIRI.String())
}
- acctFollowing, err := f.db.GetAccountFollows(acct.ID)
+ acctFollowing, err := f.db.GetAccountFollows(ctx, acct.ID)
if err != nil {
return nil, fmt.Errorf("FOLLOWING: db error getting following for account id %s: %s", acct.ID, err)
}
@@ -64,13 +64,17 @@ func (f *federatingDB) Following(c context.Context, actorIRI *url.URL) (followin
following = streams.NewActivityStreamsCollection()
items := streams.NewActivityStreamsItemsProperty()
for _, follow := range acctFollowing {
- gtsFollowing := &gtsmodel.Account{}
- if err := f.db.GetByID(follow.AccountID, gtsFollowing); err != nil {
- return nil, fmt.Errorf("FOLLOWING: db error getting account id %s: %s", follow.AccountID, err)
+ if follow.Account == nil {
+ followAccount, err := f.db.GetAccountByID(ctx, follow.AccountID)
+ if err != nil {
+ return nil, fmt.Errorf("FOLLOWING: db error getting account id %s: %s", follow.AccountID, err)
+ }
+ follow.Account = followAccount
}
- uri, err := url.Parse(gtsFollowing.URI)
+
+ uri, err := url.Parse(follow.Account.URI)
if err != nil {
- return nil, fmt.Errorf("FOLLOWING: error parsing %s as url: %s", gtsFollowing.URI, err)
+ return nil, fmt.Errorf("FOLLOWING: error parsing %s as url: %s", follow.Account.URI, err)
}
items.AppendIRI(uri)
}
diff --git a/internal/federation/federatingdb/get.go b/internal/federation/federatingdb/get.go
index 0265080f9..cc04dd851 100644
--- a/internal/federation/federatingdb/get.go
+++ b/internal/federation/federatingdb/get.go
@@ -33,7 +33,7 @@ import (
// Get returns the database entry for the specified id.
//
// The library makes this call only after acquiring a lock first.
-func (f *federatingDB) Get(c context.Context, id *url.URL) (value vocab.Type, err error) {
+func (f *federatingDB) Get(ctx context.Context, id *url.URL) (value vocab.Type, err error) {
l := f.log.WithFields(
logrus.Fields{
"func": "Get",
@@ -43,17 +43,17 @@ func (f *federatingDB) Get(c context.Context, id *url.URL) (value vocab.Type, er
l.Debug("entering GET function")
if util.IsUserPath(id) {
- acct, err := f.db.GetAccountByURI(id.String())
+ acct, err := f.db.GetAccountByURI(ctx, id.String())
if err != nil {
return nil, err
}
l.Debug("is user path! returning account")
- return f.typeConverter.AccountToAS(acct)
+ return f.typeConverter.AccountToAS(ctx, acct)
}
if util.IsFollowersPath(id) {
acct := &gtsmodel.Account{}
- if err := f.db.GetWhere([]db.Where{{Key: "followers_uri", Value: id.String()}}, acct); err != nil {
+ if err := f.db.GetWhere(ctx, []db.Where{{Key: "followers_uri", Value: id.String()}}, acct); err != nil {
return nil, err
}
@@ -62,12 +62,12 @@ func (f *federatingDB) Get(c context.Context, id *url.URL) (value vocab.Type, er
return nil, err
}
- return f.Followers(c, followersURI)
+ return f.Followers(ctx, followersURI)
}
if util.IsFollowingPath(id) {
acct := &gtsmodel.Account{}
- if err := f.db.GetWhere([]db.Where{{Key: "following_uri", Value: id.String()}}, acct); err != nil {
+ if err := f.db.GetWhere(ctx, []db.Where{{Key: "following_uri", Value: id.String()}}, acct); err != nil {
return nil, err
}
@@ -76,7 +76,7 @@ func (f *federatingDB) Get(c context.Context, id *url.URL) (value vocab.Type, er
return nil, err
}
- return f.Following(c, followingURI)
+ return f.Following(ctx, followingURI)
}
return nil, errors.New("could not get")
diff --git a/internal/federation/federatingdb/outbox.go b/internal/federation/federatingdb/outbox.go
index 849014432..81b90aae2 100644
--- a/internal/federation/federatingdb/outbox.go
+++ b/internal/federation/federatingdb/outbox.go
@@ -35,7 +35,7 @@ import (
// at the specified IRI, for prepending new items.
//
// The library makes this call only after acquiring a lock first.
-func (f *federatingDB) GetOutbox(c context.Context, outboxIRI *url.URL) (inbox vocab.ActivityStreamsOrderedCollectionPage, err error) {
+func (f *federatingDB) GetOutbox(ctx context.Context, outboxIRI *url.URL) (inbox vocab.ActivityStreamsOrderedCollectionPage, err error) {
l := f.log.WithFields(
logrus.Fields{
"func": "GetOutbox",
@@ -51,7 +51,7 @@ func (f *federatingDB) GetOutbox(c context.Context, outboxIRI *url.URL) (inbox v
// database entries. Separate calls to Create will do that.
//
// The library makes this call only after acquiring a lock first.
-func (f *federatingDB) SetOutbox(c context.Context, outbox vocab.ActivityStreamsOrderedCollectionPage) error {
+func (f *federatingDB) SetOutbox(ctx context.Context, outbox vocab.ActivityStreamsOrderedCollectionPage) error {
l := f.log.WithFields(
logrus.Fields{
"func": "SetOutbox",
@@ -66,7 +66,7 @@ func (f *federatingDB) SetOutbox(c context.Context, outbox vocab.ActivityStreams
// actor's inbox IRI.
//
// The library makes this call only after acquiring a lock first.
-func (f *federatingDB) OutboxForInbox(c context.Context, inboxIRI *url.URL) (outboxIRI *url.URL, err error) {
+func (f *federatingDB) OutboxForInbox(ctx context.Context, inboxIRI *url.URL) (outboxIRI *url.URL, err error) {
l := f.log.WithFields(
logrus.Fields{
"func": "OutboxForInbox",
@@ -79,7 +79,7 @@ func (f *federatingDB) OutboxForInbox(c context.Context, inboxIRI *url.URL) (out
return nil, fmt.Errorf("%s is not an inbox URI", inboxIRI.String())
}
acct := &gtsmodel.Account{}
- if err := f.db.GetWhere([]db.Where{{Key: "inbox_uri", Value: inboxIRI.String()}}, acct); err != nil {
+ if err := f.db.GetWhere(ctx, []db.Where{{Key: "inbox_uri", Value: inboxIRI.String()}}, acct); err != nil {
if err == db.ErrNoEntries {
return nil, fmt.Errorf("no actor found that corresponds to inbox %s", inboxIRI.String())
}
diff --git a/internal/federation/federatingdb/owns.go b/internal/federation/federatingdb/owns.go
index 0a65397ff..1c1f2512d 100644
--- a/internal/federation/federatingdb/owns.go
+++ b/internal/federation/federatingdb/owns.go
@@ -32,7 +32,7 @@ import (
// Owns returns true if the IRI belongs to this instance, and if
// the database has an entry for the IRI.
// The library makes this call only after acquiring a lock first.
-func (f *federatingDB) Owns(c context.Context, id *url.URL) (bool, error) {
+func (f *federatingDB) Owns(ctx context.Context, id *url.URL) (bool, error) {
l := f.log.WithFields(
logrus.Fields{
"func": "Owns",
@@ -54,7 +54,7 @@ func (f *federatingDB) Owns(c context.Context, id *url.URL) (bool, error) {
if err != nil {
return false, fmt.Errorf("error parsing statuses path for url %s: %s", id.String(), err)
}
- status, err := f.db.GetStatusByURI(uid)
+ status, err := f.db.GetStatusByURI(ctx, uid)
if err != nil {
if err == db.ErrNoEntries {
// there are no entries for this status
@@ -71,7 +71,7 @@ func (f *federatingDB) Owns(c context.Context, id *url.URL) (bool, error) {
if err != nil {
return false, fmt.Errorf("error parsing statuses path for url %s: %s", id.String(), err)
}
- if _, err := f.db.GetLocalAccountByUsername(username); err != nil {
+ if _, err := f.db.GetLocalAccountByUsername(ctx, username); err != nil {
if err == db.ErrNoEntries {
// there are no entries for this username
return false, nil
@@ -88,7 +88,7 @@ func (f *federatingDB) Owns(c context.Context, id *url.URL) (bool, error) {
if err != nil {
return false, fmt.Errorf("error parsing statuses path for url %s: %s", id.String(), err)
}
- if _, err := f.db.GetLocalAccountByUsername(username); err != nil {
+ if _, err := f.db.GetLocalAccountByUsername(ctx, username); err != nil {
if err == db.ErrNoEntries {
// there are no entries for this username
return false, nil
@@ -105,7 +105,7 @@ func (f *federatingDB) Owns(c context.Context, id *url.URL) (bool, error) {
if err != nil {
return false, fmt.Errorf("error parsing statuses path for url %s: %s", id.String(), err)
}
- if _, err := f.db.GetLocalAccountByUsername(username); err != nil {
+ if _, err := f.db.GetLocalAccountByUsername(ctx, username); err != nil {
if err == db.ErrNoEntries {
// there are no entries for this username
return false, nil
@@ -122,7 +122,7 @@ func (f *federatingDB) Owns(c context.Context, id *url.URL) (bool, error) {
if err != nil {
return false, fmt.Errorf("error parsing like path for url %s: %s", id.String(), err)
}
- if _, err := f.db.GetLocalAccountByUsername(username); err != nil {
+ if _, err := f.db.GetLocalAccountByUsername(ctx, username); err != nil {
if err == db.ErrNoEntries {
// there are no entries for this username
return false, nil
@@ -130,7 +130,7 @@ func (f *federatingDB) Owns(c context.Context, id *url.URL) (bool, error) {
// an actual error happened
return false, fmt.Errorf("database error fetching account with username %s: %s", username, err)
}
- if err := f.db.GetByID(likeID, &gtsmodel.StatusFave{}); err != nil {
+ if err := f.db.GetByID(ctx, likeID, &gtsmodel.StatusFave{}); err != nil {
if err == db.ErrNoEntries {
// there are no entries
return false, nil
@@ -147,7 +147,7 @@ func (f *federatingDB) Owns(c context.Context, id *url.URL) (bool, error) {
if err != nil {
return false, fmt.Errorf("error parsing block path for url %s: %s", id.String(), err)
}
- if _, err := f.db.GetLocalAccountByUsername(username); err != nil {
+ if _, err := f.db.GetLocalAccountByUsername(ctx, username); err != nil {
if err == db.ErrNoEntries {
// there are no entries for this username
return false, nil
@@ -155,7 +155,7 @@ func (f *federatingDB) Owns(c context.Context, id *url.URL) (bool, error) {
// an actual error happened
return false, fmt.Errorf("database error fetching account with username %s: %s", username, err)
}
- if err := f.db.GetByID(blockID, &gtsmodel.Block{}); err != nil {
+ if err := f.db.GetByID(ctx, blockID, &gtsmodel.Block{}); err != nil {
if err == db.ErrNoEntries {
// there are no entries
return false, nil
diff --git a/internal/federation/federatingdb/undo.go b/internal/federation/federatingdb/undo.go
index c527833b4..0fa38114d 100644
--- a/internal/federation/federatingdb/undo.go
+++ b/internal/federation/federatingdb/undo.go
@@ -83,7 +83,7 @@ func (f *federatingDB) Undo(ctx context.Context, undo vocab.ActivityStreamsUndo)
return errors.New("UNDO: follow actor and activity actor not the same")
}
// convert the follow to something we can understand
- gtsFollow, err := f.typeConverter.ASFollowToFollow(ASFollow)
+ gtsFollow, err := f.typeConverter.ASFollowToFollow(ctx, ASFollow)
if err != nil {
return fmt.Errorf("UNDO: error converting asfollow to gtsfollow: %s", err)
}
@@ -92,11 +92,11 @@ func (f *federatingDB) Undo(ctx context.Context, undo vocab.ActivityStreamsUndo)
return errors.New("UNDO: follow object account and inbox account were not the same")
}
// delete any existing FOLLOW
- if err := f.db.DeleteWhere([]db.Where{{Key: "uri", Value: gtsFollow.URI}}, &gtsmodel.Follow{}); err != nil {
+ if err := f.db.DeleteWhere(ctx, []db.Where{{Key: "uri", Value: gtsFollow.URI}}, &gtsmodel.Follow{}); err != nil {
return fmt.Errorf("UNDO: db error removing follow: %s", err)
}
// delete any existing FOLLOW REQUEST
- if err := f.db.DeleteWhere([]db.Where{{Key: "uri", Value: gtsFollow.URI}}, &gtsmodel.FollowRequest{}); err != nil {
+ if err := f.db.DeleteWhere(ctx, []db.Where{{Key: "uri", Value: gtsFollow.URI}}, &gtsmodel.FollowRequest{}); err != nil {
return fmt.Errorf("UNDO: db error removing follow request: %s", err)
}
l.Debug("follow undone")
@@ -116,7 +116,7 @@ func (f *federatingDB) Undo(ctx context.Context, undo vocab.ActivityStreamsUndo)
return errors.New("UNDO: block actor and activity actor not the same")
}
// convert the block to something we can understand
- gtsBlock, err := f.typeConverter.ASBlockToBlock(ASBlock)
+ gtsBlock, err := f.typeConverter.ASBlockToBlock(ctx, ASBlock)
if err != nil {
return fmt.Errorf("UNDO: error converting asblock to gtsblock: %s", err)
}
@@ -125,7 +125,7 @@ func (f *federatingDB) Undo(ctx context.Context, undo vocab.ActivityStreamsUndo)
return errors.New("UNDO: block object account and inbox account were not the same")
}
// delete any existing BLOCK
- if err := f.db.DeleteWhere([]db.Where{{Key: "uri", Value: gtsBlock.URI}}, &gtsmodel.Block{}); err != nil {
+ if err := f.db.DeleteWhere(ctx, []db.Where{{Key: "uri", Value: gtsBlock.URI}}, &gtsmodel.Block{}); err != nil {
return fmt.Errorf("UNDO: db error removing block: %s", err)
}
l.Debug("block undone")
diff --git a/internal/federation/federatingdb/update.go b/internal/federation/federatingdb/update.go
index 88ffc23b4..e9dfe5315 100644
--- a/internal/federation/federatingdb/update.go
+++ b/internal/federation/federatingdb/update.go
@@ -136,7 +136,7 @@ func (f *federatingDB) Update(ctx context.Context, asType vocab.Type) error {
accountable = i
}
- updatedAcct, err := f.typeConverter.ASRepresentationToAccount(accountable, true)
+ updatedAcct, err := f.typeConverter.ASRepresentationToAccount(ctx, accountable, true)
if err != nil {
return fmt.Errorf("UPDATE: error converting to account: %s", err)
}
@@ -152,7 +152,8 @@ func (f *federatingDB) Update(ctx context.Context, asType vocab.Type) error {
}
updatedAcct.ID = requestingAcct.ID // set this here so the db will update properly instead of trying to PUT this and getting constraint issues
- if err := f.db.UpdateByID(requestingAcct.ID, updatedAcct); err != nil {
+ updatedAcct, err = f.db.UpdateAccount(ctx, updatedAcct)
+ if err != nil {
return fmt.Errorf("UPDATE: database error inserting updated account: %s", err)
}
diff --git a/internal/federation/federatingdb/util.go b/internal/federation/federatingdb/util.go
index eac70d85c..b5befc613 100644
--- a/internal/federation/federatingdb/util.go
+++ b/internal/federation/federatingdb/util.go
@@ -60,7 +60,7 @@ func sameActor(activityActor vocab.ActivityStreamsActorProperty, followActor voc
//
// The go-fed library will handle setting the 'id' property on the
// activity or object provided with the value returned.
-func (f *federatingDB) NewID(c context.Context, t vocab.Type) (idURL *url.URL, err error) {
+func (f *federatingDB) NewID(ctx context.Context, t vocab.Type) (idURL *url.URL, err error) {
l := f.log.WithFields(
logrus.Fields{
"func": "NewID",
@@ -98,7 +98,7 @@ func (f *federatingDB) NewID(c context.Context, t vocab.Type) (idURL *url.URL, e
// take the IRI of the first actor we can find (there should only be one)
if iter.IsIRI() {
// if there's an error here, just use the fallback behavior -- we don't need to return an error here
- if actorAccount, err := f.db.GetAccountByURI(iter.GetIRI().String()); err == nil {
+ if actorAccount, err := f.db.GetAccountByURI(ctx, iter.GetIRI().String()); err == nil {
newID, err := id.NewRandomULID()
if err != nil {
return nil, err
@@ -199,7 +199,7 @@ func (f *federatingDB) NewID(c context.Context, t vocab.Type) (idURL *url.URL, e
// ActorForOutbox fetches the actor's IRI for the given outbox IRI.
//
// The library makes this call only after acquiring a lock first.
-func (f *federatingDB) ActorForOutbox(c context.Context, outboxIRI *url.URL) (actorIRI *url.URL, err error) {
+func (f *federatingDB) ActorForOutbox(ctx context.Context, outboxIRI *url.URL) (actorIRI *url.URL, err error) {
l := f.log.WithFields(
logrus.Fields{
"func": "ActorForOutbox",
@@ -212,7 +212,7 @@ func (f *federatingDB) ActorForOutbox(c context.Context, outboxIRI *url.URL) (ac
return nil, fmt.Errorf("%s is not an outbox URI", outboxIRI.String())
}
acct := &gtsmodel.Account{}
- if err := f.db.GetWhere([]db.Where{{Key: "outbox_uri", Value: outboxIRI.String()}}, acct); err != nil {
+ if err := f.db.GetWhere(ctx, []db.Where{{Key: "outbox_uri", Value: outboxIRI.String()}}, acct); err != nil {
if err == db.ErrNoEntries {
return nil, fmt.Errorf("no actor found that corresponds to outbox %s", outboxIRI.String())
}
@@ -224,7 +224,7 @@ func (f *federatingDB) ActorForOutbox(c context.Context, outboxIRI *url.URL) (ac
// ActorForInbox fetches the actor's IRI for the given outbox IRI.
//
// The library makes this call only after acquiring a lock first.
-func (f *federatingDB) ActorForInbox(c context.Context, inboxIRI *url.URL) (actorIRI *url.URL, err error) {
+func (f *federatingDB) ActorForInbox(ctx context.Context, inboxIRI *url.URL) (actorIRI *url.URL, err error) {
l := f.log.WithFields(
logrus.Fields{
"func": "ActorForInbox",
@@ -237,7 +237,7 @@ func (f *federatingDB) ActorForInbox(c context.Context, inboxIRI *url.URL) (acto
return nil, fmt.Errorf("%s is not an inbox URI", inboxIRI.String())
}
acct := &gtsmodel.Account{}
- if err := f.db.GetWhere([]db.Where{{Key: "inbox_uri", Value: inboxIRI.String()}}, acct); err != nil {
+ if err := f.db.GetWhere(ctx, []db.Where{{Key: "inbox_uri", Value: inboxIRI.String()}}, acct); err != nil {
if err == db.ErrNoEntries {
return nil, fmt.Errorf("no actor found that corresponds to inbox %s", inboxIRI.String())
}
diff --git a/internal/federation/federatingprotocol.go b/internal/federation/federatingprotocol.go
index 5da68afd3..7f8958111 100644
--- a/internal/federation/federatingprotocol.go
+++ b/internal/federation/federatingprotocol.go
@@ -113,7 +113,7 @@ func (f *federator) AuthenticatePostInbox(ctx context.Context, w http.ResponseWr
return nil, false, errors.New("username was empty")
}
- requestedAccount, err := f.db.GetLocalAccountByUsername(username)
+ requestedAccount, err := f.db.GetLocalAccountByUsername(ctx, username)
if err != nil {
return nil, false, fmt.Errorf("could not fetch requested account with username %s: %s", username, err)
}
@@ -131,14 +131,14 @@ func (f *federator) AuthenticatePostInbox(ctx context.Context, w http.ResponseWr
// authentication has passed, so add an instance entry for this instance if it hasn't been done already
i := &gtsmodel.Instance{}
- if err := f.db.GetWhere([]db.Where{{Key: "domain", Value: publicKeyOwnerURI.Host, CaseInsensitive: true}}, i); err != nil {
+ if err := f.db.GetWhere(ctx, []db.Where{{Key: "domain", Value: publicKeyOwnerURI.Host, CaseInsensitive: true}}, i); err != nil {
if err != db.ErrNoEntries {
// there's been an actual error
return ctx, false, fmt.Errorf("error getting requesting account with public key id %s: %s", publicKeyOwnerURI.String(), err)
}
// we don't have an entry for this instance yet so dereference it
- i, err = f.GetRemoteInstance(username, &url.URL{
+ i, err = f.GetRemoteInstance(ctx, username, &url.URL{
Scheme: publicKeyOwnerURI.Scheme,
Host: publicKeyOwnerURI.Host,
})
@@ -147,12 +147,12 @@ func (f *federator) AuthenticatePostInbox(ctx context.Context, w http.ResponseWr
}
// and put it in the db
- if err := f.db.Put(i); err != nil {
+ if err := f.db.Put(ctx, i); err != nil {
return nil, false, fmt.Errorf("error inserting newly dereferenced instance %s: %s", publicKeyOwnerURI.Host, err)
}
}
- requestingAccount, _, err := f.GetRemoteAccount(username, publicKeyOwnerURI, false)
+ requestingAccount, _, err := f.GetRemoteAccount(ctx, username, publicKeyOwnerURI, false)
if err != nil {
return nil, false, fmt.Errorf("couldn't get remote account: %s", err)
}
@@ -189,7 +189,7 @@ func (f *federator) Blocked(ctx context.Context, actorIRIs []*url.URL) (bool, er
return false, errors.New("requested account not set on request context, so couldn't determine blocks")
}
- blocked, err := f.db.AreURIsBlocked(actorIRIs)
+ blocked, err := f.db.AreURIsBlocked(ctx, actorIRIs)
if err != nil {
return false, fmt.Errorf("error checking domain blocks: %s", err)
}
@@ -198,7 +198,7 @@ func (f *federator) Blocked(ctx context.Context, actorIRIs []*url.URL) (bool, er
}
for _, uri := range actorIRIs {
- requestingAccount, err := f.db.GetAccountByURI(uri.String())
+ requestingAccount, err := f.db.GetAccountByURI(ctx, uri.String())
if err != nil {
if err == db.ErrNoEntries {
// we don't have an entry for this account so it's not blocked
@@ -208,7 +208,7 @@ func (f *federator) Blocked(ctx context.Context, actorIRIs []*url.URL) (bool, er
return false, fmt.Errorf("error getting account with uri %s: %s", uri.String(), err)
}
- blocked, err = f.db.IsBlocked(requestedAccount.ID, requestingAccount.ID, true)
+ blocked, err = f.db.IsBlocked(ctx, requestedAccount.ID, requestingAccount.ID, true)
if err != nil {
return false, fmt.Errorf("error checking account block: %s", err)
}
diff --git a/internal/federation/federator.go b/internal/federation/federator.go
index 1b5f5441a..5eddcbb99 100644
--- a/internal/federation/federator.go
+++ b/internal/federation/federator.go
@@ -54,21 +54,21 @@ type Federator interface {
// FingerRemoteAccount performs a webfinger lookup for a remote account, using the .well-known path. It will return the ActivityPub URI for that
// account, or an error if it doesn't exist or can't be retrieved.
- FingerRemoteAccount(requestingUsername string, targetUsername string, targetDomain string) (*url.URL, error)
+ FingerRemoteAccount(ctx context.Context, requestingUsername string, targetUsername string, targetDomain string) (*url.URL, error)
- DereferenceRemoteThread(username string, statusURI *url.URL) error
- DereferenceAnnounce(announce *gtsmodel.Status, requestingUsername string) error
+ DereferenceRemoteThread(ctx context.Context, username string, statusURI *url.URL) error
+ DereferenceAnnounce(ctx context.Context, announce *gtsmodel.Status, requestingUsername string) error
- GetRemoteAccount(username string, remoteAccountID *url.URL, refresh bool) (*gtsmodel.Account, bool, error)
- EnrichRemoteAccount(username string, account *gtsmodel.Account) (*gtsmodel.Account, error)
+ GetRemoteAccount(ctx context.Context, username string, remoteAccountID *url.URL, refresh bool) (*gtsmodel.Account, bool, error)
+ EnrichRemoteAccount(ctx context.Context, username string, account *gtsmodel.Account) (*gtsmodel.Account, error)
- GetRemoteStatus(username string, remoteStatusID *url.URL, refresh bool) (*gtsmodel.Status, ap.Statusable, bool, error)
- EnrichRemoteStatus(username string, status *gtsmodel.Status) (*gtsmodel.Status, error)
+ GetRemoteStatus(ctx context.Context, username string, remoteStatusID *url.URL, refresh bool) (*gtsmodel.Status, ap.Statusable, bool, error)
+ EnrichRemoteStatus(ctx context.Context, username string, status *gtsmodel.Status) (*gtsmodel.Status, error)
- GetRemoteInstance(username string, remoteInstanceURI *url.URL) (*gtsmodel.Instance, error)
+ GetRemoteInstance(ctx context.Context, username string, remoteInstanceURI *url.URL) (*gtsmodel.Instance, error)
// Handshaking returns true if the given username is currently in the process of dereferencing the remoteAccountID.
- Handshaking(username string, remoteAccountID *url.URL) bool
+ Handshaking(ctx context.Context, username string, remoteAccountID *url.URL) bool
pub.CommonBehavior
pub.FederatingProtocol
}
diff --git a/internal/federation/finger.go b/internal/federation/finger.go
index a5a4fa0e7..5cdd4c04d 100644
--- a/internal/federation/finger.go
+++ b/internal/federation/finger.go
@@ -29,12 +29,12 @@ import (
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
)
-func (f *federator) FingerRemoteAccount(requestingUsername string, targetUsername string, targetDomain string) (*url.URL, error) {
- if blocked, err := f.db.IsDomainBlocked(targetDomain); blocked || err != nil {
+func (f *federator) FingerRemoteAccount(ctx context.Context, requestingUsername string, targetUsername string, targetDomain string) (*url.URL, error) {
+ if blocked, err := f.db.IsDomainBlocked(ctx, targetDomain); blocked || err != nil {
return nil, fmt.Errorf("FingerRemoteAccount: domain %s is blocked", targetDomain)
}
- t, err := f.transportController.NewTransportForUsername(requestingUsername)
+ t, err := f.transportController.NewTransportForUsername(ctx, requestingUsername)
if err != nil {
return nil, fmt.Errorf("FingerRemoteAccount: error getting transport for username %s while dereferencing @%s@%s: %s", requestingUsername, targetUsername, targetDomain, err)
}
diff --git a/internal/federation/handshake.go b/internal/federation/handshake.go
index 0671e78a9..b973680b3 100644
--- a/internal/federation/handshake.go
+++ b/internal/federation/handshake.go
@@ -18,8 +18,11 @@
package federation
-import "net/url"
+import (
+ "context"
+ "net/url"
+)
-func (f *federator) Handshaking(username string, remoteAccountID *url.URL) bool {
- return f.dereferencer.Handshaking(username, remoteAccountID)
+func (f *federator) Handshaking(ctx context.Context, username string, remoteAccountID *url.URL) bool {
+ return f.dereferencer.Handshaking(ctx, username, remoteAccountID)
}
diff --git a/internal/federation/transport.go b/internal/federation/transport.go
index 20aee964b..9e2e38e19 100644
--- a/internal/federation/transport.go
+++ b/internal/federation/transport.go
@@ -68,5 +68,5 @@ func (f *federator) NewTransport(ctx context.Context, actorBoxIRI *url.URL, gofe
return nil, fmt.Errorf("id %s was neither an inbox path nor an outbox path", actorBoxIRI.String())
}
- return f.transportController.NewTransportForUsername(username)
+ return f.transportController.NewTransportForUsername(ctx, username)
}
diff --git a/internal/gtsmodel/account.go b/internal/gtsmodel/account.go
index 9673dcb2c..98d2dcfc9 100644
--- a/internal/gtsmodel/account.go
+++ b/internal/gtsmodel/account.go
@@ -18,8 +18,8 @@
// Package gtsmodel contains types used *internally* by GoToSocial and added/removed/selected from the database.
// These types should never be serialized and/or sent out via public APIs, as they contain sensitive information.
-// The annotation used on these structs is for handling them via the go-pg ORM (hence why they're in this db subdir).
-// See here for more info on go-pg model annotations: https://pg.uptrace.dev/models/
+// The annotation used on these structs is for handling them via the bun-db ORM.
+// See here for more info on bun model annotations: https://bun.uptrace.dev/guide/models.html
package gtsmodel
import (
@@ -34,24 +34,24 @@ type Account struct {
*/
// id of this account in the local database
- ID string `pg:"type:CHAR(26),pk,notnull,unique"`
+ ID string `bun:"type:CHAR(26),pk,nullzero,notnull,unique"`
// Username of the account, should just be a string of [a-z0-9_]. Can be added to domain to create the full username in the form ``[username]@[domain]`` eg., ``user_96@example.org``
- Username string `pg:",notnull,unique:userdomain"` // username and domain should be unique *with* each other
+ Username string `bun:",notnull,unique:userdomain,nullzero"` // username and domain should be unique *with* each other
// Domain of the account, will be null if this is a local account, otherwise something like ``example.org`` or ``mastodon.social``. Should be unique with username.
- Domain string `pg:",unique:userdomain"` // username and domain should be unique *with* each other
+ Domain string `bun:",unique:userdomain,nullzero"` // username and domain should be unique *with* each other
/*
ACCOUNT METADATA
*/
// ID of the avatar as a media attachment
- AvatarMediaAttachmentID string `pg:"type:CHAR(26)"`
- AvatarMediaAttachment *MediaAttachment `pg:"rel:has-one"`
+ AvatarMediaAttachmentID string `bun:"type:CHAR(26),nullzero"`
+ AvatarMediaAttachment *MediaAttachment `bun:"rel:belongs-to"`
// For a non-local account, where can the header be fetched?
AvatarRemoteURL string
// ID of the header as a media attachment
- HeaderMediaAttachmentID string `pg:"type:CHAR(26)"`
- HeaderMediaAttachment *MediaAttachment `pg:"rel:has-one"`
+ HeaderMediaAttachmentID string `bun:"type:CHAR(26),nullzero"`
+ HeaderMediaAttachment *MediaAttachment `bun:"rel:belongs-to"`
// For a non-local account, where can the header be fetched?
HeaderRemoteURL string
// DisplayName for this account. Can be empty, then just the Username will be used for display purposes.
@@ -63,11 +63,11 @@ type Account struct {
// Is this a memorial account, ie., has the user passed away?
Memorial bool
// This account has moved this account id in the database
- MovedToAccountID string `pg:"type:CHAR(26)"`
+ MovedToAccountID string `bun:"type:CHAR(26),nullzero"`
// When was this account created?
- CreatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
+ CreatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"`
// When was this account last updated?
- UpdatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
+ UpdatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"`
// Does this account identify itself as a bot?
Bot bool
// What reason was given for signing up when this account was created?
@@ -78,36 +78,36 @@ type Account struct {
*/
// Does this account need an approval for new followers?
- Locked bool `pg:",default:true,use_zero"`
+ Locked bool `bun:",default:true"`
// Should this account be shown in the instance's profile directory?
- Discoverable bool `pg:",default:false"`
+ Discoverable bool `bun:",default:false"`
// Default post privacy for this account
- Privacy Visibility `pg:",default:'public'"`
+ Privacy Visibility `bun:",default:'public'"`
// Set posts from this account to sensitive by default?
- Sensitive bool `pg:",default:false"`
+ Sensitive bool `bun:",default:false"`
// What language does this account post in?
- Language string `pg:",default:'en'"`
+ Language string `bun:",default:'en'"`
/*
ACTIVITYPUB THINGS
*/
// What is the activitypub URI for this account discovered by webfinger?
- URI string `pg:",unique"`
+ URI string `bun:",unique,nullzero"`
// At which URL can we see the user account in a web browser?
- URL string `pg:",unique"`
+ URL string `bun:",unique,nullzero"`
// Last time this account was located using the webfinger API.
- LastWebfingeredAt time.Time `pg:"type:timestamp"`
+ LastWebfingeredAt time.Time `bun:",nullzero"`
// Address of this account's activitypub inbox, for sending activity to
- InboxURI string `pg:",unique"`
+ InboxURI string `bun:",unique,nullzero"`
// Address of this account's activitypub outbox
- OutboxURI string `pg:",unique"`
+ OutboxURI string `bun:",unique,nullzero"`
// URI for getting the following list of this account
- FollowingURI string `pg:",unique"`
+ FollowingURI string `bun:",unique,nullzero"`
// URI for getting the followers list of this account
- FollowersURI string `pg:",unique"`
+ FollowersURI string `bun:",unique,nullzero"`
// URL for getting the featured collection list of this account
- FeaturedCollectionURI string `pg:",unique"`
+ FeaturedCollectionURI string `bun:",unique,nullzero"`
// What type of activitypub actor is this account?
ActorType string
// This account is associated with x account id
@@ -129,15 +129,15 @@ type Account struct {
*/
// When was this account set to have all its media shown as sensitive?
- SensitizedAt time.Time `pg:"type:timestamp"`
+ SensitizedAt time.Time `bun:",nullzero"`
// When was this account silenced (eg., statuses only visible to followers, not public)?
- SilencedAt time.Time `pg:"type:timestamp"`
+ SilencedAt time.Time `bun:",nullzero"`
// When was this account suspended (eg., don't allow it to log in/post, don't accept media/posts from this account)
- SuspendedAt time.Time `pg:"type:timestamp"`
+ SuspendedAt time.Time `bun:",nullzero"`
// Should we hide this account's collections?
HideCollections bool
// id of the database entry that caused this account to become suspended -- can be an account ID or a domain block ID
- SuspensionOrigin string `pg:"type:CHAR(26)"`
+ SuspensionOrigin string `bun:"type:CHAR(26),nullzero"`
}
// Field represents a key value field on an account, for things like pronouns, website, etc.
@@ -146,5 +146,5 @@ type Account struct {
type Field struct {
Name string
Value string
- VerifiedAt time.Time `pg:"type:timestamp"`
+ VerifiedAt time.Time `bun:",nullzero"`
}
diff --git a/internal/gtsmodel/application.go b/internal/gtsmodel/application.go
index 91287dff3..a6976eafd 100644
--- a/internal/gtsmodel/application.go
+++ b/internal/gtsmodel/application.go
@@ -22,7 +22,7 @@ package gtsmodel
// It is used to authorize tokens etc, and is associated with an oauth client id in the database.
type Application struct {
// id of this application in the db
- ID string `pg:"type:CHAR(26),pk,notnull"`
+ ID string `bun:"type:CHAR(26),pk,notnull"`
// name of the application given when it was created (eg., 'tusky')
Name string
// website for the application given when it was created (eg., 'https://tusky.app')
@@ -30,7 +30,7 @@ type Application struct {
// redirect uri requested by the application for oauth2 flow
RedirectURI string
// id of the associated oauth client entity in the db
- ClientID string `pg:"type:CHAR(26)"`
+ ClientID string `bun:"type:CHAR(26)"`
// secret of the associated oauth client entity in the db
ClientSecret string
// scopes requested when this app was created
diff --git a/internal/gtsmodel/block.go b/internal/gtsmodel/block.go
index 32afede55..0c762837d 100644
--- a/internal/gtsmodel/block.go
+++ b/internal/gtsmodel/block.go
@@ -5,17 +5,17 @@ import "time"
// Block refers to the blocking of one account by another.
type Block struct {
// id of this block in the database
- ID string `pg:"type:CHAR(26),pk,notnull"`
+ ID string `bun:"type:CHAR(26),pk,notnull"`
// When was this block created
- CreatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
+ CreatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"`
// When was this block updated
- UpdatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
+ UpdatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"`
// Who created this block?
- AccountID string `pg:"type:CHAR(26),notnull"`
- Account *Account `pg:"rel:has-one"`
+ AccountID string `bun:"type:CHAR(26),notnull"`
+ Account *Account `bun:"rel:belongs-to"`
// Who is targeted by this block?
- TargetAccountID string `pg:"type:CHAR(26),notnull"`
- TargetAccount *Account `pg:"rel:has-one"`
+ TargetAccountID string `bun:"type:CHAR(26),notnull"`
+ TargetAccount *Account `bun:"rel:belongs-to"`
// Activitypub URI for this block
- URI string `pg:",notnull"`
+ URI string `bun:",notnull"`
}
diff --git a/internal/gtsmodel/domainblock.go b/internal/gtsmodel/domainblock.go
index 1bed86d8f..03d5ab0af 100644
--- a/internal/gtsmodel/domainblock.go
+++ b/internal/gtsmodel/domainblock.go
@@ -23,16 +23,16 @@ import "time"
// DomainBlock represents a federation block against a particular domain
type DomainBlock struct {
// ID of this block in the database
- ID string `pg:"type:CHAR(26),pk,notnull,unique"`
+ ID string `bun:"type:CHAR(26),pk,notnull,unique"`
// blocked domain
- Domain string `pg:",pk,notnull,unique"`
+ Domain string `bun:",pk,notnull,unique"`
// When was this block created
- CreatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
+ CreatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"`
// When was this block updated
- UpdatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
+ UpdatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"`
// Account ID of the creator of this block
- CreatedByAccountID string `pg:"type:CHAR(26),notnull"`
- CreatedByAccount *Account `pg:"rel:belongs-to"`
+ CreatedByAccountID string `bun:"type:CHAR(26),notnull"`
+ CreatedByAccount *Account `bun:"rel:belongs-to"`
// Private comment on this block, viewable to admins
PrivateComment string
// Public comment on this block, viewable (optionally) by everyone
@@ -40,5 +40,5 @@ type DomainBlock struct {
// whether the domain name should appear obfuscated when displaying it publicly
Obfuscate bool
// if this block was created through a subscription, what's the subscription ID?
- SubscriptionID string `pg:"type:CHAR(26)"`
+ SubscriptionID string `bun:"type:CHAR(26),nullzero"`
}
diff --git a/internal/gtsmodel/emaildomainblock.go b/internal/gtsmodel/emaildomainblock.go
index 374454374..1919172fa 100644
--- a/internal/gtsmodel/emaildomainblock.go
+++ b/internal/gtsmodel/emaildomainblock.go
@@ -23,14 +23,14 @@ import "time"
// EmailDomainBlock represents a domain that the server should automatically reject sign-up requests from.
type EmailDomainBlock struct {
// ID of this block in the database
- ID string `pg:"type:CHAR(26),pk,notnull,unique"`
+ ID string `bun:"type:CHAR(26),pk,notnull,unique"`
// Email domain to block. Eg. 'gmail.com' or 'hotmail.com'
- Domain string `pg:",notnull"`
+ Domain string `bun:",notnull"`
// When was this block created
- CreatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
+ CreatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"`
// When was this block updated
- UpdatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
+ UpdatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"`
// Account ID of the creator of this block
- CreatedByAccountID string `pg:"type:CHAR(26),notnull"`
- CreatedByAccount *Account `pg:"rel:belongs-to"`
+ CreatedByAccountID string `bun:"type:CHAR(26),notnull"`
+ CreatedByAccount *Account `bun:"rel:belongs-to"`
}
diff --git a/internal/gtsmodel/emoji.go b/internal/gtsmodel/emoji.go
index f0996d1a3..3b02c14e7 100644
--- a/internal/gtsmodel/emoji.go
+++ b/internal/gtsmodel/emoji.go
@@ -23,16 +23,16 @@ import "time"
// Emoji represents a custom emoji that's been uploaded through the admin UI, and is useable by instance denizens.
type Emoji struct {
// database ID of this emoji
- ID string `pg:"type:CHAR(26),pk,notnull"`
+ ID string `bun:"type:CHAR(26),pk,notnull"`
// String shortcode for this emoji -- the part that's between colons. This should be lowercase a-z_
// eg., 'blob_hug' 'purple_heart' Must be unique with domain.
- Shortcode string `pg:",notnull,unique:shortcodedomain"`
+ Shortcode string `bun:",notnull,unique:shortcodedomain"`
// Origin domain of this emoji, eg 'example.org', 'queer.party'. empty string for local emojis.
- Domain string `pg:",notnull,default:'',use_zero,unique:shortcodedomain"`
+ Domain string `bun:",notnull,default:'',unique:shortcodedomain"`
// When was this emoji created. Must be unique with shortcode.
- CreatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
+ CreatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"`
// When was this emoji updated
- UpdatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
+ UpdatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"`
// Where can this emoji be retrieved remotely? Null for local emojis.
// For remote emojis, it'll be something like:
// https://hackers.town/system/custom_emojis/images/000/049/842/original/1b74481204feabfd.png
@@ -51,28 +51,27 @@ type Emoji struct {
ImageStaticURL string
// Path of the emoji image in the server storage system. Will be something like:
// '/gotosocial/storage/6339820e-ef65-4166-a262-5a9f46adb1a7/emoji/original/bfa6c9c5-6c25-4ea4-98b4-d78b8126fb52.png'
- ImagePath string `pg:",notnull"`
+ ImagePath string `bun:",notnull"`
// Path of a static version of the emoji image in the server storage system. Will be something like:
// '/gotosocial/storage/6339820e-ef65-4166-a262-5a9f46adb1a7/emoji/small/bfa6c9c5-6c25-4ea4-98b4-d78b8126fb52.png'
- ImageStaticPath string `pg:",notnull"`
+ ImageStaticPath string `bun:",notnull"`
// MIME content type of the emoji image
// Probably "image/png"
- ImageContentType string `pg:",notnull"`
+ ImageContentType string `bun:",notnull"`
// MIME content type of the static version of the emoji image.
- ImageStaticContentType string `pg:",notnull"`
+ ImageStaticContentType string `bun:",notnull"`
// Size of the emoji image file in bytes, for serving purposes.
- ImageFileSize int `pg:",notnull"`
+ ImageFileSize int `bun:",notnull"`
// Size of the static version of the emoji image file in bytes, for serving purposes.
- ImageStaticFileSize int `pg:",notnull"`
+ ImageStaticFileSize int `bun:",notnull"`
// When was the emoji image last updated?
- ImageUpdatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
+ ImageUpdatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"`
// Has a moderation action disabled this emoji from being shown?
- Disabled bool `pg:",notnull,default:false"`
+ Disabled bool `bun:",notnull,default:false"`
// ActivityStreams uri of this emoji. Something like 'https://example.org/emojis/1234'
- URI string `pg:",notnull,unique"`
+ URI string `bun:",notnull,unique"`
// Is this emoji visible in the admin emoji picker?
- VisibleInPicker bool `pg:",notnull,default:true"`
+ VisibleInPicker bool `bun:",notnull,default:true"`
// In which emoji category is this emoji visible?
- CategoryID string `pg:"type:CHAR(26)"`
- Status *Status `pg:"rel:belongs-to"`
+ CategoryID string `bun:"type:CHAR(26),nullzero"`
}
diff --git a/internal/gtsmodel/follow.go b/internal/gtsmodel/follow.go
index 8f169f8c4..3d3eb1f1b 100644
--- a/internal/gtsmodel/follow.go
+++ b/internal/gtsmodel/follow.go
@@ -23,21 +23,21 @@ import "time"
// Follow represents one account following another, and the metadata around that follow.
type Follow struct {
// id of this follow in the database
- ID string `pg:"type:CHAR(26),pk,notnull,unique"`
+ ID string `bun:"type:CHAR(26),pk,notnull,unique"`
// When was this follow created?
- CreatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
+ CreatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"`
// When was this follow last updated?
- UpdatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
+ UpdatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"`
// Who does this follow belong to?
- AccountID string `pg:"type:CHAR(26),unique:srctarget,notnull"`
- Account *Account `pg:"rel:belongs-to"`
+ AccountID string `bun:"type:CHAR(26),unique:srctarget,notnull"`
+ Account *Account `bun:"rel:belongs-to"`
// Who does AccountID follow?
- TargetAccountID string `pg:"type:CHAR(26),unique:srctarget,notnull"`
- TargetAccount *Account `pg:"rel:has-one"`
+ TargetAccountID string `bun:"type:CHAR(26),unique:srctarget,notnull"`
+ TargetAccount *Account `bun:"rel:belongs-to"`
// Does this follow also want to see reblogs and not just posts?
- ShowReblogs bool `pg:"default:true"`
+ ShowReblogs bool `bun:"default:true"`
// What is the activitypub URI of this follow?
- URI string `pg:",unique"`
+ URI string `bun:",unique"`
// does the following account want to be notified when the followed account posts?
Notify bool
}
diff --git a/internal/gtsmodel/followrequest.go b/internal/gtsmodel/followrequest.go
index 752c7d0a2..5a6cb5e02 100644
--- a/internal/gtsmodel/followrequest.go
+++ b/internal/gtsmodel/followrequest.go
@@ -23,21 +23,21 @@ import "time"
// FollowRequest represents one account requesting to follow another, and the metadata around that request.
type FollowRequest struct {
// id of this follow request in the database
- ID string `pg:"type:CHAR(26),pk,notnull,unique"`
+ ID string `bun:"type:CHAR(26),pk,notnull,unique"`
// When was this follow request created?
- CreatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
+ CreatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"`
// When was this follow request last updated?
- UpdatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
+ UpdatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"`
// Who does this follow request originate from?
- AccountID string `pg:"type:CHAR(26),unique:srctarget,notnull"`
- Account Account `pg:"rel:has-one"`
+ AccountID string `bun:"type:CHAR(26),unique:frsrctarget,notnull"`
+ Account *Account `bun:"rel:belongs-to"`
// Who is the target of this follow request?
- TargetAccountID string `pg:"type:CHAR(26),unique:srctarget,notnull"`
- TargetAccount Account `pg:"rel:has-one"`
+ TargetAccountID string `bun:"type:CHAR(26),unique:frsrctarget,notnull"`
+ TargetAccount *Account `bun:"rel:belongs-to"`
// Does this follow also want to see reblogs and not just posts?
- ShowReblogs bool `pg:"default:true"`
+ ShowReblogs bool `bun:"default:true"`
// What is the activitypub URI of this follow request?
- URI string `pg:",unique"`
+ URI string `bun:",unique,nullzero"`
// does the following account want to be notified when the followed account posts?
Notify bool
}
diff --git a/internal/gtsmodel/instance.go b/internal/gtsmodel/instance.go
index 7b453a0b3..5bfe942f7 100644
--- a/internal/gtsmodel/instance.go
+++ b/internal/gtsmodel/instance.go
@@ -5,22 +5,22 @@ import "time"
// Instance represents a federated instance, either local or remote.
type Instance struct {
// ID of this instance in the database
- ID string `pg:"type:CHAR(26),pk,notnull,unique"`
+ ID string `bun:"type:CHAR(26),pk,notnull,unique"`
// Instance domain eg example.org
- Domain string `pg:",pk,notnull,unique"`
+ Domain string `bun:",pk,notnull,unique"`
// Title of this instance as it would like to be displayed.
Title string
// base URI of this instance eg https://example.org
- URI string `pg:",notnull,unique"`
+ URI string `bun:",notnull,unique"`
// When was this instance created in the db?
- CreatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
+ CreatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"`
// When was this instance last updated in the db?
- UpdatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
+ UpdatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"`
// When was this instance suspended, if at all?
- SuspendedAt time.Time
+ SuspendedAt time.Time `bun:",nullzero"`
// ID of any existing domain block for this instance in the database
- DomainBlockID string `pg:"type:CHAR(26)"`
- DomainBlock *DomainBlock `pg:"rel:has-one"`
+ DomainBlockID string `bun:"type:CHAR(26),nullzero"`
+ DomainBlock *DomainBlock `bun:"rel:belongs-to"`
// Short description of this instance
ShortDescription string
// Longer description of this instance
@@ -32,10 +32,10 @@ type Instance struct {
// Username of the contact account for this instance
ContactAccountUsername string
// Contact account ID in the database for this instance
- ContactAccountID string `pg:"type:CHAR(26)"`
- ContactAccount *Account `pg:"rel:has-one"`
+ ContactAccountID string `bun:"type:CHAR(26),nullzero"`
+ ContactAccount *Account `bun:"rel:belongs-to"`
// Reputation score of this instance
- Reputation int64 `pg:",notnull,default:0"`
+ Reputation int64 `bun:",notnull,default:0"`
// Version of the software used on this instance
Version string
}
diff --git a/internal/gtsmodel/mediaattachment.go b/internal/gtsmodel/mediaattachment.go
index 0f12caaad..b767e538c 100644
--- a/internal/gtsmodel/mediaattachment.go
+++ b/internal/gtsmodel/mediaattachment.go
@@ -26,28 +26,28 @@ import (
// somewhere in storage and that can be retrieved and served by the router.
type MediaAttachment struct {
// ID of the attachment in the database
- ID string `pg:"type:CHAR(26),pk,notnull,unique"`
+ ID string `bun:"type:CHAR(26),pk,notnull,unique"`
// ID of the status to which this is attached
- StatusID string `pg:"type:CHAR(26)"`
+ StatusID string `bun:"type:CHAR(26),nullzero"`
// Where can the attachment be retrieved on *this* server
URL string
// Where can the attachment be retrieved on a remote server (empty for local media)
RemoteURL string
// When was the attachment created
- CreatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
+ CreatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"`
// When was the attachment last updated
- UpdatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
+ UpdatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"`
// Type of file (image/gif/audio/video)
- Type FileType `pg:",notnull"`
+ Type FileType `bun:",notnull"`
// Metadata about the file
FileMeta FileMeta
// To which account does this attachment belong
- AccountID string `pg:"type:CHAR(26),notnull"`
- Account *Account `pg:"rel:belongs-to"`
+ AccountID string `bun:"type:CHAR(26),notnull"`
+ Account *Account `bun:"rel:has-one"`
// Description of the attachment (for screenreaders)
Description string
// To which scheduled status does this attachment belong
- ScheduledStatusID string `pg:"type:CHAR(26)"`
+ ScheduledStatusID string `bun:"type:CHAR(26),nullzero"`
// What is the generated blurhash of this attachment
Blurhash string
// What is the processing status of this attachment
@@ -71,7 +71,7 @@ type File struct {
// What is the size of the file in bytes.
FileSize int
// When was the file last updated.
- UpdatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
+ UpdatedAt time.Time `bun:"type:timestamp,notnull,default:current_timestamp"`
}
// Thumbnail refers to a small image thumbnail derived from a larger image, video, or audio file.
@@ -83,7 +83,7 @@ type Thumbnail struct {
// What is the size of the file in bytes
FileSize int
// When was the file last updated
- UpdatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
+ UpdatedAt time.Time `bun:"type:timestamp,notnull,default:current_timestamp"`
// What is the URL of the thumbnail on the local server
URL string
// What is the remote URL of the thumbnail (empty for local media)
diff --git a/internal/gtsmodel/mention.go b/internal/gtsmodel/mention.go
index 931e681db..ce5977659 100644
--- a/internal/gtsmodel/mention.go
+++ b/internal/gtsmodel/mention.go
@@ -23,22 +23,22 @@ import "time"
// Mention refers to the 'tagging' or 'mention' of a user within a status.
type Mention struct {
// ID of this mention in the database
- ID string `pg:"type:CHAR(26),pk,notnull,unique"`
+ ID string `bun:"type:CHAR(26),pk,notnull,unique"`
// ID of the status this mention originates from
- StatusID string `pg:"type:CHAR(26),notnull"`
- Status *Status `pg:"rel:belongs-to"`
+ StatusID string `bun:"type:CHAR(26),notnull"`
+ Status *Status `bun:"rel:belongs-to"`
// When was this mention created?
- CreatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
+ CreatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"`
// When was this mention last updated?
- UpdatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
+ UpdatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"`
// What's the internal account ID of the originator of the mention?
- OriginAccountID string `pg:"type:CHAR(26),notnull"`
- OriginAccount *Account `pg:"rel:has-one"`
+ OriginAccountID string `bun:"type:CHAR(26),notnull"`
+ OriginAccount *Account `bun:"rel:belongs-to"`
// What's the AP URI of the originator of the mention?
- OriginAccountURI string `pg:",notnull"`
+ OriginAccountURI string `bun:",notnull"`
// What's the internal account ID of the mention target?
- TargetAccountID string `pg:"type:CHAR(26),notnull"`
- TargetAccount *Account `pg:"rel:has-one"`
+ TargetAccountID string `bun:"type:CHAR(26),notnull"`
+ TargetAccount *Account `bun:"rel:belongs-to"`
// Prevent this mention from generating a notification?
Silent bool
@@ -54,15 +54,15 @@ type Mention struct {
// @whatever_username@example.org
//
// This will not be put in the database, it's just for convenience.
- NameString string `pg:"-"`
+ NameString string `bun:"-"`
// TargetAccountURI is the AP ID (uri) of the user mentioned.
//
// This will not be put in the database, it's just for convenience.
- TargetAccountURI string `pg:"-"`
+ TargetAccountURI string `bun:"-"`
// TargetAccountURL is the web url of the user mentioned.
//
// This will not be put in the database, it's just for convenience.
- TargetAccountURL string `pg:"-"`
+ TargetAccountURL string `bun:"-"`
// A pointer to the gtsmodel account of the mentioned account.
}
diff --git a/internal/gtsmodel/messages.go b/internal/gtsmodel/messages.go
index 910c74898..62beb0adc 100644
--- a/internal/gtsmodel/messages.go
+++ b/internal/gtsmodel/messages.go
@@ -1,11 +1,22 @@
-package gtsmodel
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
-// // ToClientAPI wraps a message that travels from the processor into the client API
-// type ToClientAPI struct {
-// APObjectType ActivityStreamsObject
-// APActivityType ActivityStreamsActivity
-// Activity interface{}
-// }
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package gtsmodel
// FromClientAPI wraps a message that travels from client API into the processor
type FromClientAPI struct {
@@ -16,13 +27,6 @@ type FromClientAPI struct {
TargetAccount *Account
}
-// // ToFederator wraps a message that travels from the processor into the federator
-// type ToFederator struct {
-// APObjectType ActivityStreamsObject
-// APActivityType ActivityStreamsActivity
-// GTSModel interface{}
-// }
-
// FromFederator wraps a message that travels from the federator into the processor
type FromFederator struct {
APObjectType string
diff --git a/internal/gtsmodel/notification.go b/internal/gtsmodel/notification.go
index b85bc969e..14ab90802 100644
--- a/internal/gtsmodel/notification.go
+++ b/internal/gtsmodel/notification.go
@@ -23,20 +23,20 @@ import "time"
// Notification models an alert/notification sent to an account about something like a reblog, like, new follow request, etc.
type Notification struct {
// ID of this notification in the database
- ID string `pg:"type:CHAR(26),pk,notnull"`
+ ID string `bun:"type:CHAR(26),pk,notnull"`
// Type of this notification
- NotificationType NotificationType `pg:",notnull"`
+ NotificationType NotificationType `bun:",notnull"`
// Creation time of this notification
- CreatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
+ CreatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"`
// Which account does this notification target (ie., who will receive the notification?)
- TargetAccountID string `pg:"type:CHAR(26),notnull"`
- TargetAccount *Account `pg:"rel:has-one"`
+ TargetAccountID string `bun:"type:CHAR(26),notnull"`
+ TargetAccount *Account `bun:"rel:belongs-to"`
// Which account performed the action that created this notification?
- OriginAccountID string `pg:"type:CHAR(26),notnull"`
- OriginAccount *Account `pg:"rel:has-one"`
+ OriginAccountID string `bun:"type:CHAR(26),notnull"`
+ OriginAccount *Account `bun:"rel:belongs-to"`
// If the notification pertains to a status, what is the database ID of that status?
- StatusID string `pg:"type:CHAR(26)"`
- Status *Status `pg:"rel:has-one"`
+ StatusID string `bun:"type:CHAR(26),nullzero"`
+ Status *Status `bun:"rel:belongs-to"`
// Has this notification been read already?
Read bool
}
diff --git a/internal/gtsmodel/routersession.go b/internal/gtsmodel/routersession.go
index c0f8e1f4d..7f3bd85c3 100644
--- a/internal/gtsmodel/routersession.go
+++ b/internal/gtsmodel/routersession.go
@@ -20,7 +20,7 @@ package gtsmodel
// RouterSession is used to store and retrieve settings for a router session.
type RouterSession struct {
- ID string `pg:"type:CHAR(26),pk,notnull"`
- Auth []byte `pg:",notnull"`
- Crypt []byte `pg:",notnull"`
+ ID string `bun:"type:CHAR(26),pk,notnull"`
+ Auth []byte `bun:"type:bytea,notnull"`
+ Crypt []byte `bun:"type:bytea,notnull"`
}
diff --git a/internal/gtsmodel/status.go b/internal/gtsmodel/status.go
index 354f37e04..fd33bd788 100644
--- a/internal/gtsmodel/status.go
+++ b/internal/gtsmodel/status.go
@@ -25,61 +25,61 @@ import (
// Status represents a user-created 'post' or 'status' in the database, either remote or local
type Status struct {
// id of the status in the database
- ID string `pg:"type:CHAR(26),pk,notnull"`
+ ID string `bun:"type:CHAR(26),pk,notnull"`
// uri at which this status is reachable
- URI string `pg:",unique"`
+ URI string `bun:",unique,nullzero"`
// web url for viewing this status
- URL string `pg:",unique"`
+ URL string `bun:",unique,nullzero"`
// the html-formatted content of this status
Content string
// Database IDs of any media attachments associated with this status
- AttachmentIDs []string `pg:"attachments,array"`
- Attachments []*MediaAttachment `pg:"attached_media,rel:has-many"`
+ AttachmentIDs []string `bun:"attachments,array"`
+ Attachments []*MediaAttachment `bun:"attached_media,rel:has-many"`
// Database IDs of any tags used in this status
- TagIDs []string `pg:"tags,array"`
- Tags []*Tag `pg:"attached_tags,many2many:status_to_tags"` // https://pg.uptrace.dev/orm/many-to-many-relation/
+ TagIDs []string `bun:"tags,array"`
+ Tags []*Tag `bun:"attached_tags,m2m:status_to_tags"` // https://bun.uptrace.dev/guide/relations.html#many-to-many-relation
// Database IDs of any mentions in this status
- MentionIDs []string `pg:"mentions,array"`
- Mentions []*Mention `pg:"attached_mentions,rel:has-many"`
+ MentionIDs []string `bun:"mentions,array"`
+ Mentions []*Mention `bun:"attached_mentions,rel:has-many"`
// Database IDs of any emojis used in this status
- EmojiIDs []string `pg:"emojis,array"`
- Emojis []*Emoji `pg:"attached_emojis,many2many:status_to_emojis"` // https://pg.uptrace.dev/orm/many-to-many-relation/
+ EmojiIDs []string `bun:"emojis,array"`
+ Emojis []*Emoji `bun:"attached_emojis,m2m:status_to_emojis"` // https://bun.uptrace.dev/guide/relations.html#many-to-many-relation
// when was this status created?
- CreatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
+ CreatedAt time.Time `bun:",notnull,nullzero,default:current_timestamp"`
// when was this status updated?
- UpdatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
+ UpdatedAt time.Time `bun:",notnull,nullzero,default:current_timestamp"`
// is this status from a local account?
Local bool
// which account posted this status?
- AccountID string `pg:"type:CHAR(26),notnull"`
- Account *Account `pg:"rel:has-one"`
+ AccountID string `bun:"type:CHAR(26),notnull"`
+ Account *Account `bun:"rel:belongs-to"`
// AP uri of the owner of this status
AccountURI string
// id of the status this status is a reply to
- InReplyToID string `pg:"type:CHAR(26)"`
- InReplyTo *Status `pg:"rel:has-one"`
+ InReplyToID string `bun:"type:CHAR(26),nullzero"`
+ InReplyTo *Status `bun:"-"`
// AP uri of the status this status is a reply to
InReplyToURI string
// id of the account that this status replies to
- InReplyToAccountID string `pg:"type:CHAR(26)"`
- InReplyToAccount *Account `pg:"rel:has-one"`
+ InReplyToAccountID string `bun:"type:CHAR(26),nullzero"`
+ InReplyToAccount *Account `bun:"rel:belongs-to"`
// id of the status this status is a boost of
- BoostOfID string `pg:"type:CHAR(26)"`
- BoostOf *Status `pg:"rel:has-one"`
+ BoostOfID string `bun:"type:CHAR(26),nullzero"`
+ BoostOf *Status `bun:"-"`
// id of the account that owns the boosted status
- BoostOfAccountID string `pg:"type:CHAR(26)"`
- BoostOfAccount *Account `pg:"rel:has-one"`
+ BoostOfAccountID string `bun:"type:CHAR(26),nullzero"`
+ BoostOfAccount *Account `bun:"rel:belongs-to"`
// cw string for this status
ContentWarning string
// visibility entry for this status
- Visibility Visibility `pg:",notnull"`
+ Visibility Visibility `bun:",notnull"`
// mark the status as sensitive?
Sensitive bool
// what language is this status written in?
Language string
// Which application was used to create this status?
- CreatedWithApplicationID string `pg:"type:CHAR(26)"`
- CreatedWithApplication *Application `pg:"rel:has-one"`
+ CreatedWithApplicationID string `bun:"type:CHAR(26),nullzero"`
+ CreatedWithApplication *Application `bun:"rel:belongs-to"`
// advanced visibility for this status
VisibilityAdvanced *VisibilityAdvanced
// What is the activitystreams type of this status? See: https://www.w3.org/TR/activitystreams-vocabulary/#object-types
@@ -93,14 +93,18 @@ type Status struct {
// StatusToTag is an intermediate struct to facilitate the many2many relationship between a status and one or more tags.
type StatusToTag struct {
- StatusID string `pg:"unique:statustag"`
- TagID string `pg:"unique:statustag"`
+ StatusID string `bun:"type:CHAR(26),unique:statustag,nullzero"`
+ Status *Status `bun:"rel:belongs-to"`
+ TagID string `bun:"type:CHAR(26),unique:statustag,nullzero"`
+ Tag *Tag `bun:"rel:belongs-to"`
}
// StatusToEmoji is an intermediate struct to facilitate the many2many relationship between a status and one or more emojis.
type StatusToEmoji struct {
- StatusID string `pg:"unique:statusemoji"`
- EmojiID string `pg:"unique:statusemoji"`
+ StatusID string `bun:"type:CHAR(26),unique:statusemoji,nullzero"`
+ Status *Status `bun:"rel:belongs-to"`
+ EmojiID string `bun:"type:CHAR(26),unique:statusemoji,nullzero"`
+ Emoji *Emoji `bun:"rel:belongs-to"`
}
// Visibility represents the visibility granularity of a status.
@@ -134,11 +138,11 @@ const (
// If DIRECT is selected, boostable will be FALSE, and all other flags will be TRUE.
type VisibilityAdvanced struct {
// This status will be federated beyond the local timeline(s)
- Federated bool `pg:"default:true"`
+ Federated bool `bun:"default:true"`
// This status can be boosted/reblogged
- Boostable bool `pg:"default:true"`
+ Boostable bool `bun:"default:true"`
// This status can be replied to
- Replyable bool `pg:"default:true"`
+ Replyable bool `bun:"default:true"`
// This status can be liked/faved
- Likeable bool `pg:"default:true"`
+ Likeable bool `bun:"default:true"`
}
diff --git a/internal/gtsmodel/statusbookmark.go b/internal/gtsmodel/statusbookmark.go
index 468939bae..26dafa420 100644
--- a/internal/gtsmodel/statusbookmark.go
+++ b/internal/gtsmodel/statusbookmark.go
@@ -23,15 +23,15 @@ import "time"
// StatusBookmark refers to one account having a 'bookmark' of the status of another account
type StatusBookmark struct {
// id of this bookmark in the database
- ID string `pg:"type:CHAR(26),pk,notnull,unique"`
+ ID string `bun:"type:CHAR(26),pk,notnull,unique"`
// when was this bookmark created
- CreatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
+ CreatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"`
// id of the account that created ('did') the bookmarking
- AccountID string `pg:"type:CHAR(26),notnull"`
- Account *Account `pg:"rel:belongs-to"`
+ AccountID string `bun:"type:CHAR(26),notnull"`
+ Account *Account `bun:"rel:belongs-to"`
// id the account owning the bookmarked status
- TargetAccountID string `pg:"type:CHAR(26),notnull"`
- TargetAccount *Account `pg:"rel:has-one"`
+ TargetAccountID string `bun:"type:CHAR(26),notnull"`
+ TargetAccount *Account `bun:"rel:belongs-to"`
// database id of the status that has been bookmarked
- StatusID string `pg:"type:CHAR(26),notnull"`
+ StatusID string `bun:"type:CHAR(26),notnull"`
}
diff --git a/internal/gtsmodel/statusfave.go b/internal/gtsmodel/statusfave.go
index 17952673a..3b816af56 100644
--- a/internal/gtsmodel/statusfave.go
+++ b/internal/gtsmodel/statusfave.go
@@ -23,18 +23,18 @@ import "time"
// StatusFave refers to a 'fave' or 'like' in the database, from one account, targeting the status of another account
type StatusFave struct {
// id of this fave in the database
- ID string `pg:"type:CHAR(26),pk,notnull,unique"`
+ ID string `bun:"type:CHAR(26),pk,notnull,unique"`
// when was this fave created
- CreatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
+ CreatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"`
// id of the account that created ('did') the fave
- AccountID string `pg:"type:CHAR(26),notnull"`
- Account *Account `pg:"rel:has-one"`
+ AccountID string `bun:"type:CHAR(26),notnull"`
+ Account *Account `bun:"rel:belongs-to"`
// id the account owning the faved status
- TargetAccountID string `pg:"type:CHAR(26),notnull"`
- TargetAccount *Account `pg:"rel:has-one"`
+ TargetAccountID string `bun:"type:CHAR(26),notnull"`
+ TargetAccount *Account `bun:"rel:belongs-to"`
// database id of the status that has been 'faved'
- StatusID string `pg:"type:CHAR(26),notnull"`
- Status *Status `pg:"rel:has-one"`
+ StatusID string `bun:"type:CHAR(26),notnull"`
+ Status *Status `bun:"rel:belongs-to"`
// ActivityPub URI of this fave
- URI string `pg:",notnull"`
+ URI string `bun:",notnull"`
}
diff --git a/internal/gtsmodel/statusmute.go b/internal/gtsmodel/statusmute.go
index 472a5ec09..56a792ab4 100644
--- a/internal/gtsmodel/statusmute.go
+++ b/internal/gtsmodel/statusmute.go
@@ -23,16 +23,16 @@ import "time"
// StatusMute refers to one account having muted the status of another account or its own
type StatusMute struct {
// id of this mute in the database
- ID string `pg:"type:CHAR(26),pk,notnull,unique"`
+ ID string `bun:"type:CHAR(26),pk,notnull,unique"`
// when was this mute created
- CreatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
+ CreatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"`
// id of the account that created ('did') the mute
- AccountID string `pg:"type:CHAR(26),notnull"`
- Account *Account `pg:"rel:belongs-to"`
+ AccountID string `bun:"type:CHAR(26),notnull"`
+ Account *Account `bun:"rel:belongs-to"`
// id the account owning the muted status (can be the same as accountID)
- TargetAccountID string `pg:"type:CHAR(26),notnull"`
- TargetAccount *Account `pg:"rel:has-one"`
+ TargetAccountID string `bun:"type:CHAR(26),notnull"`
+ TargetAccount *Account `bun:"rel:belongs-to"`
// database id of the status that has been muted
- StatusID string `pg:"type:CHAR(26),notnull"`
- Status *Status `pg:"rel:has-one"`
+ StatusID string `bun:"type:CHAR(26),notnull"`
+ Status *Status `bun:"rel:belongs-to"`
}
diff --git a/internal/gtsmodel/tag.go b/internal/gtsmodel/tag.go
index 27cce1c8b..5006a36f4 100644
--- a/internal/gtsmodel/tag.go
+++ b/internal/gtsmodel/tag.go
@@ -23,21 +23,21 @@ import "time"
// Tag represents a hashtag for gathering public statuses together
type Tag struct {
// id of this tag in the database
- ID string `pg:",unique,type:CHAR(26),pk,notnull"`
+ ID string `bun:",unique,type:CHAR(26),pk,notnull"`
// Href of this tag, eg https://example.org/tags/somehashtag
URL string
// name of this tag -- the tag without the hash part
- Name string `pg:",unique,notnull"`
+ Name string `bun:",unique,notnull"`
// Which account ID is the first one we saw using this tag?
- FirstSeenFromAccountID string `pg:"type:CHAR(26)"`
+ FirstSeenFromAccountID string `bun:"type:CHAR(26),nullzero"`
// when was this tag created
- CreatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
+ CreatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"`
// when was this tag last updated
- UpdatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
+ UpdatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"`
// can our instance users use this tag?
- Useable bool `pg:",notnull,default:true"`
+ Useable bool `bun:",notnull,default:true"`
// can our instance users look up this tag?
- Listable bool `pg:",notnull,default:true"`
+ Listable bool `bun:",notnull,default:true"`
// when was this tag last used?
- LastStatusAt time.Time `pg:"type:timestamp,notnull,default:now()"`
+ LastStatusAt time.Time `bun:",nullzero"`
}
diff --git a/internal/gtsmodel/user.go b/internal/gtsmodel/user.go
index fe8ebcabe..f439be439 100644
--- a/internal/gtsmodel/user.go
+++ b/internal/gtsmodel/user.go
@@ -31,37 +31,37 @@ type User struct {
*/
// id of this user in the local database; the end-user will never need to know this, it's strictly internal
- ID string `pg:"type:CHAR(26),pk,notnull,unique"`
+ ID string `bun:"type:CHAR(26),pk,notnull,unique"`
// confirmed email address for this user, this should be unique -- only one email address registered per instance, multiple users per email are not supported
- Email string `pg:"default:null,unique"`
+ Email string `bun:"default:null,unique,nullzero"`
// The id of the local gtsmodel.Account entry for this user, if it exists (unconfirmed users don't have an account yet)
- AccountID string `pg:"type:CHAR(26),unique"`
- Account *Account `pg:"rel:has-one"`
+ AccountID string `bun:"type:CHAR(26),unique,nullzero"`
+ Account *Account `bun:"rel:belongs-to"`
// The encrypted password of this user, generated using https://pkg.go.dev/golang.org/x/crypto/bcrypt#GenerateFromPassword. A salt is included so we're safe against 🌈 tables
- EncryptedPassword string `pg:",notnull"`
+ EncryptedPassword string `bun:",notnull"`
/*
USER METADATA
*/
// When was this user created?
- CreatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
+ CreatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"`
// From what IP was this user created?
- SignUpIP net.IP
+ SignUpIP net.IP `bun:",nullzero"`
// When was this user updated (eg., password changed, email address changed)?
- UpdatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
+ UpdatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"`
// When did this user sign in for their current session?
- CurrentSignInAt time.Time `pg:"type:timestamp"`
+ CurrentSignInAt time.Time `bun:",nullzero"`
// What's the most recent IP of this user
- CurrentSignInIP net.IP
+ CurrentSignInIP net.IP `bun:",nullzero"`
// When did this user last sign in?
- LastSignInAt time.Time `pg:"type:timestamp"`
+ LastSignInAt time.Time `bun:",nullzero"`
// What's the previous IP of this user?
- LastSignInIP net.IP
+ LastSignInIP net.IP `bun:",nullzero"`
// How many times has this user signed in?
SignInCount int
// id of the user who invited this user (who let this guy in?)
- InviteID string `pg:"type:CHAR(26)"`
+ InviteID string `bun:"type:CHAR(26),nullzero"`
// What languages does this user want to see?
ChosenLanguages []string
// What languages does this user not want to see?
@@ -69,10 +69,10 @@ type User struct {
// In what timezone/locale is this user located?
Locale string
// Which application id created this user? See gtsmodel.Application
- CreatedByApplicationID string `pg:"type:CHAR(26)"`
- CreatedByApplication *Application `pg:"rel:has-one"`
+ CreatedByApplicationID string `bun:"type:CHAR(26),nullzero"`
+ CreatedByApplication *Application `bun:"rel:belongs-to"`
// When did we last contact this user
- LastEmailedAt time.Time `pg:"type:timestamp"`
+ LastEmailedAt time.Time `bun:",nullzero"`
/*
USER CONFIRMATION
@@ -81,9 +81,9 @@ type User struct {
// What confirmation token did we send this user/what are we expecting back?
ConfirmationToken string
// When did the user confirm their email address
- ConfirmedAt time.Time `pg:"type:timestamp"`
+ ConfirmedAt time.Time `bun:",nullzero"`
// When did we send email confirmation to this user?
- ConfirmationSentAt time.Time `pg:"type:timestamp"`
+ ConfirmationSentAt time.Time `bun:",nullzero"`
// Email address that hasn't yet been confirmed
UnconfirmedEmail string
@@ -107,7 +107,7 @@ type User struct {
// The generated token that the user can use to reset their password
ResetPasswordToken string
// When did we email the user their reset-password email?
- ResetPasswordSentAt time.Time `pg:"type:timestamp"`
+ ResetPasswordSentAt time.Time `bun:",nullzero"`
EncryptedOTPSecret string
EncryptedOTPSecretIv string
@@ -117,6 +117,6 @@ type User struct {
ConsumedTimestamp int
RememberToken string
SignInToken string
- SignInTokenSentAt time.Time `pg:"type:timestamp"`
+ SignInTokenSentAt time.Time `bun:",nullzero"`
WebauthnID string
}
diff --git a/internal/media/handler.go b/internal/media/handler.go
index c383a922e..1150f7e87 100644
--- a/internal/media/handler.go
+++ b/internal/media/handler.go
@@ -67,27 +67,27 @@ type Handler interface {
// ProcessHeaderOrAvatar takes a new header image for an account, checks it out, removes exif data from it,
// puts it in whatever storage backend we're using, sets the relevant fields in the database for the new image,
// and then returns information to the caller about the new header.
- ProcessHeaderOrAvatar(attachment []byte, accountID string, mediaType Type, remoteURL string) (*gtsmodel.MediaAttachment, error)
+ ProcessHeaderOrAvatar(ctx context.Context, attachment []byte, accountID string, mediaType Type, remoteURL string) (*gtsmodel.MediaAttachment, error)
// ProcessLocalAttachment takes a new attachment and the requesting account, checks it out, removes exif data from it,
// puts it in whatever storage backend we're using, sets the relevant fields in the database for the new media,
// and then returns information to the caller about the attachment. It's the caller's responsibility to put the returned struct
// in the database.
- ProcessAttachment(attachment []byte, accountID string, remoteURL string) (*gtsmodel.MediaAttachment, error)
+ ProcessAttachment(ctx context.Context, attachment []byte, accountID string, remoteURL string) (*gtsmodel.MediaAttachment, error)
// ProcessLocalEmoji takes a new emoji and a shortcode, cleans it up, puts it in storage, and creates a new
// *gts.Emoji for it, then returns it to the caller. It's the caller's responsibility to put the returned struct
// in the database.
- ProcessLocalEmoji(emojiBytes []byte, shortcode string) (*gtsmodel.Emoji, error)
+ ProcessLocalEmoji(ctx context.Context, emojiBytes []byte, shortcode string) (*gtsmodel.Emoji, error)
// ProcessRemoteAttachment takes a transport, a bare-bones current attachment, and an accountID that the attachment belongs to.
// It then dereferences the attachment (ie., fetches the attachment bytes from the remote server), ensuring that the bytes are
// the correct content type. It stores the attachment in whatever storage backend the Handler has been initalized with, and returns
// information to the caller about the new attachment. It's the caller's responsibility to put the returned struct
// in the database.
- ProcessRemoteAttachment(t transport.Transport, currentAttachment *gtsmodel.MediaAttachment, accountID string) (*gtsmodel.MediaAttachment, error)
+ ProcessRemoteAttachment(ctx context.Context, t transport.Transport, currentAttachment *gtsmodel.MediaAttachment, accountID string) (*gtsmodel.MediaAttachment, error)
- ProcessRemoteHeaderOrAvatar(t transport.Transport, currentAttachment *gtsmodel.MediaAttachment, accountID string) (*gtsmodel.MediaAttachment, error)
+ ProcessRemoteHeaderOrAvatar(ctx context.Context, t transport.Transport, currentAttachment *gtsmodel.MediaAttachment, accountID string) (*gtsmodel.MediaAttachment, error)
}
type mediaHandler struct {
@@ -114,7 +114,7 @@ func New(config *config.Config, database db.DB, storage blob.Storage, log *logru
// ProcessHeaderOrAvatar takes a new header image for an account, checks it out, removes exif data from it,
// puts it in whatever storage backend we're using, sets the relevant fields in the database for the new image,
// and then returns information to the caller about the new header.
-func (mh *mediaHandler) ProcessHeaderOrAvatar(attachment []byte, accountID string, mediaType Type, remoteURL string) (*gtsmodel.MediaAttachment, error) {
+func (mh *mediaHandler) ProcessHeaderOrAvatar(ctx context.Context, attachment []byte, accountID string, mediaType Type, remoteURL string) (*gtsmodel.MediaAttachment, error) {
l := mh.log.WithField("func", "SetHeaderForAccountID")
if mediaType != Header && mediaType != Avatar {
@@ -142,7 +142,7 @@ func (mh *mediaHandler) ProcessHeaderOrAvatar(attachment []byte, accountID strin
}
// set it in the database
- if err := mh.db.SetAccountHeaderOrAvatar(ma, accountID); err != nil {
+ if err := mh.db.SetAccountHeaderOrAvatar(ctx, ma, accountID); err != nil {
return nil, fmt.Errorf("error putting %s in database: %s", mediaType, err)
}
@@ -152,7 +152,7 @@ func (mh *mediaHandler) ProcessHeaderOrAvatar(attachment []byte, accountID strin
// ProcessAttachment takes a new attachment and the owning account, checks it out, removes exif data from it,
// puts it in whatever storage backend we're using, sets the relevant fields in the database for the new media,
// and then returns information to the caller about the attachment.
-func (mh *mediaHandler) ProcessAttachment(attachment []byte, accountID string, remoteURL string) (*gtsmodel.MediaAttachment, error) {
+func (mh *mediaHandler) ProcessAttachment(ctx context.Context, attachment []byte, accountID string, remoteURL string) (*gtsmodel.MediaAttachment, error) {
contentType, err := parseContentType(attachment)
if err != nil {
return nil, err
@@ -184,7 +184,7 @@ func (mh *mediaHandler) ProcessAttachment(attachment []byte, accountID string, r
// ProcessLocalEmoji takes a new emoji and a shortcode, cleans it up, puts it in storage, and creates a new
// *gts.Emoji for it, then returns it to the caller. It's the caller's responsibility to put the returned struct
// in the database.
-func (mh *mediaHandler) ProcessLocalEmoji(emojiBytes []byte, shortcode string) (*gtsmodel.Emoji, error) {
+func (mh *mediaHandler) ProcessLocalEmoji(ctx context.Context, emojiBytes []byte, shortcode string) (*gtsmodel.Emoji, error) {
var clean []byte
var err error
var original *imageAndMeta
@@ -231,7 +231,7 @@ func (mh *mediaHandler) ProcessLocalEmoji(emojiBytes []byte, shortcode string) (
// since emoji aren't 'owned' by an account, but we still want to use the same pattern for serving them through the filserver,
// (ie., fileserver/ACCOUNT_ID/etc etc) we need to fetch the INSTANCE ACCOUNT from the database. That is, the account that's created
// with the same username as the instance hostname, which doesn't belong to any particular user.
- instanceAccount, err := mh.db.GetInstanceAccount("")
+ instanceAccount, err := mh.db.GetInstanceAccount(ctx, "")
if err != nil {
return nil, fmt.Errorf("error fetching instance account: %s", err)
}
@@ -296,7 +296,7 @@ func (mh *mediaHandler) ProcessLocalEmoji(emojiBytes []byte, shortcode string) (
return e, nil
}
-func (mh *mediaHandler) ProcessRemoteAttachment(t transport.Transport, currentAttachment *gtsmodel.MediaAttachment, accountID string) (*gtsmodel.MediaAttachment, error) {
+func (mh *mediaHandler) ProcessRemoteAttachment(ctx context.Context, t transport.Transport, currentAttachment *gtsmodel.MediaAttachment, accountID string) (*gtsmodel.MediaAttachment, error) {
if currentAttachment.RemoteURL == "" {
return nil, errors.New("no remote URL on media attachment to dereference")
}
@@ -317,10 +317,10 @@ func (mh *mediaHandler) ProcessRemoteAttachment(t transport.Transport, currentAt
return nil, fmt.Errorf("dereferencing remote media with url %s: %s", remoteIRI.String(), err)
}
- return mh.ProcessAttachment(attachmentBytes, accountID, currentAttachment.RemoteURL)
+ return mh.ProcessAttachment(ctx, attachmentBytes, accountID, currentAttachment.RemoteURL)
}
-func (mh *mediaHandler) ProcessRemoteHeaderOrAvatar(t transport.Transport, currentAttachment *gtsmodel.MediaAttachment, accountID string) (*gtsmodel.MediaAttachment, error) {
+func (mh *mediaHandler) ProcessRemoteHeaderOrAvatar(ctx context.Context, t transport.Transport, currentAttachment *gtsmodel.MediaAttachment, accountID string) (*gtsmodel.MediaAttachment, error) {
if !currentAttachment.Header && !currentAttachment.Avatar {
return nil, errors.New("provided attachment was set to neither header nor avatar")
@@ -357,5 +357,5 @@ func (mh *mediaHandler) ProcessRemoteHeaderOrAvatar(t transport.Transport, curre
return nil, fmt.Errorf("dereferencing remote media with url %s: %s", remoteIRI.String(), err)
}
- return mh.ProcessHeaderOrAvatar(attachmentBytes, accountID, headerOrAvi, currentAttachment.RemoteURL)
+ return mh.ProcessHeaderOrAvatar(ctx, attachmentBytes, accountID, headerOrAvi, currentAttachment.RemoteURL)
}
diff --git a/internal/oauth/clientstore.go b/internal/oauth/clientstore.go
index 2e7e0ae88..a642f6cfa 100644
--- a/internal/oauth/clientstore.go
+++ b/internal/oauth/clientstore.go
@@ -39,10 +39,8 @@ func NewClientStore(db db.Basic) oauth2.ClientStore {
}
func (cs *clientStore) GetByID(ctx context.Context, clientID string) (oauth2.ClientInfo, error) {
- poc := &Client{
- ID: clientID,
- }
- if err := cs.db.GetByID(clientID, poc); err != nil {
+ poc := &Client{}
+ if err := cs.db.GetByID(ctx, clientID, poc); err != nil {
return nil, err
}
return models.New(poc.ID, poc.Secret, poc.Domain, poc.UserID), nil
@@ -55,19 +53,19 @@ func (cs *clientStore) Set(ctx context.Context, id string, cli oauth2.ClientInfo
Domain: cli.GetDomain(),
UserID: cli.GetUserID(),
}
- return cs.db.UpdateByID(id, poc)
+ return cs.db.Put(ctx, poc)
}
func (cs *clientStore) Delete(ctx context.Context, id string) error {
poc := &Client{
ID: id,
}
- return cs.db.DeleteByID(id, poc)
+ return cs.db.DeleteByID(ctx, id, poc)
}
// Client is a handy little wrapper for typical oauth client details
type Client struct {
- ID string `pg:"type:CHAR(26),pk,notnull"`
+ ID string `bun:"type:CHAR(26),pk,notnull"`
Secret string
Domain string
UserID string
diff --git a/internal/oauth/tokenstore.go b/internal/oauth/tokenstore.go
index 4fd3183fc..264678ff5 100644
--- a/internal/oauth/tokenstore.go
+++ b/internal/oauth/tokenstore.go
@@ -43,13 +43,13 @@ type tokenStore struct {
// In order to allow tokens to 'expire', it will also set off a goroutine that iterates through
// the tokens in the DB once per minute and deletes any that have expired.
func newTokenStore(ctx context.Context, db db.Basic, log *logrus.Logger) oauth2.TokenStore {
- pts := &tokenStore{
+ ts := &tokenStore{
db: db,
log: log,
}
// set the token store to clean out expired tokens once per minute, or return if we're done
- go func(ctx context.Context, pts *tokenStore, log *logrus.Logger) {
+ go func(ctx context.Context, ts *tokenStore, log *logrus.Logger) {
cleanloop:
for {
select {
@@ -58,32 +58,32 @@ func newTokenStore(ctx context.Context, db db.Basic, log *logrus.Logger) oauth2.
break cleanloop
case <-time.After(1 * time.Minute):
log.Trace("sweeping out old oauth entries broom broom")
- if err := pts.sweep(); err != nil {
+ if err := ts.sweep(ctx); err != nil {
log.Errorf("error while sweeping oauth entries: %s", err)
}
}
}
- }(ctx, pts, log)
- return pts
+ }(ctx, ts, log)
+ return ts
}
// sweep clears out old tokens that have expired; it should be run on a loop about once per minute or so.
-func (pts *tokenStore) sweep() error {
+func (ts *tokenStore) sweep(ctx context.Context) error {
// select *all* tokens from the db
// todo: if this becomes expensive (ie., there are fucking LOADS of tokens) then figure out a better way.
tokens := new([]*Token)
- if err := pts.db.GetAll(tokens); err != nil {
+ if err := ts.db.GetAll(ctx, tokens); err != nil {
return err
}
// iterate through and remove expired tokens
now := time.Now()
- for _, pgt := range *tokens {
+ for _, dbt := range *tokens {
// The zero value of a time.Time is 00:00 january 1 1970, which will always be before now. So:
// we only want to check if a token expired before now if the expiry time is *not zero*;
// ie., if it's been explicity set.
- if !pgt.CodeExpiresAt.IsZero() && pgt.CodeExpiresAt.Before(now) || !pgt.RefreshExpiresAt.IsZero() && pgt.RefreshExpiresAt.Before(now) || !pgt.AccessExpiresAt.IsZero() && pgt.AccessExpiresAt.Before(now) {
- if err := pts.db.DeleteByID(pgt.ID, pgt); err != nil {
+ if !dbt.CodeExpiresAt.IsZero() && dbt.CodeExpiresAt.Before(now) || !dbt.RefreshExpiresAt.IsZero() && dbt.RefreshExpiresAt.Before(now) || !dbt.AccessExpiresAt.IsZero() && dbt.AccessExpiresAt.Before(now) {
+ if err := ts.db.DeleteByID(ctx, dbt.ID, dbt); err != nil {
return err
}
}
@@ -94,92 +94,92 @@ func (pts *tokenStore) sweep() error {
// Create creates and store the new token information.
// For the original implementation, see https://github.com/superseriousbusiness/oauth2/blob/master/store/token.go#L34
-func (pts *tokenStore) Create(ctx context.Context, info oauth2.TokenInfo) error {
+func (ts *tokenStore) Create(ctx context.Context, info oauth2.TokenInfo) error {
t, ok := info.(*models.Token)
if !ok {
return errors.New("info param was not a models.Token")
}
- pgt := TokenToPGToken(t)
- if pgt.ID == "" {
- pgtID, err := id.NewRandomULID()
+ dbt := TokenToDBToken(t)
+ if dbt.ID == "" {
+ dbtID, err := id.NewRandomULID()
if err != nil {
return err
}
- pgt.ID = pgtID
+ dbt.ID = dbtID
}
- if err := pts.db.Put(pgt); err != nil {
+ if err := ts.db.Put(ctx, dbt); err != nil {
return fmt.Errorf("error in tokenstore create: %s", err)
}
return nil
}
// RemoveByCode deletes a token from the DB based on the Code field
-func (pts *tokenStore) RemoveByCode(ctx context.Context, code string) error {
- return pts.db.DeleteWhere([]db.Where{{Key: "code", Value: code}}, &Token{})
+func (ts *tokenStore) RemoveByCode(ctx context.Context, code string) error {
+ return ts.db.DeleteWhere(ctx, []db.Where{{Key: "code", Value: code}}, &Token{})
}
// RemoveByAccess deletes a token from the DB based on the Access field
-func (pts *tokenStore) RemoveByAccess(ctx context.Context, access string) error {
- return pts.db.DeleteWhere([]db.Where{{Key: "access", Value: access}}, &Token{})
+func (ts *tokenStore) RemoveByAccess(ctx context.Context, access string) error {
+ return ts.db.DeleteWhere(ctx, []db.Where{{Key: "access", Value: access}}, &Token{})
}
// RemoveByRefresh deletes a token from the DB based on the Refresh field
-func (pts *tokenStore) RemoveByRefresh(ctx context.Context, refresh string) error {
- return pts.db.DeleteWhere([]db.Where{{Key: "refresh", Value: refresh}}, &Token{})
+func (ts *tokenStore) RemoveByRefresh(ctx context.Context, refresh string) error {
+ return ts.db.DeleteWhere(ctx, []db.Where{{Key: "refresh", Value: refresh}}, &Token{})
}
// GetByCode selects a token from the DB based on the Code field
-func (pts *tokenStore) GetByCode(ctx context.Context, code string) (oauth2.TokenInfo, error) {
+func (ts *tokenStore) GetByCode(ctx context.Context, code string) (oauth2.TokenInfo, error) {
if code == "" {
return nil, nil
}
- pgt := &Token{
+ dbt := &Token{
Code: code,
}
- if err := pts.db.GetWhere([]db.Where{{Key: "code", Value: code}}, pgt); err != nil {
+ if err := ts.db.GetWhere(ctx, []db.Where{{Key: "code", Value: code}}, dbt); err != nil {
return nil, err
}
- return TokenToOauthToken(pgt), nil
+ return DBTokenToToken(dbt), nil
}
// GetByAccess selects a token from the DB based on the Access field
-func (pts *tokenStore) GetByAccess(ctx context.Context, access string) (oauth2.TokenInfo, error) {
+func (ts *tokenStore) GetByAccess(ctx context.Context, access string) (oauth2.TokenInfo, error) {
if access == "" {
return nil, nil
}
- pgt := &Token{
+ dbt := &Token{
Access: access,
}
- if err := pts.db.GetWhere([]db.Where{{Key: "access", Value: access}}, pgt); err != nil {
+ if err := ts.db.GetWhere(ctx, []db.Where{{Key: "access", Value: access}}, dbt); err != nil {
return nil, err
}
- return TokenToOauthToken(pgt), nil
+ return DBTokenToToken(dbt), nil
}
// GetByRefresh selects a token from the DB based on the Refresh field
-func (pts *tokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2.TokenInfo, error) {
+func (ts *tokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2.TokenInfo, error) {
if refresh == "" {
return nil, nil
}
- pgt := &Token{
+ dbt := &Token{
Refresh: refresh,
}
- if err := pts.db.GetWhere([]db.Where{{Key: "refresh", Value: refresh}}, pgt); err != nil {
+ if err := ts.db.GetWhere(ctx, []db.Where{{Key: "refresh", Value: refresh}}, dbt); err != nil {
return nil, err
}
- return TokenToOauthToken(pgt), nil
+ return DBTokenToToken(dbt), nil
}
/*
- The following models are basically helpers for the postgres token store implementation, they should only be used internally.
+ The following models are basically helpers for the token store implementation, they should only be used internally.
*/
// Token is a translation of the gotosocial token with the ExpiresIn fields replaced with ExpiresAt.
//
// Explanation for this: gotosocial assumes an in-memory or file database of some kind, where a time-to-live parameter (TTL) can be defined,
-// and tokens with expired TTLs are automatically removed. Since Postgres doesn't have that feature, it's easier to set an expiry time and
+// and tokens with expired TTLs are automatically removed. Since some databases don't have that feature, it's easier to set an expiry time and
// then periodically sweep out tokens when that time has passed.
//
// Note that this struct does *not* satisfy the token interface shown here: https://github.com/superseriousbusiness/oauth2/blob/master/model.go#L22
@@ -187,26 +187,26 @@ func (pts *tokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2
// As such, manual translation is always required between Token and the gotosocial *model.Token. The helper functions oauthTokenToPGToken
// and pgTokenToOauthToken can be used for that.
type Token struct {
- ID string `pg:"type:CHAR(26),pk,notnull"`
+ ID string `bun:"type:CHAR(26),pk,notnull"`
ClientID string
UserID string
RedirectURI string
Scope string
- Code string `pg:"default:'',pk"`
+ Code string `bun:"default:'',pk"`
CodeChallenge string
CodeChallengeMethod string
- CodeCreateAt time.Time `pg:"type:timestamp"`
- CodeExpiresAt time.Time `pg:"type:timestamp"`
- Access string `pg:"default:'',pk"`
- AccessCreateAt time.Time `pg:"type:timestamp"`
- AccessExpiresAt time.Time `pg:"type:timestamp"`
- Refresh string `pg:"default:'',pk"`
- RefreshCreateAt time.Time `pg:"type:timestamp"`
- RefreshExpiresAt time.Time `pg:"type:timestamp"`
+ CodeCreateAt time.Time `bun:",nullzero"`
+ CodeExpiresAt time.Time `bun:",nullzero"`
+ Access string `bun:"default:'',pk"`
+ AccessCreateAt time.Time `bun:",nullzero"`
+ AccessExpiresAt time.Time `bun:",nullzero"`
+ Refresh string `bun:"default:'',pk"`
+ RefreshCreateAt time.Time `bun:",nullzero"`
+ RefreshExpiresAt time.Time `bun:",nullzero"`
}
-// TokenToPGToken is a lil util function that takes a gotosocial token and gives back a token for inserting into postgres
-func TokenToPGToken(tkn *models.Token) *Token {
+// TokenToDBToken is a lil util function that takes a gotosocial token and gives back a token for inserting into a database.
+func TokenToDBToken(tkn *models.Token) *Token {
now := time.Now()
// For the following, we want to make sure we're not adding a time.Now() to an *empty* ExpiresIn, otherwise that's
@@ -247,40 +247,40 @@ func TokenToPGToken(tkn *models.Token) *Token {
}
}
-// TokenToOauthToken is a lil util function that takes a postgres token and gives back a gotosocial token
-func TokenToOauthToken(pgt *Token) *models.Token {
+// DBTokenToToken is a lil util function that takes a database token and gives back a gotosocial token
+func DBTokenToToken(dbt *Token) *models.Token {
now := time.Now()
var codeExpiresIn time.Duration
- if !pgt.CodeExpiresAt.IsZero() {
- codeExpiresIn = pgt.CodeExpiresAt.Sub(now)
+ if !dbt.CodeExpiresAt.IsZero() {
+ codeExpiresIn = dbt.CodeExpiresAt.Sub(now)
}
var accessExpiresIn time.Duration
- if !pgt.AccessExpiresAt.IsZero() {
- accessExpiresIn = pgt.AccessExpiresAt.Sub(now)
+ if !dbt.AccessExpiresAt.IsZero() {
+ accessExpiresIn = dbt.AccessExpiresAt.Sub(now)
}
var refreshExpiresIn time.Duration
- if !pgt.RefreshExpiresAt.IsZero() {
- refreshExpiresIn = pgt.RefreshExpiresAt.Sub(now)
+ if !dbt.RefreshExpiresAt.IsZero() {
+ refreshExpiresIn = dbt.RefreshExpiresAt.Sub(now)
}
return &models.Token{
- ClientID: pgt.ClientID,
- UserID: pgt.UserID,
- RedirectURI: pgt.RedirectURI,
- Scope: pgt.Scope,
- Code: pgt.Code,
- CodeChallenge: pgt.CodeChallenge,
- CodeChallengeMethod: pgt.CodeChallengeMethod,
- CodeCreateAt: pgt.CodeCreateAt,
+ ClientID: dbt.ClientID,
+ UserID: dbt.UserID,
+ RedirectURI: dbt.RedirectURI,
+ Scope: dbt.Scope,
+ Code: dbt.Code,
+ CodeChallenge: dbt.CodeChallenge,
+ CodeChallengeMethod: dbt.CodeChallengeMethod,
+ CodeCreateAt: dbt.CodeCreateAt,
CodeExpiresIn: codeExpiresIn,
- Access: pgt.Access,
- AccessCreateAt: pgt.AccessCreateAt,
+ Access: dbt.Access,
+ AccessCreateAt: dbt.AccessCreateAt,
AccessExpiresIn: accessExpiresIn,
- Refresh: pgt.Refresh,
- RefreshCreateAt: pgt.RefreshCreateAt,
+ Refresh: dbt.Refresh,
+ RefreshCreateAt: dbt.RefreshCreateAt,
RefreshExpiresIn: refreshExpiresIn,
}
}
diff --git a/internal/processing/account.go b/internal/processing/account.go
index f722c88eb..94ba596ac 100644
--- a/internal/processing/account.go
+++ b/internal/processing/account.go
@@ -19,51 +19,53 @@
package processing
import (
+ "context"
+
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
)
-func (p *processor) AccountCreate(authed *oauth.Auth, form *apimodel.AccountCreateRequest) (*apimodel.Token, error) {
- return p.accountProcessor.Create(authed.Token, authed.Application, form)
+func (p *processor) AccountCreate(ctx context.Context, authed *oauth.Auth, form *apimodel.AccountCreateRequest) (*apimodel.Token, error) {
+ return p.accountProcessor.Create(ctx, authed.Token, authed.Application, form)
}
-func (p *processor) AccountGet(authed *oauth.Auth, targetAccountID string) (*apimodel.Account, error) {
- return p.accountProcessor.Get(authed.Account, targetAccountID)
+func (p *processor) AccountGet(ctx context.Context, authed *oauth.Auth, targetAccountID string) (*apimodel.Account, error) {
+ return p.accountProcessor.Get(ctx, authed.Account, targetAccountID)
}
-func (p *processor) AccountUpdate(authed *oauth.Auth, form *apimodel.UpdateCredentialsRequest) (*apimodel.Account, error) {
- return p.accountProcessor.Update(authed.Account, form)
+func (p *processor) AccountUpdate(ctx context.Context, authed *oauth.Auth, form *apimodel.UpdateCredentialsRequest) (*apimodel.Account, error) {
+ return p.accountProcessor.Update(ctx, authed.Account, form)
}
-func (p *processor) AccountStatusesGet(authed *oauth.Auth, targetAccountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]apimodel.Status, gtserror.WithCode) {
- return p.accountProcessor.StatusesGet(authed.Account, targetAccountID, limit, excludeReplies, maxID, pinnedOnly, mediaOnly)
+func (p *processor) AccountStatusesGet(ctx context.Context, authed *oauth.Auth, targetAccountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]apimodel.Status, gtserror.WithCode) {
+ return p.accountProcessor.StatusesGet(ctx, authed.Account, targetAccountID, limit, excludeReplies, maxID, pinnedOnly, mediaOnly)
}
-func (p *processor) AccountFollowersGet(authed *oauth.Auth, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) {
- return p.accountProcessor.FollowersGet(authed.Account, targetAccountID)
+func (p *processor) AccountFollowersGet(ctx context.Context, authed *oauth.Auth, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) {
+ return p.accountProcessor.FollowersGet(ctx, authed.Account, targetAccountID)
}
-func (p *processor) AccountFollowingGet(authed *oauth.Auth, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) {
- return p.accountProcessor.FollowingGet(authed.Account, targetAccountID)
+func (p *processor) AccountFollowingGet(ctx context.Context, authed *oauth.Auth, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) {
+ return p.accountProcessor.FollowingGet(ctx, authed.Account, targetAccountID)
}
-func (p *processor) AccountRelationshipGet(authed *oauth.Auth, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) {
- return p.accountProcessor.RelationshipGet(authed.Account, targetAccountID)
+func (p *processor) AccountRelationshipGet(ctx context.Context, authed *oauth.Auth, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) {
+ return p.accountProcessor.RelationshipGet(ctx, authed.Account, targetAccountID)
}
-func (p *processor) AccountFollowCreate(authed *oauth.Auth, form *apimodel.AccountFollowRequest) (*apimodel.Relationship, gtserror.WithCode) {
- return p.accountProcessor.FollowCreate(authed.Account, form)
+func (p *processor) AccountFollowCreate(ctx context.Context, authed *oauth.Auth, form *apimodel.AccountFollowRequest) (*apimodel.Relationship, gtserror.WithCode) {
+ return p.accountProcessor.FollowCreate(ctx, authed.Account, form)
}
-func (p *processor) AccountFollowRemove(authed *oauth.Auth, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) {
- return p.accountProcessor.FollowRemove(authed.Account, targetAccountID)
+func (p *processor) AccountFollowRemove(ctx context.Context, authed *oauth.Auth, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) {
+ return p.accountProcessor.FollowRemove(ctx, authed.Account, targetAccountID)
}
-func (p *processor) AccountBlockCreate(authed *oauth.Auth, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) {
- return p.accountProcessor.BlockCreate(authed.Account, targetAccountID)
+func (p *processor) AccountBlockCreate(ctx context.Context, authed *oauth.Auth, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) {
+ return p.accountProcessor.BlockCreate(ctx, authed.Account, targetAccountID)
}
-func (p *processor) AccountBlockRemove(authed *oauth.Auth, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) {
- return p.accountProcessor.BlockRemove(authed.Account, targetAccountID)
+func (p *processor) AccountBlockRemove(ctx context.Context, authed *oauth.Auth, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) {
+ return p.accountProcessor.BlockRemove(ctx, authed.Account, targetAccountID)
}
diff --git a/internal/processing/account/account.go b/internal/processing/account/account.go
index 7b8910149..81701fd7c 100644
--- a/internal/processing/account/account.go
+++ b/internal/processing/account/account.go
@@ -19,6 +19,7 @@
package account
import (
+ "context"
"mime/multipart"
"github.com/sirupsen/logrus"
@@ -38,40 +39,40 @@ import (
// Processor wraps a bunch of functions for processing account actions.
type Processor interface {
// Create processes the given form for creating a new account, returning an oauth token for that account if successful.
- Create(applicationToken oauth2.TokenInfo, application *gtsmodel.Application, form *apimodel.AccountCreateRequest) (*apimodel.Token, error)
+ Create(ctx context.Context, applicationToken oauth2.TokenInfo, application *gtsmodel.Application, form *apimodel.AccountCreateRequest) (*apimodel.Token, error)
// Delete deletes an account, and all of that account's statuses, media, follows, notifications, etc etc etc.
// The origin passed here should be either the ID of the account doing the delete (can be itself), or the ID of a domain block.
- Delete(account *gtsmodel.Account, origin string) error
+ Delete(ctx context.Context, account *gtsmodel.Account, origin string) error
// Get processes the given request for account information.
- Get(requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Account, error)
+ Get(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Account, error)
// Update processes the update of an account with the given form
- Update(account *gtsmodel.Account, form *apimodel.UpdateCredentialsRequest) (*apimodel.Account, error)
+ Update(ctx context.Context, account *gtsmodel.Account, form *apimodel.UpdateCredentialsRequest) (*apimodel.Account, error)
// StatusesGet fetches a number of statuses (in time descending order) from the given account, filtered by visibility for
// the account given in authed.
- StatusesGet(requestingAccount *gtsmodel.Account, targetAccountID string, limit int, excludeReplies bool, maxID string, pinned bool, mediaOnly bool) ([]apimodel.Status, gtserror.WithCode)
+ StatusesGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string, limit int, excludeReplies bool, maxID string, pinned bool, mediaOnly bool) ([]apimodel.Status, gtserror.WithCode)
// FollowersGet fetches a list of the target account's followers.
- FollowersGet(requestingAccount *gtsmodel.Account, targetAccountID string) ([]apimodel.Account, gtserror.WithCode)
+ FollowersGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) ([]apimodel.Account, gtserror.WithCode)
// FollowingGet fetches a list of the accounts that target account is following.
- FollowingGet(requestingAccount *gtsmodel.Account, targetAccountID string) ([]apimodel.Account, gtserror.WithCode)
+ FollowingGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) ([]apimodel.Account, gtserror.WithCode)
// RelationshipGet returns a relationship model describing the relationship of the targetAccount to the Authed account.
- RelationshipGet(requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode)
+ RelationshipGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode)
// FollowCreate handles a follow request to an account, either remote or local.
- FollowCreate(requestingAccount *gtsmodel.Account, form *apimodel.AccountFollowRequest) (*apimodel.Relationship, gtserror.WithCode)
+ FollowCreate(ctx context.Context, requestingAccount *gtsmodel.Account, form *apimodel.AccountFollowRequest) (*apimodel.Relationship, gtserror.WithCode)
// FollowRemove handles the removal of a follow/follow request to an account, either remote or local.
- FollowRemove(requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode)
+ FollowRemove(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode)
// BlockCreate handles the creation of a block from requestingAccount to targetAccountID, either remote or local.
- BlockCreate(requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode)
+ BlockCreate(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode)
// BlockRemove handles the removal of a block from requestingAccount to targetAccountID, either remote or local.
- BlockRemove(requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode)
+ BlockRemove(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode)
// UpdateHeader does the dirty work of checking the header part of an account update form,
// parsing and checking the image, and doing the necessary updates in the database for this to become
// the account's new header image.
- UpdateAvatar(avatar *multipart.FileHeader, accountID string) (*gtsmodel.MediaAttachment, error)
+ UpdateAvatar(ctx context.Context, avatar *multipart.FileHeader, accountID string) (*gtsmodel.MediaAttachment, error)
// UpdateAvatar does the dirty work of checking the avatar part of an account update form,
// parsing and checking the image, and doing the necessary updates in the database for this to become
// the account's new avatar image.
- UpdateHeader(header *multipart.FileHeader, accountID string) (*gtsmodel.MediaAttachment, error)
+ UpdateHeader(ctx context.Context, header *multipart.FileHeader, accountID string) (*gtsmodel.MediaAttachment, error)
}
type processor struct {
diff --git a/internal/processing/account/create.go b/internal/processing/account/create.go
index 83e76973d..37c742b45 100644
--- a/internal/processing/account/create.go
+++ b/internal/processing/account/create.go
@@ -19,6 +19,7 @@
package account
import (
+ "context"
"fmt"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
@@ -27,16 +28,24 @@ import (
"github.com/superseriousbusiness/oauth2/v4"
)
-func (p *processor) Create(applicationToken oauth2.TokenInfo, application *gtsmodel.Application, form *apimodel.AccountCreateRequest) (*apimodel.Token, error) {
+func (p *processor) Create(ctx context.Context, applicationToken oauth2.TokenInfo, application *gtsmodel.Application, form *apimodel.AccountCreateRequest) (*apimodel.Token, error) {
l := p.log.WithField("func", "accountCreate")
- if err := p.db.IsEmailAvailable(form.Email); err != nil {
+ emailAvailable, err := p.db.IsEmailAvailable(ctx, form.Email)
+ if err != nil {
return nil, err
}
+ if !emailAvailable {
+ return nil, fmt.Errorf("email address %s in use", form.Email)
+ }
- if err := p.db.IsUsernameAvailable(form.Username); err != nil {
+ usernameAvailable, err := p.db.IsUsernameAvailable(ctx, form.Username)
+ if err != nil {
return nil, err
}
+ if !usernameAvailable {
+ return nil, fmt.Errorf("username %s in use", form.Username)
+ }
// don't store a reason if we don't require one
reason := form.Reason
@@ -45,7 +54,7 @@ func (p *processor) Create(applicationToken oauth2.TokenInfo, application *gtsmo
}
l.Trace("creating new username and account")
- user, err := p.db.NewSignup(form.Username, text.RemoveHTML(reason), p.config.AccountsConfig.RequireApproval, form.Email, form.Password, form.IP, form.Locale, application.ID, false, false)
+ user, err := p.db.NewSignup(ctx, form.Username, text.RemoveHTML(reason), p.config.AccountsConfig.RequireApproval, form.Email, form.Password, form.IP, form.Locale, application.ID, false, false)
if err != nil {
return nil, fmt.Errorf("error creating new signup in the database: %s", err)
}
diff --git a/internal/processing/account/createblock.go b/internal/processing/account/createblock.go
index f10a2efa3..06f82b37d 100644
--- a/internal/processing/account/createblock.go
+++ b/internal/processing/account/createblock.go
@@ -19,6 +19,7 @@
package account
import (
+ "context"
"fmt"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
@@ -29,18 +30,18 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/util"
)
-func (p *processor) BlockCreate(requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) {
+func (p *processor) BlockCreate(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) {
// make sure the target account actually exists in our db
- targetAccount, err := p.db.GetAccountByID(targetAccountID)
+ targetAccount, err := p.db.GetAccountByID(ctx, targetAccountID)
if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("BlockCreate: error getting account %s from the db: %s", targetAccountID, err))
}
// if requestingAccount already blocks target account, we don't need to do anything
- if blocked, err := p.db.IsBlocked(requestingAccount.ID, targetAccountID, false); err != nil {
+ if blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, targetAccountID, false); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("BlockCreate: error checking existence of block: %s", err))
} else if blocked {
- return p.RelationshipGet(requestingAccount, targetAccountID)
+ return p.RelationshipGet(ctx, requestingAccount, targetAccountID)
}
// make the block
@@ -57,18 +58,18 @@ func (p *processor) BlockCreate(requestingAccount *gtsmodel.Account, targetAccou
block.URI = util.GenerateURIForBlock(requestingAccount.Username, p.config.Protocol, p.config.Host, newBlockID)
// whack it in the database
- if err := p.db.Put(block); err != nil {
+ if err := p.db.Put(ctx, block); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("BlockCreate: error creating block in db: %s", err))
}
// clear any follows or follow requests from the blocked account to the target account -- this is a simple delete
- if err := p.db.DeleteWhere([]db.Where{
+ if err := p.db.DeleteWhere(ctx, []db.Where{
{Key: "account_id", Value: targetAccountID},
{Key: "target_account_id", Value: requestingAccount.ID},
}, &gtsmodel.Follow{}); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("BlockCreate: error removing follow in db: %s", err))
}
- if err := p.db.DeleteWhere([]db.Where{
+ if err := p.db.DeleteWhere(ctx, []db.Where{
{Key: "account_id", Value: targetAccountID},
{Key: "target_account_id", Value: requestingAccount.ID},
}, &gtsmodel.FollowRequest{}); err != nil {
@@ -82,12 +83,12 @@ func (p *processor) BlockCreate(requestingAccount *gtsmodel.Account, targetAccou
var frChanged bool
var frURI string
fr := &gtsmodel.FollowRequest{}
- if err := p.db.GetWhere([]db.Where{
+ if err := p.db.GetWhere(ctx, []db.Where{
{Key: "account_id", Value: requestingAccount.ID},
{Key: "target_account_id", Value: targetAccountID},
}, fr); err == nil {
frURI = fr.URI
- if err := p.db.DeleteByID(fr.ID, fr); err != nil {
+ if err := p.db.DeleteByID(ctx, fr.ID, fr); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("BlockCreate: error removing follow request from db: %s", err))
}
frChanged = true
@@ -97,12 +98,12 @@ func (p *processor) BlockCreate(requestingAccount *gtsmodel.Account, targetAccou
var fChanged bool
var fURI string
f := &gtsmodel.Follow{}
- if err := p.db.GetWhere([]db.Where{
+ if err := p.db.GetWhere(ctx, []db.Where{
{Key: "account_id", Value: requestingAccount.ID},
{Key: "target_account_id", Value: targetAccountID},
}, f); err == nil {
fURI = f.URI
- if err := p.db.DeleteByID(f.ID, f); err != nil {
+ if err := p.db.DeleteByID(ctx, f.ID, f); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("BlockCreate: error removing follow from db: %s", err))
}
fChanged = true
@@ -147,5 +148,5 @@ func (p *processor) BlockCreate(requestingAccount *gtsmodel.Account, targetAccou
TargetAccount: targetAccount,
}
- return p.RelationshipGet(requestingAccount, targetAccountID)
+ return p.RelationshipGet(ctx, requestingAccount, targetAccountID)
}
diff --git a/internal/processing/account/createfollow.go b/internal/processing/account/createfollow.go
index 8c856a50e..a7767afea 100644
--- a/internal/processing/account/createfollow.go
+++ b/internal/processing/account/createfollow.go
@@ -19,6 +19,7 @@
package account
import (
+ "context"
"fmt"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
@@ -29,16 +30,16 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/util"
)
-func (p *processor) FollowCreate(requestingAccount *gtsmodel.Account, form *apimodel.AccountFollowRequest) (*apimodel.Relationship, gtserror.WithCode) {
+func (p *processor) FollowCreate(ctx context.Context, requestingAccount *gtsmodel.Account, form *apimodel.AccountFollowRequest) (*apimodel.Relationship, gtserror.WithCode) {
// if there's a block between the accounts we shouldn't create the request ofc
- if blocked, err := p.db.IsBlocked(requestingAccount.ID, form.ID, true); err != nil {
+ if blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, form.ID, true); err != nil {
return nil, gtserror.NewErrorInternalError(err)
} else if blocked {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("block exists between accounts"))
}
// make sure the target account actually exists in our db
- targetAcct, err := p.db.GetAccountByID(form.ID)
+ targetAcct, err := p.db.GetAccountByID(ctx, form.ID)
if err != nil {
if err == db.ErrNoEntries {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("accountfollowcreate: account %s not found in the db: %s", form.ID, err))
@@ -47,19 +48,19 @@ func (p *processor) FollowCreate(requestingAccount *gtsmodel.Account, form *apim
}
// check if a follow exists already
- if follows, err := p.db.IsFollowing(requestingAccount, targetAcct); err != nil {
+ if follows, err := p.db.IsFollowing(ctx, requestingAccount, targetAcct); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("accountfollowcreate: error checking follow in db: %s", err))
} else if follows {
// already follows so just return the relationship
- return p.RelationshipGet(requestingAccount, form.ID)
+ return p.RelationshipGet(ctx, requestingAccount, form.ID)
}
// check if a follow request exists already
- if followRequested, err := p.db.IsFollowRequested(requestingAccount, targetAcct); err != nil {
+ if followRequested, err := p.db.IsFollowRequested(ctx, requestingAccount, targetAcct); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("accountfollowcreate: error checking follow request in db: %s", err))
} else if followRequested {
// already follow requested so just return the relationship
- return p.RelationshipGet(requestingAccount, form.ID)
+ return p.RelationshipGet(ctx, requestingAccount, form.ID)
}
// make the follow request
@@ -84,17 +85,17 @@ func (p *processor) FollowCreate(requestingAccount *gtsmodel.Account, form *apim
}
// whack it in the database
- if err := p.db.Put(fr); err != nil {
+ if err := p.db.Put(ctx, fr); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("accountfollowcreate: error creating follow request in db: %s", err))
}
// if it's a local account that's not locked we can just straight up accept the follow request
if !targetAcct.Locked && targetAcct.Domain == "" {
- if _, err := p.db.AcceptFollowRequest(requestingAccount.ID, form.ID); err != nil {
+ if _, err := p.db.AcceptFollowRequest(ctx, requestingAccount.ID, form.ID); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("accountfollowcreate: error accepting folow request for local unlocked account: %s", err))
}
// return the new relationship
- return p.RelationshipGet(requestingAccount, form.ID)
+ return p.RelationshipGet(ctx, requestingAccount, form.ID)
}
// otherwise we leave the follow request as it is and we handle the rest of the process asynchronously
@@ -107,5 +108,5 @@ func (p *processor) FollowCreate(requestingAccount *gtsmodel.Account, form *apim
}
// return whatever relationship results from this
- return p.RelationshipGet(requestingAccount, form.ID)
+ return p.RelationshipGet(ctx, requestingAccount, form.ID)
}
diff --git a/internal/processing/account/delete.go b/internal/processing/account/delete.go
index e8840abae..d97af4d2e 100644
--- a/internal/processing/account/delete.go
+++ b/internal/processing/account/delete.go
@@ -19,6 +19,7 @@
package account
import (
+ "context"
"time"
"github.com/sirupsen/logrus"
@@ -48,7 +49,7 @@ import (
// 16. Delete account's user
// 17. Delete account's timeline
// 18. Delete account itself
-func (p *processor) Delete(account *gtsmodel.Account, origin string) error {
+func (p *processor) Delete(ctx context.Context, account *gtsmodel.Account, origin string) error {
l := p.log.WithFields(logrus.Fields{
"func": "Delete",
"username": account.Username,
@@ -61,22 +62,22 @@ func (p *processor) Delete(account *gtsmodel.Account, origin string) error {
if account.Domain == "" {
// see if we can get a user for this account
u := &gtsmodel.User{}
- if err := p.db.GetWhere([]db.Where{{Key: "account_id", Value: account.ID}}, u); err == nil {
+ if err := p.db.GetWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, u); err == nil {
// we got one! select all tokens with the user's ID
tokens := []*oauth.Token{}
- if err := p.db.GetWhere([]db.Where{{Key: "user_id", Value: u.ID}}, &tokens); err == nil {
+ if err := p.db.GetWhere(ctx, []db.Where{{Key: "user_id", Value: u.ID}}, &tokens); err == nil {
// we have some tokens to delete
for _, t := range tokens {
// delete client(s) associated with this token
- if err := p.db.DeleteByID(t.ClientID, &oauth.Client{}); err != nil {
+ if err := p.db.DeleteByID(ctx, t.ClientID, &oauth.Client{}); err != nil {
l.Errorf("error deleting oauth client: %s", err)
}
// delete application(s) associated with this token
- if err := p.db.DeleteWhere([]db.Where{{Key: "client_id", Value: t.ClientID}}, &gtsmodel.Application{}); err != nil {
+ if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "client_id", Value: t.ClientID}}, &gtsmodel.Application{}); err != nil {
l.Errorf("error deleting application: %s", err)
}
// delete the token itself
- if err := p.db.DeleteByID(t.ID, t); err != nil {
+ if err := p.db.DeleteByID(ctx, t.ID, t); err != nil {
l.Errorf("error deleting oauth token: %s", err)
}
}
@@ -87,12 +88,12 @@ func (p *processor) Delete(account *gtsmodel.Account, origin string) error {
// 2. Delete account's blocks
l.Debug("deleting account blocks")
// first delete any blocks that this account created
- if err := p.db.DeleteWhere([]db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.Block{}); err != nil {
+ if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.Block{}); err != nil {
l.Errorf("error deleting blocks created by account: %s", err)
}
// now delete any blocks that target this account
- if err := p.db.DeleteWhere([]db.Where{{Key: "target_account_id", Value: account.ID}}, &[]*gtsmodel.Block{}); err != nil {
+ if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "target_account_id", Value: account.ID}}, &[]*gtsmodel.Block{}); err != nil {
l.Errorf("error deleting blocks targeting account: %s", err)
}
@@ -103,12 +104,12 @@ func (p *processor) Delete(account *gtsmodel.Account, origin string) error {
// TODO: federate these if necessary
l.Debug("deleting account follow requests")
// first delete any follow requests that this account created
- if err := p.db.DeleteWhere([]db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.FollowRequest{}); err != nil {
+ if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.FollowRequest{}); err != nil {
l.Errorf("error deleting follow requests created by account: %s", err)
}
// now delete any follow requests that target this account
- if err := p.db.DeleteWhere([]db.Where{{Key: "target_account_id", Value: account.ID}}, &[]*gtsmodel.FollowRequest{}); err != nil {
+ if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "target_account_id", Value: account.ID}}, &[]*gtsmodel.FollowRequest{}); err != nil {
l.Errorf("error deleting follow requests targeting account: %s", err)
}
@@ -116,12 +117,12 @@ func (p *processor) Delete(account *gtsmodel.Account, origin string) error {
// TODO: federate these if necessary
l.Debug("deleting account follows")
// first delete any follows that this account created
- if err := p.db.DeleteWhere([]db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.Follow{}); err != nil {
+ if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.Follow{}); err != nil {
l.Errorf("error deleting follows created by account: %s", err)
}
// now delete any follows that target this account
- if err := p.db.DeleteWhere([]db.Where{{Key: "target_account_id", Value: account.ID}}, &[]*gtsmodel.Follow{}); err != nil {
+ if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "target_account_id", Value: account.ID}}, &[]*gtsmodel.Follow{}); err != nil {
l.Errorf("error deleting follows targeting account: %s", err)
}
@@ -133,7 +134,7 @@ func (p *processor) Delete(account *gtsmodel.Account, origin string) error {
var maxID string
selectStatusesLoop:
for {
- statuses, err := p.db.GetAccountStatuses(account.ID, 20, false, maxID, false, false)
+ statuses, err := p.db.GetAccountStatuses(ctx, account.ID, 20, false, maxID, false, false)
if err != nil {
if err == db.ErrNoEntries {
// no statuses left for this instance so we're done
@@ -157,7 +158,7 @@ selectStatusesLoop:
TargetAccount: account,
}
- if err := p.db.DeleteByID(s.ID, s); err != nil {
+ if err := p.db.DeleteByID(ctx, s.ID, s); err != nil {
if err != db.ErrNoEntries {
// actual error has occurred
l.Errorf("Delete: db error status %s for account %s: %s", s.ID, account.Username, err)
@@ -167,7 +168,7 @@ selectStatusesLoop:
// if there are any boosts of this status, delete them as well
boosts := []*gtsmodel.Status{}
- if err := p.db.GetWhere([]db.Where{{Key: "boost_of_id", Value: s.ID}}, &boosts); err != nil {
+ if err := p.db.GetWhere(ctx, []db.Where{{Key: "boost_of_id", Value: s.ID}}, &boosts); err != nil {
if err != db.ErrNoEntries {
// an actual error has occurred
l.Errorf("Delete: db error selecting boosts of status %s for account %s: %s", s.ID, account.Username, err)
@@ -176,20 +177,24 @@ selectStatusesLoop:
}
for _, b := range boosts {
- oa := &gtsmodel.Account{}
- if err := p.db.GetByID(b.AccountID, oa); err == nil {
-
- l.Debug("putting boost undo in the client api channel")
- p.fromClientAPI <- gtsmodel.FromClientAPI{
- APObjectType: gtsmodel.ActivityStreamsAnnounce,
- APActivityType: gtsmodel.ActivityStreamsUndo,
- GTSModel: s,
- OriginAccount: oa,
- TargetAccount: account,
+ if b.Account == nil {
+ bAccount, err := p.db.GetAccountByID(ctx, b.AccountID)
+ if err != nil {
+ continue
}
+ b.Account = bAccount
}
- if err := p.db.DeleteByID(b.ID, b); err != nil {
+ l.Debug("putting boost undo in the client api channel")
+ p.fromClientAPI <- gtsmodel.FromClientAPI{
+ APObjectType: gtsmodel.ActivityStreamsAnnounce,
+ APActivityType: gtsmodel.ActivityStreamsUndo,
+ GTSModel: s,
+ OriginAccount: b.Account,
+ TargetAccount: account,
+ }
+
+ if err := p.db.DeleteByID(ctx, b.ID, b); err != nil {
if err != db.ErrNoEntries {
// actual error has occurred
l.Errorf("Delete: db error deleting boost with id %s: %s", b.ID, err)
@@ -208,26 +213,26 @@ selectStatusesLoop:
// 10. Delete account's notifications
l.Debug("deleting account notifications")
- if err := p.db.DeleteWhere([]db.Where{{Key: "origin_account_id", Value: account.ID}}, &[]*gtsmodel.Notification{}); err != nil {
+ if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "origin_account_id", Value: account.ID}}, &[]*gtsmodel.Notification{}); err != nil {
l.Errorf("error deleting notifications created by account: %s", err)
}
// 11. Delete account's bookmarks
l.Debug("deleting account bookmarks")
- if err := p.db.DeleteWhere([]db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.StatusBookmark{}); err != nil {
+ if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.StatusBookmark{}); err != nil {
l.Errorf("error deleting bookmarks created by account: %s", err)
}
// 12. Delete account's faves
// TODO: federate these if necessary
l.Debug("deleting account faves")
- if err := p.db.DeleteWhere([]db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.StatusFave{}); err != nil {
+ if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.StatusFave{}); err != nil {
l.Errorf("error deleting faves created by account: %s", err)
}
// 13. Delete account's mutes
l.Debug("deleting account mutes")
- if err := p.db.DeleteWhere([]db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.StatusMute{}); err != nil {
+ if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.StatusMute{}); err != nil {
l.Errorf("error deleting status mutes created by account: %s", err)
}
@@ -239,7 +244,7 @@ selectStatusesLoop:
// 16. Delete account's user
l.Debug("deleting account user")
- if err := p.db.DeleteWhere([]db.Where{{Key: "account_id", Value: account.ID}}, &gtsmodel.User{}); err != nil {
+ if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, &gtsmodel.User{}); err != nil {
return err
}
@@ -266,7 +271,8 @@ selectStatusesLoop:
account.SuspendedAt = time.Now()
account.SuspensionOrigin = origin
- if err := p.db.UpdateByID(account.ID, account); err != nil {
+ account, err := p.db.UpdateAccount(ctx, account)
+ if err != nil {
return err
}
diff --git a/internal/processing/account/get.go b/internal/processing/account/get.go
index 3dfc54b51..5f039127c 100644
--- a/internal/processing/account/get.go
+++ b/internal/processing/account/get.go
@@ -19,6 +19,7 @@
package account
import (
+ "context"
"errors"
"fmt"
@@ -27,9 +28,9 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
-func (p *processor) Get(requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Account, error) {
- targetAccount := &gtsmodel.Account{}
- if err := p.db.GetByID(targetAccountID, targetAccount); err != nil {
+func (p *processor) Get(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Account, error) {
+ targetAccount, err := p.db.GetAccountByID(ctx, targetAccountID)
+ if err != nil {
if err == db.ErrNoEntries {
return nil, errors.New("account not found")
}
@@ -37,9 +38,8 @@ func (p *processor) Get(requestingAccount *gtsmodel.Account, targetAccountID str
}
var blocked bool
- var err error
if requestingAccount != nil {
- blocked, err = p.db.IsBlocked(requestingAccount.ID, targetAccountID, true)
+ blocked, err = p.db.IsBlocked(ctx, requestingAccount.ID, targetAccountID, true)
if err != nil {
return nil, fmt.Errorf("error checking account block: %s", err)
}
@@ -47,7 +47,7 @@ func (p *processor) Get(requestingAccount *gtsmodel.Account, targetAccountID str
var mastoAccount *apimodel.Account
if blocked {
- mastoAccount, err = p.tc.AccountToMastoBlocked(targetAccount)
+ mastoAccount, err = p.tc.AccountToMastoBlocked(ctx, targetAccount)
if err != nil {
return nil, fmt.Errorf("error converting account: %s", err)
}
@@ -56,16 +56,16 @@ func (p *processor) Get(requestingAccount *gtsmodel.Account, targetAccountID str
// last-minute check to make sure we have remote account header/avi cached
if targetAccount.Domain != "" {
- a, err := p.federator.EnrichRemoteAccount(requestingAccount.Username, targetAccount)
+ a, err := p.federator.EnrichRemoteAccount(ctx, requestingAccount.Username, targetAccount)
if err == nil {
targetAccount = a
}
}
if requestingAccount != nil && targetAccount.ID == requestingAccount.ID {
- mastoAccount, err = p.tc.AccountToMastoSensitive(targetAccount)
+ mastoAccount, err = p.tc.AccountToMastoSensitive(ctx, targetAccount)
} else {
- mastoAccount, err = p.tc.AccountToMastoPublic(targetAccount)
+ mastoAccount, err = p.tc.AccountToMastoPublic(ctx, targetAccount)
}
if err != nil {
return nil, fmt.Errorf("error converting account: %s", err)
diff --git a/internal/processing/account/getfollowers.go b/internal/processing/account/getfollowers.go
index 4f66b40ee..517467085 100644
--- a/internal/processing/account/getfollowers.go
+++ b/internal/processing/account/getfollowers.go
@@ -19,6 +19,7 @@
package account
import (
+ "context"
"fmt"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
@@ -27,15 +28,15 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
-func (p *processor) FollowersGet(requestingAccount *gtsmodel.Account, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) {
- if blocked, err := p.db.IsBlocked(requestingAccount.ID, targetAccountID, true); err != nil {
+func (p *processor) FollowersGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) {
+ if blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, targetAccountID, true); err != nil {
return nil, gtserror.NewErrorInternalError(err)
} else if blocked {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("block exists between accounts"))
}
accounts := []apimodel.Account{}
- follows, err := p.db.GetAccountFollowedBy(targetAccountID, false)
+ follows, err := p.db.GetAccountFollowedBy(ctx, targetAccountID, false)
if err != nil {
if err == db.ErrNoEntries {
return accounts, nil
@@ -44,7 +45,7 @@ func (p *processor) FollowersGet(requestingAccount *gtsmodel.Account, targetAcco
}
for _, f := range follows {
- blocked, err := p.db.IsBlocked(requestingAccount.ID, f.AccountID, true)
+ blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, f.AccountID, true)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
@@ -53,7 +54,7 @@ func (p *processor) FollowersGet(requestingAccount *gtsmodel.Account, targetAcco
}
if f.Account == nil {
- a, err := p.db.GetAccountByID(f.AccountID)
+ a, err := p.db.GetAccountByID(ctx, f.AccountID)
if err != nil {
if err == db.ErrNoEntries {
continue
@@ -63,7 +64,7 @@ func (p *processor) FollowersGet(requestingAccount *gtsmodel.Account, targetAcco
f.Account = a
}
- account, err := p.tc.AccountToMastoPublic(f.Account)
+ account, err := p.tc.AccountToMastoPublic(ctx, f.Account)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
diff --git a/internal/processing/account/getfollowing.go b/internal/processing/account/getfollowing.go
index c7fb426f9..543213f90 100644
--- a/internal/processing/account/getfollowing.go
+++ b/internal/processing/account/getfollowing.go
@@ -19,6 +19,7 @@
package account
import (
+ "context"
"fmt"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
@@ -27,15 +28,15 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
-func (p *processor) FollowingGet(requestingAccount *gtsmodel.Account, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) {
- if blocked, err := p.db.IsBlocked(requestingAccount.ID, targetAccountID, true); err != nil {
+func (p *processor) FollowingGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) {
+ if blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, targetAccountID, true); err != nil {
return nil, gtserror.NewErrorInternalError(err)
} else if blocked {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("block exists between accounts"))
}
accounts := []apimodel.Account{}
- follows, err := p.db.GetAccountFollows(targetAccountID)
+ follows, err := p.db.GetAccountFollows(ctx, targetAccountID)
if err != nil {
if err == db.ErrNoEntries {
return accounts, nil
@@ -44,7 +45,7 @@ func (p *processor) FollowingGet(requestingAccount *gtsmodel.Account, targetAcco
}
for _, f := range follows {
- blocked, err := p.db.IsBlocked(requestingAccount.ID, f.AccountID, true)
+ blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, f.AccountID, true)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
@@ -53,7 +54,7 @@ func (p *processor) FollowingGet(requestingAccount *gtsmodel.Account, targetAcco
}
if f.TargetAccount == nil {
- a, err := p.db.GetAccountByID(f.TargetAccountID)
+ a, err := p.db.GetAccountByID(ctx, f.TargetAccountID)
if err != nil {
if err == db.ErrNoEntries {
continue
@@ -63,7 +64,7 @@ func (p *processor) FollowingGet(requestingAccount *gtsmodel.Account, targetAcco
f.TargetAccount = a
}
- account, err := p.tc.AccountToMastoPublic(f.TargetAccount)
+ account, err := p.tc.AccountToMastoPublic(ctx, f.TargetAccount)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
diff --git a/internal/processing/account/getrelationship.go b/internal/processing/account/getrelationship.go
index a0a93a4c2..ebfd9b479 100644
--- a/internal/processing/account/getrelationship.go
+++ b/internal/processing/account/getrelationship.go
@@ -19,6 +19,7 @@
package account
import (
+ "context"
"errors"
"fmt"
@@ -27,17 +28,17 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
-func (p *processor) RelationshipGet(requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) {
+func (p *processor) RelationshipGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) {
if requestingAccount == nil {
return nil, gtserror.NewErrorForbidden(errors.New("not authed"))
}
- gtsR, err := p.db.GetRelationship(requestingAccount.ID, targetAccountID)
+ gtsR, err := p.db.GetRelationship(ctx, requestingAccount.ID, targetAccountID)
if err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error getting relationship: %s", err))
}
- r, err := p.tc.RelationshipToMasto(gtsR)
+ r, err := p.tc.RelationshipToMasto(ctx, gtsR)
if err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error converting relationship: %s", err))
}
diff --git a/internal/processing/account/getstatuses.go b/internal/processing/account/getstatuses.go
index dc21e7006..dc157e43c 100644
--- a/internal/processing/account/getstatuses.go
+++ b/internal/processing/account/getstatuses.go
@@ -19,6 +19,7 @@
package account
import (
+ "context"
"fmt"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
@@ -27,8 +28,8 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
-func (p *processor) StatusesGet(requestingAccount *gtsmodel.Account, targetAccountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]apimodel.Status, gtserror.WithCode) {
- if blocked, err := p.db.IsBlocked(requestingAccount.ID, targetAccountID, true); err != nil {
+func (p *processor) StatusesGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]apimodel.Status, gtserror.WithCode) {
+ if blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, targetAccountID, true); err != nil {
return nil, gtserror.NewErrorInternalError(err)
} else if blocked {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("block exists between accounts"))
@@ -36,7 +37,7 @@ func (p *processor) StatusesGet(requestingAccount *gtsmodel.Account, targetAccou
apiStatuses := []apimodel.Status{}
- statuses, err := p.db.GetAccountStatuses(targetAccountID, limit, excludeReplies, maxID, pinnedOnly, mediaOnly)
+ statuses, err := p.db.GetAccountStatuses(ctx, targetAccountID, limit, excludeReplies, maxID, pinnedOnly, mediaOnly)
if err != nil {
if err == db.ErrNoEntries {
return apiStatuses, nil
@@ -45,12 +46,12 @@ func (p *processor) StatusesGet(requestingAccount *gtsmodel.Account, targetAccou
}
for _, s := range statuses {
- visible, err := p.filter.StatusVisible(s, requestingAccount)
+ visible, err := p.filter.StatusVisible(ctx, s, requestingAccount)
if err != nil || !visible {
continue
}
- apiStatus, err := p.tc.StatusToMasto(s, requestingAccount)
+ apiStatus, err := p.tc.StatusToMasto(ctx, s, requestingAccount)
if err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error converting status to masto: %s", err))
}
diff --git a/internal/processing/account/removeblock.go b/internal/processing/account/removeblock.go
index 7c1f2bc17..7e3d78076 100644
--- a/internal/processing/account/removeblock.go
+++ b/internal/processing/account/removeblock.go
@@ -19,6 +19,7 @@
package account
import (
+ "context"
"fmt"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
@@ -27,9 +28,9 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
-func (p *processor) BlockRemove(requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) {
+func (p *processor) BlockRemove(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) {
// make sure the target account actually exists in our db
- targetAccount, err := p.db.GetAccountByID(targetAccountID)
+ targetAccount, err := p.db.GetAccountByID(ctx, targetAccountID)
if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("BlockCreate: error getting account %s from the db: %s", targetAccountID, err))
}
@@ -37,13 +38,13 @@ func (p *processor) BlockRemove(requestingAccount *gtsmodel.Account, targetAccou
// check if a block exists, and remove it if it does (storing the URI for later)
var blockChanged bool
block := &gtsmodel.Block{}
- if err := p.db.GetWhere([]db.Where{
+ if err := p.db.GetWhere(ctx, []db.Where{
{Key: "account_id", Value: requestingAccount.ID},
{Key: "target_account_id", Value: targetAccountID},
}, block); err == nil {
block.Account = requestingAccount
block.TargetAccount = targetAccount
- if err := p.db.DeleteByID(block.ID, &gtsmodel.Block{}); err != nil {
+ if err := p.db.DeleteByID(ctx, block.ID, &gtsmodel.Block{}); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("BlockRemove: error removing block from db: %s", err))
}
blockChanged = true
@@ -61,5 +62,5 @@ func (p *processor) BlockRemove(requestingAccount *gtsmodel.Account, targetAccou
}
// return whatever relationship results from all this
- return p.RelationshipGet(requestingAccount, targetAccountID)
+ return p.RelationshipGet(ctx, requestingAccount, targetAccountID)
}
diff --git a/internal/processing/account/removefollow.go b/internal/processing/account/removefollow.go
index 6646d694e..6186c550f 100644
--- a/internal/processing/account/removefollow.go
+++ b/internal/processing/account/removefollow.go
@@ -19,6 +19,7 @@
package account
import (
+ "context"
"fmt"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
@@ -27,9 +28,9 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
-func (p *processor) FollowRemove(requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) {
+func (p *processor) FollowRemove(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) {
// if there's a block between the accounts we shouldn't do anything
- blocked, err := p.db.IsBlocked(requestingAccount.ID, targetAccountID, true)
+ blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, targetAccountID, true)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
@@ -38,8 +39,8 @@ func (p *processor) FollowRemove(requestingAccount *gtsmodel.Account, targetAcco
}
// make sure the target account actually exists in our db
- targetAcct := &gtsmodel.Account{}
- if err := p.db.GetByID(targetAccountID, targetAcct); err != nil {
+ targetAcct, err := p.db.GetAccountByID(ctx, targetAccountID)
+ if err != nil {
if err == db.ErrNoEntries {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("AccountFollowRemove: account %s not found in the db: %s", targetAccountID, err))
}
@@ -49,12 +50,12 @@ func (p *processor) FollowRemove(requestingAccount *gtsmodel.Account, targetAcco
var frChanged bool
var frURI string
fr := &gtsmodel.FollowRequest{}
- if err := p.db.GetWhere([]db.Where{
+ if err := p.db.GetWhere(ctx, []db.Where{
{Key: "account_id", Value: requestingAccount.ID},
{Key: "target_account_id", Value: targetAccountID},
}, fr); err == nil {
frURI = fr.URI
- if err := p.db.DeleteByID(fr.ID, fr); err != nil {
+ if err := p.db.DeleteByID(ctx, fr.ID, fr); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("AccountFollowRemove: error removing follow request from db: %s", err))
}
frChanged = true
@@ -64,12 +65,12 @@ func (p *processor) FollowRemove(requestingAccount *gtsmodel.Account, targetAcco
var fChanged bool
var fURI string
f := &gtsmodel.Follow{}
- if err := p.db.GetWhere([]db.Where{
+ if err := p.db.GetWhere(ctx, []db.Where{
{Key: "account_id", Value: requestingAccount.ID},
{Key: "target_account_id", Value: targetAccountID},
}, f); err == nil {
fURI = f.URI
- if err := p.db.DeleteByID(f.ID, f); err != nil {
+ if err := p.db.DeleteByID(ctx, f.ID, f); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("AccountFollowRemove: error removing follow from db: %s", err))
}
fChanged = true
@@ -106,5 +107,5 @@ func (p *processor) FollowRemove(requestingAccount *gtsmodel.Account, targetAcco
}
// return whatever relationship results from all this
- return p.RelationshipGet(requestingAccount, targetAccountID)
+ return p.RelationshipGet(ctx, requestingAccount, targetAccountID)
}
diff --git a/internal/processing/account/update.go b/internal/processing/account/update.go
index df842bacd..99ccbf5a0 100644
--- a/internal/processing/account/update.go
+++ b/internal/processing/account/update.go
@@ -20,6 +20,7 @@ package account
import (
"bytes"
+ "context"
"errors"
"fmt"
"io"
@@ -32,17 +33,17 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/util"
)
-func (p *processor) Update(account *gtsmodel.Account, form *apimodel.UpdateCredentialsRequest) (*apimodel.Account, error) {
+func (p *processor) Update(ctx context.Context, account *gtsmodel.Account, form *apimodel.UpdateCredentialsRequest) (*apimodel.Account, error) {
l := p.log.WithField("func", "AccountUpdate")
if form.Discoverable != nil {
- if err := p.db.UpdateOneByID(account.ID, "discoverable", *form.Discoverable, &gtsmodel.Account{}); err != nil {
+ if err := p.db.UpdateOneByID(ctx, account.ID, "discoverable", *form.Discoverable, &gtsmodel.Account{}); err != nil {
return nil, fmt.Errorf("error updating discoverable: %s", err)
}
}
if form.Bot != nil {
- if err := p.db.UpdateOneByID(account.ID, "bot", *form.Bot, &gtsmodel.Account{}); err != nil {
+ if err := p.db.UpdateOneByID(ctx, account.ID, "bot", *form.Bot, &gtsmodel.Account{}); err != nil {
return nil, fmt.Errorf("error updating bot: %s", err)
}
}
@@ -52,7 +53,7 @@ func (p *processor) Update(account *gtsmodel.Account, form *apimodel.UpdateCrede
return nil, err
}
displayName := text.RemoveHTML(*form.DisplayName) // no html allowed in display name
- if err := p.db.UpdateOneByID(account.ID, "display_name", displayName, &gtsmodel.Account{}); err != nil {
+ if err := p.db.UpdateOneByID(ctx, account.ID, "display_name", displayName, &gtsmodel.Account{}); err != nil {
return nil, err
}
}
@@ -62,13 +63,13 @@ func (p *processor) Update(account *gtsmodel.Account, form *apimodel.UpdateCrede
return nil, err
}
note := text.SanitizeHTML(*form.Note) // html OK in note but sanitize it
- if err := p.db.UpdateOneByID(account.ID, "note", note, &gtsmodel.Account{}); err != nil {
+ if err := p.db.UpdateOneByID(ctx, account.ID, "note", note, &gtsmodel.Account{}); err != nil {
return nil, err
}
}
if form.Avatar != nil && form.Avatar.Size != 0 {
- avatarInfo, err := p.UpdateAvatar(form.Avatar, account.ID)
+ avatarInfo, err := p.UpdateAvatar(ctx, form.Avatar, account.ID)
if err != nil {
return nil, err
}
@@ -76,7 +77,7 @@ func (p *processor) Update(account *gtsmodel.Account, form *apimodel.UpdateCrede
}
if form.Header != nil && form.Header.Size != 0 {
- headerInfo, err := p.UpdateHeader(form.Header, account.ID)
+ headerInfo, err := p.UpdateHeader(ctx, form.Header, account.ID)
if err != nil {
return nil, err
}
@@ -84,7 +85,7 @@ func (p *processor) Update(account *gtsmodel.Account, form *apimodel.UpdateCrede
}
if form.Locked != nil {
- if err := p.db.UpdateOneByID(account.ID, "locked", *form.Locked, &gtsmodel.Account{}); err != nil {
+ if err := p.db.UpdateOneByID(ctx, account.ID, "locked", *form.Locked, &gtsmodel.Account{}); err != nil {
return nil, err
}
}
@@ -94,13 +95,13 @@ func (p *processor) Update(account *gtsmodel.Account, form *apimodel.UpdateCrede
if err := util.ValidateLanguage(*form.Source.Language); err != nil {
return nil, err
}
- if err := p.db.UpdateOneByID(account.ID, "language", *form.Source.Language, &gtsmodel.Account{}); err != nil {
+ if err := p.db.UpdateOneByID(ctx, account.ID, "language", *form.Source.Language, &gtsmodel.Account{}); err != nil {
return nil, err
}
}
if form.Source.Sensitive != nil {
- if err := p.db.UpdateOneByID(account.ID, "locked", *form.Locked, &gtsmodel.Account{}); err != nil {
+ if err := p.db.UpdateOneByID(ctx, account.ID, "locked", *form.Locked, &gtsmodel.Account{}); err != nil {
return nil, err
}
}
@@ -109,15 +110,15 @@ func (p *processor) Update(account *gtsmodel.Account, form *apimodel.UpdateCrede
if err := util.ValidatePrivacy(*form.Source.Privacy); err != nil {
return nil, err
}
- if err := p.db.UpdateOneByID(account.ID, "privacy", *form.Source.Privacy, &gtsmodel.Account{}); err != nil {
+ if err := p.db.UpdateOneByID(ctx, account.ID, "privacy", *form.Source.Privacy, &gtsmodel.Account{}); err != nil {
return nil, err
}
}
}
// fetch the account with all updated values set
- updatedAccount := &gtsmodel.Account{}
- if err := p.db.GetByID(account.ID, updatedAccount); err != nil {
+ updatedAccount, err := p.db.GetAccountByID(ctx, account.ID)
+ if err != nil {
return nil, fmt.Errorf("could not fetch updated account %s: %s", account.ID, err)
}
@@ -128,7 +129,7 @@ func (p *processor) Update(account *gtsmodel.Account, form *apimodel.UpdateCrede
OriginAccount: updatedAccount,
}
- acctSensitive, err := p.tc.AccountToMastoSensitive(updatedAccount)
+ acctSensitive, err := p.tc.AccountToMastoSensitive(ctx, updatedAccount)
if err != nil {
return nil, fmt.Errorf("could not convert account into mastosensitive account: %s", err)
}
@@ -138,7 +139,7 @@ func (p *processor) Update(account *gtsmodel.Account, form *apimodel.UpdateCrede
// UpdateAvatar does the dirty work of checking the avatar part of an account update form,
// parsing and checking the image, and doing the necessary updates in the database for this to become
// the account's new avatar image.
-func (p *processor) UpdateAvatar(avatar *multipart.FileHeader, accountID string) (*gtsmodel.MediaAttachment, error) {
+func (p *processor) UpdateAvatar(ctx context.Context, avatar *multipart.FileHeader, accountID string) (*gtsmodel.MediaAttachment, error) {
var err error
if int(avatar.Size) > p.config.MediaConfig.MaxImageSize {
err = fmt.Errorf("avatar with size %d exceeded max image size of %d bytes", avatar.Size, p.config.MediaConfig.MaxImageSize)
@@ -160,7 +161,7 @@ func (p *processor) UpdateAvatar(avatar *multipart.FileHeader, accountID string)
}
// do the setting
- avatarInfo, err := p.mediaHandler.ProcessHeaderOrAvatar(buf.Bytes(), accountID, media.Avatar, "")
+ avatarInfo, err := p.mediaHandler.ProcessHeaderOrAvatar(ctx, buf.Bytes(), accountID, media.Avatar, "")
if err != nil {
return nil, fmt.Errorf("error processing avatar: %s", err)
}
@@ -171,7 +172,7 @@ func (p *processor) UpdateAvatar(avatar *multipart.FileHeader, accountID string)
// UpdateHeader does the dirty work of checking the header part of an account update form,
// parsing and checking the image, and doing the necessary updates in the database for this to become
// the account's new header image.
-func (p *processor) UpdateHeader(header *multipart.FileHeader, accountID string) (*gtsmodel.MediaAttachment, error) {
+func (p *processor) UpdateHeader(ctx context.Context, header *multipart.FileHeader, accountID string) (*gtsmodel.MediaAttachment, error) {
var err error
if int(header.Size) > p.config.MediaConfig.MaxImageSize {
err = fmt.Errorf("header with size %d exceeded max image size of %d bytes", header.Size, p.config.MediaConfig.MaxImageSize)
@@ -193,7 +194,7 @@ func (p *processor) UpdateHeader(header *multipart.FileHeader, accountID string)
}
// do the setting
- headerInfo, err := p.mediaHandler.ProcessHeaderOrAvatar(buf.Bytes(), accountID, media.Header, "")
+ headerInfo, err := p.mediaHandler.ProcessHeaderOrAvatar(ctx, buf.Bytes(), accountID, media.Header, "")
if err != nil {
return nil, fmt.Errorf("error processing header: %s", err)
}
diff --git a/internal/processing/admin.go b/internal/processing/admin.go
index 9a38f5ec1..48faee986 100644
--- a/internal/processing/admin.go
+++ b/internal/processing/admin.go
@@ -19,31 +19,33 @@
package processing
import (
+ "context"
+
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
)
-func (p *processor) AdminEmojiCreate(authed *oauth.Auth, form *apimodel.EmojiCreateRequest) (*apimodel.Emoji, error) {
- return p.adminProcessor.EmojiCreate(authed.Account, authed.User, form)
+func (p *processor) AdminEmojiCreate(ctx context.Context, authed *oauth.Auth, form *apimodel.EmojiCreateRequest) (*apimodel.Emoji, error) {
+ return p.adminProcessor.EmojiCreate(ctx, authed.Account, authed.User, form)
}
-func (p *processor) AdminDomainBlockCreate(authed *oauth.Auth, form *apimodel.DomainBlockCreateRequest) (*apimodel.DomainBlock, gtserror.WithCode) {
- return p.adminProcessor.DomainBlockCreate(authed.Account, form.Domain, form.Obfuscate, form.PublicComment, form.PrivateComment, "")
+func (p *processor) AdminDomainBlockCreate(ctx context.Context, authed *oauth.Auth, form *apimodel.DomainBlockCreateRequest) (*apimodel.DomainBlock, gtserror.WithCode) {
+ return p.adminProcessor.DomainBlockCreate(ctx, authed.Account, form.Domain, form.Obfuscate, form.PublicComment, form.PrivateComment, "")
}
-func (p *processor) AdminDomainBlocksImport(authed *oauth.Auth, form *apimodel.DomainBlockCreateRequest) ([]*apimodel.DomainBlock, gtserror.WithCode) {
- return p.adminProcessor.DomainBlocksImport(authed.Account, form.Domains)
+func (p *processor) AdminDomainBlocksImport(ctx context.Context, authed *oauth.Auth, form *apimodel.DomainBlockCreateRequest) ([]*apimodel.DomainBlock, gtserror.WithCode) {
+ return p.adminProcessor.DomainBlocksImport(ctx, authed.Account, form.Domains)
}
-func (p *processor) AdminDomainBlocksGet(authed *oauth.Auth, export bool) ([]*apimodel.DomainBlock, gtserror.WithCode) {
- return p.adminProcessor.DomainBlocksGet(authed.Account, export)
+func (p *processor) AdminDomainBlocksGet(ctx context.Context, authed *oauth.Auth, export bool) ([]*apimodel.DomainBlock, gtserror.WithCode) {
+ return p.adminProcessor.DomainBlocksGet(ctx, authed.Account, export)
}
-func (p *processor) AdminDomainBlockGet(authed *oauth.Auth, id string, export bool) (*apimodel.DomainBlock, gtserror.WithCode) {
- return p.adminProcessor.DomainBlockGet(authed.Account, id, export)
+func (p *processor) AdminDomainBlockGet(ctx context.Context, authed *oauth.Auth, id string, export bool) (*apimodel.DomainBlock, gtserror.WithCode) {
+ return p.adminProcessor.DomainBlockGet(ctx, authed.Account, id, export)
}
-func (p *processor) AdminDomainBlockDelete(authed *oauth.Auth, id string) (*apimodel.DomainBlock, gtserror.WithCode) {
- return p.adminProcessor.DomainBlockDelete(authed.Account, id)
+func (p *processor) AdminDomainBlockDelete(ctx context.Context, authed *oauth.Auth, id string) (*apimodel.DomainBlock, gtserror.WithCode) {
+ return p.adminProcessor.DomainBlockDelete(ctx, authed.Account, id)
}
diff --git a/internal/processing/admin/admin.go b/internal/processing/admin/admin.go
index fd63d8a10..de288811b 100644
--- a/internal/processing/admin/admin.go
+++ b/internal/processing/admin/admin.go
@@ -19,6 +19,7 @@
package admin
import (
+ "context"
"mime/multipart"
"github.com/sirupsen/logrus"
@@ -33,12 +34,12 @@ import (
// Processor wraps a bunch of functions for processing admin actions.
type Processor interface {
- DomainBlockCreate(account *gtsmodel.Account, domain string, obfuscate bool, publicComment string, privateComment string, subscriptionID string) (*apimodel.DomainBlock, gtserror.WithCode)
- DomainBlocksImport(account *gtsmodel.Account, domains *multipart.FileHeader) ([]*apimodel.DomainBlock, gtserror.WithCode)
- DomainBlocksGet(account *gtsmodel.Account, export bool) ([]*apimodel.DomainBlock, gtserror.WithCode)
- DomainBlockGet(account *gtsmodel.Account, id string, export bool) (*apimodel.DomainBlock, gtserror.WithCode)
- DomainBlockDelete(account *gtsmodel.Account, id string) (*apimodel.DomainBlock, gtserror.WithCode)
- EmojiCreate(account *gtsmodel.Account, user *gtsmodel.User, form *apimodel.EmojiCreateRequest) (*apimodel.Emoji, error)
+ DomainBlockCreate(ctx context.Context, account *gtsmodel.Account, domain string, obfuscate bool, publicComment string, privateComment string, subscriptionID string) (*apimodel.DomainBlock, gtserror.WithCode)
+ DomainBlocksImport(ctx context.Context, account *gtsmodel.Account, domains *multipart.FileHeader) ([]*apimodel.DomainBlock, gtserror.WithCode)
+ DomainBlocksGet(ctx context.Context, account *gtsmodel.Account, export bool) ([]*apimodel.DomainBlock, gtserror.WithCode)
+ DomainBlockGet(ctx context.Context, account *gtsmodel.Account, id string, export bool) (*apimodel.DomainBlock, gtserror.WithCode)
+ DomainBlockDelete(ctx context.Context, account *gtsmodel.Account, id string) (*apimodel.DomainBlock, gtserror.WithCode)
+ EmojiCreate(ctx context.Context, account *gtsmodel.Account, user *gtsmodel.User, form *apimodel.EmojiCreateRequest) (*apimodel.Emoji, error)
}
type processor struct {
diff --git a/internal/processing/admin/createdomainblock.go b/internal/processing/admin/createdomainblock.go
index 624f632dc..a34c03a44 100644
--- a/internal/processing/admin/createdomainblock.go
+++ b/internal/processing/admin/createdomainblock.go
@@ -19,6 +19,7 @@
package admin
import (
+ "context"
"fmt"
"time"
@@ -31,10 +32,10 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/text"
)
-func (p *processor) DomainBlockCreate(account *gtsmodel.Account, domain string, obfuscate bool, publicComment string, privateComment string, subscriptionID string) (*apimodel.DomainBlock, gtserror.WithCode) {
+func (p *processor) DomainBlockCreate(ctx context.Context, account *gtsmodel.Account, domain string, obfuscate bool, publicComment string, privateComment string, subscriptionID string) (*apimodel.DomainBlock, gtserror.WithCode) {
// first check if we already have a block -- if err == nil we already had a block so we can skip a whole lot of work
domainBlock := &gtsmodel.DomainBlock{}
- err := p.db.GetWhere([]db.Where{{Key: "domain", Value: domain, CaseInsensitive: true}}, domainBlock)
+ err := p.db.GetWhere(ctx, []db.Where{{Key: "domain", Value: domain, CaseInsensitive: true}}, domainBlock)
if err != nil {
if err != db.ErrNoEntries {
// something went wrong in the DB
@@ -59,7 +60,7 @@ func (p *processor) DomainBlockCreate(account *gtsmodel.Account, domain string,
}
// put the new block in the database
- if err := p.db.Put(domainBlock); err != nil {
+ if err := p.db.Put(ctx, domainBlock); err != nil {
if err != db.ErrNoEntries {
// there's a real error creating the block
return nil, gtserror.NewErrorInternalError(fmt.Errorf("DomainBlockCreate: db error putting new domain block %s: %s", domain, err))
@@ -67,10 +68,10 @@ func (p *processor) DomainBlockCreate(account *gtsmodel.Account, domain string,
}
// process the side effects of the domain block asynchronously since it might take a while
- go p.initiateDomainBlockSideEffects(account, domainBlock) // TODO: add this to a queuing system so it can retry/resume
+ go p.initiateDomainBlockSideEffects(ctx, account, domainBlock) // TODO: add this to a queuing system so it can retry/resume
}
- mastoDomainBlock, err := p.tc.DomainBlockToMasto(domainBlock, false)
+ mastoDomainBlock, err := p.tc.DomainBlockToMasto(ctx, domainBlock, false)
if err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("DomainBlockCreate: error converting domain block to frontend/masto representation %s: %s", domain, err))
}
@@ -83,7 +84,7 @@ func (p *processor) DomainBlockCreate(account *gtsmodel.Account, domain string,
// 1. Strip most info away from the instance entry for the domain.
// 2. Delete the instance account for that instance if it exists.
// 3. Select all accounts from this instance and pass them through the delete functionality of the processor.
-func (p *processor) initiateDomainBlockSideEffects(account *gtsmodel.Account, block *gtsmodel.DomainBlock) {
+func (p *processor) initiateDomainBlockSideEffects(ctx context.Context, account *gtsmodel.Account, block *gtsmodel.DomainBlock) {
l := p.log.WithFields(logrus.Fields{
"func": "domainBlockProcessSideEffects",
"domain": block.Domain,
@@ -93,7 +94,7 @@ func (p *processor) initiateDomainBlockSideEffects(account *gtsmodel.Account, bl
// if we have an instance entry for this domain, update it with the new block ID and clear all fields
instance := &gtsmodel.Instance{}
- if err := p.db.GetWhere([]db.Where{{Key: "domain", Value: block.Domain, CaseInsensitive: true}}, instance); err == nil {
+ if err := p.db.GetWhere(ctx, []db.Where{{Key: "domain", Value: block.Domain, CaseInsensitive: true}}, instance); err == nil {
instance.Title = ""
instance.UpdatedAt = time.Now()
instance.SuspendedAt = time.Now()
@@ -105,14 +106,14 @@ func (p *processor) initiateDomainBlockSideEffects(account *gtsmodel.Account, bl
instance.ContactAccountUsername = ""
instance.ContactAccountID = ""
instance.Version = ""
- if err := p.db.UpdateByID(instance.ID, instance); err != nil {
+ if err := p.db.UpdateByID(ctx, instance.ID, instance); err != nil {
l.Errorf("domainBlockProcessSideEffects: db error updating instance: %s", err)
}
l.Debug("domainBlockProcessSideEffects: instance entry updated")
}
// if we have an instance account for this instance, delete it
- if err := p.db.DeleteWhere([]db.Where{{Key: "username", Value: block.Domain, CaseInsensitive: true}}, &gtsmodel.Account{}); err != nil {
+ if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "username", Value: block.Domain, CaseInsensitive: true}}, &gtsmodel.Account{}); err != nil {
l.Errorf("domainBlockProcessSideEffects: db error removing instance account: %s", err)
}
@@ -123,7 +124,7 @@ func (p *processor) initiateDomainBlockSideEffects(account *gtsmodel.Account, bl
selectAccountsLoop:
for {
- accounts, err := p.db.GetInstanceAccounts(block.Domain, maxID, limit)
+ accounts, err := p.db.GetInstanceAccounts(ctx, block.Domain, maxID, limit)
if err != nil {
if err == db.ErrNoEntries {
// no accounts left for this instance so we're done
diff --git a/internal/processing/admin/deletedomainblock.go b/internal/processing/admin/deletedomainblock.go
index edb0a58f9..2563b557d 100644
--- a/internal/processing/admin/deletedomainblock.go
+++ b/internal/processing/admin/deletedomainblock.go
@@ -19,6 +19,7 @@
package admin
import (
+ "context"
"fmt"
"time"
@@ -28,10 +29,10 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
-func (p *processor) DomainBlockDelete(account *gtsmodel.Account, id string) (*apimodel.DomainBlock, gtserror.WithCode) {
+func (p *processor) DomainBlockDelete(ctx context.Context, account *gtsmodel.Account, id string) (*apimodel.DomainBlock, gtserror.WithCode) {
domainBlock := &gtsmodel.DomainBlock{}
- if err := p.db.GetByID(id, domainBlock); err != nil {
+ if err := p.db.GetByID(ctx, id, domainBlock); err != nil {
if err != db.ErrNoEntries {
// something has gone really wrong
return nil, gtserror.NewErrorInternalError(err)
@@ -41,39 +42,39 @@ func (p *processor) DomainBlockDelete(account *gtsmodel.Account, id string) (*ap
}
// prepare the domain block to return
- mastoDomainBlock, err := p.tc.DomainBlockToMasto(domainBlock, false)
+ mastoDomainBlock, err := p.tc.DomainBlockToMasto(ctx, domainBlock, false)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
// delete the domain block
- if err := p.db.DeleteByID(id, domainBlock); err != nil {
+ if err := p.db.DeleteByID(ctx, id, domainBlock); err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
// remove the domain block reference from the instance, if we have an entry for it
i := &gtsmodel.Instance{}
- if err := p.db.GetWhere([]db.Where{
+ if err := p.db.GetWhere(ctx, []db.Where{
{Key: "domain", Value: domainBlock.Domain, CaseInsensitive: true},
{Key: "domain_block_id", Value: id},
}, i); err == nil {
i.SuspendedAt = time.Time{}
i.DomainBlockID = ""
- if err := p.db.UpdateByID(i.ID, i); err != nil {
+ if err := p.db.UpdateByID(ctx, i.ID, i); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("couldn't update database entry for instance %s: %s", domainBlock.Domain, err))
}
}
// unsuspend all accounts whose suspension origin was this domain block
// 1. remove the 'suspended_at' entry from their accounts
- if err := p.db.UpdateWhere([]db.Where{
+ if err := p.db.UpdateWhere(ctx, []db.Where{
{Key: "suspension_origin", Value: domainBlock.ID},
}, "suspended_at", nil, &[]*gtsmodel.Account{}); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("database error removing suspended_at from accounts: %s", err))
}
// 2. remove the 'suspension_origin' entry from their accounts
- if err := p.db.UpdateWhere([]db.Where{
+ if err := p.db.UpdateWhere(ctx, []db.Where{
{Key: "suspension_origin", Value: domainBlock.ID},
}, "suspension_origin", nil, &[]*gtsmodel.Account{}); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("database error removing suspension_origin from accounts: %s", err))
diff --git a/internal/processing/admin/emoji.go b/internal/processing/admin/emoji.go
index f19e173b5..f56bde8e0 100644
--- a/internal/processing/admin/emoji.go
+++ b/internal/processing/admin/emoji.go
@@ -20,6 +20,7 @@ package admin
import (
"bytes"
+ "context"
"errors"
"fmt"
"io"
@@ -29,7 +30,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/id"
)
-func (p *processor) EmojiCreate(account *gtsmodel.Account, user *gtsmodel.User, form *apimodel.EmojiCreateRequest) (*apimodel.Emoji, error) {
+func (p *processor) EmojiCreate(ctx context.Context, account *gtsmodel.Account, user *gtsmodel.User, form *apimodel.EmojiCreateRequest) (*apimodel.Emoji, error) {
if user.Admin {
return nil, fmt.Errorf("user %s not an admin", user.ID)
}
@@ -49,7 +50,7 @@ func (p *processor) EmojiCreate(account *gtsmodel.Account, user *gtsmodel.User,
}
// allow the mediaHandler to work its magic of processing the emoji bytes, and putting them in whatever storage backend we're using
- emoji, err := p.mediaHandler.ProcessLocalEmoji(buf.Bytes(), form.Shortcode)
+ emoji, err := p.mediaHandler.ProcessLocalEmoji(ctx, buf.Bytes(), form.Shortcode)
if err != nil {
return nil, fmt.Errorf("error reading emoji: %s", err)
}
@@ -60,12 +61,12 @@ func (p *processor) EmojiCreate(account *gtsmodel.Account, user *gtsmodel.User,
}
emoji.ID = emojiID
- mastoEmoji, err := p.tc.EmojiToMasto(emoji)
+ mastoEmoji, err := p.tc.EmojiToMasto(ctx, emoji)
if err != nil {
return nil, fmt.Errorf("error converting emoji to mastotype: %s", err)
}
- if err := p.db.Put(emoji); err != nil {
+ if err := p.db.Put(ctx, emoji); err != nil {
return nil, fmt.Errorf("database error while processing emoji: %s", err)
}
diff --git a/internal/processing/admin/getdomainblock.go b/internal/processing/admin/getdomainblock.go
index f74010627..19bc9fe09 100644
--- a/internal/processing/admin/getdomainblock.go
+++ b/internal/processing/admin/getdomainblock.go
@@ -19,6 +19,7 @@
package admin
import (
+ "context"
"fmt"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
@@ -27,10 +28,10 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
-func (p *processor) DomainBlockGet(account *gtsmodel.Account, id string, export bool) (*apimodel.DomainBlock, gtserror.WithCode) {
+func (p *processor) DomainBlockGet(ctx context.Context, account *gtsmodel.Account, id string, export bool) (*apimodel.DomainBlock, gtserror.WithCode) {
domainBlock := &gtsmodel.DomainBlock{}
- if err := p.db.GetByID(id, domainBlock); err != nil {
+ if err := p.db.GetByID(ctx, id, domainBlock); err != nil {
if err != db.ErrNoEntries {
// something has gone really wrong
return nil, gtserror.NewErrorInternalError(err)
@@ -39,7 +40,7 @@ func (p *processor) DomainBlockGet(account *gtsmodel.Account, id string, export
return nil, gtserror.NewErrorNotFound(fmt.Errorf("no entry for ID %s", id))
}
- mastoDomainBlock, err := p.tc.DomainBlockToMasto(domainBlock, export)
+ mastoDomainBlock, err := p.tc.DomainBlockToMasto(ctx, domainBlock, export)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
diff --git a/internal/processing/admin/getdomainblocks.go b/internal/processing/admin/getdomainblocks.go
index f827d03fc..0ec33cfff 100644
--- a/internal/processing/admin/getdomainblocks.go
+++ b/internal/processing/admin/getdomainblocks.go
@@ -19,16 +19,18 @@
package admin
import (
+ "context"
+
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
-func (p *processor) DomainBlocksGet(account *gtsmodel.Account, export bool) ([]*apimodel.DomainBlock, gtserror.WithCode) {
+func (p *processor) DomainBlocksGet(ctx context.Context, account *gtsmodel.Account, export bool) ([]*apimodel.DomainBlock, gtserror.WithCode) {
domainBlocks := []*gtsmodel.DomainBlock{}
- if err := p.db.GetAll(&domainBlocks); err != nil {
+ if err := p.db.GetAll(ctx, &domainBlocks); err != nil {
if err != db.ErrNoEntries {
// something has gone really wrong
return nil, gtserror.NewErrorInternalError(err)
@@ -37,7 +39,7 @@ func (p *processor) DomainBlocksGet(account *gtsmodel.Account, export bool) ([]*
mastoDomainBlocks := []*apimodel.DomainBlock{}
for _, b := range domainBlocks {
- mastoDomainBlock, err := p.tc.DomainBlockToMasto(b, export)
+ mastoDomainBlock, err := p.tc.DomainBlockToMasto(ctx, b, export)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
diff --git a/internal/processing/admin/importdomainblocks.go b/internal/processing/admin/importdomainblocks.go
index ab171b712..66326bd62 100644
--- a/internal/processing/admin/importdomainblocks.go
+++ b/internal/processing/admin/importdomainblocks.go
@@ -20,6 +20,7 @@ package admin
import (
"bytes"
+ "context"
"encoding/json"
"errors"
"fmt"
@@ -32,7 +33,7 @@ import (
)
// DomainBlocksImport handles the import of a bunch of domain blocks at once, by calling the DomainBlockCreate function for each domain in the provided file.
-func (p *processor) DomainBlocksImport(account *gtsmodel.Account, domains *multipart.FileHeader) ([]*apimodel.DomainBlock, gtserror.WithCode) {
+func (p *processor) DomainBlocksImport(ctx context.Context, account *gtsmodel.Account, domains *multipart.FileHeader) ([]*apimodel.DomainBlock, gtserror.WithCode) {
f, err := domains.Open()
if err != nil {
@@ -54,7 +55,7 @@ func (p *processor) DomainBlocksImport(account *gtsmodel.Account, domains *multi
blocks := []*apimodel.DomainBlock{}
for _, d := range d {
- block, err := p.DomainBlockCreate(account, d.Domain, false, d.PublicComment, "", "")
+ block, err := p.DomainBlockCreate(ctx, account, d.Domain, false, d.PublicComment, "", "")
if err != nil {
return nil, err
diff --git a/internal/processing/app.go b/internal/processing/app.go
index 7da5344ac..4f805572b 100644
--- a/internal/processing/app.go
+++ b/internal/processing/app.go
@@ -19,6 +19,8 @@
package processing
import (
+ "context"
+
"github.com/google/uuid"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
@@ -26,7 +28,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/oauth"
)
-func (p *processor) AppCreate(authed *oauth.Auth, form *apimodel.ApplicationCreateRequest) (*apimodel.Application, error) {
+func (p *processor) AppCreate(ctx context.Context, authed *oauth.Auth, form *apimodel.ApplicationCreateRequest) (*apimodel.Application, error) {
// set default 'read' for scopes if it's not set, this follows the default of the mastodon api https://docs.joinmastodon.org/methods/apps/
var scopes string
if form.Scopes == "" {
@@ -61,7 +63,7 @@ func (p *processor) AppCreate(authed *oauth.Auth, form *apimodel.ApplicationCrea
}
// chuck it in the db
- if err := p.db.Put(app); err != nil {
+ if err := p.db.Put(ctx, app); err != nil {
return nil, err
}
@@ -74,11 +76,11 @@ func (p *processor) AppCreate(authed *oauth.Auth, form *apimodel.ApplicationCrea
}
// chuck it in the db
- if err := p.db.Put(oc); err != nil {
+ if err := p.db.Put(ctx, oc); err != nil {
return nil, err
}
- mastoApp, err := p.tc.AppToMastoSensitive(app)
+ mastoApp, err := p.tc.AppToMastoSensitive(ctx, app)
if err != nil {
return nil, err
}
diff --git a/internal/processing/blocks.go b/internal/processing/blocks.go
index 809cbde8e..7c8371989 100644
--- a/internal/processing/blocks.go
+++ b/internal/processing/blocks.go
@@ -19,6 +19,7 @@
package processing
import (
+ "context"
"fmt"
"net/url"
@@ -28,8 +29,8 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/oauth"
)
-func (p *processor) BlocksGet(authed *oauth.Auth, maxID string, sinceID string, limit int) (*apimodel.BlocksResponse, gtserror.WithCode) {
- accounts, nextMaxID, prevMinID, err := p.db.GetAccountBlocks(authed.Account.ID, maxID, sinceID, limit)
+func (p *processor) BlocksGet(ctx context.Context, authed *oauth.Auth, maxID string, sinceID string, limit int) (*apimodel.BlocksResponse, gtserror.WithCode) {
+ accounts, nextMaxID, prevMinID, err := p.db.GetAccountBlocks(ctx, authed.Account.ID, maxID, sinceID, limit)
if err != nil {
if err == db.ErrNoEntries {
// there are just no entries
@@ -43,7 +44,7 @@ func (p *processor) BlocksGet(authed *oauth.Auth, maxID string, sinceID string,
apiAccounts := []*apimodel.Account{}
for _, a := range accounts {
- apiAccount, err := p.tc.AccountToMastoBlocked(a)
+ apiAccount, err := p.tc.AccountToMastoBlocked(ctx, a)
if err != nil {
continue
}
diff --git a/internal/processing/federation.go b/internal/processing/federation.go
index cea14b4de..352a6ddc2 100644
--- a/internal/processing/federation.go
+++ b/internal/processing/federation.go
@@ -36,7 +36,7 @@ import (
func (p *processor) GetFediUser(ctx context.Context, requestedUsername string, requestURL *url.URL) (interface{}, gtserror.WithCode) {
// get the account the request is referring to
- requestedAccount, err := p.db.GetLocalAccountByUsername(requestedUsername)
+ requestedAccount, err := p.db.GetLocalAccountByUsername(ctx, requestedUsername)
if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("database error getting account with username %s: %s", requestedUsername, err))
}
@@ -44,7 +44,7 @@ func (p *processor) GetFediUser(ctx context.Context, requestedUsername string, r
var requestedPerson vocab.ActivityStreamsPerson
if util.IsPublicKeyPath(requestURL) {
// if it's a public key path, we don't need to authenticate but we'll only serve the bare minimum user profile needed for the public key
- requestedPerson, err = p.tc.AccountToASMinimal(requestedAccount)
+ requestedPerson, err = p.tc.AccountToASMinimal(ctx, requestedAccount)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
@@ -56,13 +56,13 @@ func (p *processor) GetFediUser(ctx context.Context, requestedUsername string, r
}
// if we're not already handshaking/dereferencing a remote account, dereference it now
- if !p.federator.Handshaking(requestedUsername, requestingAccountURI) {
- requestingAccount, _, err := p.federator.GetRemoteAccount(requestedUsername, requestingAccountURI, false)
+ if !p.federator.Handshaking(ctx, requestedUsername, requestingAccountURI) {
+ requestingAccount, _, err := p.federator.GetRemoteAccount(ctx, requestedUsername, requestingAccountURI, false)
if err != nil {
return nil, gtserror.NewErrorNotAuthorized(err)
}
- blocked, err := p.db.IsBlocked(requestedAccount.ID, requestingAccount.ID, true)
+ blocked, err := p.db.IsBlocked(ctx, requestedAccount.ID, requestingAccount.ID, true)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
@@ -72,7 +72,7 @@ func (p *processor) GetFediUser(ctx context.Context, requestedUsername string, r
}
}
- requestedPerson, err = p.tc.AccountToAS(requestedAccount)
+ requestedPerson, err = p.tc.AccountToAS(ctx, requestedAccount)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
@@ -90,7 +90,7 @@ func (p *processor) GetFediUser(ctx context.Context, requestedUsername string, r
func (p *processor) GetFediFollowers(ctx context.Context, requestedUsername string, requestURL *url.URL) (interface{}, gtserror.WithCode) {
// get the account the request is referring to
- requestedAccount, err := p.db.GetLocalAccountByUsername(requestedUsername)
+ requestedAccount, err := p.db.GetLocalAccountByUsername(ctx, requestedUsername)
if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("database error getting account with username %s: %s", requestedUsername, err))
}
@@ -101,12 +101,12 @@ func (p *processor) GetFediFollowers(ctx context.Context, requestedUsername stri
return nil, gtserror.NewErrorNotAuthorized(errors.New("not authorized"), "not authorized")
}
- requestingAccount, _, err := p.federator.GetRemoteAccount(requestedUsername, requestingAccountURI, false)
+ requestingAccount, _, err := p.federator.GetRemoteAccount(ctx, requestedUsername, requestingAccountURI, false)
if err != nil {
return nil, gtserror.NewErrorNotAuthorized(err)
}
- blocked, err := p.db.IsBlocked(requestedAccount.ID, requestingAccount.ID, true)
+ blocked, err := p.db.IsBlocked(ctx, requestedAccount.ID, requestingAccount.ID, true)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
@@ -135,7 +135,7 @@ func (p *processor) GetFediFollowers(ctx context.Context, requestedUsername stri
func (p *processor) GetFediFollowing(ctx context.Context, requestedUsername string, requestURL *url.URL) (interface{}, gtserror.WithCode) {
// get the account the request is referring to
- requestedAccount, err := p.db.GetLocalAccountByUsername(requestedUsername)
+ requestedAccount, err := p.db.GetLocalAccountByUsername(ctx, requestedUsername)
if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("database error getting account with username %s: %s", requestedUsername, err))
}
@@ -146,12 +146,12 @@ func (p *processor) GetFediFollowing(ctx context.Context, requestedUsername stri
return nil, gtserror.NewErrorNotAuthorized(errors.New("not authorized"), "not authorized")
}
- requestingAccount, _, err := p.federator.GetRemoteAccount(requestedUsername, requestingAccountURI, false)
+ requestingAccount, _, err := p.federator.GetRemoteAccount(ctx, requestedUsername, requestingAccountURI, false)
if err != nil {
return nil, gtserror.NewErrorNotAuthorized(err)
}
- blocked, err := p.db.IsBlocked(requestedAccount.ID, requestingAccount.ID, true)
+ blocked, err := p.db.IsBlocked(ctx, requestedAccount.ID, requestingAccount.ID, true)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
@@ -180,7 +180,7 @@ func (p *processor) GetFediFollowing(ctx context.Context, requestedUsername stri
func (p *processor) GetFediStatus(ctx context.Context, requestedUsername string, requestedStatusID string, requestURL *url.URL) (interface{}, gtserror.WithCode) {
// get the account the request is referring to
- requestedAccount, err := p.db.GetLocalAccountByUsername(requestedUsername)
+ requestedAccount, err := p.db.GetLocalAccountByUsername(ctx, requestedUsername)
if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("database error getting account with username %s: %s", requestedUsername, err))
}
@@ -191,14 +191,14 @@ func (p *processor) GetFediStatus(ctx context.Context, requestedUsername string,
return nil, gtserror.NewErrorNotAuthorized(errors.New("not authorized"), "not authorized")
}
- requestingAccount, _, err := p.federator.GetRemoteAccount(requestedUsername, requestingAccountURI, false)
+ requestingAccount, _, err := p.federator.GetRemoteAccount(ctx, requestedUsername, requestingAccountURI, false)
if err != nil {
return nil, gtserror.NewErrorNotAuthorized(err)
}
// authorize the request:
// 1. check if a block exists between the requester and the requestee
- blocked, err := p.db.IsBlocked(requestedAccount.ID, requestingAccount.ID, true)
+ blocked, err := p.db.IsBlocked(ctx, requestedAccount.ID, requestingAccount.ID, true)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
@@ -209,14 +209,14 @@ func (p *processor) GetFediStatus(ctx context.Context, requestedUsername string,
// get the status out of the database here
s := &gtsmodel.Status{}
- if err := p.db.GetWhere([]db.Where{
+ if err := p.db.GetWhere(ctx, []db.Where{
{Key: "id", Value: requestedStatusID},
{Key: "account_id", Value: requestedAccount.ID},
}, s); err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("database error getting status with id %s and account id %s: %s", requestedStatusID, requestedAccount.ID, err))
}
- visible, err := p.filter.StatusVisible(s, requestingAccount)
+ visible, err := p.filter.StatusVisible(ctx, s, requestingAccount)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
@@ -225,7 +225,7 @@ func (p *processor) GetFediStatus(ctx context.Context, requestedUsername string,
}
// requester is authorized to view the status, so convert it to AP representation and serialize it
- asStatus, err := p.tc.StatusToAS(s)
+ asStatus, err := p.tc.StatusToAS(ctx, s)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
@@ -240,7 +240,7 @@ func (p *processor) GetFediStatus(ctx context.Context, requestedUsername string,
func (p *processor) GetFediStatusReplies(ctx context.Context, requestedUsername string, requestedStatusID string, page bool, onlyOtherAccounts bool, minID string, requestURL *url.URL) (interface{}, gtserror.WithCode) {
// get the account the request is referring to
- requestedAccount, err := p.db.GetLocalAccountByUsername(requestedUsername)
+ requestedAccount, err := p.db.GetLocalAccountByUsername(ctx, requestedUsername)
if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("database error getting account with username %s: %s", requestedUsername, err))
}
@@ -251,14 +251,14 @@ func (p *processor) GetFediStatusReplies(ctx context.Context, requestedUsername
return nil, gtserror.NewErrorNotAuthorized(errors.New("not authorized"), "not authorized")
}
- requestingAccount, _, err := p.federator.GetRemoteAccount(requestedUsername, requestingAccountURI, false)
+ requestingAccount, _, err := p.federator.GetRemoteAccount(ctx, requestedUsername, requestingAccountURI, false)
if err != nil {
return nil, gtserror.NewErrorNotAuthorized(err)
}
// authorize the request:
// 1. check if a block exists between the requester and the requestee
- blocked, err := p.db.IsBlocked(requestedAccount.ID, requestingAccount.ID, true)
+ blocked, err := p.db.IsBlocked(ctx, requestedAccount.ID, requestingAccount.ID, true)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
@@ -269,14 +269,14 @@ func (p *processor) GetFediStatusReplies(ctx context.Context, requestedUsername
// get the status out of the database here
s := &gtsmodel.Status{}
- if err := p.db.GetWhere([]db.Where{
+ if err := p.db.GetWhere(ctx, []db.Where{
{Key: "id", Value: requestedStatusID},
{Key: "account_id", Value: requestedAccount.ID},
}, s); err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("database error getting status with id %s and account id %s: %s", requestedStatusID, requestedAccount.ID, err))
}
- visible, err := p.filter.StatusVisible(s, requestingAccount)
+ visible, err := p.filter.StatusVisible(ctx, s, requestingAccount)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
@@ -295,7 +295,7 @@ func (p *processor) GetFediStatusReplies(ctx context.Context, requestedUsername
// scenario 1
// get the collection
- collection, err := p.tc.StatusToASRepliesCollection(s, onlyOtherAccounts)
+ collection, err := p.tc.StatusToASRepliesCollection(ctx, s, onlyOtherAccounts)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
@@ -308,7 +308,7 @@ func (p *processor) GetFediStatusReplies(ctx context.Context, requestedUsername
// scenario 2
// get the collection
- collection, err := p.tc.StatusToASRepliesCollection(s, onlyOtherAccounts)
+ collection, err := p.tc.StatusToASRepliesCollection(ctx, s, onlyOtherAccounts)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
@@ -320,7 +320,7 @@ func (p *processor) GetFediStatusReplies(ctx context.Context, requestedUsername
} else {
// scenario 3
// get immediate children
- replies, err := p.db.GetStatusChildren(s, true, minID)
+ replies, err := p.db.GetStatusChildren(ctx, s, true, minID)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
@@ -339,13 +339,13 @@ func (p *processor) GetFediStatusReplies(ctx context.Context, requestedUsername
}
// only show replies that the status owner can see
- visibleToStatusOwner, err := p.filter.StatusVisible(r, requestedAccount)
+ visibleToStatusOwner, err := p.filter.StatusVisible(ctx, r, requestedAccount)
if err != nil || !visibleToStatusOwner {
continue
}
// only show replies that the requester can see
- visibleToRequester, err := p.filter.StatusVisible(r, requestingAccount)
+ visibleToRequester, err := p.filter.StatusVisible(ctx, r, requestingAccount)
if err != nil || !visibleToRequester {
continue
}
@@ -358,7 +358,7 @@ func (p *processor) GetFediStatusReplies(ctx context.Context, requestedUsername
replyURIs[r.ID] = rURI
}
- repliesPage, err := p.tc.StatusURIsToASRepliesPage(s, onlyOtherAccounts, minID, replyURIs)
+ repliesPage, err := p.tc.StatusURIsToASRepliesPage(ctx, s, onlyOtherAccounts, minID, replyURIs)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
@@ -373,7 +373,7 @@ func (p *processor) GetFediStatusReplies(ctx context.Context, requestedUsername
func (p *processor) GetWebfingerAccount(ctx context.Context, requestedUsername string, requestURL *url.URL) (*apimodel.WellKnownResponse, gtserror.WithCode) {
// get the account the request is referring to
- requestedAccount, err := p.db.GetLocalAccountByUsername(requestedUsername)
+ requestedAccount, err := p.db.GetLocalAccountByUsername(ctx, requestedUsername)
if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("database error getting account with username %s: %s", requestedUsername, err))
}
@@ -400,7 +400,7 @@ func (p *processor) GetWebfingerAccount(ctx context.Context, requestedUsername s
}, nil
}
-func (p *processor) GetNodeInfoRel(request *http.Request) (*apimodel.WellKnownResponse, gtserror.WithCode) {
+func (p *processor) GetNodeInfoRel(ctx context.Context, request *http.Request) (*apimodel.WellKnownResponse, gtserror.WithCode) {
return &apimodel.WellKnownResponse{
Links: []apimodel.Link{
{
@@ -411,7 +411,7 @@ func (p *processor) GetNodeInfoRel(request *http.Request) (*apimodel.WellKnownRe
}, nil
}
-func (p *processor) GetNodeInfo(request *http.Request) (*apimodel.Nodeinfo, gtserror.WithCode) {
+func (p *processor) GetNodeInfo(ctx context.Context, request *http.Request) (*apimodel.Nodeinfo, gtserror.WithCode) {
return &apimodel.Nodeinfo{
Version: "2.0",
Software: apimodel.NodeInfoSoftware{
diff --git a/internal/processing/followrequest.go b/internal/processing/followrequest.go
index 867725023..3dd6432e2 100644
--- a/internal/processing/followrequest.go
+++ b/internal/processing/followrequest.go
@@ -19,6 +19,8 @@
package processing
import (
+ "context"
+
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
@@ -26,8 +28,8 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/oauth"
)
-func (p *processor) FollowRequestsGet(auth *oauth.Auth) ([]apimodel.Account, gtserror.WithCode) {
- frs, err := p.db.GetAccountFollowRequests(auth.Account.ID)
+func (p *processor) FollowRequestsGet(ctx context.Context, auth *oauth.Auth) ([]apimodel.Account, gtserror.WithCode) {
+ frs, err := p.db.GetAccountFollowRequests(ctx, auth.Account.ID)
if err != nil {
if err != db.ErrNoEntries {
return nil, gtserror.NewErrorInternalError(err)
@@ -36,11 +38,15 @@ func (p *processor) FollowRequestsGet(auth *oauth.Auth) ([]apimodel.Account, gts
accts := []apimodel.Account{}
for _, fr := range frs {
- acct := &gtsmodel.Account{}
- if err := p.db.GetByID(fr.AccountID, acct); err != nil {
- return nil, gtserror.NewErrorInternalError(err)
+ if fr.Account == nil {
+ frAcct, err := p.db.GetAccountByID(ctx, fr.AccountID)
+ if err != nil {
+ return nil, gtserror.NewErrorInternalError(err)
+ }
+ fr.Account = frAcct
}
- mastoAcct, err := p.tc.AccountToMastoPublic(acct)
+
+ mastoAcct, err := p.tc.AccountToMastoPublic(ctx, fr.Account)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
@@ -49,36 +55,42 @@ func (p *processor) FollowRequestsGet(auth *oauth.Auth) ([]apimodel.Account, gts
return accts, nil
}
-func (p *processor) FollowRequestAccept(auth *oauth.Auth, accountID string) (*apimodel.Relationship, gtserror.WithCode) {
- follow, err := p.db.AcceptFollowRequest(accountID, auth.Account.ID)
+func (p *processor) FollowRequestAccept(ctx context.Context, auth *oauth.Auth, accountID string) (*apimodel.Relationship, gtserror.WithCode) {
+ follow, err := p.db.AcceptFollowRequest(ctx, accountID, auth.Account.ID)
if err != nil {
return nil, gtserror.NewErrorNotFound(err)
}
- originAccount := &gtsmodel.Account{}
- if err := p.db.GetByID(follow.AccountID, originAccount); err != nil {
- return nil, gtserror.NewErrorInternalError(err)
+ if follow.Account == nil {
+ followAccount, err := p.db.GetAccountByID(ctx, follow.AccountID)
+ if err != nil {
+ return nil, gtserror.NewErrorInternalError(err)
+ }
+ follow.Account = followAccount
}
- targetAccount := &gtsmodel.Account{}
- if err := p.db.GetByID(follow.TargetAccountID, targetAccount); err != nil {
- return nil, gtserror.NewErrorInternalError(err)
+ if follow.TargetAccount == nil {
+ followTargetAccount, err := p.db.GetAccountByID(ctx, follow.TargetAccountID)
+ if err != nil {
+ return nil, gtserror.NewErrorInternalError(err)
+ }
+ follow.TargetAccount = followTargetAccount
}
p.fromClientAPI <- gtsmodel.FromClientAPI{
APObjectType: gtsmodel.ActivityStreamsFollow,
APActivityType: gtsmodel.ActivityStreamsAccept,
GTSModel: follow,
- OriginAccount: originAccount,
- TargetAccount: targetAccount,
+ OriginAccount: follow.Account,
+ TargetAccount: follow.TargetAccount,
}
- gtsR, err := p.db.GetRelationship(auth.Account.ID, accountID)
+ gtsR, err := p.db.GetRelationship(ctx, auth.Account.ID, accountID)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
- r, err := p.tc.RelationshipToMasto(gtsR)
+ r, err := p.tc.RelationshipToMasto(ctx, gtsR)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
@@ -86,6 +98,6 @@ func (p *processor) FollowRequestAccept(auth *oauth.Auth, accountID string) (*ap
return r, nil
}
-func (p *processor) FollowRequestDeny(auth *oauth.Auth) gtserror.WithCode {
+func (p *processor) FollowRequestDeny(ctx context.Context, auth *oauth.Auth) gtserror.WithCode {
return nil
}
diff --git a/internal/processing/fromclientapi.go b/internal/processing/fromclientapi.go
index beed283c1..a6ea0068b 100644
--- a/internal/processing/fromclientapi.go
+++ b/internal/processing/fromclientapi.go
@@ -29,7 +29,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
-func (p *processor) processFromClientAPI(clientMsg gtsmodel.FromClientAPI) error {
+func (p *processor) processFromClientAPI(ctx context.Context, clientMsg gtsmodel.FromClientAPI) error {
switch clientMsg.APActivityType {
case gtsmodel.ActivityStreamsCreate:
// CREATE
@@ -41,16 +41,16 @@ func (p *processor) processFromClientAPI(clientMsg gtsmodel.FromClientAPI) error
return errors.New("note was not parseable as *gtsmodel.Status")
}
- if err := p.timelineStatus(status); err != nil {
+ if err := p.timelineStatus(ctx, status); err != nil {
return err
}
- if err := p.notifyStatus(status); err != nil {
+ if err := p.notifyStatus(ctx, status); err != nil {
return err
}
if status.VisibilityAdvanced != nil && status.VisibilityAdvanced.Federated {
- return p.federateStatus(status)
+ return p.federateStatus(ctx, status)
}
case gtsmodel.ActivityStreamsFollow:
// CREATE FOLLOW REQUEST
@@ -59,11 +59,11 @@ func (p *processor) processFromClientAPI(clientMsg gtsmodel.FromClientAPI) error
return errors.New("followrequest was not parseable as *gtsmodel.FollowRequest")
}
- if err := p.notifyFollowRequest(followRequest, clientMsg.TargetAccount); err != nil {
+ if err := p.notifyFollowRequest(ctx, followRequest, clientMsg.TargetAccount); err != nil {
return err
}
- return p.federateFollow(followRequest, clientMsg.OriginAccount, clientMsg.TargetAccount)
+ return p.federateFollow(ctx, followRequest, clientMsg.OriginAccount, clientMsg.TargetAccount)
case gtsmodel.ActivityStreamsLike:
// CREATE LIKE/FAVE
fave, ok := clientMsg.GTSModel.(*gtsmodel.StatusFave)
@@ -71,11 +71,11 @@ func (p *processor) processFromClientAPI(clientMsg gtsmodel.FromClientAPI) error
return errors.New("fave was not parseable as *gtsmodel.StatusFave")
}
- if err := p.notifyFave(fave, clientMsg.TargetAccount); err != nil {
+ if err := p.notifyFave(ctx, fave, clientMsg.TargetAccount); err != nil {
return err
}
- return p.federateFave(fave, clientMsg.OriginAccount, clientMsg.TargetAccount)
+ return p.federateFave(ctx, fave, clientMsg.OriginAccount, clientMsg.TargetAccount)
case gtsmodel.ActivityStreamsAnnounce:
// CREATE BOOST/ANNOUNCE
boostWrapperStatus, ok := clientMsg.GTSModel.(*gtsmodel.Status)
@@ -83,15 +83,15 @@ func (p *processor) processFromClientAPI(clientMsg gtsmodel.FromClientAPI) error
return errors.New("boost was not parseable as *gtsmodel.Status")
}
- if err := p.timelineStatus(boostWrapperStatus); err != nil {
+ if err := p.timelineStatus(ctx, boostWrapperStatus); err != nil {
return err
}
- if err := p.notifyAnnounce(boostWrapperStatus); err != nil {
+ if err := p.notifyAnnounce(ctx, boostWrapperStatus); err != nil {
return err
}
- return p.federateAnnounce(boostWrapperStatus, clientMsg.OriginAccount, clientMsg.TargetAccount)
+ return p.federateAnnounce(ctx, boostWrapperStatus, clientMsg.OriginAccount, clientMsg.TargetAccount)
case gtsmodel.ActivityStreamsBlock:
// CREATE BLOCK
block, ok := clientMsg.GTSModel.(*gtsmodel.Block)
@@ -100,17 +100,17 @@ func (p *processor) processFromClientAPI(clientMsg gtsmodel.FromClientAPI) error
}
// remove any of the blocking account's statuses from the blocked account's timeline, and vice versa
- if err := p.timelineManager.WipeStatusesFromAccountID(block.AccountID, block.TargetAccountID); err != nil {
+ if err := p.timelineManager.WipeStatusesFromAccountID(ctx, block.AccountID, block.TargetAccountID); err != nil {
return err
}
- if err := p.timelineManager.WipeStatusesFromAccountID(block.TargetAccountID, block.AccountID); err != nil {
+ if err := p.timelineManager.WipeStatusesFromAccountID(ctx, block.TargetAccountID, block.AccountID); err != nil {
return err
}
// TODO: same with notifications
// TODO: same with bookmarks
- return p.federateBlock(block)
+ return p.federateBlock(ctx, block)
}
case gtsmodel.ActivityStreamsUpdate:
// UPDATE
@@ -122,7 +122,7 @@ func (p *processor) processFromClientAPI(clientMsg gtsmodel.FromClientAPI) error
return errors.New("account was not parseable as *gtsmodel.Account")
}
- return p.federateAccountUpdate(account, clientMsg.OriginAccount)
+ return p.federateAccountUpdate(ctx, account, clientMsg.OriginAccount)
}
case gtsmodel.ActivityStreamsAccept:
// ACCEPT
@@ -134,11 +134,11 @@ func (p *processor) processFromClientAPI(clientMsg gtsmodel.FromClientAPI) error
return errors.New("accept was not parseable as *gtsmodel.Follow")
}
- if err := p.notifyFollow(follow, clientMsg.TargetAccount); err != nil {
+ if err := p.notifyFollow(ctx, follow, clientMsg.TargetAccount); err != nil {
return err
}
- return p.federateAcceptFollowRequest(follow, clientMsg.OriginAccount, clientMsg.TargetAccount)
+ return p.federateAcceptFollowRequest(ctx, follow, clientMsg.OriginAccount, clientMsg.TargetAccount)
}
case gtsmodel.ActivityStreamsUndo:
// UNDO
@@ -149,21 +149,21 @@ func (p *processor) processFromClientAPI(clientMsg gtsmodel.FromClientAPI) error
if !ok {
return errors.New("undo was not parseable as *gtsmodel.Follow")
}
- return p.federateUnfollow(follow, clientMsg.OriginAccount, clientMsg.TargetAccount)
+ return p.federateUnfollow(ctx, follow, clientMsg.OriginAccount, clientMsg.TargetAccount)
case gtsmodel.ActivityStreamsBlock:
// UNDO BLOCK
block, ok := clientMsg.GTSModel.(*gtsmodel.Block)
if !ok {
return errors.New("undo was not parseable as *gtsmodel.Block")
}
- return p.federateUnblock(block)
+ return p.federateUnblock(ctx, block)
case gtsmodel.ActivityStreamsLike:
// UNDO LIKE/FAVE
fave, ok := clientMsg.GTSModel.(*gtsmodel.StatusFave)
if !ok {
return errors.New("undo was not parseable as *gtsmodel.StatusFave")
}
- return p.federateUnfave(fave, clientMsg.OriginAccount, clientMsg.TargetAccount)
+ return p.federateUnfave(ctx, fave, clientMsg.OriginAccount, clientMsg.TargetAccount)
case gtsmodel.ActivityStreamsAnnounce:
// UNDO ANNOUNCE/BOOST
boost, ok := clientMsg.GTSModel.(*gtsmodel.Status)
@@ -171,11 +171,11 @@ func (p *processor) processFromClientAPI(clientMsg gtsmodel.FromClientAPI) error
return errors.New("undo was not parseable as *gtsmodel.Status")
}
- if err := p.deleteStatusFromTimelines(boost); err != nil {
+ if err := p.deleteStatusFromTimelines(ctx, boost); err != nil {
return err
}
- return p.federateUnannounce(boost, clientMsg.OriginAccount, clientMsg.TargetAccount)
+ return p.federateUnannounce(ctx, boost, clientMsg.OriginAccount, clientMsg.TargetAccount)
}
case gtsmodel.ActivityStreamsDelete:
// DELETE
@@ -193,29 +193,29 @@ func (p *processor) processFromClientAPI(clientMsg gtsmodel.FromClientAPI) error
// delete all attachments for this status
for _, a := range statusToDelete.AttachmentIDs {
- if err := p.mediaProcessor.Delete(a); err != nil {
+ if err := p.mediaProcessor.Delete(ctx, a); err != nil {
return err
}
}
// delete all mentions for this status
for _, m := range statusToDelete.MentionIDs {
- if err := p.db.DeleteByID(m, &gtsmodel.Mention{}); err != nil {
+ if err := p.db.DeleteByID(ctx, m, &gtsmodel.Mention{}); err != nil {
return err
}
}
// delete all notifications for this status
- if err := p.db.DeleteWhere([]db.Where{{Key: "status_id", Value: statusToDelete.ID}}, &[]*gtsmodel.Notification{}); err != nil {
+ if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "status_id", Value: statusToDelete.ID}}, &[]*gtsmodel.Notification{}); err != nil {
return err
}
// delete this status from any and all timelines
- if err := p.deleteStatusFromTimelines(statusToDelete); err != nil {
+ if err := p.deleteStatusFromTimelines(ctx, statusToDelete); err != nil {
return err
}
- return p.federateStatusDelete(statusToDelete)
+ return p.federateStatusDelete(ctx, statusToDelete)
case gtsmodel.ActivityStreamsProfile, gtsmodel.ActivityStreamsPerson:
// DELETE ACCOUNT/PROFILE
@@ -228,7 +228,7 @@ func (p *processor) processFromClientAPI(clientMsg gtsmodel.FromClientAPI) error
// origin is whichever account caused this message
origin = clientMsg.OriginAccount.ID
}
- return p.accountProcessor.Delete(clientMsg.TargetAccount, origin)
+ return p.accountProcessor.Delete(ctx, clientMsg.TargetAccount, origin)
}
}
return nil
@@ -236,13 +236,13 @@ func (p *processor) processFromClientAPI(clientMsg gtsmodel.FromClientAPI) error
// TODO: move all the below functions into federation.Federator
-func (p *processor) federateStatus(status *gtsmodel.Status) error {
+func (p *processor) federateStatus(ctx context.Context, status *gtsmodel.Status) error {
if status.Account == nil {
- a := &gtsmodel.Account{}
- if err := p.db.GetByID(status.AccountID, a); err != nil {
+ statusAccount, err := p.db.GetAccountByID(ctx, status.AccountID)
+ if err != nil {
return fmt.Errorf("federateStatus: error fetching status author account: %s", err)
}
- status.Account = a
+ status.Account = statusAccount
}
// do nothing if this isn't our status
@@ -250,7 +250,7 @@ func (p *processor) federateStatus(status *gtsmodel.Status) error {
return nil
}
- asStatus, err := p.tc.StatusToAS(status)
+ asStatus, err := p.tc.StatusToAS(ctx, status)
if err != nil {
return fmt.Errorf("federateStatus: error converting status to as format: %s", err)
}
@@ -260,17 +260,17 @@ func (p *processor) federateStatus(status *gtsmodel.Status) error {
return fmt.Errorf("federateStatus: error parsing outboxURI %s: %s", status.Account.OutboxURI, err)
}
- _, err = p.federator.FederatingActor().Send(context.Background(), outboxIRI, asStatus)
+ _, err = p.federator.FederatingActor().Send(ctx, outboxIRI, asStatus)
return err
}
-func (p *processor) federateStatusDelete(status *gtsmodel.Status) error {
+func (p *processor) federateStatusDelete(ctx context.Context, status *gtsmodel.Status) error {
if status.Account == nil {
- a := &gtsmodel.Account{}
- if err := p.db.GetByID(status.AccountID, a); err != nil {
- return fmt.Errorf("federateStatus: error fetching status author account: %s", err)
+ statusAccount, err := p.db.GetAccountByID(ctx, status.AccountID)
+ if err != nil {
+ return fmt.Errorf("federateStatusDelete: error fetching status author account: %s", err)
}
- status.Account = a
+ status.Account = statusAccount
}
// do nothing if this isn't our status
@@ -278,7 +278,7 @@ func (p *processor) federateStatusDelete(status *gtsmodel.Status) error {
return nil
}
- asStatus, err := p.tc.StatusToAS(status)
+ asStatus, err := p.tc.StatusToAS(ctx, status)
if err != nil {
return fmt.Errorf("federateStatusDelete: error converting status to as format: %s", err)
}
@@ -310,19 +310,19 @@ func (p *processor) federateStatusDelete(status *gtsmodel.Status) error {
delete.SetActivityStreamsTo(asStatus.GetActivityStreamsTo())
delete.SetActivityStreamsCc(asStatus.GetActivityStreamsCc())
- _, err = p.federator.FederatingActor().Send(context.Background(), outboxIRI, delete)
+ _, err = p.federator.FederatingActor().Send(ctx, outboxIRI, delete)
return err
}
-func (p *processor) federateFollow(followRequest *gtsmodel.FollowRequest, originAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) error {
+func (p *processor) federateFollow(ctx context.Context, followRequest *gtsmodel.FollowRequest, originAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) error {
// if both accounts are local there's nothing to do here
if originAccount.Domain == "" && targetAccount.Domain == "" {
return nil
}
- follow := p.tc.FollowRequestToFollow(followRequest)
+ follow := p.tc.FollowRequestToFollow(ctx, followRequest)
- asFollow, err := p.tc.FollowToAS(follow, originAccount, targetAccount)
+ asFollow, err := p.tc.FollowToAS(ctx, follow, originAccount, targetAccount)
if err != nil {
return fmt.Errorf("federateFollow: error converting follow to as format: %s", err)
}
@@ -332,18 +332,18 @@ func (p *processor) federateFollow(followRequest *gtsmodel.FollowRequest, origin
return fmt.Errorf("federateFollow: error parsing outboxURI %s: %s", originAccount.OutboxURI, err)
}
- _, err = p.federator.FederatingActor().Send(context.Background(), outboxIRI, asFollow)
+ _, err = p.federator.FederatingActor().Send(ctx, outboxIRI, asFollow)
return err
}
-func (p *processor) federateUnfollow(follow *gtsmodel.Follow, originAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) error {
+func (p *processor) federateUnfollow(ctx context.Context, follow *gtsmodel.Follow, originAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) error {
// if both accounts are local there's nothing to do here
if originAccount.Domain == "" && targetAccount.Domain == "" {
return nil
}
// recreate the follow
- asFollow, err := p.tc.FollowToAS(follow, originAccount, targetAccount)
+ asFollow, err := p.tc.FollowToAS(ctx, follow, originAccount, targetAccount)
if err != nil {
return fmt.Errorf("federateUnfollow: error converting follow to as format: %s", err)
}
@@ -373,18 +373,18 @@ func (p *processor) federateUnfollow(follow *gtsmodel.Follow, originAccount *gts
}
// send off the Undo
- _, err = p.federator.FederatingActor().Send(context.Background(), outboxIRI, undo)
+ _, err = p.federator.FederatingActor().Send(ctx, outboxIRI, undo)
return err
}
-func (p *processor) federateUnfave(fave *gtsmodel.StatusFave, originAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) error {
+func (p *processor) federateUnfave(ctx context.Context, fave *gtsmodel.StatusFave, originAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) error {
// if both accounts are local there's nothing to do here
if originAccount.Domain == "" && targetAccount.Domain == "" {
return nil
}
// create the AS fave
- asFave, err := p.tc.FaveToAS(fave)
+ asFave, err := p.tc.FaveToAS(ctx, fave)
if err != nil {
return fmt.Errorf("federateFave: error converting fave to as format: %s", err)
}
@@ -412,17 +412,17 @@ func (p *processor) federateUnfave(fave *gtsmodel.StatusFave, originAccount *gts
if err != nil {
return fmt.Errorf("federateFave: error parsing outboxURI %s: %s", originAccount.OutboxURI, err)
}
- _, err = p.federator.FederatingActor().Send(context.Background(), outboxIRI, undo)
+ _, err = p.federator.FederatingActor().Send(ctx, outboxIRI, undo)
return err
}
-func (p *processor) federateUnannounce(boost *gtsmodel.Status, originAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) error {
+func (p *processor) federateUnannounce(ctx context.Context, boost *gtsmodel.Status, originAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) error {
if originAccount.Domain != "" {
// nothing to do here
return nil
}
- asAnnounce, err := p.tc.BoostToAS(boost, originAccount, targetAccount)
+ asAnnounce, err := p.tc.BoostToAS(ctx, boost, originAccount, targetAccount)
if err != nil {
return fmt.Errorf("federateUnannounce: error converting status to announce: %s", err)
}
@@ -447,18 +447,18 @@ func (p *processor) federateUnannounce(boost *gtsmodel.Status, originAccount *gt
return fmt.Errorf("federateUnannounce: error parsing outboxURI %s: %s", originAccount.OutboxURI, err)
}
- _, err = p.federator.FederatingActor().Send(context.Background(), outboxIRI, undo)
+ _, err = p.federator.FederatingActor().Send(ctx, outboxIRI, undo)
return err
}
-func (p *processor) federateAcceptFollowRequest(follow *gtsmodel.Follow, originAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) error {
+func (p *processor) federateAcceptFollowRequest(ctx context.Context, follow *gtsmodel.Follow, originAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) error {
// if both accounts are local there's nothing to do here
if originAccount.Domain == "" && targetAccount.Domain == "" {
return nil
}
// recreate the AS follow
- asFollow, err := p.tc.FollowToAS(follow, originAccount, targetAccount)
+ asFollow, err := p.tc.FollowToAS(ctx, follow, originAccount, targetAccount)
if err != nil {
return fmt.Errorf("federateUnfollow: error converting follow to as format: %s", err)
}
@@ -497,18 +497,18 @@ func (p *processor) federateAcceptFollowRequest(follow *gtsmodel.Follow, originA
}
// send off the accept using the accepter's outbox
- _, err = p.federator.FederatingActor().Send(context.Background(), outboxIRI, accept)
+ _, err = p.federator.FederatingActor().Send(ctx, outboxIRI, accept)
return err
}
-func (p *processor) federateFave(fave *gtsmodel.StatusFave, originAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) error {
+func (p *processor) federateFave(ctx context.Context, fave *gtsmodel.StatusFave, originAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) error {
// if both accounts are local there's nothing to do here
if originAccount.Domain == "" && targetAccount.Domain == "" {
return nil
}
// create the AS fave
- asFave, err := p.tc.FaveToAS(fave)
+ asFave, err := p.tc.FaveToAS(ctx, fave)
if err != nil {
return fmt.Errorf("federateFave: error converting fave to as format: %s", err)
}
@@ -517,12 +517,12 @@ func (p *processor) federateFave(fave *gtsmodel.StatusFave, originAccount *gtsmo
if err != nil {
return fmt.Errorf("federateFave: error parsing outboxURI %s: %s", originAccount.OutboxURI, err)
}
- _, err = p.federator.FederatingActor().Send(context.Background(), outboxIRI, asFave)
+ _, err = p.federator.FederatingActor().Send(ctx, outboxIRI, asFave)
return err
}
-func (p *processor) federateAnnounce(boostWrapperStatus *gtsmodel.Status, boostingAccount *gtsmodel.Account, boostedAccount *gtsmodel.Account) error {
- announce, err := p.tc.BoostToAS(boostWrapperStatus, boostingAccount, boostedAccount)
+func (p *processor) federateAnnounce(ctx context.Context, boostWrapperStatus *gtsmodel.Status, boostingAccount *gtsmodel.Account, boostedAccount *gtsmodel.Account) error {
+ announce, err := p.tc.BoostToAS(ctx, boostWrapperStatus, boostingAccount, boostedAccount)
if err != nil {
return fmt.Errorf("federateAnnounce: error converting status to announce: %s", err)
}
@@ -532,12 +532,12 @@ func (p *processor) federateAnnounce(boostWrapperStatus *gtsmodel.Status, boosti
return fmt.Errorf("federateAnnounce: error parsing outboxURI %s: %s", boostingAccount.OutboxURI, err)
}
- _, err = p.federator.FederatingActor().Send(context.Background(), outboxIRI, announce)
+ _, err = p.federator.FederatingActor().Send(ctx, outboxIRI, announce)
return err
}
-func (p *processor) federateAccountUpdate(updatedAccount *gtsmodel.Account, originAccount *gtsmodel.Account) error {
- person, err := p.tc.AccountToAS(updatedAccount)
+func (p *processor) federateAccountUpdate(ctx context.Context, updatedAccount *gtsmodel.Account, originAccount *gtsmodel.Account) error {
+ person, err := p.tc.AccountToAS(ctx, updatedAccount)
if err != nil {
return fmt.Errorf("federateAccountUpdate: error converting account to person: %s", err)
}
@@ -552,25 +552,25 @@ func (p *processor) federateAccountUpdate(updatedAccount *gtsmodel.Account, orig
return fmt.Errorf("federateAnnounce: error parsing outboxURI %s: %s", originAccount.OutboxURI, err)
}
- _, err = p.federator.FederatingActor().Send(context.Background(), outboxIRI, update)
+ _, err = p.federator.FederatingActor().Send(ctx, outboxIRI, update)
return err
}
-func (p *processor) federateBlock(block *gtsmodel.Block) error {
+func (p *processor) federateBlock(ctx context.Context, block *gtsmodel.Block) error {
if block.Account == nil {
- a := &gtsmodel.Account{}
- if err := p.db.GetByID(block.AccountID, a); err != nil {
+ blockAccount, err := p.db.GetAccountByID(ctx, block.AccountID)
+ if err != nil {
return fmt.Errorf("federateBlock: error getting block account from database: %s", err)
}
- block.Account = a
+ block.Account = blockAccount
}
if block.TargetAccount == nil {
- a := &gtsmodel.Account{}
- if err := p.db.GetByID(block.TargetAccountID, a); err != nil {
+ blockTargetAccount, err := p.db.GetAccountByID(ctx, block.TargetAccountID)
+ if err != nil {
return fmt.Errorf("federateBlock: error getting block target account from database: %s", err)
}
- block.TargetAccount = a
+ block.TargetAccount = blockTargetAccount
}
// if both accounts are local there's nothing to do here
@@ -578,7 +578,7 @@ func (p *processor) federateBlock(block *gtsmodel.Block) error {
return nil
}
- asBlock, err := p.tc.BlockToAS(block)
+ asBlock, err := p.tc.BlockToAS(ctx, block)
if err != nil {
return fmt.Errorf("federateBlock: error converting block to AS format: %s", err)
}
@@ -588,25 +588,25 @@ func (p *processor) federateBlock(block *gtsmodel.Block) error {
return fmt.Errorf("federateBlock: error parsing outboxURI %s: %s", block.Account.OutboxURI, err)
}
- _, err = p.federator.FederatingActor().Send(context.Background(), outboxIRI, asBlock)
+ _, err = p.federator.FederatingActor().Send(ctx, outboxIRI, asBlock)
return err
}
-func (p *processor) federateUnblock(block *gtsmodel.Block) error {
+func (p *processor) federateUnblock(ctx context.Context, block *gtsmodel.Block) error {
if block.Account == nil {
- a := &gtsmodel.Account{}
- if err := p.db.GetByID(block.AccountID, a); err != nil {
+ blockAccount, err := p.db.GetAccountByID(ctx, block.AccountID)
+ if err != nil {
return fmt.Errorf("federateUnblock: error getting block account from database: %s", err)
}
- block.Account = a
+ block.Account = blockAccount
}
if block.TargetAccount == nil {
- a := &gtsmodel.Account{}
- if err := p.db.GetByID(block.TargetAccountID, a); err != nil {
+ blockTargetAccount, err := p.db.GetAccountByID(ctx, block.TargetAccountID)
+ if err != nil {
return fmt.Errorf("federateUnblock: error getting block target account from database: %s", err)
}
- block.TargetAccount = a
+ block.TargetAccount = blockTargetAccount
}
// if both accounts are local there's nothing to do here
@@ -614,7 +614,7 @@ func (p *processor) federateUnblock(block *gtsmodel.Block) error {
return nil
}
- asBlock, err := p.tc.BlockToAS(block)
+ asBlock, err := p.tc.BlockToAS(ctx, block)
if err != nil {
return fmt.Errorf("federateUnblock: error converting block to AS format: %s", err)
}
@@ -642,6 +642,6 @@ func (p *processor) federateUnblock(block *gtsmodel.Block) error {
if err != nil {
return fmt.Errorf("federateUnblock: error parsing outboxURI %s: %s", block.Account.OutboxURI, err)
}
- _, err = p.federator.FederatingActor().Send(context.Background(), outboxIRI, undo)
+ _, err = p.federator.FederatingActor().Send(ctx, outboxIRI, undo)
return err
}
diff --git a/internal/processing/fromcommon.go b/internal/processing/fromcommon.go
index 2c2635175..b7a6defc3 100644
--- a/internal/processing/fromcommon.go
+++ b/internal/processing/fromcommon.go
@@ -19,6 +19,7 @@
package processing
import (
+ "context"
"fmt"
"strings"
"sync"
@@ -28,7 +29,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/id"
)
-func (p *processor) notifyStatus(status *gtsmodel.Status) error {
+func (p *processor) notifyStatus(ctx context.Context, status *gtsmodel.Status) error {
// if there are no mentions in this status then just bail
if len(status.MentionIDs) == 0 {
return nil
@@ -36,7 +37,7 @@ func (p *processor) notifyStatus(status *gtsmodel.Status) error {
if status.Mentions == nil {
// there are mentions but they're not fully populated on the status yet so do this
- menchies, err := p.db.GetMentions(status.MentionIDs)
+ menchies, err := p.db.GetMentions(ctx, status.MentionIDs)
if err != nil {
return fmt.Errorf("notifyStatus: error getting mentions for status %s from the db: %s", status.ID, err)
}
@@ -47,7 +48,7 @@ func (p *processor) notifyStatus(status *gtsmodel.Status) error {
for _, m := range status.Mentions {
// make sure this is a local account, otherwise we don't need to create a notification for it
if m.TargetAccount == nil {
- a, err := p.db.GetAccountByID(m.TargetAccountID)
+ a, err := p.db.GetAccountByID(ctx, m.TargetAccountID)
if err != nil {
// we don't have the account or there's been an error
return fmt.Errorf("notifyStatus: error getting account with id %s from the db: %s", m.TargetAccountID, err)
@@ -60,7 +61,7 @@ func (p *processor) notifyStatus(status *gtsmodel.Status) error {
}
// make sure a notif doesn't already exist for this mention
- err := p.db.GetWhere([]db.Where{
+ err := p.db.GetWhere(ctx, []db.Where{
{Key: "notification_type", Value: gtsmodel.NotificationMention},
{Key: "target_account_id", Value: m.TargetAccountID},
{Key: "origin_account_id", Value: status.AccountID},
@@ -92,12 +93,12 @@ func (p *processor) notifyStatus(status *gtsmodel.Status) error {
Status: status,
}
- if err := p.db.Put(notif); err != nil {
+ if err := p.db.Put(ctx, notif); err != nil {
return fmt.Errorf("notifyStatus: error putting notification in database: %s", err)
}
// now stream the notification to the user
- mastoNotif, err := p.tc.NotificationToMasto(notif)
+ mastoNotif, err := p.tc.NotificationToMasto(ctx, notif)
if err != nil {
return fmt.Errorf("notifyStatus: error converting notification to masto representation: %s", err)
}
@@ -110,7 +111,7 @@ func (p *processor) notifyStatus(status *gtsmodel.Status) error {
return nil
}
-func (p *processor) notifyFollowRequest(followRequest *gtsmodel.FollowRequest, receivingAccount *gtsmodel.Account) error {
+func (p *processor) notifyFollowRequest(ctx context.Context, followRequest *gtsmodel.FollowRequest, receivingAccount *gtsmodel.Account) error {
// return if this isn't a local account
if receivingAccount.Domain != "" {
return nil
@@ -128,12 +129,12 @@ func (p *processor) notifyFollowRequest(followRequest *gtsmodel.FollowRequest, r
OriginAccountID: followRequest.AccountID,
}
- if err := p.db.Put(notif); err != nil {
+ if err := p.db.Put(ctx, notif); err != nil {
return fmt.Errorf("notifyFollowRequest: error putting notification in database: %s", err)
}
// now stream the notification to the user
- mastoNotif, err := p.tc.NotificationToMasto(notif)
+ mastoNotif, err := p.tc.NotificationToMasto(ctx, notif)
if err != nil {
return fmt.Errorf("notifyStatus: error converting notification to masto representation: %s", err)
}
@@ -145,14 +146,14 @@ func (p *processor) notifyFollowRequest(followRequest *gtsmodel.FollowRequest, r
return nil
}
-func (p *processor) notifyFollow(follow *gtsmodel.Follow, targetAccount *gtsmodel.Account) error {
+func (p *processor) notifyFollow(ctx context.Context, follow *gtsmodel.Follow, targetAccount *gtsmodel.Account) error {
// return if this isn't a local account
if targetAccount.Domain != "" {
return nil
}
// first remove the follow request notification
- if err := p.db.DeleteWhere([]db.Where{
+ if err := p.db.DeleteWhere(ctx, []db.Where{
{Key: "notification_type", Value: gtsmodel.NotificationFollowRequest},
{Key: "target_account_id", Value: follow.TargetAccountID},
{Key: "origin_account_id", Value: follow.AccountID},
@@ -174,12 +175,12 @@ func (p *processor) notifyFollow(follow *gtsmodel.Follow, targetAccount *gtsmode
OriginAccountID: follow.AccountID,
OriginAccount: follow.Account,
}
- if err := p.db.Put(notif); err != nil {
+ if err := p.db.Put(ctx, notif); err != nil {
return fmt.Errorf("notifyFollow: error putting notification in database: %s", err)
}
// now stream the notification to the user
- mastoNotif, err := p.tc.NotificationToMasto(notif)
+ mastoNotif, err := p.tc.NotificationToMasto(ctx, notif)
if err != nil {
return fmt.Errorf("notifyStatus: error converting notification to masto representation: %s", err)
}
@@ -191,7 +192,7 @@ func (p *processor) notifyFollow(follow *gtsmodel.Follow, targetAccount *gtsmode
return nil
}
-func (p *processor) notifyFave(fave *gtsmodel.StatusFave, targetAccount *gtsmodel.Account) error {
+func (p *processor) notifyFave(ctx context.Context, fave *gtsmodel.StatusFave, targetAccount *gtsmodel.Account) error {
// return if this isn't a local account
if targetAccount.Domain != "" {
return nil
@@ -213,12 +214,12 @@ func (p *processor) notifyFave(fave *gtsmodel.StatusFave, targetAccount *gtsmode
Status: fave.Status,
}
- if err := p.db.Put(notif); err != nil {
+ if err := p.db.Put(ctx, notif); err != nil {
return fmt.Errorf("notifyFave: error putting notification in database: %s", err)
}
// now stream the notification to the user
- mastoNotif, err := p.tc.NotificationToMasto(notif)
+ mastoNotif, err := p.tc.NotificationToMasto(ctx, notif)
if err != nil {
return fmt.Errorf("notifyStatus: error converting notification to masto representation: %s", err)
}
@@ -230,14 +231,14 @@ func (p *processor) notifyFave(fave *gtsmodel.StatusFave, targetAccount *gtsmode
return nil
}
-func (p *processor) notifyAnnounce(status *gtsmodel.Status) error {
+func (p *processor) notifyAnnounce(ctx context.Context, status *gtsmodel.Status) error {
if status.BoostOfID == "" {
// not a boost, nothing to do
return nil
}
if status.BoostOf == nil {
- boostedStatus, err := p.db.GetStatusByID(status.BoostOfID)
+ boostedStatus, err := p.db.GetStatusByID(ctx, status.BoostOfID)
if err != nil {
return fmt.Errorf("notifyAnnounce: error getting status with id %s: %s", status.BoostOfID, err)
}
@@ -245,7 +246,7 @@ func (p *processor) notifyAnnounce(status *gtsmodel.Status) error {
}
if status.BoostOfAccount == nil {
- boostedAcct, err := p.db.GetAccountByID(status.BoostOfAccountID)
+ boostedAcct, err := p.db.GetAccountByID(ctx, status.BoostOfAccountID)
if err != nil {
return fmt.Errorf("notifyAnnounce: error getting account with id %s: %s", status.BoostOfAccountID, err)
}
@@ -264,7 +265,7 @@ func (p *processor) notifyAnnounce(status *gtsmodel.Status) error {
}
// make sure a notif doesn't already exist for this announce
- err := p.db.GetWhere([]db.Where{
+ err := p.db.GetWhere(ctx, []db.Where{
{Key: "notification_type", Value: gtsmodel.NotificationReblog},
{Key: "target_account_id", Value: status.BoostOfAccountID},
{Key: "origin_account_id", Value: status.AccountID},
@@ -292,12 +293,12 @@ func (p *processor) notifyAnnounce(status *gtsmodel.Status) error {
Status: status,
}
- if err := p.db.Put(notif); err != nil {
+ if err := p.db.Put(ctx, notif); err != nil {
return fmt.Errorf("notifyAnnounce: error putting notification in database: %s", err)
}
// now stream the notification to the user
- mastoNotif, err := p.tc.NotificationToMasto(notif)
+ mastoNotif, err := p.tc.NotificationToMasto(ctx, notif)
if err != nil {
return fmt.Errorf("notifyStatus: error converting notification to masto representation: %s", err)
}
@@ -309,10 +310,10 @@ func (p *processor) notifyAnnounce(status *gtsmodel.Status) error {
return nil
}
-func (p *processor) timelineStatus(status *gtsmodel.Status) error {
+func (p *processor) timelineStatus(ctx context.Context, status *gtsmodel.Status) error {
// make sure the author account is pinned onto the status
if status.Account == nil {
- a, err := p.db.GetAccountByID(status.AccountID)
+ a, err := p.db.GetAccountByID(ctx, status.AccountID)
if err != nil {
return fmt.Errorf("timelineStatus: error getting author account with id %s: %s", status.AccountID, err)
}
@@ -320,7 +321,7 @@ func (p *processor) timelineStatus(status *gtsmodel.Status) error {
}
// get local followers of the account that posted the status
- follows, err := p.db.GetAccountFollowedBy(status.AccountID, true)
+ follows, err := p.db.GetAccountFollowedBy(ctx, status.AccountID, true)
if err != nil {
return fmt.Errorf("timelineStatus: error getting followers for account id %s: %s", status.AccountID, err)
}
@@ -338,7 +339,7 @@ func (p *processor) timelineStatus(status *gtsmodel.Status) error {
errors := make(chan error, len(follows))
for _, f := range follows {
- go p.timelineStatusForAccount(status, f.AccountID, errors, &wg)
+ go p.timelineStatusForAccount(ctx, status, f.AccountID, errors, &wg)
}
// read any errors that come in from the async functions
@@ -365,18 +366,18 @@ func (p *processor) timelineStatus(status *gtsmodel.Status) error {
return nil
}
-func (p *processor) timelineStatusForAccount(status *gtsmodel.Status, accountID string, errors chan error, wg *sync.WaitGroup) {
+func (p *processor) timelineStatusForAccount(ctx context.Context, status *gtsmodel.Status, accountID string, errors chan error, wg *sync.WaitGroup) {
defer wg.Done()
// get the timeline owner account
- timelineAccount, err := p.db.GetAccountByID(accountID)
+ timelineAccount, err := p.db.GetAccountByID(ctx, accountID)
if err != nil {
errors <- fmt.Errorf("timelineStatusForAccount: error getting account for timeline with id %s: %s", accountID, err)
return
}
// make sure the status is timelineable
- timelineable, err := p.filter.StatusHometimelineable(status, timelineAccount)
+ timelineable, err := p.filter.StatusHometimelineable(ctx, status, timelineAccount)
if err != nil {
errors <- fmt.Errorf("timelineStatusForAccount: error getting timelineability for status for timeline with id %s: %s", accountID, err)
return
@@ -387,7 +388,7 @@ func (p *processor) timelineStatusForAccount(status *gtsmodel.Status, accountID
}
// stick the status in the timeline for the account and then immediately prepare it so they can see it right away
- inserted, err := p.timelineManager.IngestAndPrepare(status, timelineAccount.ID)
+ inserted, err := p.timelineManager.IngestAndPrepare(ctx, status, timelineAccount.ID)
if err != nil {
errors <- fmt.Errorf("timelineStatusForAccount: error ingesting status %s: %s", status.ID, err)
return
@@ -395,7 +396,7 @@ func (p *processor) timelineStatusForAccount(status *gtsmodel.Status, accountID
// the status was inserted to stream it to the user
if inserted {
- mastoStatus, err := p.tc.StatusToMasto(status, timelineAccount)
+ mastoStatus, err := p.tc.StatusToMasto(ctx, status, timelineAccount)
if err != nil {
errors <- fmt.Errorf("timelineStatusForAccount: error converting status %s to frontend representation: %s", status.ID, err)
} else {
@@ -405,7 +406,7 @@ func (p *processor) timelineStatusForAccount(status *gtsmodel.Status, accountID
}
}
- mastoStatus, err := p.tc.StatusToMasto(status, timelineAccount)
+ mastoStatus, err := p.tc.StatusToMasto(ctx, status, timelineAccount)
if err != nil {
errors <- fmt.Errorf("timelineStatusForAccount: error converting status %s to frontend representation: %s", status.ID, err)
} else {
@@ -415,8 +416,8 @@ func (p *processor) timelineStatusForAccount(status *gtsmodel.Status, accountID
}
}
-func (p *processor) deleteStatusFromTimelines(status *gtsmodel.Status) error {
- if err := p.timelineManager.WipeStatusFromAllTimelines(status.ID); err != nil {
+func (p *processor) deleteStatusFromTimelines(ctx context.Context, status *gtsmodel.Status) error {
+ if err := p.timelineManager.WipeStatusFromAllTimelines(ctx, status.ID); err != nil {
return err
}
diff --git a/internal/processing/fromfederator.go b/internal/processing/fromfederator.go
index c95c27778..2bb74db34 100644
--- a/internal/processing/fromfederator.go
+++ b/internal/processing/fromfederator.go
@@ -19,6 +19,7 @@
package processing
import (
+ "context"
"errors"
"fmt"
"net/url"
@@ -29,7 +30,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/id"
)
-func (p *processor) processFromFederator(federatorMsg gtsmodel.FromFederator) error {
+func (p *processor) processFromFederator(ctx context.Context, federatorMsg gtsmodel.FromFederator) error {
l := p.log.WithFields(logrus.Fields{
"func": "processFromFederator",
"federatorMsg": fmt.Sprintf("%+v", federatorMsg),
@@ -48,16 +49,16 @@ func (p *processor) processFromFederator(federatorMsg gtsmodel.FromFederator) er
return errors.New("note was not parseable as *gtsmodel.Status")
}
- status, err := p.federator.EnrichRemoteStatus(federatorMsg.ReceivingAccount.Username, incomingStatus)
+ status, err := p.federator.EnrichRemoteStatus(ctx, federatorMsg.ReceivingAccount.Username, incomingStatus)
if err != nil {
return err
}
- if err := p.timelineStatus(status); err != nil {
+ if err := p.timelineStatus(ctx, status); err != nil {
return err
}
- if err := p.notifyStatus(status); err != nil {
+ if err := p.notifyStatus(ctx, status); err != nil {
return err
}
case gtsmodel.ActivityStreamsProfile:
@@ -70,7 +71,7 @@ func (p *processor) processFromFederator(federatorMsg gtsmodel.FromFederator) er
return errors.New("like was not parseable as *gtsmodel.StatusFave")
}
- if err := p.notifyFave(incomingFave, federatorMsg.ReceivingAccount); err != nil {
+ if err := p.notifyFave(ctx, incomingFave, federatorMsg.ReceivingAccount); err != nil {
return err
}
case gtsmodel.ActivityStreamsFollow:
@@ -80,7 +81,7 @@ func (p *processor) processFromFederator(federatorMsg gtsmodel.FromFederator) er
return errors.New("incomingFollowRequest was not parseable as *gtsmodel.FollowRequest")
}
- if err := p.notifyFollowRequest(incomingFollowRequest, federatorMsg.ReceivingAccount); err != nil {
+ if err := p.notifyFollowRequest(ctx, incomingFollowRequest, federatorMsg.ReceivingAccount); err != nil {
return err
}
case gtsmodel.ActivityStreamsAnnounce:
@@ -90,7 +91,7 @@ func (p *processor) processFromFederator(federatorMsg gtsmodel.FromFederator) er
return errors.New("announce was not parseable as *gtsmodel.Status")
}
- if err := p.federator.DereferenceAnnounce(incomingAnnounce, federatorMsg.ReceivingAccount.Username); err != nil {
+ if err := p.federator.DereferenceAnnounce(ctx, incomingAnnounce, federatorMsg.ReceivingAccount.Username); err != nil {
return fmt.Errorf("error dereferencing announce from federator: %s", err)
}
@@ -100,17 +101,17 @@ func (p *processor) processFromFederator(federatorMsg gtsmodel.FromFederator) er
}
incomingAnnounce.ID = incomingAnnounceID
- if err := p.db.PutStatus(incomingAnnounce); err != nil {
+ if err := p.db.PutStatus(ctx, incomingAnnounce); err != nil {
if err != db.ErrNoEntries {
return fmt.Errorf("error adding dereferenced announce to the db: %s", err)
}
}
- if err := p.timelineStatus(incomingAnnounce); err != nil {
+ if err := p.timelineStatus(ctx, incomingAnnounce); err != nil {
return err
}
- if err := p.notifyAnnounce(incomingAnnounce); err != nil {
+ if err := p.notifyAnnounce(ctx, incomingAnnounce); err != nil {
return err
}
case gtsmodel.ActivityStreamsBlock:
@@ -121,10 +122,10 @@ func (p *processor) processFromFederator(federatorMsg gtsmodel.FromFederator) er
}
// remove any of the blocking account's statuses from the blocked account's timeline, and vice versa
- if err := p.timelineManager.WipeStatusesFromAccountID(block.AccountID, block.TargetAccountID); err != nil {
+ if err := p.timelineManager.WipeStatusesFromAccountID(ctx, block.AccountID, block.TargetAccountID); err != nil {
return err
}
- if err := p.timelineManager.WipeStatusesFromAccountID(block.TargetAccountID, block.AccountID); err != nil {
+ if err := p.timelineManager.WipeStatusesFromAccountID(ctx, block.TargetAccountID, block.AccountID); err != nil {
return err
}
// TODO: same with notifications
@@ -145,7 +146,7 @@ func (p *processor) processFromFederator(federatorMsg gtsmodel.FromFederator) er
return err
}
- if _, _, err := p.federator.GetRemoteAccount(federatorMsg.ReceivingAccount.Username, incomingAccountURI, true); err != nil {
+ if _, _, err := p.federator.GetRemoteAccount(ctx, federatorMsg.ReceivingAccount.Username, incomingAccountURI, true); err != nil {
return fmt.Errorf("error dereferencing account from federator: %s", err)
}
}
@@ -165,25 +166,25 @@ func (p *processor) processFromFederator(federatorMsg gtsmodel.FromFederator) er
// delete all attachments for this status
for _, a := range statusToDelete.AttachmentIDs {
- if err := p.mediaProcessor.Delete(a); err != nil {
+ if err := p.mediaProcessor.Delete(ctx, a); err != nil {
return err
}
}
// delete all mentions for this status
for _, m := range statusToDelete.MentionIDs {
- if err := p.db.DeleteByID(m, &gtsmodel.Mention{}); err != nil {
+ if err := p.db.DeleteByID(ctx, m, &gtsmodel.Mention{}); err != nil {
return err
}
}
// delete all notifications for this status
- if err := p.db.DeleteWhere([]db.Where{{Key: "status_id", Value: statusToDelete.ID}}, &[]*gtsmodel.Notification{}); err != nil {
+ if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "status_id", Value: statusToDelete.ID}}, &[]*gtsmodel.Notification{}); err != nil {
return err
}
// remove this status from any and all timelines
- return p.deleteStatusFromTimelines(statusToDelete)
+ return p.deleteStatusFromTimelines(ctx, statusToDelete)
case gtsmodel.ActivityStreamsProfile:
// DELETE A PROFILE/ACCOUNT
// TODO: handle side effects of account deletion here: delete all objects, statuses, media etc associated with account
@@ -198,7 +199,7 @@ func (p *processor) processFromFederator(federatorMsg gtsmodel.FromFederator) er
return errors.New("follow was not parseable as *gtsmodel.Follow")
}
- if err := p.notifyFollow(follow, federatorMsg.ReceivingAccount); err != nil {
+ if err := p.notifyFollow(ctx, follow, federatorMsg.ReceivingAccount); err != nil {
return err
}
}
diff --git a/internal/processing/instance.go b/internal/processing/instance.go
index b151744ef..ced798c2e 100644
--- a/internal/processing/instance.go
+++ b/internal/processing/instance.go
@@ -19,6 +19,7 @@
package processing
import (
+ "context"
"fmt"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
@@ -29,13 +30,13 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/util"
)
-func (p *processor) InstanceGet(domain string) (*apimodel.Instance, gtserror.WithCode) {
+func (p *processor) InstanceGet(ctx context.Context, domain string) (*apimodel.Instance, gtserror.WithCode) {
i := &gtsmodel.Instance{}
- if err := p.db.GetWhere([]db.Where{{Key: "domain", Value: domain}}, i); err != nil {
+ if err := p.db.GetWhere(ctx, []db.Where{{Key: "domain", Value: domain}}, i); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error fetching instance %s: %s", p.config.Host, err))
}
- ai, err := p.tc.InstanceToMasto(i)
+ ai, err := p.tc.InstanceToMasto(ctx, i)
if err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error converting instance to api representation: %s", err))
}
@@ -43,15 +44,15 @@ func (p *processor) InstanceGet(domain string) (*apimodel.Instance, gtserror.Wit
return ai, nil
}
-func (p *processor) InstancePatch(form *apimodel.InstanceSettingsUpdateRequest) (*apimodel.Instance, gtserror.WithCode) {
+func (p *processor) InstancePatch(ctx context.Context, form *apimodel.InstanceSettingsUpdateRequest) (*apimodel.Instance, gtserror.WithCode) {
// fetch the instance entry from the db for processing
i := &gtsmodel.Instance{}
- if err := p.db.GetWhere([]db.Where{{Key: "domain", Value: p.config.Host}}, i); err != nil {
+ if err := p.db.GetWhere(ctx, []db.Where{{Key: "domain", Value: p.config.Host}}, i); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error fetching instance %s: %s", p.config.Host, err))
}
// fetch the instance account from the db for processing
- ia, err := p.db.GetInstanceAccount("")
+ ia, err := p.db.GetInstanceAccount(ctx, "")
if err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error fetching instance account %s: %s", p.config.Host, err))
}
@@ -67,13 +68,13 @@ func (p *processor) InstancePatch(form *apimodel.InstanceSettingsUpdateRequest)
// validate & update site contact account if it's set on the form
if form.ContactUsername != nil {
// make sure the account with the given username exists in the db
- contactAccount, err := p.db.GetLocalAccountByUsername(*form.ContactUsername)
+ contactAccount, err := p.db.GetLocalAccountByUsername(ctx, *form.ContactUsername)
if err != nil {
return nil, gtserror.NewErrorBadRequest(err, fmt.Sprintf("account with username %s not retrievable", *form.ContactUsername))
}
// make sure it has a user associated with it
contactUser := &gtsmodel.User{}
- if err := p.db.GetWhere([]db.Where{{Key: "account_id", Value: contactAccount.ID}}, contactUser); err != nil {
+ if err := p.db.GetWhere(ctx, []db.Where{{Key: "account_id", Value: contactAccount.ID}}, contactUser); err != nil {
return nil, gtserror.NewErrorBadRequest(err, fmt.Sprintf("user for account with username %s not retrievable", *form.ContactUsername))
}
// suspended accounts cannot be contact accounts
@@ -132,7 +133,7 @@ func (p *processor) InstancePatch(form *apimodel.InstanceSettingsUpdateRequest)
// process avatar if provided
if form.Avatar != nil && form.Avatar.Size != 0 {
- _, err := p.accountProcessor.UpdateAvatar(form.Avatar, ia.ID)
+ _, err := p.accountProcessor.UpdateAvatar(ctx, form.Avatar, ia.ID)
if err != nil {
return nil, gtserror.NewErrorBadRequest(err, "error processing avatar")
}
@@ -140,17 +141,17 @@ func (p *processor) InstancePatch(form *apimodel.InstanceSettingsUpdateRequest)
// process header if provided
if form.Header != nil && form.Header.Size != 0 {
- _, err := p.accountProcessor.UpdateHeader(form.Header, ia.ID)
+ _, err := p.accountProcessor.UpdateHeader(ctx, form.Header, ia.ID)
if err != nil {
return nil, gtserror.NewErrorBadRequest(err, "error processing header")
}
}
- if err := p.db.UpdateByID(i.ID, i); err != nil {
+ if err := p.db.UpdateByID(ctx, i.ID, i); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error updating instance %s: %s", p.config.Host, err))
}
- ai, err := p.tc.InstanceToMasto(i)
+ ai, err := p.tc.InstanceToMasto(ctx, i)
if err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error converting instance to api representation: %s", err))
}
diff --git a/internal/processing/media.go b/internal/processing/media.go
index 6ca0eda5b..0b2443893 100644
--- a/internal/processing/media.go
+++ b/internal/processing/media.go
@@ -19,23 +19,25 @@
package processing
import (
+ "context"
+
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
)
-func (p *processor) MediaCreate(authed *oauth.Auth, form *apimodel.AttachmentRequest) (*apimodel.Attachment, error) {
- return p.mediaProcessor.Create(authed.Account, form)
+func (p *processor) MediaCreate(ctx context.Context, authed *oauth.Auth, form *apimodel.AttachmentRequest) (*apimodel.Attachment, error) {
+ return p.mediaProcessor.Create(ctx, authed.Account, form)
}
-func (p *processor) MediaGet(authed *oauth.Auth, mediaAttachmentID string) (*apimodel.Attachment, gtserror.WithCode) {
- return p.mediaProcessor.GetMedia(authed.Account, mediaAttachmentID)
+func (p *processor) MediaGet(ctx context.Context, authed *oauth.Auth, mediaAttachmentID string) (*apimodel.Attachment, gtserror.WithCode) {
+ return p.mediaProcessor.GetMedia(ctx, authed.Account, mediaAttachmentID)
}
-func (p *processor) MediaUpdate(authed *oauth.Auth, mediaAttachmentID string, form *apimodel.AttachmentUpdateRequest) (*apimodel.Attachment, gtserror.WithCode) {
- return p.mediaProcessor.Update(authed.Account, mediaAttachmentID, form)
+func (p *processor) MediaUpdate(ctx context.Context, authed *oauth.Auth, mediaAttachmentID string, form *apimodel.AttachmentUpdateRequest) (*apimodel.Attachment, gtserror.WithCode) {
+ return p.mediaProcessor.Update(ctx, authed.Account, mediaAttachmentID, form)
}
-func (p *processor) FileGet(authed *oauth.Auth, form *apimodel.GetContentRequestForm) (*apimodel.Content, error) {
- return p.mediaProcessor.GetFile(authed.Account, form)
+func (p *processor) FileGet(ctx context.Context, authed *oauth.Auth, form *apimodel.GetContentRequestForm) (*apimodel.Content, error) {
+ return p.mediaProcessor.GetFile(ctx, authed.Account, form)
}
diff --git a/internal/processing/media/create.go b/internal/processing/media/create.go
index 5b8cdf604..648e4d46a 100644
--- a/internal/processing/media/create.go
+++ b/internal/processing/media/create.go
@@ -20,6 +20,7 @@ package media
import (
"bytes"
+ "context"
"errors"
"fmt"
"io"
@@ -29,7 +30,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/text"
)
-func (p *processor) Create(account *gtsmodel.Account, form *apimodel.AttachmentRequest) (*apimodel.Attachment, error) {
+func (p *processor) Create(ctx context.Context, account *gtsmodel.Account, form *apimodel.AttachmentRequest) (*apimodel.Attachment, error) {
// open the attachment and extract the bytes from it
f, err := form.File.Open()
if err != nil {
@@ -45,7 +46,7 @@ func (p *processor) Create(account *gtsmodel.Account, form *apimodel.AttachmentR
}
// allow the mediaHandler to work its magic of processing the attachment bytes, and putting them in whatever storage backend we're using
- attachment, err := p.mediaHandler.ProcessAttachment(buf.Bytes(), account.ID, "")
+ attachment, err := p.mediaHandler.ProcessAttachment(ctx, buf.Bytes(), account.ID, "")
if err != nil {
return nil, fmt.Errorf("error reading attachment: %s", err)
}
@@ -66,13 +67,13 @@ func (p *processor) Create(account *gtsmodel.Account, form *apimodel.AttachmentR
// prepare the frontend representation now -- if there are any errors here at least we can bail without
// having already put something in the database and then having to clean it up again (eugh)
- mastoAttachment, err := p.tc.AttachmentToMasto(attachment)
+ mastoAttachment, err := p.tc.AttachmentToMasto(ctx, attachment)
if err != nil {
return nil, fmt.Errorf("error parsing media attachment to frontend type: %s", err)
}
// now we can confidently put the attachment in the database
- if err := p.db.Put(attachment); err != nil {
+ if err := p.db.Put(ctx, attachment); err != nil {
return nil, fmt.Errorf("error storing media attachment in db: %s", err)
}
diff --git a/internal/processing/media/delete.go b/internal/processing/media/delete.go
index b5ea8c806..281ddba03 100644
--- a/internal/processing/media/delete.go
+++ b/internal/processing/media/delete.go
@@ -1,17 +1,17 @@
package media
import (
+ "context"
"fmt"
"strings"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
- "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
-func (p *processor) Delete(mediaAttachmentID string) gtserror.WithCode {
- a := &gtsmodel.MediaAttachment{}
- if err := p.db.GetByID(mediaAttachmentID, a); err != nil {
+func (p *processor) Delete(ctx context.Context, mediaAttachmentID string) gtserror.WithCode {
+ attachment, err := p.db.GetAttachmentByID(ctx, mediaAttachmentID)
+ if err != nil {
if err == db.ErrNoEntries {
// attachment already gone
return nil
@@ -23,21 +23,21 @@ func (p *processor) Delete(mediaAttachmentID string) gtserror.WithCode {
errs := []string{}
// delete the thumbnail from storage
- if a.Thumbnail.Path != "" {
- if err := p.storage.RemoveFileAt(a.Thumbnail.Path); err != nil {
- errs = append(errs, fmt.Sprintf("remove thumbnail at path %s: %s", a.Thumbnail.Path, err))
+ if attachment.Thumbnail.Path != "" {
+ if err := p.storage.RemoveFileAt(attachment.Thumbnail.Path); err != nil {
+ errs = append(errs, fmt.Sprintf("remove thumbnail at path %s: %s", attachment.Thumbnail.Path, err))
}
}
// delete the file from storage
- if a.File.Path != "" {
- if err := p.storage.RemoveFileAt(a.File.Path); err != nil {
- errs = append(errs, fmt.Sprintf("remove file at path %s: %s", a.File.Path, err))
+ if attachment.File.Path != "" {
+ if err := p.storage.RemoveFileAt(attachment.File.Path); err != nil {
+ errs = append(errs, fmt.Sprintf("remove file at path %s: %s", attachment.File.Path, err))
}
}
// delete the attachment
- if err := p.db.DeleteByID(mediaAttachmentID, a); err != nil {
+ if err := p.db.DeleteByID(ctx, mediaAttachmentID, attachment); err != nil {
if err != db.ErrNoEntries {
errs = append(errs, fmt.Sprintf("remove attachment: %s", err))
}
diff --git a/internal/processing/media/getfile.go b/internal/processing/media/getfile.go
index 01288c56d..c9c9b556d 100644
--- a/internal/processing/media/getfile.go
+++ b/internal/processing/media/getfile.go
@@ -19,6 +19,7 @@
package media
import (
+ "context"
"fmt"
"strings"
@@ -28,7 +29,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/media"
)
-func (p *processor) GetFile(account *gtsmodel.Account, form *apimodel.GetContentRequestForm) (*apimodel.Content, error) {
+func (p *processor) GetFile(ctx context.Context, account *gtsmodel.Account, form *apimodel.GetContentRequestForm) (*apimodel.Content, error) {
// parse the form fields
mediaSize, err := media.ParseMediaSize(form.MediaSize)
if err != nil {
@@ -47,8 +48,8 @@ func (p *processor) GetFile(account *gtsmodel.Account, form *apimodel.GetContent
wantedMediaID := spl[0]
// get the account that owns the media and make sure it's not suspended
- acct := &gtsmodel.Account{}
- if err := p.db.GetByID(form.AccountID, acct); err != nil {
+ acct, err := p.db.GetAccountByID(ctx, form.AccountID)
+ if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("account with id %s could not be selected from the db: %s", form.AccountID, err))
}
if !acct.SuspendedAt.IsZero() {
@@ -57,7 +58,7 @@ func (p *processor) GetFile(account *gtsmodel.Account, form *apimodel.GetContent
// make sure the requesting account and the media account don't block each other
if account != nil {
- blocked, err := p.db.IsBlocked(account.ID, form.AccountID, true)
+ blocked, err := p.db.IsBlocked(ctx, account.ID, form.AccountID, true)
if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("block status could not be established between accounts %s and %s: %s", form.AccountID, account.ID, err))
}
@@ -73,7 +74,7 @@ func (p *processor) GetFile(account *gtsmodel.Account, form *apimodel.GetContent
switch mediaType {
case media.Emoji:
e := &gtsmodel.Emoji{}
- if err := p.db.GetByID(wantedMediaID, e); err != nil {
+ if err := p.db.GetByID(ctx, wantedMediaID, e); err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("emoji %s could not be taken from the db: %s", wantedMediaID, err))
}
if e.Disabled {
@@ -90,8 +91,8 @@ func (p *processor) GetFile(account *gtsmodel.Account, form *apimodel.GetContent
return nil, gtserror.NewErrorNotFound(fmt.Errorf("media size %s not recognized for emoji", mediaSize))
}
case media.Attachment, media.Header, media.Avatar:
- a := &gtsmodel.MediaAttachment{}
- if err := p.db.GetByID(wantedMediaID, a); err != nil {
+ a, err := p.db.GetAttachmentByID(ctx, wantedMediaID)
+ if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("attachment %s could not be taken from the db: %s", wantedMediaID, err))
}
if a.AccountID != form.AccountID {
diff --git a/internal/processing/media/getmedia.go b/internal/processing/media/getmedia.go
index 380a54cc2..91608e90d 100644
--- a/internal/processing/media/getmedia.go
+++ b/internal/processing/media/getmedia.go
@@ -19,6 +19,7 @@
package media
import (
+ "context"
"errors"
"fmt"
@@ -28,9 +29,9 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
-func (p *processor) GetMedia(account *gtsmodel.Account, mediaAttachmentID string) (*apimodel.Attachment, gtserror.WithCode) {
- attachment := &gtsmodel.MediaAttachment{}
- if err := p.db.GetByID(mediaAttachmentID, attachment); err != nil {
+func (p *processor) GetMedia(ctx context.Context, account *gtsmodel.Account, mediaAttachmentID string) (*apimodel.Attachment, gtserror.WithCode) {
+ attachment, err := p.db.GetAttachmentByID(ctx, mediaAttachmentID)
+ if err != nil {
if err == db.ErrNoEntries {
// attachment doesn't exist
return nil, gtserror.NewErrorNotFound(errors.New("attachment doesn't exist in the db"))
@@ -42,7 +43,7 @@ func (p *processor) GetMedia(account *gtsmodel.Account, mediaAttachmentID string
return nil, gtserror.NewErrorNotFound(errors.New("attachment not owned by requesting account"))
}
- a, err := p.tc.AttachmentToMasto(attachment)
+ a, err := p.tc.AttachmentToMasto(ctx, attachment)
if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("error converting attachment: %s", err))
}
diff --git a/internal/processing/media/media.go b/internal/processing/media/media.go
index 79c9a7e18..6b88143e2 100644
--- a/internal/processing/media/media.go
+++ b/internal/processing/media/media.go
@@ -19,6 +19,8 @@
package media
import (
+ "context"
+
"github.com/sirupsen/logrus"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/blob"
@@ -33,12 +35,12 @@ import (
// Processor wraps a bunch of functions for processing media actions.
type Processor interface {
// Create creates a new media attachment belonging to the given account, using the request form.
- Create(account *gtsmodel.Account, form *apimodel.AttachmentRequest) (*apimodel.Attachment, error)
+ Create(ctx context.Context, account *gtsmodel.Account, form *apimodel.AttachmentRequest) (*apimodel.Attachment, error)
// Delete deletes the media attachment with the given ID, including all files pertaining to that attachment.
- Delete(mediaAttachmentID string) gtserror.WithCode
- GetFile(account *gtsmodel.Account, form *apimodel.GetContentRequestForm) (*apimodel.Content, error)
- GetMedia(account *gtsmodel.Account, mediaAttachmentID string) (*apimodel.Attachment, gtserror.WithCode)
- Update(account *gtsmodel.Account, mediaAttachmentID string, form *apimodel.AttachmentUpdateRequest) (*apimodel.Attachment, gtserror.WithCode)
+ Delete(ctx context.Context, mediaAttachmentID string) gtserror.WithCode
+ GetFile(ctx context.Context, account *gtsmodel.Account, form *apimodel.GetContentRequestForm) (*apimodel.Content, error)
+ GetMedia(ctx context.Context, account *gtsmodel.Account, mediaAttachmentID string) (*apimodel.Attachment, gtserror.WithCode)
+ Update(ctx context.Context, account *gtsmodel.Account, mediaAttachmentID string, form *apimodel.AttachmentUpdateRequest) (*apimodel.Attachment, gtserror.WithCode)
}
type processor struct {
diff --git a/internal/processing/media/update.go b/internal/processing/media/update.go
index 89ed08ac1..6f15f2ace 100644
--- a/internal/processing/media/update.go
+++ b/internal/processing/media/update.go
@@ -19,6 +19,7 @@
package media
import (
+ "context"
"errors"
"fmt"
@@ -29,9 +30,9 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/text"
)
-func (p *processor) Update(account *gtsmodel.Account, mediaAttachmentID string, form *apimodel.AttachmentUpdateRequest) (*apimodel.Attachment, gtserror.WithCode) {
- attachment := &gtsmodel.MediaAttachment{}
- if err := p.db.GetByID(mediaAttachmentID, attachment); err != nil {
+func (p *processor) Update(ctx context.Context, account *gtsmodel.Account, mediaAttachmentID string, form *apimodel.AttachmentUpdateRequest) (*apimodel.Attachment, gtserror.WithCode) {
+ attachment, err := p.db.GetAttachmentByID(ctx, mediaAttachmentID)
+ if err != nil {
if err == db.ErrNoEntries {
// attachment doesn't exist
return nil, gtserror.NewErrorNotFound(errors.New("attachment doesn't exist in the db"))
@@ -45,7 +46,7 @@ func (p *processor) Update(account *gtsmodel.Account, mediaAttachmentID string,
if form.Description != nil {
attachment.Description = text.RemoveHTML(*form.Description)
- if err := p.db.UpdateByID(mediaAttachmentID, attachment); err != nil {
+ if err := p.db.UpdateByID(ctx, mediaAttachmentID, attachment); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("database error updating description: %s", err))
}
}
@@ -57,12 +58,12 @@ func (p *processor) Update(account *gtsmodel.Account, mediaAttachmentID string,
}
attachment.FileMeta.Focus.X = focusx
attachment.FileMeta.Focus.Y = focusy
- if err := p.db.UpdateByID(mediaAttachmentID, attachment); err != nil {
+ if err := p.db.UpdateByID(ctx, mediaAttachmentID, attachment); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("database error updating focus: %s", err))
}
}
- a, err := p.tc.AttachmentToMasto(attachment)
+ a, err := p.tc.AttachmentToMasto(ctx, attachment)
if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("error converting attachment: %s", err))
}
diff --git a/internal/processing/notification.go b/internal/processing/notification.go
index 7af74b04f..f91d2f2cd 100644
--- a/internal/processing/notification.go
+++ b/internal/processing/notification.go
@@ -19,22 +19,24 @@
package processing
import (
+ "context"
+
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
)
-func (p *processor) NotificationsGet(authed *oauth.Auth, limit int, maxID string, sinceID string) ([]*apimodel.Notification, gtserror.WithCode) {
+func (p *processor) NotificationsGet(ctx context.Context, authed *oauth.Auth, limit int, maxID string, sinceID string) ([]*apimodel.Notification, gtserror.WithCode) {
l := p.log.WithField("func", "NotificationsGet")
- notifs, err := p.db.GetNotifications(authed.Account.ID, limit, maxID, sinceID)
+ notifs, err := p.db.GetNotifications(ctx, authed.Account.ID, limit, maxID, sinceID)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
mastoNotifs := []*apimodel.Notification{}
for _, n := range notifs {
- mastoNotif, err := p.tc.NotificationToMasto(n)
+ mastoNotif, err := p.tc.NotificationToMasto(ctx, n)
if err != nil {
l.Debugf("got an error converting a notification to masto, will skip it: %s", err)
continue
diff --git a/internal/processing/processor.go b/internal/processing/processor.go
index 48ed2a35f..8df464ce0 100644
--- a/internal/processing/processor.go
+++ b/internal/processing/processor.go
@@ -51,7 +51,7 @@ import (
// for clean distribution of messages without slowing down the client API and harming the user experience.
type Processor interface {
// Start starts the Processor, reading from its channels and passing messages back and forth.
- Start() error
+ Start(ctx context.Context) error
// Stop stops the processor cleanly, finishing handling any remaining messages before closing down.
Stop() error
@@ -64,108 +64,108 @@ type Processor interface {
*/
// AccountCreate processes the given form for creating a new account, returning an oauth token for that account if successful.
- AccountCreate(authed *oauth.Auth, form *apimodel.AccountCreateRequest) (*apimodel.Token, error)
+ AccountCreate(ctx context.Context, authed *oauth.Auth, form *apimodel.AccountCreateRequest) (*apimodel.Token, error)
// AccountGet processes the given request for account information.
- AccountGet(authed *oauth.Auth, targetAccountID string) (*apimodel.Account, error)
+ AccountGet(ctx context.Context, authed *oauth.Auth, targetAccountID string) (*apimodel.Account, error)
// AccountUpdate processes the update of an account with the given form
- AccountUpdate(authed *oauth.Auth, form *apimodel.UpdateCredentialsRequest) (*apimodel.Account, error)
+ AccountUpdate(ctx context.Context, authed *oauth.Auth, form *apimodel.UpdateCredentialsRequest) (*apimodel.Account, error)
// AccountStatusesGet fetches a number of statuses (in time descending order) from the given account, filtered by visibility for
// the account given in authed.
- AccountStatusesGet(authed *oauth.Auth, targetAccountID string, limit int, excludeReplies bool, maxID string, pinned bool, mediaOnly bool) ([]apimodel.Status, gtserror.WithCode)
+ AccountStatusesGet(ctx context.Context, authed *oauth.Auth, targetAccountID string, limit int, excludeReplies bool, maxID string, pinned bool, mediaOnly bool) ([]apimodel.Status, gtserror.WithCode)
// AccountFollowersGet fetches a list of the target account's followers.
- AccountFollowersGet(authed *oauth.Auth, targetAccountID string) ([]apimodel.Account, gtserror.WithCode)
+ AccountFollowersGet(ctx context.Context, authed *oauth.Auth, targetAccountID string) ([]apimodel.Account, gtserror.WithCode)
// AccountFollowingGet fetches a list of the accounts that target account is following.
- AccountFollowingGet(authed *oauth.Auth, targetAccountID string) ([]apimodel.Account, gtserror.WithCode)
+ AccountFollowingGet(ctx context.Context, authed *oauth.Auth, targetAccountID string) ([]apimodel.Account, gtserror.WithCode)
// AccountRelationshipGet returns a relationship model describing the relationship of the targetAccount to the Authed account.
- AccountRelationshipGet(authed *oauth.Auth, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode)
+ AccountRelationshipGet(ctx context.Context, authed *oauth.Auth, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode)
// AccountFollowCreate handles a follow request to an account, either remote or local.
- AccountFollowCreate(authed *oauth.Auth, form *apimodel.AccountFollowRequest) (*apimodel.Relationship, gtserror.WithCode)
+ AccountFollowCreate(ctx context.Context, authed *oauth.Auth, form *apimodel.AccountFollowRequest) (*apimodel.Relationship, gtserror.WithCode)
// AccountFollowRemove handles the removal of a follow/follow request to an account, either remote or local.
- AccountFollowRemove(authed *oauth.Auth, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode)
+ AccountFollowRemove(ctx context.Context, authed *oauth.Auth, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode)
// AccountBlockCreate handles the creation of a block from authed account to target account, either remote or local.
- AccountBlockCreate(authed *oauth.Auth, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode)
+ AccountBlockCreate(ctx context.Context, authed *oauth.Auth, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode)
// AccountBlockRemove handles the removal of a block from authed account to target account, either remote or local.
- AccountBlockRemove(authed *oauth.Auth, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode)
+ AccountBlockRemove(ctx context.Context, authed *oauth.Auth, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode)
// AdminEmojiCreate handles the creation of a new instance emoji by an admin, using the given form.
- AdminEmojiCreate(authed *oauth.Auth, form *apimodel.EmojiCreateRequest) (*apimodel.Emoji, error)
+ AdminEmojiCreate(ctx context.Context, authed *oauth.Auth, form *apimodel.EmojiCreateRequest) (*apimodel.Emoji, error)
// AdminDomainBlockCreate handles the creation of a new domain block by an admin, using the given form.
- AdminDomainBlockCreate(authed *oauth.Auth, form *apimodel.DomainBlockCreateRequest) (*apimodel.DomainBlock, gtserror.WithCode)
+ AdminDomainBlockCreate(ctx context.Context, authed *oauth.Auth, form *apimodel.DomainBlockCreateRequest) (*apimodel.DomainBlock, gtserror.WithCode)
// AdminDomainBlocksImport handles the import of multiple domain blocks by an admin, using the given form.
- AdminDomainBlocksImport(authed *oauth.Auth, form *apimodel.DomainBlockCreateRequest) ([]*apimodel.DomainBlock, gtserror.WithCode)
+ AdminDomainBlocksImport(ctx context.Context, authed *oauth.Auth, form *apimodel.DomainBlockCreateRequest) ([]*apimodel.DomainBlock, gtserror.WithCode)
// AdminDomainBlocksGet returns a list of currently blocked domains.
- AdminDomainBlocksGet(authed *oauth.Auth, export bool) ([]*apimodel.DomainBlock, gtserror.WithCode)
+ AdminDomainBlocksGet(ctx context.Context, authed *oauth.Auth, export bool) ([]*apimodel.DomainBlock, gtserror.WithCode)
// AdminDomainBlockGet returns one domain block, specified by ID.
- AdminDomainBlockGet(authed *oauth.Auth, id string, export bool) (*apimodel.DomainBlock, gtserror.WithCode)
+ AdminDomainBlockGet(ctx context.Context, authed *oauth.Auth, id string, export bool) (*apimodel.DomainBlock, gtserror.WithCode)
// AdminDomainBlockDelete deletes one domain block, specified by ID, returning the deleted domain block.
- AdminDomainBlockDelete(authed *oauth.Auth, id string) (*apimodel.DomainBlock, gtserror.WithCode)
+ AdminDomainBlockDelete(ctx context.Context, authed *oauth.Auth, id string) (*apimodel.DomainBlock, gtserror.WithCode)
// AppCreate processes the creation of a new API application
- AppCreate(authed *oauth.Auth, form *apimodel.ApplicationCreateRequest) (*apimodel.Application, error)
+ AppCreate(ctx context.Context, authed *oauth.Auth, form *apimodel.ApplicationCreateRequest) (*apimodel.Application, error)
// BlocksGet returns a list of accounts blocked by the requesting account.
- BlocksGet(authed *oauth.Auth, maxID string, sinceID string, limit int) (*apimodel.BlocksResponse, gtserror.WithCode)
+ BlocksGet(ctx context.Context, authed *oauth.Auth, maxID string, sinceID string, limit int) (*apimodel.BlocksResponse, gtserror.WithCode)
// FileGet handles the fetching of a media attachment file via the fileserver.
- FileGet(authed *oauth.Auth, form *apimodel.GetContentRequestForm) (*apimodel.Content, error)
+ FileGet(ctx context.Context, authed *oauth.Auth, form *apimodel.GetContentRequestForm) (*apimodel.Content, error)
// FollowRequestsGet handles the getting of the authed account's incoming follow requests
- FollowRequestsGet(auth *oauth.Auth) ([]apimodel.Account, gtserror.WithCode)
+ FollowRequestsGet(ctx context.Context, auth *oauth.Auth) ([]apimodel.Account, gtserror.WithCode)
// FollowRequestAccept handles the acceptance of a follow request from the given account ID
- FollowRequestAccept(auth *oauth.Auth, accountID string) (*apimodel.Relationship, gtserror.WithCode)
+ FollowRequestAccept(ctx context.Context, auth *oauth.Auth, accountID string) (*apimodel.Relationship, gtserror.WithCode)
// InstanceGet retrieves instance information for serving at api/v1/instance
- InstanceGet(domain string) (*apimodel.Instance, gtserror.WithCode)
+ InstanceGet(ctx context.Context, domain string) (*apimodel.Instance, gtserror.WithCode)
// InstancePatch updates this instance according to the given form.
//
// It should already be ascertained that the requesting account is authenticated and an admin.
- InstancePatch(form *apimodel.InstanceSettingsUpdateRequest) (*apimodel.Instance, gtserror.WithCode)
+ InstancePatch(ctx context.Context, form *apimodel.InstanceSettingsUpdateRequest) (*apimodel.Instance, gtserror.WithCode)
// MediaCreate handles the creation of a media attachment, using the given form.
- MediaCreate(authed *oauth.Auth, form *apimodel.AttachmentRequest) (*apimodel.Attachment, error)
+ MediaCreate(ctx context.Context, authed *oauth.Auth, form *apimodel.AttachmentRequest) (*apimodel.Attachment, error)
// MediaGet handles the GET of a media attachment with the given ID
- MediaGet(authed *oauth.Auth, attachmentID string) (*apimodel.Attachment, gtserror.WithCode)
+ MediaGet(ctx context.Context, authed *oauth.Auth, attachmentID string) (*apimodel.Attachment, gtserror.WithCode)
// MediaUpdate handles the PUT of a media attachment with the given ID and form
- MediaUpdate(authed *oauth.Auth, attachmentID string, form *apimodel.AttachmentUpdateRequest) (*apimodel.Attachment, gtserror.WithCode)
+ MediaUpdate(ctx context.Context, authed *oauth.Auth, attachmentID string, form *apimodel.AttachmentUpdateRequest) (*apimodel.Attachment, gtserror.WithCode)
// NotificationsGet
- NotificationsGet(authed *oauth.Auth, limit int, maxID string, sinceID string) ([]*apimodel.Notification, gtserror.WithCode)
+ NotificationsGet(ctx context.Context, authed *oauth.Auth, limit int, maxID string, sinceID string) ([]*apimodel.Notification, gtserror.WithCode)
// SearchGet performs a search with the given params, resolving/dereferencing remotely as desired
- SearchGet(authed *oauth.Auth, searchQuery *apimodel.SearchQuery) (*apimodel.SearchResult, gtserror.WithCode)
+ SearchGet(ctx context.Context, authed *oauth.Auth, searchQuery *apimodel.SearchQuery) (*apimodel.SearchResult, gtserror.WithCode)
// StatusCreate processes the given form to create a new status, returning the api model representation of that status if it's OK.
- StatusCreate(authed *oauth.Auth, form *apimodel.AdvancedStatusCreateForm) (*apimodel.Status, error)
+ StatusCreate(ctx context.Context, authed *oauth.Auth, form *apimodel.AdvancedStatusCreateForm) (*apimodel.Status, error)
// StatusDelete processes the delete of a given status, returning the deleted status if the delete goes through.
- StatusDelete(authed *oauth.Auth, targetStatusID string) (*apimodel.Status, error)
+ StatusDelete(ctx context.Context, authed *oauth.Auth, targetStatusID string) (*apimodel.Status, error)
// StatusFave processes the faving of a given status, returning the updated status if the fave goes through.
- StatusFave(authed *oauth.Auth, targetStatusID string) (*apimodel.Status, error)
+ StatusFave(ctx context.Context, authed *oauth.Auth, targetStatusID string) (*apimodel.Status, error)
// StatusBoost processes the boost/reblog of a given status, returning the newly-created boost if all is well.
- StatusBoost(authed *oauth.Auth, targetStatusID string) (*apimodel.Status, gtserror.WithCode)
+ StatusBoost(ctx context.Context, authed *oauth.Auth, targetStatusID string) (*apimodel.Status, gtserror.WithCode)
// StatusUnboost processes the unboost/unreblog of a given status, returning the status if all is well.
- StatusUnboost(authed *oauth.Auth, targetStatusID string) (*apimodel.Status, gtserror.WithCode)
+ StatusUnboost(ctx context.Context, authed *oauth.Auth, targetStatusID string) (*apimodel.Status, gtserror.WithCode)
// StatusBoostedBy returns a slice of accounts that have boosted the given status, filtered according to privacy settings.
- StatusBoostedBy(authed *oauth.Auth, targetStatusID string) ([]*apimodel.Account, gtserror.WithCode)
+ StatusBoostedBy(ctx context.Context, authed *oauth.Auth, targetStatusID string) ([]*apimodel.Account, gtserror.WithCode)
// StatusFavedBy returns a slice of accounts that have liked the given status, filtered according to privacy settings.
- StatusFavedBy(authed *oauth.Auth, targetStatusID string) ([]*apimodel.Account, error)
+ StatusFavedBy(ctx context.Context, authed *oauth.Auth, targetStatusID string) ([]*apimodel.Account, error)
// StatusGet gets the given status, taking account of privacy settings and blocks etc.
- StatusGet(authed *oauth.Auth, targetStatusID string) (*apimodel.Status, error)
+ StatusGet(ctx context.Context, authed *oauth.Auth, targetStatusID string) (*apimodel.Status, error)
// StatusUnfave processes the unfaving of a given status, returning the updated status if the fave goes through.
- StatusUnfave(authed *oauth.Auth, targetStatusID string) (*apimodel.Status, error)
+ StatusUnfave(ctx context.Context, authed *oauth.Auth, targetStatusID string) (*apimodel.Status, error)
// StatusGetContext returns the context (previous and following posts) from the given status ID
- StatusGetContext(authed *oauth.Auth, targetStatusID string) (*apimodel.Context, gtserror.WithCode)
+ StatusGetContext(ctx context.Context, authed *oauth.Auth, targetStatusID string) (*apimodel.Context, gtserror.WithCode)
// HomeTimelineGet returns statuses from the home timeline, with the given filters/parameters.
- HomeTimelineGet(authed *oauth.Auth, maxID string, sinceID string, minID string, limit int, local bool) (*apimodel.StatusTimelineResponse, gtserror.WithCode)
+ HomeTimelineGet(ctx context.Context, authed *oauth.Auth, maxID string, sinceID string, minID string, limit int, local bool) (*apimodel.StatusTimelineResponse, gtserror.WithCode)
// PublicTimelineGet returns statuses from the public/local timeline, with the given filters/parameters.
- PublicTimelineGet(authed *oauth.Auth, maxID string, sinceID string, minID string, limit int, local bool) (*apimodel.StatusTimelineResponse, gtserror.WithCode)
+ PublicTimelineGet(ctx context.Context, authed *oauth.Auth, maxID string, sinceID string, minID string, limit int, local bool) (*apimodel.StatusTimelineResponse, gtserror.WithCode)
// FavedTimelineGet returns faved statuses, with the given filters/parameters.
- FavedTimelineGet(authed *oauth.Auth, maxID string, minID string, limit int) (*apimodel.StatusTimelineResponse, gtserror.WithCode)
+ FavedTimelineGet(ctx context.Context, authed *oauth.Auth, maxID string, minID string, limit int) (*apimodel.StatusTimelineResponse, gtserror.WithCode)
// AuthorizeStreamingRequest returns a gotosocial account in exchange for an access token, or an error if the given token is not valid.
- AuthorizeStreamingRequest(accessToken string) (*gtsmodel.Account, error)
+ AuthorizeStreamingRequest(ctx context.Context, accessToken string) (*gtsmodel.Account, error)
// OpenStreamForAccount opens a new stream for the given account, with the given stream type.
- OpenStreamForAccount(account *gtsmodel.Account, streamType string) (*gtsmodel.Stream, gtserror.WithCode)
+ OpenStreamForAccount(ctx context.Context, account *gtsmodel.Account, streamType string) (*gtsmodel.Stream, gtserror.WithCode)
/*
FEDERATION API-FACING PROCESSING FUNCTIONS
@@ -199,10 +199,10 @@ type Processor interface {
GetWebfingerAccount(ctx context.Context, requestedUsername string, requestURL *url.URL) (*apimodel.WellKnownResponse, gtserror.WithCode)
// GetNodeInfoRel returns a well known response giving the path to node info.
- GetNodeInfoRel(request *http.Request) (*apimodel.WellKnownResponse, gtserror.WithCode)
+ GetNodeInfoRel(ctx context.Context, request *http.Request) (*apimodel.WellKnownResponse, gtserror.WithCode)
// GetNodeInfo returns a node info struct in response to a node info request.
- GetNodeInfo(request *http.Request) (*apimodel.Nodeinfo, gtserror.WithCode)
+ GetNodeInfo(ctx context.Context, request *http.Request) (*apimodel.Nodeinfo, gtserror.WithCode)
// InboxPost handles POST requests to a user's inbox for new activitypub messages.
//
@@ -280,7 +280,7 @@ func NewProcessor(config *config.Config, tc typeutils.TypeConverter, federator f
}
// Start starts the Processor, reading from its channels and passing messages back and forth.
-func (p *processor) Start() error {
+func (p *processor) Start(ctx context.Context) error {
go func() {
DistLoop:
for {
@@ -288,14 +288,14 @@ func (p *processor) Start() error {
case clientMsg := <-p.fromClientAPI:
p.log.Tracef("received message FROM client API: %+v", clientMsg)
go func() {
- if err := p.processFromClientAPI(clientMsg); err != nil {
+ if err := p.processFromClientAPI(ctx, clientMsg); err != nil {
p.log.Error(err)
}
}()
case federatorMsg := <-p.fromFederator:
p.log.Tracef("received message FROM federator: %+v", federatorMsg)
go func() {
- if err := p.processFromFederator(federatorMsg); err != nil {
+ if err := p.processFromFederator(ctx, federatorMsg); err != nil {
p.log.Error(err)
}
}()
diff --git a/internal/processing/search.go b/internal/processing/search.go
index f2ae721ae..768fceacd 100644
--- a/internal/processing/search.go
+++ b/internal/processing/search.go
@@ -19,6 +19,7 @@
package processing
import (
+ "context"
"fmt"
"net/url"
"strings"
@@ -32,7 +33,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/util"
)
-func (p *processor) SearchGet(authed *oauth.Auth, searchQuery *apimodel.SearchQuery) (*apimodel.SearchResult, gtserror.WithCode) {
+func (p *processor) SearchGet(ctx context.Context, authed *oauth.Auth, searchQuery *apimodel.SearchQuery) (*apimodel.SearchResult, gtserror.WithCode) {
l := p.log.WithFields(logrus.Fields{
"func": "SearchGet",
"query": searchQuery.Query,
@@ -54,7 +55,7 @@ func (p *processor) SearchGet(authed *oauth.Auth, searchQuery *apimodel.SearchQu
// check if the query is something like @whatever_username@example.org -- this means it's a remote account
if !foundOne && util.IsMention(searchQuery.Query) {
l.Debug("search term is a mention, looking it up...")
- foundAccount, err := p.searchAccountByMention(authed, searchQuery.Query, searchQuery.Resolve)
+ foundAccount, err := p.searchAccountByMention(ctx, authed, searchQuery.Query, searchQuery.Resolve)
if err == nil && foundAccount != nil {
foundAccounts = append(foundAccounts, foundAccount)
foundOne = true
@@ -65,14 +66,14 @@ func (p *processor) SearchGet(authed *oauth.Auth, searchQuery *apimodel.SearchQu
// check if the query is a URI and just do a lookup for that, straight up
if uri, err := url.Parse(query); err == nil && !foundOne {
// 1. check if it's a status
- if foundStatus, err := p.searchStatusByURI(authed, uri, searchQuery.Resolve); err == nil && foundStatus != nil {
+ if foundStatus, err := p.searchStatusByURI(ctx, authed, uri, searchQuery.Resolve); err == nil && foundStatus != nil {
foundStatuses = append(foundStatuses, foundStatus)
foundOne = true
l.Debug("got a status by searching by URI")
}
// 2. check if it's an account
- if foundAccount, err := p.searchAccountByURI(authed, uri, searchQuery.Resolve); err == nil && foundAccount != nil {
+ if foundAccount, err := p.searchAccountByURI(ctx, authed, uri, searchQuery.Resolve); err == nil && foundAccount != nil {
foundAccounts = append(foundAccounts, foundAccount)
foundOne = true
l.Debug("got an account by searching by URI")
@@ -90,20 +91,20 @@ func (p *processor) SearchGet(authed *oauth.Auth, searchQuery *apimodel.SearchQu
*/
for _, foundAccount := range foundAccounts {
// make sure there's no block in either direction between the account and the requester
- if blocked, err := p.db.IsBlocked(authed.Account.ID, foundAccount.ID, true); err == nil && !blocked {
+ if blocked, err := p.db.IsBlocked(ctx, authed.Account.ID, foundAccount.ID, true); err == nil && !blocked {
// all good, convert it and add it to the results
- if acctMasto, err := p.tc.AccountToMastoPublic(foundAccount); err == nil && acctMasto != nil {
+ if acctMasto, err := p.tc.AccountToMastoPublic(ctx, foundAccount); err == nil && acctMasto != nil {
results.Accounts = append(results.Accounts, *acctMasto)
}
}
}
for _, foundStatus := range foundStatuses {
- if visible, err := p.filter.StatusVisible(foundStatus, authed.Account); !visible || err != nil {
+ if visible, err := p.filter.StatusVisible(ctx, foundStatus, authed.Account); !visible || err != nil {
continue
}
- statusMasto, err := p.tc.StatusToMasto(foundStatus, authed.Account)
+ statusMasto, err := p.tc.StatusToMasto(ctx, foundStatus, authed.Account)
if err != nil {
continue
}
@@ -114,24 +115,24 @@ func (p *processor) SearchGet(authed *oauth.Auth, searchQuery *apimodel.SearchQu
return results, nil
}
-func (p *processor) searchStatusByURI(authed *oauth.Auth, uri *url.URL, resolve bool) (*gtsmodel.Status, error) {
+func (p *processor) searchStatusByURI(ctx context.Context, authed *oauth.Auth, uri *url.URL, resolve bool) (*gtsmodel.Status, error) {
l := p.log.WithFields(logrus.Fields{
"func": "searchStatusByURI",
"uri": uri.String(),
"resolve": resolve,
})
- if maybeStatus, err := p.db.GetStatusByURI(uri.String()); err == nil {
+ if maybeStatus, err := p.db.GetStatusByURI(ctx, uri.String()); err == nil {
return maybeStatus, nil
- } else if maybeStatus, err := p.db.GetStatusByURL(uri.String()); err == nil {
+ } else if maybeStatus, err := p.db.GetStatusByURL(ctx, uri.String()); err == nil {
return maybeStatus, nil
}
// we don't have it locally so dereference it if we're allowed to
if resolve {
- status, _, _, err := p.federator.GetRemoteStatus(authed.Account.Username, uri, true)
+ status, _, _, err := p.federator.GetRemoteStatus(ctx, authed.Account.Username, uri, true)
if err == nil {
- if err := p.federator.DereferenceRemoteThread(authed.Account.Username, uri); err != nil {
+ if err := p.federator.DereferenceRemoteThread(ctx, authed.Account.Username, uri); err != nil {
// try to deref the thread while we're here
l.Debugf("searchStatusByURI: error dereferencing remote thread: %s", err)
}
@@ -141,16 +142,16 @@ func (p *processor) searchStatusByURI(authed *oauth.Auth, uri *url.URL, resolve
return nil, nil
}
-func (p *processor) searchAccountByURI(authed *oauth.Auth, uri *url.URL, resolve bool) (*gtsmodel.Account, error) {
- if maybeAccount, err := p.db.GetAccountByURI(uri.String()); err == nil {
+func (p *processor) searchAccountByURI(ctx context.Context, authed *oauth.Auth, uri *url.URL, resolve bool) (*gtsmodel.Account, error) {
+ if maybeAccount, err := p.db.GetAccountByURI(ctx, uri.String()); err == nil {
return maybeAccount, nil
- } else if maybeAccount, err := p.db.GetAccountByURL(uri.String()); err == nil {
+ } else if maybeAccount, err := p.db.GetAccountByURL(ctx, uri.String()); err == nil {
return maybeAccount, nil
}
if resolve {
// we don't have it locally so try and dereference it
- account, _, err := p.federator.GetRemoteAccount(authed.Account.Username, uri, true)
+ account, _, err := p.federator.GetRemoteAccount(ctx, authed.Account.Username, uri, true)
if err != nil {
return nil, fmt.Errorf("searchAccountByURI: error dereferencing account with uri %s: %s", uri.String(), err)
}
@@ -159,7 +160,7 @@ func (p *processor) searchAccountByURI(authed *oauth.Auth, uri *url.URL, resolve
return nil, nil
}
-func (p *processor) searchAccountByMention(authed *oauth.Auth, mention string, resolve bool) (*gtsmodel.Account, error) {
+func (p *processor) searchAccountByMention(ctx context.Context, authed *oauth.Auth, mention string, resolve bool) (*gtsmodel.Account, error) {
// query is for a remote account
username, domain, err := util.ExtractMentionParts(mention)
if err != nil {
@@ -169,7 +170,7 @@ func (p *processor) searchAccountByMention(authed *oauth.Auth, mention string, r
// if it's a local account we can skip a whole bunch of stuff
maybeAcct := &gtsmodel.Account{}
if domain == p.config.Host {
- maybeAcct, err = p.db.GetLocalAccountByUsername(username)
+ maybeAcct, err = p.db.GetLocalAccountByUsername(ctx, username)
if err != nil {
return nil, fmt.Errorf("searchAccountByMention: error getting local account by username: %s", err)
}
@@ -181,7 +182,7 @@ func (p *processor) searchAccountByMention(authed *oauth.Auth, mention string, r
{Key: "username", Value: username, CaseInsensitive: true},
{Key: "domain", Value: domain, CaseInsensitive: true},
}
- err = p.db.GetWhere(where, maybeAcct)
+ err = p.db.GetWhere(ctx, where, maybeAcct)
if err == nil {
// we've got it stored locally already!
return maybeAcct, nil
@@ -197,14 +198,14 @@ func (p *processor) searchAccountByMention(authed *oauth.Auth, mention string, r
// we're allowed to resolve it so let's try
// first we need to webfinger the remote account to convert the username and domain into the activitypub URI for the account
- acctURI, err := p.federator.FingerRemoteAccount(authed.Account.Username, username, domain)
+ acctURI, err := p.federator.FingerRemoteAccount(ctx, authed.Account.Username, username, domain)
if err != nil {
// something went wrong doing the webfinger lookup so we can't process the request
return nil, fmt.Errorf("searchAccountByMention: error fingering remote account with username %s and domain %s: %s", username, domain, err)
}
// we don't have it locally so try and dereference it
- account, _, err := p.federator.GetRemoteAccount(authed.Account.Username, acctURI, true)
+ account, _, err := p.federator.GetRemoteAccount(ctx, authed.Account.Username, acctURI, true)
if err != nil {
return nil, fmt.Errorf("searchAccountByMention: error dereferencing account with uri %s: %s", acctURI.String(), err)
}
diff --git a/internal/processing/status.go b/internal/processing/status.go
index ab3843ded..c31c20628 100644
--- a/internal/processing/status.go
+++ b/internal/processing/status.go
@@ -19,47 +19,49 @@
package processing
import (
+ "context"
+
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
)
-func (p *processor) StatusCreate(authed *oauth.Auth, form *apimodel.AdvancedStatusCreateForm) (*apimodel.Status, error) {
- return p.statusProcessor.Create(authed.Account, authed.Application, form)
+func (p *processor) StatusCreate(ctx context.Context, authed *oauth.Auth, form *apimodel.AdvancedStatusCreateForm) (*apimodel.Status, error) {
+ return p.statusProcessor.Create(ctx, authed.Account, authed.Application, form)
}
-func (p *processor) StatusDelete(authed *oauth.Auth, targetStatusID string) (*apimodel.Status, error) {
- return p.statusProcessor.Delete(authed.Account, targetStatusID)
+func (p *processor) StatusDelete(ctx context.Context, authed *oauth.Auth, targetStatusID string) (*apimodel.Status, error) {
+ return p.statusProcessor.Delete(ctx, authed.Account, targetStatusID)
}
-func (p *processor) StatusFave(authed *oauth.Auth, targetStatusID string) (*apimodel.Status, error) {
- return p.statusProcessor.Fave(authed.Account, targetStatusID)
+func (p *processor) StatusFave(ctx context.Context, authed *oauth.Auth, targetStatusID string) (*apimodel.Status, error) {
+ return p.statusProcessor.Fave(ctx, authed.Account, targetStatusID)
}
-func (p *processor) StatusBoost(authed *oauth.Auth, targetStatusID string) (*apimodel.Status, gtserror.WithCode) {
- return p.statusProcessor.Boost(authed.Account, authed.Application, targetStatusID)
+func (p *processor) StatusBoost(ctx context.Context, authed *oauth.Auth, targetStatusID string) (*apimodel.Status, gtserror.WithCode) {
+ return p.statusProcessor.Boost(ctx, authed.Account, authed.Application, targetStatusID)
}
-func (p *processor) StatusUnboost(authed *oauth.Auth, targetStatusID string) (*apimodel.Status, gtserror.WithCode) {
- return p.statusProcessor.Unboost(authed.Account, authed.Application, targetStatusID)
+func (p *processor) StatusUnboost(ctx context.Context, authed *oauth.Auth, targetStatusID string) (*apimodel.Status, gtserror.WithCode) {
+ return p.statusProcessor.Unboost(ctx, authed.Account, authed.Application, targetStatusID)
}
-func (p *processor) StatusBoostedBy(authed *oauth.Auth, targetStatusID string) ([]*apimodel.Account, gtserror.WithCode) {
- return p.statusProcessor.BoostedBy(authed.Account, targetStatusID)
+func (p *processor) StatusBoostedBy(ctx context.Context, authed *oauth.Auth, targetStatusID string) ([]*apimodel.Account, gtserror.WithCode) {
+ return p.statusProcessor.BoostedBy(ctx, authed.Account, targetStatusID)
}
-func (p *processor) StatusFavedBy(authed *oauth.Auth, targetStatusID string) ([]*apimodel.Account, error) {
- return p.statusProcessor.FavedBy(authed.Account, targetStatusID)
+func (p *processor) StatusFavedBy(ctx context.Context, authed *oauth.Auth, targetStatusID string) ([]*apimodel.Account, error) {
+ return p.statusProcessor.FavedBy(ctx, authed.Account, targetStatusID)
}
-func (p *processor) StatusGet(authed *oauth.Auth, targetStatusID string) (*apimodel.Status, error) {
- return p.statusProcessor.Get(authed.Account, targetStatusID)
+func (p *processor) StatusGet(ctx context.Context, authed *oauth.Auth, targetStatusID string) (*apimodel.Status, error) {
+ return p.statusProcessor.Get(ctx, authed.Account, targetStatusID)
}
-func (p *processor) StatusUnfave(authed *oauth.Auth, targetStatusID string) (*apimodel.Status, error) {
- return p.statusProcessor.Unfave(authed.Account, targetStatusID)
+func (p *processor) StatusUnfave(ctx context.Context, authed *oauth.Auth, targetStatusID string) (*apimodel.Status, error) {
+ return p.statusProcessor.Unfave(ctx, authed.Account, targetStatusID)
}
-func (p *processor) StatusGetContext(authed *oauth.Auth, targetStatusID string) (*apimodel.Context, gtserror.WithCode) {
- return p.statusProcessor.Context(authed.Account, targetStatusID)
+func (p *processor) StatusGetContext(ctx context.Context, authed *oauth.Auth, targetStatusID string) (*apimodel.Context, gtserror.WithCode) {
+ return p.statusProcessor.Context(ctx, authed.Account, targetStatusID)
}
diff --git a/internal/processing/status/boost.go b/internal/processing/status/boost.go
index d7a62beb1..948d57a48 100644
--- a/internal/processing/status/boost.go
+++ b/internal/processing/status/boost.go
@@ -1,6 +1,25 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
package status
import (
+ "context"
"errors"
"fmt"
@@ -9,8 +28,8 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
-func (p *processor) Boost(requestingAccount *gtsmodel.Account, application *gtsmodel.Application, targetStatusID string) (*apimodel.Status, gtserror.WithCode) {
- targetStatus, err := p.db.GetStatusByID(targetStatusID)
+func (p *processor) Boost(ctx context.Context, requestingAccount *gtsmodel.Account, application *gtsmodel.Application, targetStatusID string) (*apimodel.Status, gtserror.WithCode) {
+ targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID)
if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err))
}
@@ -18,7 +37,7 @@ func (p *processor) Boost(requestingAccount *gtsmodel.Account, application *gtsm
return nil, gtserror.NewErrorNotFound(fmt.Errorf("no status owner for status %s", targetStatusID))
}
- visible, err := p.filter.StatusVisible(targetStatus, requestingAccount)
+ visible, err := p.filter.StatusVisible(ctx, targetStatus, requestingAccount)
if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("error seeing if status %s is visible: %s", targetStatus.ID, err))
}
@@ -32,7 +51,7 @@ func (p *processor) Boost(requestingAccount *gtsmodel.Account, application *gtsm
}
// it's visible! it's boostable! so let's boost the FUCK out of it
- boostWrapperStatus, err := p.tc.StatusToBoost(targetStatus, requestingAccount)
+ boostWrapperStatus, err := p.tc.StatusToBoost(ctx, targetStatus, requestingAccount)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
@@ -41,7 +60,7 @@ func (p *processor) Boost(requestingAccount *gtsmodel.Account, application *gtsm
boostWrapperStatus.BoostOfAccount = targetStatus.Account
// put the boost in the database
- if err := p.db.PutStatus(boostWrapperStatus); err != nil {
+ if err := p.db.PutStatus(ctx, boostWrapperStatus); err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
@@ -55,7 +74,7 @@ func (p *processor) Boost(requestingAccount *gtsmodel.Account, application *gtsm
}
// return the frontend representation of the new status to the submitter
- mastoStatus, err := p.tc.StatusToMasto(boostWrapperStatus, requestingAccount)
+ mastoStatus, err := p.tc.StatusToMasto(ctx, boostWrapperStatus, requestingAccount)
if err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error converting status %s to frontend representation: %s", targetStatus.ID, err))
}
diff --git a/internal/processing/status/boostedby.go b/internal/processing/status/boostedby.go
index 1bde6b5ae..46f41039f 100644
--- a/internal/processing/status/boostedby.go
+++ b/internal/processing/status/boostedby.go
@@ -1,6 +1,25 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
package status
import (
+ "context"
"errors"
"fmt"
@@ -9,8 +28,8 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
-func (p *processor) BoostedBy(requestingAccount *gtsmodel.Account, targetStatusID string) ([]*apimodel.Account, gtserror.WithCode) {
- targetStatus, err := p.db.GetStatusByID(targetStatusID)
+func (p *processor) BoostedBy(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) ([]*apimodel.Account, gtserror.WithCode) {
+ targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID)
if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err))
}
@@ -18,7 +37,7 @@ func (p *processor) BoostedBy(requestingAccount *gtsmodel.Account, targetStatusI
return nil, gtserror.NewErrorNotFound(fmt.Errorf("no status owner for status %s", targetStatusID))
}
- visible, err := p.filter.StatusVisible(targetStatus, requestingAccount)
+ visible, err := p.filter.StatusVisible(ctx, targetStatus, requestingAccount)
if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("error seeing if status %s is visible: %s", targetStatus.ID, err))
}
@@ -26,7 +45,7 @@ func (p *processor) BoostedBy(requestingAccount *gtsmodel.Account, targetStatusI
return nil, gtserror.NewErrorNotFound(errors.New("status is not visible"))
}
- statusReblogs, err := p.db.GetStatusReblogs(targetStatus)
+ statusReblogs, err := p.db.GetStatusReblogs(ctx, targetStatus)
if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("StatusBoostedBy: error seeing who boosted status: %s", err))
}
@@ -34,7 +53,7 @@ func (p *processor) BoostedBy(requestingAccount *gtsmodel.Account, targetStatusI
// filter the list so the user doesn't see accounts they blocked or which blocked them
filteredAccounts := []*gtsmodel.Account{}
for _, s := range statusReblogs {
- blocked, err := p.db.IsBlocked(requestingAccount.ID, s.AccountID, true)
+ blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, s.AccountID, true)
if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("StatusBoostedBy: error checking blocks: %s", err))
}
@@ -48,7 +67,7 @@ func (p *processor) BoostedBy(requestingAccount *gtsmodel.Account, targetStatusI
// now we can return the masto representation of those accounts
mastoAccounts := []*apimodel.Account{}
for _, acc := range filteredAccounts {
- mastoAccount, err := p.tc.AccountToMastoPublic(acc)
+ mastoAccount, err := p.tc.AccountToMastoPublic(ctx, acc)
if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("StatusFavedBy: error converting account to api model: %s", err))
}
diff --git a/internal/processing/status/context.go b/internal/processing/status/context.go
index 43002545e..3e8e93d09 100644
--- a/internal/processing/status/context.go
+++ b/internal/processing/status/context.go
@@ -1,6 +1,25 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
package status
import (
+ "context"
"errors"
"fmt"
"sort"
@@ -10,8 +29,8 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
-func (p *processor) Context(requestingAccount *gtsmodel.Account, targetStatusID string) (*apimodel.Context, gtserror.WithCode) {
- targetStatus, err := p.db.GetStatusByID(targetStatusID)
+func (p *processor) Context(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) (*apimodel.Context, gtserror.WithCode) {
+ targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID)
if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err))
}
@@ -19,7 +38,7 @@ func (p *processor) Context(requestingAccount *gtsmodel.Account, targetStatusID
return nil, gtserror.NewErrorNotFound(fmt.Errorf("no status owner for status %s", targetStatusID))
}
- visible, err := p.filter.StatusVisible(targetStatus, requestingAccount)
+ visible, err := p.filter.StatusVisible(ctx, targetStatus, requestingAccount)
if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("error seeing if status %s is visible: %s", targetStatus.ID, err))
}
@@ -32,14 +51,14 @@ func (p *processor) Context(requestingAccount *gtsmodel.Account, targetStatusID
Descendants: []apimodel.Status{},
}
- parents, err := p.db.GetStatusParents(targetStatus, false)
+ parents, err := p.db.GetStatusParents(ctx, targetStatus, false)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
for _, status := range parents {
- if v, err := p.filter.StatusVisible(status, requestingAccount); err == nil && v {
- mastoStatus, err := p.tc.StatusToMasto(status, requestingAccount)
+ if v, err := p.filter.StatusVisible(ctx, status, requestingAccount); err == nil && v {
+ mastoStatus, err := p.tc.StatusToMasto(ctx, status, requestingAccount)
if err == nil {
context.Ancestors = append(context.Ancestors, *mastoStatus)
}
@@ -50,14 +69,14 @@ func (p *processor) Context(requestingAccount *gtsmodel.Account, targetStatusID
return context.Ancestors[i].ID < context.Ancestors[j].ID
})
- children, err := p.db.GetStatusChildren(targetStatus, false, "")
+ children, err := p.db.GetStatusChildren(ctx, targetStatus, false, "")
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
for _, status := range children {
- if v, err := p.filter.StatusVisible(status, requestingAccount); err == nil && v {
- mastoStatus, err := p.tc.StatusToMasto(status, requestingAccount)
+ if v, err := p.filter.StatusVisible(ctx, status, requestingAccount); err == nil && v {
+ mastoStatus, err := p.tc.StatusToMasto(ctx, status, requestingAccount)
if err == nil {
context.Descendants = append(context.Descendants, *mastoStatus)
}
diff --git a/internal/processing/status/create.go b/internal/processing/status/create.go
index fc112ed8b..2e0b30ad8 100644
--- a/internal/processing/status/create.go
+++ b/internal/processing/status/create.go
@@ -1,6 +1,25 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
package status
import (
+ "context"
"fmt"
"time"
@@ -12,7 +31,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/util"
)
-func (p *processor) Create(account *gtsmodel.Account, application *gtsmodel.Application, form *apimodel.AdvancedStatusCreateForm) (*apimodel.Status, gtserror.WithCode) {
+func (p *processor) Create(ctx context.Context, account *gtsmodel.Account, application *gtsmodel.Application, form *apimodel.AdvancedStatusCreateForm) (*apimodel.Status, gtserror.WithCode) {
uris := util.GenerateURIsForAccount(account.Username, p.config.Protocol, p.config.Host)
thisStatusID, err := id.NewULID()
if err != nil {
@@ -38,40 +57,40 @@ func (p *processor) Create(account *gtsmodel.Account, application *gtsmodel.Appl
Text: form.Status,
}
- if err := p.ProcessReplyToID(form, account.ID, newStatus); err != nil {
+ if err := p.ProcessReplyToID(ctx, form, account.ID, newStatus); err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
- if err := p.ProcessMediaIDs(form, account.ID, newStatus); err != nil {
+ if err := p.ProcessMediaIDs(ctx, form, account.ID, newStatus); err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
- if err := p.ProcessVisibility(form, account.Privacy, newStatus); err != nil {
+ if err := p.ProcessVisibility(ctx, form, account.Privacy, newStatus); err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
- if err := p.ProcessLanguage(form, account.Language, newStatus); err != nil {
+ if err := p.ProcessLanguage(ctx, form, account.Language, newStatus); err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
- if err := p.ProcessMentions(form, account.ID, newStatus); err != nil {
+ if err := p.ProcessMentions(ctx, form, account.ID, newStatus); err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
- if err := p.ProcessTags(form, account.ID, newStatus); err != nil {
+ if err := p.ProcessTags(ctx, form, account.ID, newStatus); err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
- if err := p.ProcessEmojis(form, account.ID, newStatus); err != nil {
+ if err := p.ProcessEmojis(ctx, form, account.ID, newStatus); err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
- if err := p.ProcessContent(form, account.ID, newStatus); err != nil {
+ if err := p.ProcessContent(ctx, form, account.ID, newStatus); err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
// put the new status in the database
- if err := p.db.PutStatus(newStatus); err != nil {
+ if err := p.db.PutStatus(ctx, newStatus); err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
@@ -84,7 +103,7 @@ func (p *processor) Create(account *gtsmodel.Account, application *gtsmodel.Appl
}
// return the frontend representation of the new status to the submitter
- mastoStatus, err := p.tc.StatusToMasto(newStatus, account)
+ mastoStatus, err := p.tc.StatusToMasto(ctx, newStatus, account)
if err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error converting status %s to frontend representation: %s", newStatus.ID, err))
}
diff --git a/internal/processing/status/delete.go b/internal/processing/status/delete.go
index 4c5dfd744..daa7a934f 100644
--- a/internal/processing/status/delete.go
+++ b/internal/processing/status/delete.go
@@ -1,6 +1,25 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
package status
import (
+ "context"
"errors"
"fmt"
@@ -9,8 +28,8 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
-func (p *processor) Delete(requestingAccount *gtsmodel.Account, targetStatusID string) (*apimodel.Status, gtserror.WithCode) {
- targetStatus, err := p.db.GetStatusByID(targetStatusID)
+func (p *processor) Delete(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) (*apimodel.Status, gtserror.WithCode) {
+ targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID)
if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err))
}
@@ -22,12 +41,12 @@ func (p *processor) Delete(requestingAccount *gtsmodel.Account, targetStatusID s
return nil, gtserror.NewErrorForbidden(errors.New("status doesn't belong to requesting account"))
}
- mastoStatus, err := p.tc.StatusToMasto(targetStatus, requestingAccount)
+ mastoStatus, err := p.tc.StatusToMasto(ctx, targetStatus, requestingAccount)
if err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error converting status %s to frontend representation: %s", targetStatus.ID, err))
}
- if err := p.db.DeleteByID(targetStatus.ID, &gtsmodel.Status{}); err != nil {
+ if err := p.db.DeleteByID(ctx, targetStatus.ID, &gtsmodel.Status{}); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error deleting status from the database: %s", err))
}
diff --git a/internal/processing/status/fave.go b/internal/processing/status/fave.go
index 7ba8c8fe8..2badf83b3 100644
--- a/internal/processing/status/fave.go
+++ b/internal/processing/status/fave.go
@@ -1,6 +1,25 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
package status
import (
+ "context"
"errors"
"fmt"
@@ -12,8 +31,8 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/util"
)
-func (p *processor) Fave(requestingAccount *gtsmodel.Account, targetStatusID string) (*apimodel.Status, gtserror.WithCode) {
- targetStatus, err := p.db.GetStatusByID(targetStatusID)
+func (p *processor) Fave(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) (*apimodel.Status, gtserror.WithCode) {
+ targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID)
if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err))
}
@@ -21,7 +40,7 @@ func (p *processor) Fave(requestingAccount *gtsmodel.Account, targetStatusID str
return nil, gtserror.NewErrorNotFound(fmt.Errorf("no status owner for status %s", targetStatusID))
}
- visible, err := p.filter.StatusVisible(targetStatus, requestingAccount)
+ visible, err := p.filter.StatusVisible(ctx, targetStatus, requestingAccount)
if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("error seeing if status %s is visible: %s", targetStatus.ID, err))
}
@@ -37,7 +56,7 @@ func (p *processor) Fave(requestingAccount *gtsmodel.Account, targetStatusID str
// first check if the status is already faved, if so we don't need to do anything
newFave := true
gtsFave := &gtsmodel.StatusFave{}
- if err := p.db.GetWhere([]db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsFave); err == nil {
+ if err := p.db.GetWhere(ctx, []db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsFave); err == nil {
// we already have a fave for this status
newFave = false
}
@@ -60,7 +79,7 @@ func (p *processor) Fave(requestingAccount *gtsmodel.Account, targetStatusID str
URI: util.GenerateURIForLike(requestingAccount.Username, p.config.Protocol, p.config.Host, thisFaveID),
}
- if err := p.db.Put(gtsFave); err != nil {
+ if err := p.db.Put(ctx, gtsFave); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error putting fave in database: %s", err))
}
@@ -75,7 +94,7 @@ func (p *processor) Fave(requestingAccount *gtsmodel.Account, targetStatusID str
}
// return the mastodon representation of the target status
- mastoStatus, err := p.tc.StatusToMasto(targetStatus, requestingAccount)
+ mastoStatus, err := p.tc.StatusToMasto(ctx, targetStatus, requestingAccount)
if err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error converting status %s to frontend representation: %s", targetStatus.ID, err))
}
diff --git a/internal/processing/status/favedby.go b/internal/processing/status/favedby.go
index dffe6bba9..227fb669d 100644
--- a/internal/processing/status/favedby.go
+++ b/internal/processing/status/favedby.go
@@ -1,6 +1,25 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
package status
import (
+ "context"
"errors"
"fmt"
@@ -9,8 +28,8 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
-func (p *processor) FavedBy(requestingAccount *gtsmodel.Account, targetStatusID string) ([]*apimodel.Account, gtserror.WithCode) {
- targetStatus, err := p.db.GetStatusByID(targetStatusID)
+func (p *processor) FavedBy(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) ([]*apimodel.Account, gtserror.WithCode) {
+ targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID)
if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err))
}
@@ -18,7 +37,7 @@ func (p *processor) FavedBy(requestingAccount *gtsmodel.Account, targetStatusID
return nil, gtserror.NewErrorNotFound(fmt.Errorf("no status owner for status %s", targetStatusID))
}
- visible, err := p.filter.StatusVisible(targetStatus, requestingAccount)
+ visible, err := p.filter.StatusVisible(ctx, targetStatus, requestingAccount)
if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("error seeing if status %s is visible: %s", targetStatus.ID, err))
}
@@ -26,7 +45,7 @@ func (p *processor) FavedBy(requestingAccount *gtsmodel.Account, targetStatusID
return nil, gtserror.NewErrorNotFound(errors.New("status is not visible"))
}
- statusFaves, err := p.db.GetStatusFaves(targetStatus)
+ statusFaves, err := p.db.GetStatusFaves(ctx, targetStatus)
if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("error seeing who faved status: %s", err))
}
@@ -34,7 +53,7 @@ func (p *processor) FavedBy(requestingAccount *gtsmodel.Account, targetStatusID
// filter the list so the user doesn't see accounts they blocked or which blocked them
filteredAccounts := []*gtsmodel.Account{}
for _, fave := range statusFaves {
- blocked, err := p.db.IsBlocked(requestingAccount.ID, fave.AccountID, true)
+ blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, fave.AccountID, true)
if err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error checking blocks: %s", err))
}
@@ -46,7 +65,7 @@ func (p *processor) FavedBy(requestingAccount *gtsmodel.Account, targetStatusID
// now we can return the masto representation of those accounts
mastoAccounts := []*apimodel.Account{}
for _, acc := range filteredAccounts {
- mastoAccount, err := p.tc.AccountToMastoPublic(acc)
+ mastoAccount, err := p.tc.AccountToMastoPublic(ctx, acc)
if err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error converting status %s to frontend representation: %s", targetStatus.ID, err))
}
diff --git a/internal/processing/status/get.go b/internal/processing/status/get.go
index 9d403b901..258210faf 100644
--- a/internal/processing/status/get.go
+++ b/internal/processing/status/get.go
@@ -1,6 +1,25 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
package status
import (
+ "context"
"errors"
"fmt"
@@ -9,8 +28,8 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
-func (p *processor) Get(requestingAccount *gtsmodel.Account, targetStatusID string) (*apimodel.Status, gtserror.WithCode) {
- targetStatus, err := p.db.GetStatusByID(targetStatusID)
+func (p *processor) Get(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) (*apimodel.Status, gtserror.WithCode) {
+ targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID)
if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err))
}
@@ -18,7 +37,7 @@ func (p *processor) Get(requestingAccount *gtsmodel.Account, targetStatusID stri
return nil, gtserror.NewErrorNotFound(fmt.Errorf("no status owner for status %s", targetStatusID))
}
- visible, err := p.filter.StatusVisible(targetStatus, requestingAccount)
+ visible, err := p.filter.StatusVisible(ctx, targetStatus, requestingAccount)
if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("error seeing if status %s is visible: %s", targetStatus.ID, err))
}
@@ -26,7 +45,7 @@ func (p *processor) Get(requestingAccount *gtsmodel.Account, targetStatusID stri
return nil, gtserror.NewErrorNotFound(errors.New("status is not visible"))
}
- mastoStatus, err := p.tc.StatusToMasto(targetStatus, requestingAccount)
+ mastoStatus, err := p.tc.StatusToMasto(ctx, targetStatus, requestingAccount)
if err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error converting status %s to frontend representation: %s", targetStatus.ID, err))
}
diff --git a/internal/processing/status/status.go b/internal/processing/status/status.go
index 038ca005e..37790d062 100644
--- a/internal/processing/status/status.go
+++ b/internal/processing/status/status.go
@@ -1,6 +1,26 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
package status
import (
+ "context"
+
"github.com/sirupsen/logrus"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/config"
@@ -15,38 +35,38 @@ import (
// Processor wraps a bunch of functions for processing statuses.
type Processor interface {
// Create processes the given form to create a new status, returning the api model representation of that status if it's OK.
- Create(account *gtsmodel.Account, application *gtsmodel.Application, form *apimodel.AdvancedStatusCreateForm) (*apimodel.Status, gtserror.WithCode)
+ Create(ctx context.Context, account *gtsmodel.Account, application *gtsmodel.Application, form *apimodel.AdvancedStatusCreateForm) (*apimodel.Status, gtserror.WithCode)
// Delete processes the delete of a given status, returning the deleted status if the delete goes through.
- Delete(account *gtsmodel.Account, targetStatusID string) (*apimodel.Status, gtserror.WithCode)
+ Delete(ctx context.Context, account *gtsmodel.Account, targetStatusID string) (*apimodel.Status, gtserror.WithCode)
// Fave processes the faving of a given status, returning the updated status if the fave goes through.
- Fave(account *gtsmodel.Account, targetStatusID string) (*apimodel.Status, gtserror.WithCode)
+ Fave(ctx context.Context, account *gtsmodel.Account, targetStatusID string) (*apimodel.Status, gtserror.WithCode)
// Boost processes the boost/reblog of a given status, returning the newly-created boost if all is well.
- Boost(account *gtsmodel.Account, application *gtsmodel.Application, targetStatusID string) (*apimodel.Status, gtserror.WithCode)
+ Boost(ctx context.Context, account *gtsmodel.Account, application *gtsmodel.Application, targetStatusID string) (*apimodel.Status, gtserror.WithCode)
// Unboost processes the unboost/unreblog of a given status, returning the status if all is well.
- Unboost(account *gtsmodel.Account, application *gtsmodel.Application, targetStatusID string) (*apimodel.Status, gtserror.WithCode)
+ Unboost(ctx context.Context, account *gtsmodel.Account, application *gtsmodel.Application, targetStatusID string) (*apimodel.Status, gtserror.WithCode)
// BoostedBy returns a slice of accounts that have boosted the given status, filtered according to privacy settings.
- BoostedBy(account *gtsmodel.Account, targetStatusID string) ([]*apimodel.Account, gtserror.WithCode)
+ BoostedBy(ctx context.Context, account *gtsmodel.Account, targetStatusID string) ([]*apimodel.Account, gtserror.WithCode)
// FavedBy returns a slice of accounts that have liked the given status, filtered according to privacy settings.
- FavedBy(account *gtsmodel.Account, targetStatusID string) ([]*apimodel.Account, gtserror.WithCode)
+ FavedBy(ctx context.Context, account *gtsmodel.Account, targetStatusID string) ([]*apimodel.Account, gtserror.WithCode)
// Get gets the given status, taking account of privacy settings and blocks etc.
- Get(account *gtsmodel.Account, targetStatusID string) (*apimodel.Status, gtserror.WithCode)
+ Get(ctx context.Context, account *gtsmodel.Account, targetStatusID string) (*apimodel.Status, gtserror.WithCode)
// Unfave processes the unfaving of a given status, returning the updated status if the fave goes through.
- Unfave(account *gtsmodel.Account, targetStatusID string) (*apimodel.Status, gtserror.WithCode)
+ Unfave(ctx context.Context, account *gtsmodel.Account, targetStatusID string) (*apimodel.Status, gtserror.WithCode)
// Context returns the context (previous and following posts) from the given status ID
- Context(account *gtsmodel.Account, targetStatusID string) (*apimodel.Context, gtserror.WithCode)
+ Context(ctx context.Context, account *gtsmodel.Account, targetStatusID string) (*apimodel.Context, gtserror.WithCode)
/*
PROCESSING UTILS
*/
- ProcessVisibility(form *apimodel.AdvancedStatusCreateForm, accountDefaultVis gtsmodel.Visibility, status *gtsmodel.Status) error
- ProcessReplyToID(form *apimodel.AdvancedStatusCreateForm, thisAccountID string, status *gtsmodel.Status) error
- ProcessMediaIDs(form *apimodel.AdvancedStatusCreateForm, thisAccountID string, status *gtsmodel.Status) error
- ProcessLanguage(form *apimodel.AdvancedStatusCreateForm, accountDefaultLanguage string, status *gtsmodel.Status) error
- ProcessMentions(form *apimodel.AdvancedStatusCreateForm, accountID string, status *gtsmodel.Status) error
- ProcessTags(form *apimodel.AdvancedStatusCreateForm, accountID string, status *gtsmodel.Status) error
- ProcessEmojis(form *apimodel.AdvancedStatusCreateForm, accountID string, status *gtsmodel.Status) error
- ProcessContent(form *apimodel.AdvancedStatusCreateForm, accountID string, status *gtsmodel.Status) error
+ ProcessVisibility(ctx context.Context, form *apimodel.AdvancedStatusCreateForm, accountDefaultVis gtsmodel.Visibility, status *gtsmodel.Status) error
+ ProcessReplyToID(ctx context.Context, form *apimodel.AdvancedStatusCreateForm, thisAccountID string, status *gtsmodel.Status) error
+ ProcessMediaIDs(ctx context.Context, form *apimodel.AdvancedStatusCreateForm, thisAccountID string, status *gtsmodel.Status) error
+ ProcessLanguage(ctx context.Context, form *apimodel.AdvancedStatusCreateForm, accountDefaultLanguage string, status *gtsmodel.Status) error
+ ProcessMentions(ctx context.Context, form *apimodel.AdvancedStatusCreateForm, accountID string, status *gtsmodel.Status) error
+ ProcessTags(ctx context.Context, form *apimodel.AdvancedStatusCreateForm, accountID string, status *gtsmodel.Status) error
+ ProcessEmojis(ctx context.Context, form *apimodel.AdvancedStatusCreateForm, accountID string, status *gtsmodel.Status) error
+ ProcessContent(ctx context.Context, form *apimodel.AdvancedStatusCreateForm, accountID string, status *gtsmodel.Status) error
}
type processor struct {
diff --git a/internal/processing/status/unboost.go b/internal/processing/status/unboost.go
index 254cfe11f..c3c667a71 100644
--- a/internal/processing/status/unboost.go
+++ b/internal/processing/status/unboost.go
@@ -1,6 +1,25 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
package status
import (
+ "context"
"errors"
"fmt"
@@ -10,8 +29,8 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
-func (p *processor) Unboost(requestingAccount *gtsmodel.Account, application *gtsmodel.Application, targetStatusID string) (*apimodel.Status, gtserror.WithCode) {
- targetStatus, err := p.db.GetStatusByID(targetStatusID)
+func (p *processor) Unboost(ctx context.Context, requestingAccount *gtsmodel.Account, application *gtsmodel.Application, targetStatusID string) (*apimodel.Status, gtserror.WithCode) {
+ targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID)
if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err))
}
@@ -19,7 +38,7 @@ func (p *processor) Unboost(requestingAccount *gtsmodel.Account, application *gt
return nil, gtserror.NewErrorNotFound(fmt.Errorf("no status owner for status %s", targetStatusID))
}
- visible, err := p.filter.StatusVisible(targetStatus, requestingAccount)
+ visible, err := p.filter.StatusVisible(ctx, targetStatus, requestingAccount)
if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("error seeing if status %s is visible: %s", targetStatus.ID, err))
}
@@ -41,7 +60,7 @@ func (p *processor) Unboost(requestingAccount *gtsmodel.Account, application *gt
Value: requestingAccount.ID,
},
}
- err = p.db.GetWhere(where, gtsBoost)
+ err = p.db.GetWhere(ctx, where, gtsBoost)
if err == nil {
// we have a boost
toUnboost = true
@@ -58,7 +77,7 @@ func (p *processor) Unboost(requestingAccount *gtsmodel.Account, application *gt
if toUnboost {
// we had a boost, so take some action to get rid of it
- if err := p.db.DeleteWhere(where, &gtsmodel.Status{}); err != nil {
+ if err := p.db.DeleteWhere(ctx, where, &gtsmodel.Status{}); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error unboosting status: %s", err))
}
@@ -79,7 +98,7 @@ func (p *processor) Unboost(requestingAccount *gtsmodel.Account, application *gt
}
}
- mastoStatus, err := p.tc.StatusToMasto(targetStatus, requestingAccount)
+ mastoStatus, err := p.tc.StatusToMasto(ctx, targetStatus, requestingAccount)
if err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error converting status %s to frontend representation: %s", targetStatus.ID, err))
}
diff --git a/internal/processing/status/unfave.go b/internal/processing/status/unfave.go
index d6e5320db..3d079e2ff 100644
--- a/internal/processing/status/unfave.go
+++ b/internal/processing/status/unfave.go
@@ -1,6 +1,25 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
package status
import (
+ "context"
"errors"
"fmt"
@@ -10,8 +29,8 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
-func (p *processor) Unfave(requestingAccount *gtsmodel.Account, targetStatusID string) (*apimodel.Status, gtserror.WithCode) {
- targetStatus, err := p.db.GetStatusByID(targetStatusID)
+func (p *processor) Unfave(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) (*apimodel.Status, gtserror.WithCode) {
+ targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID)
if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err))
}
@@ -19,7 +38,7 @@ func (p *processor) Unfave(requestingAccount *gtsmodel.Account, targetStatusID s
return nil, gtserror.NewErrorNotFound(fmt.Errorf("no status owner for status %s", targetStatusID))
}
- visible, err := p.filter.StatusVisible(targetStatus, requestingAccount)
+ visible, err := p.filter.StatusVisible(ctx, targetStatus, requestingAccount)
if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("error seeing if status %s is visible: %s", targetStatus.ID, err))
}
@@ -31,7 +50,7 @@ func (p *processor) Unfave(requestingAccount *gtsmodel.Account, targetStatusID s
var toUnfave bool
gtsFave := &gtsmodel.StatusFave{}
- err = p.db.GetWhere([]db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsFave)
+ err = p.db.GetWhere(ctx, []db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsFave)
if err == nil {
// we have a fave
toUnfave = true
@@ -47,7 +66,7 @@ func (p *processor) Unfave(requestingAccount *gtsmodel.Account, targetStatusID s
if toUnfave {
// we had a fave, so take some action to get rid of it
- if err := p.db.DeleteWhere([]db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsFave); err != nil {
+ if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsFave); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error unfaveing status: %s", err))
}
@@ -61,7 +80,7 @@ func (p *processor) Unfave(requestingAccount *gtsmodel.Account, targetStatusID s
}
}
- mastoStatus, err := p.tc.StatusToMasto(targetStatus, requestingAccount)
+ mastoStatus, err := p.tc.StatusToMasto(ctx, targetStatus, requestingAccount)
if err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error converting status %s to frontend representation: %s", targetStatus.ID, err))
}
diff --git a/internal/processing/status/util.go b/internal/processing/status/util.go
index 025607f4a..26ee5d4f7 100644
--- a/internal/processing/status/util.go
+++ b/internal/processing/status/util.go
@@ -1,6 +1,25 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
package status
import (
+ "context"
"errors"
"fmt"
@@ -12,7 +31,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/util"
)
-func (p *processor) ProcessVisibility(form *apimodel.AdvancedStatusCreateForm, accountDefaultVis gtsmodel.Visibility, status *gtsmodel.Status) error {
+func (p *processor) ProcessVisibility(ctx context.Context, form *apimodel.AdvancedStatusCreateForm, accountDefaultVis gtsmodel.Visibility, status *gtsmodel.Status) error {
// by default all flags are set to true
gtsAdvancedVis := &gtsmodel.VisibilityAdvanced{
Federated: true,
@@ -83,7 +102,7 @@ func (p *processor) ProcessVisibility(form *apimodel.AdvancedStatusCreateForm, a
return nil
}
-func (p *processor) ProcessReplyToID(form *apimodel.AdvancedStatusCreateForm, thisAccountID string, status *gtsmodel.Status) error {
+func (p *processor) ProcessReplyToID(ctx context.Context, form *apimodel.AdvancedStatusCreateForm, thisAccountID string, status *gtsmodel.Status) error {
if form.InReplyToID == "" {
return nil
}
@@ -98,7 +117,7 @@ func (p *processor) ProcessReplyToID(form *apimodel.AdvancedStatusCreateForm, th
repliedStatus := &gtsmodel.Status{}
repliedAccount := &gtsmodel.Account{}
// check replied status exists + is replyable
- if err := p.db.GetByID(form.InReplyToID, repliedStatus); err != nil {
+ if err := p.db.GetByID(ctx, form.InReplyToID, repliedStatus); err != nil {
if err == db.ErrNoEntries {
return fmt.Errorf("status with id %s not replyable because it doesn't exist", form.InReplyToID)
}
@@ -112,14 +131,14 @@ func (p *processor) ProcessReplyToID(form *apimodel.AdvancedStatusCreateForm, th
}
// check replied account is known to us
- if err := p.db.GetByID(repliedStatus.AccountID, repliedAccount); err != nil {
+ if err := p.db.GetByID(ctx, repliedStatus.AccountID, repliedAccount); err != nil {
if err == db.ErrNoEntries {
return fmt.Errorf("status with id %s not replyable because account id %s is not known", form.InReplyToID, repliedStatus.AccountID)
}
return fmt.Errorf("status with id %s not replyable: %s", form.InReplyToID, err)
}
// check if a block exists
- if blocked, err := p.db.IsBlocked(thisAccountID, repliedAccount.ID, true); err != nil {
+ if blocked, err := p.db.IsBlocked(ctx, thisAccountID, repliedAccount.ID, true); err != nil {
if err != db.ErrNoEntries {
return fmt.Errorf("status with id %s not replyable: %s", form.InReplyToID, err)
}
@@ -132,7 +151,7 @@ func (p *processor) ProcessReplyToID(form *apimodel.AdvancedStatusCreateForm, th
return nil
}
-func (p *processor) ProcessMediaIDs(form *apimodel.AdvancedStatusCreateForm, thisAccountID string, status *gtsmodel.Status) error {
+func (p *processor) ProcessMediaIDs(ctx context.Context, form *apimodel.AdvancedStatusCreateForm, thisAccountID string, status *gtsmodel.Status) error {
if form.MediaIDs == nil {
return nil
}
@@ -142,7 +161,7 @@ func (p *processor) ProcessMediaIDs(form *apimodel.AdvancedStatusCreateForm, thi
for _, mediaID := range form.MediaIDs {
// check these attachments exist
a := &gtsmodel.MediaAttachment{}
- if err := p.db.GetByID(mediaID, a); err != nil {
+ if err := p.db.GetByID(ctx, mediaID, a); err != nil {
return fmt.Errorf("invalid media type or media not found for media id %s", mediaID)
}
// check they belong to the requesting account id
@@ -161,7 +180,7 @@ func (p *processor) ProcessMediaIDs(form *apimodel.AdvancedStatusCreateForm, thi
return nil
}
-func (p *processor) ProcessLanguage(form *apimodel.AdvancedStatusCreateForm, accountDefaultLanguage string, status *gtsmodel.Status) error {
+func (p *processor) ProcessLanguage(ctx context.Context, form *apimodel.AdvancedStatusCreateForm, accountDefaultLanguage string, status *gtsmodel.Status) error {
if form.Language != "" {
status.Language = form.Language
} else {
@@ -173,9 +192,9 @@ func (p *processor) ProcessLanguage(form *apimodel.AdvancedStatusCreateForm, acc
return nil
}
-func (p *processor) ProcessMentions(form *apimodel.AdvancedStatusCreateForm, accountID string, status *gtsmodel.Status) error {
+func (p *processor) ProcessMentions(ctx context.Context, form *apimodel.AdvancedStatusCreateForm, accountID string, status *gtsmodel.Status) error {
menchies := []string{}
- gtsMenchies, err := p.db.MentionStringsToMentions(util.DeriveMentionsFromStatus(form.Status), accountID, status.ID)
+ gtsMenchies, err := p.db.MentionStringsToMentions(ctx, util.DeriveMentionsFromStatus(form.Status), accountID, status.ID)
if err != nil {
return fmt.Errorf("error generating mentions from status: %s", err)
}
@@ -186,7 +205,7 @@ func (p *processor) ProcessMentions(form *apimodel.AdvancedStatusCreateForm, acc
}
menchie.ID = menchieID
- if err := p.db.Put(menchie); err != nil {
+ if err := p.db.Put(ctx, menchie); err != nil {
return fmt.Errorf("error putting mentions in db: %s", err)
}
menchies = append(menchies, menchie.ID)
@@ -198,14 +217,14 @@ func (p *processor) ProcessMentions(form *apimodel.AdvancedStatusCreateForm, acc
return nil
}
-func (p *processor) ProcessTags(form *apimodel.AdvancedStatusCreateForm, accountID string, status *gtsmodel.Status) error {
+func (p *processor) ProcessTags(ctx context.Context, form *apimodel.AdvancedStatusCreateForm, accountID string, status *gtsmodel.Status) error {
tags := []string{}
- gtsTags, err := p.db.TagStringsToTags(util.DeriveHashtagsFromStatus(form.Status), accountID, status.ID)
+ gtsTags, err := p.db.TagStringsToTags(ctx, util.DeriveHashtagsFromStatus(form.Status), accountID, status.ID)
if err != nil {
return fmt.Errorf("error generating hashtags from status: %s", err)
}
for _, tag := range gtsTags {
- if err := p.db.Upsert(tag, "name"); err != nil {
+ if err := p.db.Put(ctx, tag); err != nil && err != db.ErrAlreadyExists {
return fmt.Errorf("error putting tags in db: %s", err)
}
tags = append(tags, tag.ID)
@@ -217,9 +236,9 @@ func (p *processor) ProcessTags(form *apimodel.AdvancedStatusCreateForm, account
return nil
}
-func (p *processor) ProcessEmojis(form *apimodel.AdvancedStatusCreateForm, accountID string, status *gtsmodel.Status) error {
+func (p *processor) ProcessEmojis(ctx context.Context, form *apimodel.AdvancedStatusCreateForm, accountID string, status *gtsmodel.Status) error {
emojis := []string{}
- gtsEmojis, err := p.db.EmojiStringsToEmojis(util.DeriveEmojisFromStatus(form.Status), accountID, status.ID)
+ gtsEmojis, err := p.db.EmojiStringsToEmojis(ctx, util.DeriveEmojisFromStatus(form.Status), accountID, status.ID)
if err != nil {
return fmt.Errorf("error generating emojis from status: %s", err)
}
@@ -233,7 +252,7 @@ func (p *processor) ProcessEmojis(form *apimodel.AdvancedStatusCreateForm, accou
return nil
}
-func (p *processor) ProcessContent(form *apimodel.AdvancedStatusCreateForm, accountID string, status *gtsmodel.Status) error {
+func (p *processor) ProcessContent(ctx context.Context, form *apimodel.AdvancedStatusCreateForm, accountID string, status *gtsmodel.Status) error {
// if there's nothing in the status at all we can just return early
if form.Status == "" {
status.Content = ""
@@ -252,9 +271,9 @@ func (p *processor) ProcessContent(form *apimodel.AdvancedStatusCreateForm, acco
var formatted string
switch form.Format {
case apimodel.StatusFormatPlain:
- formatted = p.formatter.FromPlain(content, status.Mentions, status.Tags)
+ formatted = p.formatter.FromPlain(ctx, content, status.Mentions, status.Tags)
case apimodel.StatusFormatMarkdown:
- formatted = p.formatter.FromMarkdown(content, status.Mentions, status.Tags)
+ formatted = p.formatter.FromMarkdown(ctx, content, status.Mentions, status.Tags)
default:
return fmt.Errorf("format %s not recognised as a valid status format", form.Format)
}
diff --git a/internal/processing/status/util_test.go b/internal/processing/status/util_test.go
index 9c282eb52..1ec2076b1 100644
--- a/internal/processing/status/util_test.go
+++ b/internal/processing/status/util_test.go
@@ -1,6 +1,25 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
package status_test
import (
+ "context"
"fmt"
"testing"
@@ -88,7 +107,7 @@ func (suite *UtilTestSuite) TestProcessMentions1() {
ID: "01FCTDD78JJMX3K9KPXQ7ZQ8BJ",
}
- err := suite.status.ProcessMentions(form, creatingAccount.ID, status)
+ err := suite.status.ProcessMentions(context.Background(), form, creatingAccount.ID, status)
assert.NoError(suite.T(), err)
assert.Len(suite.T(), status.Mentions, 1)
@@ -138,11 +157,11 @@ func (suite *UtilTestSuite) TestProcessContentFull1() {
ID: "01FCTDD78JJMX3K9KPXQ7ZQ8BJ",
}
- err := suite.status.ProcessMentions(form, creatingAccount.ID, status)
+ err := suite.status.ProcessMentions(context.Background(), form, creatingAccount.ID, status)
assert.NoError(suite.T(), err)
assert.Empty(suite.T(), status.Content) // shouldn't be set yet
- err = suite.status.ProcessTags(form, creatingAccount.ID, status)
+ err = suite.status.ProcessTags(context.Background(), form, creatingAccount.ID, status)
assert.NoError(suite.T(), err)
assert.Empty(suite.T(), status.Content) // shouldn't be set yet
@@ -150,7 +169,7 @@ func (suite *UtilTestSuite) TestProcessContentFull1() {
ACTUAL TEST
*/
- err = suite.status.ProcessContent(form, creatingAccount.ID, status)
+ err = suite.status.ProcessContent(context.Background(), form, creatingAccount.ID, status)
assert.NoError(suite.T(), err)
assert.Equal(suite.T(), statusText1ExpectedFull, status.Content)
}
@@ -187,7 +206,7 @@ func (suite *UtilTestSuite) TestProcessContentPartial1() {
ID: "01FCTDD78JJMX3K9KPXQ7ZQ8BJ",
}
- err := suite.status.ProcessMentions(form, creatingAccount.ID, status)
+ err := suite.status.ProcessMentions(context.Background(), form, creatingAccount.ID, status)
assert.NoError(suite.T(), err)
assert.Empty(suite.T(), status.Content) // shouldn't be set yet
@@ -195,7 +214,7 @@ func (suite *UtilTestSuite) TestProcessContentPartial1() {
ACTUAL TEST
*/
- err = suite.status.ProcessContent(form, creatingAccount.ID, status)
+ err = suite.status.ProcessContent(context.Background(), form, creatingAccount.ID, status)
assert.NoError(suite.T(), err)
assert.Equal(suite.T(), statusText1ExpectedPartial, status.Content)
}
@@ -229,7 +248,7 @@ func (suite *UtilTestSuite) TestProcessMentions2() {
ID: "01FCTDD78JJMX3K9KPXQ7ZQ8BJ",
}
- err := suite.status.ProcessMentions(form, creatingAccount.ID, status)
+ err := suite.status.ProcessMentions(context.Background(), form, creatingAccount.ID, status)
assert.NoError(suite.T(), err)
assert.Len(suite.T(), status.Mentions, 1)
@@ -279,11 +298,11 @@ func (suite *UtilTestSuite) TestProcessContentFull2() {
ID: "01FCTDD78JJMX3K9KPXQ7ZQ8BJ",
}
- err := suite.status.ProcessMentions(form, creatingAccount.ID, status)
+ err := suite.status.ProcessMentions(context.Background(), form, creatingAccount.ID, status)
assert.NoError(suite.T(), err)
assert.Empty(suite.T(), status.Content) // shouldn't be set yet
- err = suite.status.ProcessTags(form, creatingAccount.ID, status)
+ err = suite.status.ProcessTags(context.Background(), form, creatingAccount.ID, status)
assert.NoError(suite.T(), err)
assert.Empty(suite.T(), status.Content) // shouldn't be set yet
@@ -291,7 +310,7 @@ func (suite *UtilTestSuite) TestProcessContentFull2() {
ACTUAL TEST
*/
- err = suite.status.ProcessContent(form, creatingAccount.ID, status)
+ err = suite.status.ProcessContent(context.Background(), form, creatingAccount.ID, status)
assert.NoError(suite.T(), err)
assert.Equal(suite.T(), status2TextExpectedFull, status.Content)
@@ -329,7 +348,7 @@ func (suite *UtilTestSuite) TestProcessContentPartial2() {
ID: "01FCTDD78JJMX3K9KPXQ7ZQ8BJ",
}
- err := suite.status.ProcessMentions(form, creatingAccount.ID, status)
+ err := suite.status.ProcessMentions(context.Background(), form, creatingAccount.ID, status)
assert.NoError(suite.T(), err)
assert.Empty(suite.T(), status.Content) // shouldn't be set yet
@@ -337,7 +356,7 @@ func (suite *UtilTestSuite) TestProcessContentPartial2() {
ACTUAL TEST
*/
- err = suite.status.ProcessContent(form, creatingAccount.ID, status)
+ err = suite.status.ProcessContent(context.Background(), form, creatingAccount.ID, status)
assert.NoError(suite.T(), err)
fmt.Println(status.Content)
diff --git a/internal/processing/streaming.go b/internal/processing/streaming.go
index 457db0576..e1c134d00 100644
--- a/internal/processing/streaming.go
+++ b/internal/processing/streaming.go
@@ -19,14 +19,16 @@
package processing
import (
+ "context"
+
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
-func (p *processor) AuthorizeStreamingRequest(accessToken string) (*gtsmodel.Account, error) {
- return p.streamingProcessor.AuthorizeStreamingRequest(accessToken)
+func (p *processor) AuthorizeStreamingRequest(ctx context.Context, accessToken string) (*gtsmodel.Account, error) {
+ return p.streamingProcessor.AuthorizeStreamingRequest(ctx, accessToken)
}
-func (p *processor) OpenStreamForAccount(account *gtsmodel.Account, streamType string) (*gtsmodel.Stream, gtserror.WithCode) {
- return p.streamingProcessor.OpenStreamForAccount(account, streamType)
+func (p *processor) OpenStreamForAccount(ctx context.Context, account *gtsmodel.Account, streamType string) (*gtsmodel.Stream, gtserror.WithCode) {
+ return p.streamingProcessor.OpenStreamForAccount(ctx, account, streamType)
}
diff --git a/internal/processing/streaming/authorize.go b/internal/processing/streaming/authorize.go
index 8bbf1856d..f938a0c0c 100644
--- a/internal/processing/streaming/authorize.go
+++ b/internal/processing/streaming/authorize.go
@@ -7,7 +7,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
-func (p *processor) AuthorizeStreamingRequest(accessToken string) (*gtsmodel.Account, error) {
+func (p *processor) AuthorizeStreamingRequest(ctx context.Context, accessToken string) (*gtsmodel.Account, error) {
ti, err := p.oauthServer.LoadAccessToken(context.Background(), accessToken)
if err != nil {
return nil, fmt.Errorf("AuthorizeStreamingRequest: error loading access token: %s", err)
@@ -20,12 +20,12 @@ func (p *processor) AuthorizeStreamingRequest(accessToken string) (*gtsmodel.Acc
// fetch user's and account for this user id
user := &gtsmodel.User{}
- if err := p.db.GetByID(uid, user); err != nil || user == nil {
+ if err := p.db.GetByID(ctx, uid, user); err != nil || user == nil {
return nil, fmt.Errorf("AuthorizeStreamingRequest: no user found for validated uid %s", uid)
}
- acct := &gtsmodel.Account{}
- if err := p.db.GetByID(user.AccountID, acct); err != nil || acct == nil {
+ acct, err := p.db.GetAccountByID(ctx, user.AccountID)
+ if err != nil || acct == nil {
return nil, fmt.Errorf("AuthorizeStreamingRequest: no account retrieved for user with id %s", uid)
}
diff --git a/internal/processing/streaming/openstream.go b/internal/processing/streaming/openstream.go
index 68446bac6..dfad5398e 100644
--- a/internal/processing/streaming/openstream.go
+++ b/internal/processing/streaming/openstream.go
@@ -1,6 +1,7 @@
package streaming
import (
+ "context"
"errors"
"fmt"
@@ -10,7 +11,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/id"
)
-func (p *processor) OpenStreamForAccount(account *gtsmodel.Account, streamType string) (*gtsmodel.Stream, gtserror.WithCode) {
+func (p *processor) OpenStreamForAccount(ctx context.Context, account *gtsmodel.Account, streamType string) (*gtsmodel.Stream, gtserror.WithCode) {
l := p.log.WithFields(logrus.Fields{
"func": "OpenStreamForAccount",
"account": account.ID,
diff --git a/internal/processing/streaming/streaming.go b/internal/processing/streaming/streaming.go
index de75b8f27..f349a655a 100644
--- a/internal/processing/streaming/streaming.go
+++ b/internal/processing/streaming/streaming.go
@@ -1,6 +1,7 @@
package streaming
import (
+ "context"
"sync"
"github.com/sirupsen/logrus"
@@ -17,9 +18,9 @@ import (
// Processor wraps a bunch of functions for processing streaming.
type Processor interface {
// AuthorizeStreamingRequest returns an oauth2 token info in response to an access token query from the streaming API
- AuthorizeStreamingRequest(accessToken string) (*gtsmodel.Account, error)
+ AuthorizeStreamingRequest(ctx context.Context, accessToken string) (*gtsmodel.Account, error)
// OpenStreamForAccount returns a new Stream for the given account, which will contain a channel for passing messages back to the caller.
- OpenStreamForAccount(account *gtsmodel.Account, streamType string) (*gtsmodel.Stream, gtserror.WithCode)
+ OpenStreamForAccount(ctx context.Context, account *gtsmodel.Account, streamType string) (*gtsmodel.Stream, gtserror.WithCode)
// StreamStatusToAccount streams the given status to any open, appropriate streams belonging to the given account.
StreamStatusToAccount(s *apimodel.Status, account *gtsmodel.Account) error
// StreamNotificationToAccount streams the given notification to any open, appropriate streams belonging to the given account.
diff --git a/internal/processing/timeline.go b/internal/processing/timeline.go
index afddd3e6c..6a409a6cc 100644
--- a/internal/processing/timeline.go
+++ b/internal/processing/timeline.go
@@ -19,6 +19,7 @@
package processing
import (
+ "context"
"fmt"
"net/url"
@@ -58,8 +59,8 @@ func (p *processor) packageStatusResponse(statuses []*apimodel.Status, path stri
return resp, nil
}
-func (p *processor) HomeTimelineGet(authed *oauth.Auth, maxID string, sinceID string, minID string, limit int, local bool) (*apimodel.StatusTimelineResponse, gtserror.WithCode) {
- statuses, err := p.timelineManager.HomeTimeline(authed.Account.ID, maxID, sinceID, minID, limit, local)
+func (p *processor) HomeTimelineGet(ctx context.Context, authed *oauth.Auth, maxID string, sinceID string, minID string, limit int, local bool) (*apimodel.StatusTimelineResponse, gtserror.WithCode) {
+ statuses, err := p.timelineManager.HomeTimeline(ctx, authed.Account.ID, maxID, sinceID, minID, limit, local)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
@@ -73,8 +74,8 @@ func (p *processor) HomeTimelineGet(authed *oauth.Auth, maxID string, sinceID st
return p.packageStatusResponse(statuses, "api/v1/timelines/home", statuses[len(statuses)-1].ID, statuses[0].ID, limit)
}
-func (p *processor) PublicTimelineGet(authed *oauth.Auth, maxID string, sinceID string, minID string, limit int, local bool) (*apimodel.StatusTimelineResponse, gtserror.WithCode) {
- statuses, err := p.db.GetPublicTimeline(authed.Account.ID, maxID, sinceID, minID, limit, local)
+func (p *processor) PublicTimelineGet(ctx context.Context, authed *oauth.Auth, maxID string, sinceID string, minID string, limit int, local bool) (*apimodel.StatusTimelineResponse, gtserror.WithCode) {
+ statuses, err := p.db.GetPublicTimeline(ctx, authed.Account.ID, maxID, sinceID, minID, limit, local)
if err != nil {
if err == db.ErrNoEntries {
// there are just no entries left
@@ -86,16 +87,22 @@ func (p *processor) PublicTimelineGet(authed *oauth.Auth, maxID string, sinceID
return nil, gtserror.NewErrorInternalError(err)
}
- s, err := p.filterPublicStatuses(authed, statuses)
+ s, err := p.filterPublicStatuses(ctx, authed, statuses)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
+ if len(s) == 0 {
+ return &apimodel.StatusTimelineResponse{
+ Statuses: []*apimodel.Status{},
+ }, nil
+ }
+
return p.packageStatusResponse(s, "api/v1/timelines/public", s[len(s)-1].ID, s[0].ID, limit)
}
-func (p *processor) FavedTimelineGet(authed *oauth.Auth, maxID string, minID string, limit int) (*apimodel.StatusTimelineResponse, gtserror.WithCode) {
- statuses, nextMaxID, prevMinID, err := p.db.GetFavedTimeline(authed.Account.ID, maxID, minID, limit)
+func (p *processor) FavedTimelineGet(ctx context.Context, authed *oauth.Auth, maxID string, minID string, limit int) (*apimodel.StatusTimelineResponse, gtserror.WithCode) {
+ statuses, nextMaxID, prevMinID, err := p.db.GetFavedTimeline(ctx, authed.Account.ID, maxID, minID, limit)
if err != nil {
if err == db.ErrNoEntries {
// there are just no entries left
@@ -107,21 +114,27 @@ func (p *processor) FavedTimelineGet(authed *oauth.Auth, maxID string, minID str
return nil, gtserror.NewErrorInternalError(err)
}
- s, err := p.filterFavedStatuses(authed, statuses)
+ s, err := p.filterFavedStatuses(ctx, authed, statuses)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
+ if len(s) == 0 {
+ return &apimodel.StatusTimelineResponse{
+ Statuses: []*apimodel.Status{},
+ }, nil
+ }
+
return p.packageStatusResponse(s, "api/v1/favourites", nextMaxID, prevMinID, limit)
}
-func (p *processor) filterPublicStatuses(authed *oauth.Auth, statuses []*gtsmodel.Status) ([]*apimodel.Status, error) {
+func (p *processor) filterPublicStatuses(ctx context.Context, authed *oauth.Auth, statuses []*gtsmodel.Status) ([]*apimodel.Status, error) {
l := p.log.WithField("func", "filterPublicStatuses")
apiStatuses := []*apimodel.Status{}
for _, s := range statuses {
targetAccount := &gtsmodel.Account{}
- if err := p.db.GetByID(s.AccountID, targetAccount); err != nil {
+ if err := p.db.GetByID(ctx, s.AccountID, targetAccount); err != nil {
if err == db.ErrNoEntries {
l.Debugf("filterPublicStatuses: skipping status %s because account %s can't be found in the db", s.ID, s.AccountID)
continue
@@ -129,7 +142,7 @@ func (p *processor) filterPublicStatuses(authed *oauth.Auth, statuses []*gtsmode
return nil, gtserror.NewErrorInternalError(fmt.Errorf("filterPublicStatuses: error getting status author: %s", err))
}
- timelineable, err := p.filter.StatusPublictimelineable(s, authed.Account)
+ timelineable, err := p.filter.StatusPublictimelineable(ctx, s, authed.Account)
if err != nil {
l.Debugf("filterPublicStatuses: skipping status %s because of an error checking status visibility: %s", s.ID, err)
continue
@@ -138,7 +151,7 @@ func (p *processor) filterPublicStatuses(authed *oauth.Auth, statuses []*gtsmode
continue
}
- apiStatus, err := p.tc.StatusToMasto(s, authed.Account)
+ apiStatus, err := p.tc.StatusToMasto(ctx, s, authed.Account)
if err != nil {
l.Debugf("filterPublicStatuses: skipping status %s because it couldn't be converted to its mastodon representation: %s", s.ID, err)
continue
@@ -150,13 +163,13 @@ func (p *processor) filterPublicStatuses(authed *oauth.Auth, statuses []*gtsmode
return apiStatuses, nil
}
-func (p *processor) filterFavedStatuses(authed *oauth.Auth, statuses []*gtsmodel.Status) ([]*apimodel.Status, error) {
+func (p *processor) filterFavedStatuses(ctx context.Context, authed *oauth.Auth, statuses []*gtsmodel.Status) ([]*apimodel.Status, error) {
l := p.log.WithField("func", "filterFavedStatuses")
apiStatuses := []*apimodel.Status{}
for _, s := range statuses {
targetAccount := &gtsmodel.Account{}
- if err := p.db.GetByID(s.AccountID, targetAccount); err != nil {
+ if err := p.db.GetByID(ctx, s.AccountID, targetAccount); err != nil {
if err == db.ErrNoEntries {
l.Debugf("filterFavedStatuses: skipping status %s because account %s can't be found in the db", s.ID, s.AccountID)
continue
@@ -164,7 +177,7 @@ func (p *processor) filterFavedStatuses(authed *oauth.Auth, statuses []*gtsmodel
return nil, gtserror.NewErrorInternalError(fmt.Errorf("filterPublicStatuses: error getting status author: %s", err))
}
- timelineable, err := p.filter.StatusVisible(s, authed.Account)
+ timelineable, err := p.filter.StatusVisible(ctx, s, authed.Account)
if err != nil {
l.Debugf("filterFavedStatuses: skipping status %s because of an error checking status visibility: %s", s.ID, err)
continue
@@ -173,7 +186,7 @@ func (p *processor) filterFavedStatuses(authed *oauth.Auth, statuses []*gtsmodel
continue
}
- apiStatus, err := p.tc.StatusToMasto(s, authed.Account)
+ apiStatus, err := p.tc.StatusToMasto(ctx, s, authed.Account)
if err != nil {
l.Debugf("filterFavedStatuses: skipping status %s because it couldn't be converted to its mastodon representation: %s", s.ID, err)
continue
diff --git a/internal/router/router.go b/internal/router/router.go
index c5f105448..621d93ff5 100644
--- a/internal/router/router.go
+++ b/internal/router/router.go
@@ -103,7 +103,7 @@ func (r *router) Stop(ctx context.Context) error {
//
// The given DB is only used in the New function for parsing config values, and is not otherwise
// pinned to the router.
-func New(cfg *config.Config, db db.DB, logger *logrus.Logger) (Router, error) {
+func New(ctx context.Context, cfg *config.Config, db db.DB, logger *logrus.Logger) (Router, error) {
// gin has different log modes; for convenience, we match the gin log mode to
// whatever log mode has been set for logrus
@@ -141,7 +141,7 @@ func New(cfg *config.Config, db db.DB, logger *logrus.Logger) (Router, error) {
}
// enable session store middleware on the engine
- if err := useSession(cfg, db, engine); err != nil {
+ if err := useSession(ctx, cfg, db, engine); err != nil {
return nil, err
}
diff --git a/internal/router/session.go b/internal/router/session.go
index 38810572f..4359a8a60 100644
--- a/internal/router/session.go
+++ b/internal/router/session.go
@@ -19,7 +19,7 @@
package router
import (
- "crypto/rand"
+ "context"
"errors"
"fmt"
"net/http"
@@ -29,8 +29,6 @@ import (
"github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
- "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
- "github.com/superseriousbusiness/gotosocial/internal/id"
)
// SessionOptions returns the standard set of options to use for each session.
@@ -45,34 +43,23 @@ func SessionOptions(cfg *config.Config) sessions.Options {
}
}
-func useSession(cfg *config.Config, dbService db.DB, engine *gin.Engine) error {
+func useSession(ctx context.Context, cfg *config.Config, sessionDB db.Session, engine *gin.Engine) error {
// check if we have a saved router session already
- routerSessions := []*gtsmodel.RouterSession{}
- if err := dbService.GetAll(&routerSessions); err != nil {
+ rs, err := sessionDB.GetSession(ctx)
+ if err != nil {
if err != db.ErrNoEntries {
// proper error occurred
return err
}
- }
-
- var rs *gtsmodel.RouterSession
- if len(routerSessions) == 1 {
- // we have a router session stored
- rs = routerSessions[0]
- } else if len(routerSessions) == 0 {
- // we have no router sessions so we need to create a new one
- var err error
- rs, err = routerSession(dbService)
+ // no session saved so create a new one
+ rs, err = sessionDB.CreateSession(ctx)
if err != nil {
- return fmt.Errorf("error creating new router session: %s", err)
+ return err
}
- } else {
- // we should only have one router session stored ever
- return errors.New("we had more than one router session in the db")
}
if rs == nil {
- return errors.New("error getting or creating router session: router session was nil")
+ return errors.New("router session was nil")
}
store := memstore.NewStore(rs.Auth, rs.Crypt)
@@ -81,34 +68,3 @@ func useSession(cfg *config.Config, dbService db.DB, engine *gin.Engine) error {
engine.Use(sessions.Sessions(sessionName, store))
return nil
}
-
-// routerSession generates a new router session with random auth and crypt bytes,
-// puts it in the database for persistence, and returns it for use.
-func routerSession(dbService db.DB) (*gtsmodel.RouterSession, error) {
- auth := make([]byte, 32)
- crypt := make([]byte, 32)
-
- if _, err := rand.Read(auth); err != nil {
- return nil, err
- }
- if _, err := rand.Read(crypt); err != nil {
- return nil, err
- }
-
- rid, err := id.NewULID()
- if err != nil {
- return nil, err
- }
-
- rs := &gtsmodel.RouterSession{
- ID: rid,
- Auth: auth,
- Crypt: crypt,
- }
-
- if err := dbService.Put(rs); err != nil {
- return nil, err
- }
-
- return rs, nil
-}
diff --git a/internal/text/common.go b/internal/text/common.go
index af77521dd..a8d585a09 100644
--- a/internal/text/common.go
+++ b/internal/text/common.go
@@ -19,6 +19,7 @@
package text
import (
+ "context"
"fmt"
"html"
"strings"
@@ -59,7 +60,7 @@ func postformat(in string) string {
return mini
}
-func (f *formatter) ReplaceTags(in string, tags []*gtsmodel.Tag) string {
+func (f *formatter) ReplaceTags(ctx context.Context, in string, tags []*gtsmodel.Tag) string {
return util.HashtagFinderRegex.ReplaceAllStringFunc(in, func(match string) string {
// we have a match
matchTrimmed := strings.TrimSpace(match)
@@ -88,7 +89,7 @@ func (f *formatter) ReplaceTags(in string, tags []*gtsmodel.Tag) string {
})
}
-func (f *formatter) ReplaceMentions(in string, mentions []*gtsmodel.Mention) string {
+func (f *formatter) ReplaceMentions(ctx context.Context, in string, mentions []*gtsmodel.Mention) string {
for _, menchie := range mentions {
// make sure we have a target account, either by getting one pinned on the mention,
// or by pulling it from the database
@@ -97,8 +98,8 @@ func (f *formatter) ReplaceMentions(in string, mentions []*gtsmodel.Mention) str
// got it from the mention
targetAccount = menchie.OriginAccount
} else {
- a := &gtsmodel.Account{}
- if err := f.db.GetByID(menchie.TargetAccountID, a); err == nil {
+ a, err := f.db.GetAccountByID(ctx, menchie.TargetAccountID)
+ if err == nil {
// got it from the db
targetAccount = a
} else {
diff --git a/internal/text/common_test.go b/internal/text/common_test.go
index 69fe7d446..174b79177 100644
--- a/internal/text/common_test.go
+++ b/internal/text/common_test.go
@@ -19,6 +19,7 @@
package text_test
import (
+ "context"
"testing"
"github.com/stretchr/testify/assert"
@@ -87,7 +88,7 @@ func (suite *CommonTestSuite) TestReplaceMentions() {
suite.testMentions["zork_mention_foss_satan"],
}
- f := suite.formatter.ReplaceMentions(replaceMentionsString, foundMentions)
+ f := suite.formatter.ReplaceMentions(context.Background(), replaceMentionsString, foundMentions)
assert.Equal(suite.T(), replaceMentionsExpected, f)
}
@@ -96,7 +97,7 @@ func (suite *CommonTestSuite) TestReplaceHashtags() {
suite.testTags["Hashtag"],
}
- f := suite.formatter.ReplaceTags(replaceMentionsString, foundTags)
+ f := suite.formatter.ReplaceTags(context.Background(), replaceMentionsString, foundTags)
assert.Equal(suite.T(), replaceHashtagsExpected, f)
}
@@ -106,7 +107,7 @@ func (suite *CommonTestSuite) TestReplaceHashtagsAfterReplaceMentions() {
suite.testTags["Hashtag"],
}
- f := suite.formatter.ReplaceTags(replaceMentionsExpected, foundTags)
+ f := suite.formatter.ReplaceTags(context.Background(), replaceMentionsExpected, foundTags)
assert.Equal(suite.T(), replaceHashtagsAfterMentionsExpected, f)
}
diff --git a/internal/text/formatter.go b/internal/text/formatter.go
index 39aaae559..769ecafbb 100644
--- a/internal/text/formatter.go
+++ b/internal/text/formatter.go
@@ -19,6 +19,8 @@
package text
import (
+ "context"
+
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
@@ -28,16 +30,16 @@ import (
// Formatter wraps some logic and functions for parsing statuses and other text input into nice html.
type Formatter interface {
// FromMarkdown parses an HTML text from a markdown-formatted text.
- FromMarkdown(md string, mentions []*gtsmodel.Mention, tags []*gtsmodel.Tag) string
+ FromMarkdown(ctx context.Context, md string, mentions []*gtsmodel.Mention, tags []*gtsmodel.Tag) string
// FromPlain parses an HTML text from a plaintext.
- FromPlain(plain string, mentions []*gtsmodel.Mention, tags []*gtsmodel.Tag) string
+ FromPlain(ctx context.Context, plain string, mentions []*gtsmodel.Mention, tags []*gtsmodel.Tag) string
// ReplaceTags takes a piece of text and a slice of tags, and returns the same text with the tags nicely formatted as hrefs.
- ReplaceTags(in string, tags []*gtsmodel.Tag) string
+ ReplaceTags(ctx context.Context, in string, tags []*gtsmodel.Tag) string
// ReplaceMentions takes a piece of text and a slice of mentions, and returns the same text with the mentions nicely formatted as hrefs.
- ReplaceMentions(in string, mentions []*gtsmodel.Mention) string
+ ReplaceMentions(ctx context.Context, in string, mentions []*gtsmodel.Mention) string
// ReplaceLinks takes a piece of text, finds all recognizable links in that text, and replaces them with hrefs.
- ReplaceLinks(in string) string
+ ReplaceLinks(ctx context.Context, in string) string
}
type formatter struct {
diff --git a/internal/text/link.go b/internal/text/link.go
index d42cc3b68..0a0f0c60d 100644
--- a/internal/text/link.go
+++ b/internal/text/link.go
@@ -19,6 +19,7 @@
package text
import (
+ "context"
"fmt"
"net/url"
@@ -82,7 +83,7 @@ func contains(urls []*url.URL, url *url.URL) bool {
// Note: because Go doesn't allow negative lookbehinds in regex, it's possible that an already-formatted
// href will end up double-formatted, if the text you pass here contains one or more hrefs already.
// To avoid this, you should sanitize any HTML out of text before you pass it into this function.
-func (f *formatter) ReplaceLinks(in string) string {
+func (f *formatter) ReplaceLinks(ctx context.Context, in string) string {
rxStrict, err := xurls.StrictMatchingScheme(schemes)
if err != nil {
panic(err)
diff --git a/internal/text/link_test.go b/internal/text/link_test.go
index 83c42f045..f8d6a1adc 100644
--- a/internal/text/link_test.go
+++ b/internal/text/link_test.go
@@ -19,6 +19,7 @@
package text_test
import (
+ "context"
"testing"
"github.com/stretchr/testify/assert"
@@ -94,7 +95,7 @@ func (suite *LinkTestSuite) TearDownTest() {
}
func (suite *LinkTestSuite) TestParseSimple() {
- f := suite.formatter.FromPlain(simple, nil, nil)
+ f := suite.formatter.FromPlain(context.Background(), simple, nil, nil)
assert.Equal(suite.T(), simpleExpected, f)
}
@@ -126,7 +127,7 @@ func (suite *LinkTestSuite) TestParseURLsFromText3() {
}
func (suite *LinkTestSuite) TestReplaceLinksFromText1() {
- replaced := suite.formatter.ReplaceLinks(text1)
+ replaced := suite.formatter.ReplaceLinks(context.Background(), text1)
assert.Equal(suite.T(), `
This is a text with some links in it. Here's link number one: <a href="https://example.org/link/to/something#fragment" rel="noopener">example.org/link/to/something#fragment</a>
@@ -141,7 +142,7 @@ really.cool.website <-- this one shouldn't be parsed as a link because it doesn'
}
func (suite *LinkTestSuite) TestReplaceLinksFromText2() {
- replaced := suite.formatter.ReplaceLinks(text2)
+ replaced := suite.formatter.ReplaceLinks(context.Background(), text2)
assert.Equal(suite.T(), `
this is one link: <a href="https://example.org" rel="noopener">example.org</a>
@@ -153,14 +154,14 @@ these should be deduplicated
func (suite *LinkTestSuite) TestReplaceLinksFromText3() {
// we know mailto links won't be replaced with hrefs -- we only accept https and http
- replaced := suite.formatter.ReplaceLinks(text3)
+ replaced := suite.formatter.ReplaceLinks(context.Background(), text3)
assert.Equal(suite.T(), `
here's a mailto link: mailto:whatever@test.org
`, replaced)
}
func (suite *LinkTestSuite) TestReplaceLinksFromText4() {
- replaced := suite.formatter.ReplaceLinks(text4)
+ replaced := suite.formatter.ReplaceLinks(context.Background(), text4)
assert.Equal(suite.T(), `
two similar links:
@@ -172,7 +173,7 @@ two similar links:
func (suite *LinkTestSuite) TestReplaceLinksFromText5() {
// we know this one doesn't work properly, which is why html should always be sanitized before being passed into the ReplaceLinks function
- replaced := suite.formatter.ReplaceLinks(text5)
+ replaced := suite.formatter.ReplaceLinks(context.Background(), text5)
assert.Equal(suite.T(), `
what happens when we already have a link within an href?
diff --git a/internal/text/markdown.go b/internal/text/markdown.go
index 5a7603615..eeeae0edf 100644
--- a/internal/text/markdown.go
+++ b/internal/text/markdown.go
@@ -19,21 +19,23 @@
package text
import (
+ "context"
+
"github.com/russross/blackfriday/v2"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
-func (f *formatter) FromMarkdown(md string, mentions []*gtsmodel.Mention, tags []*gtsmodel.Tag) string {
+func (f *formatter) FromMarkdown(ctx context.Context, md string, mentions []*gtsmodel.Mention, tags []*gtsmodel.Tag) string {
content := preformat(md)
// do the markdown parsing *first*
contentBytes := blackfriday.Run([]byte(content))
// format tags nicely
- content = f.ReplaceTags(string(contentBytes), tags)
+ content = f.ReplaceTags(ctx, string(contentBytes), tags)
// format mentions nicely
- content = f.ReplaceMentions(content, mentions)
+ content = f.ReplaceMentions(ctx, content, mentions)
return postformat(content)
}
diff --git a/internal/text/markdown_test.go b/internal/text/markdown_test.go
index 432e9a4ec..085f211d2 100644
--- a/internal/text/markdown_test.go
+++ b/internal/text/markdown_test.go
@@ -19,6 +19,7 @@
package text_test
import (
+ "context"
"fmt"
"testing"
@@ -92,13 +93,13 @@ func (suite *MarkdownTestSuite) TearDownTest() {
}
func (suite *MarkdownTestSuite) TestParseSimple() {
- s := suite.formatter.FromMarkdown(simpleMarkdown, nil, nil)
+ s := suite.formatter.FromMarkdown(context.Background(), simpleMarkdown, nil, nil)
suite.Equal(simpleMarkdownExpected, s)
}
func (suite *MarkdownTestSuite) TestParseWithCodeBlock() {
fmt.Println(withCodeBlock)
- s := suite.formatter.FromMarkdown(withCodeBlock, nil, nil)
+ s := suite.formatter.FromMarkdown(context.Background(), withCodeBlock, nil, nil)
suite.Equal(withCodeBlockExpected, s)
}
@@ -107,7 +108,7 @@ func (suite *MarkdownTestSuite) TestParseWithHashtag() {
suite.testTags["Hashtag"],
}
- s := suite.formatter.FromMarkdown(withHashtag, nil, foundTags)
+ s := suite.formatter.FromMarkdown(context.Background(), withHashtag, nil, foundTags)
suite.Equal(withHashtagExpected, s)
}
diff --git a/internal/text/plain.go b/internal/text/plain.go
index a44e02c80..34cc3fa06 100644
--- a/internal/text/plain.go
+++ b/internal/text/plain.go
@@ -19,26 +19,27 @@
package text
import (
+ "context"
"fmt"
"strings"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
-func (f *formatter) FromPlain(plain string, mentions []*gtsmodel.Mention, tags []*gtsmodel.Tag) string {
+func (f *formatter) FromPlain(ctx context.Context, plain string, mentions []*gtsmodel.Mention, tags []*gtsmodel.Tag) string {
content := preformat(plain)
// sanitize any html elements
content = RemoveHTML(content)
// format links nicely
- content = f.ReplaceLinks(content)
+ content = f.ReplaceLinks(ctx, content)
// format tags nicely
- content = f.ReplaceTags(content, tags)
+ content = f.ReplaceTags(ctx, content, tags)
// format mentions nicely
- content = f.ReplaceMentions(content, mentions)
+ content = f.ReplaceMentions(ctx, content, mentions)
// replace newlines with breaks
content = strings.ReplaceAll(content, "\n", "<br />")
diff --git a/internal/text/plain_test.go b/internal/text/plain_test.go
index 33c95234c..62c43406d 100644
--- a/internal/text/plain_test.go
+++ b/internal/text/plain_test.go
@@ -19,6 +19,7 @@
package text_test
import (
+ "context"
"fmt"
"testing"
@@ -74,7 +75,7 @@ func (suite *PlainTestSuite) TearDownTest() {
}
func (suite *PlainTestSuite) TestParseSimple() {
- f := suite.formatter.FromPlain(simple, nil, nil)
+ f := suite.formatter.FromPlain(context.Background(), simple, nil, nil)
assert.Equal(suite.T(), simpleExpected, f)
}
@@ -84,7 +85,7 @@ func (suite *PlainTestSuite) TestParseWithTag() {
suite.testTags["welcome"],
}
- f := suite.formatter.FromPlain(withTag, nil, foundTags)
+ f := suite.formatter.FromPlain(context.Background(), withTag, nil, foundTags)
assert.Equal(suite.T(), withTagExpected, f)
}
@@ -98,7 +99,7 @@ func (suite *PlainTestSuite) TestParseMoreComplex() {
suite.testMentions["zork_mention_foss_satan"],
}
- f := suite.formatter.FromPlain(moreComplex, foundMentions, foundTags)
+ f := suite.formatter.FromPlain(context.Background(), moreComplex, foundMentions, foundTags)
fmt.Println(f)
diff --git a/internal/timeline/get.go b/internal/timeline/get.go
index d800da4e3..a00613dc0 100644
--- a/internal/timeline/get.go
+++ b/internal/timeline/get.go
@@ -20,6 +20,7 @@ package timeline
import (
"container/list"
+ "context"
"errors"
"fmt"
@@ -29,7 +30,7 @@ import (
const retries = 5
-func (t *timeline) Get(amount int, maxID string, sinceID string, minID string, prepareNext bool) ([]*apimodel.Status, error) {
+func (t *timeline) Get(ctx context.Context, amount int, maxID string, sinceID string, minID string, prepareNext bool) ([]*apimodel.Status, error) {
l := t.log.WithFields(logrus.Fields{
"func": "Get",
"accountID": t.accountID,
@@ -46,14 +47,15 @@ func (t *timeline) Get(amount int, maxID string, sinceID string, minID string, p
// no params are defined to just fetch from the top
// this is equivalent to a user asking for the top x posts from their timeline
if maxID == "" && sinceID == "" && minID == "" {
- statuses, err = t.GetXFromTop(amount)
+ statuses, err = t.GetXFromTop(ctx, amount)
// aysnchronously prepare the next predicted query so it's ready when the user asks for it
if len(statuses) != 0 {
nextMaxID := statuses[len(statuses)-1].ID
if prepareNext {
// already cache the next query to speed up scrolling
go func() {
- if err := t.prepareNextQuery(amount, nextMaxID, "", ""); err != nil {
+ // use context.Background() because we don't want the query to abort when the request finishes
+ if err := t.prepareNextQuery(context.Background(), amount, nextMaxID, "", ""); err != nil {
l.Errorf("error preparing next query: %s", err)
}
}()
@@ -65,14 +67,15 @@ func (t *timeline) Get(amount int, maxID string, sinceID string, minID string, p
// this is equivalent to a user asking for the next x posts from their timeline, starting from maxID
if maxID != "" && sinceID == "" {
attempts := 0
- statuses, err = t.GetXBehindID(amount, maxID, &attempts)
+ statuses, err = t.GetXBehindID(ctx, amount, maxID, &attempts)
// aysnchronously prepare the next predicted query so it's ready when the user asks for it
if len(statuses) != 0 {
nextMaxID := statuses[len(statuses)-1].ID
if prepareNext {
// already cache the next query to speed up scrolling
go func() {
- if err := t.prepareNextQuery(amount, nextMaxID, "", ""); err != nil {
+ // use context.Background() because we don't want the query to abort when the request finishes
+ if err := t.prepareNextQuery(context.Background(), amount, nextMaxID, "", ""); err != nil {
l.Errorf("error preparing next query: %s", err)
}
}()
@@ -83,25 +86,25 @@ func (t *timeline) Get(amount int, maxID string, sinceID string, minID string, p
// maxID is defined and sinceID || minID are as well, so take a slice between them
// this is equivalent to a user asking for posts older than x but newer than y
if maxID != "" && sinceID != "" {
- statuses, err = t.GetXBetweenID(amount, maxID, minID)
+ statuses, err = t.GetXBetweenID(ctx, amount, maxID, minID)
}
if maxID != "" && minID != "" {
- statuses, err = t.GetXBetweenID(amount, maxID, minID)
+ statuses, err = t.GetXBetweenID(ctx, amount, maxID, minID)
}
// maxID isn't defined, but sinceID || minID are, so take x before
// this is equivalent to a user asking for posts newer than x (eg., refreshing the top of their timeline)
if maxID == "" && sinceID != "" {
- statuses, err = t.GetXBeforeID(amount, sinceID, true)
+ statuses, err = t.GetXBeforeID(ctx, amount, sinceID, true)
}
if maxID == "" && minID != "" {
- statuses, err = t.GetXBeforeID(amount, minID, true)
+ statuses, err = t.GetXBeforeID(ctx, amount, minID, true)
}
return statuses, err
}
-func (t *timeline) GetXFromTop(amount int) ([]*apimodel.Status, error) {
+func (t *timeline) GetXFromTop(ctx context.Context, amount int) ([]*apimodel.Status, error) {
// make a slice of statuses with the length we need to return
statuses := make([]*apimodel.Status, 0, amount)
@@ -111,7 +114,7 @@ func (t *timeline) GetXFromTop(amount int) ([]*apimodel.Status, error) {
// make sure we have enough posts prepared to return
if t.preparedPosts.data.Len() < amount {
- if err := t.PrepareFromTop(amount); err != nil {
+ if err := t.PrepareFromTop(ctx, amount); err != nil {
return nil, err
}
}
@@ -133,7 +136,7 @@ func (t *timeline) GetXFromTop(amount int) ([]*apimodel.Status, error) {
return statuses, nil
}
-func (t *timeline) GetXBehindID(amount int, behindID string, attempts *int) ([]*apimodel.Status, error) {
+func (t *timeline) GetXBehindID(ctx context.Context, amount int, behindID string, attempts *int) ([]*apimodel.Status, error) {
l := t.log.WithFields(logrus.Fields{
"func": "GetXBehindID",
"amount": amount,
@@ -174,10 +177,10 @@ findMarkLoop:
// we didn't find it, so we need to make sure it's indexed and prepared and then try again
// this can happen when a user asks for really old posts
if behindIDMark == nil {
- if err := t.PrepareBehind(behindID, amount); err != nil {
+ if err := t.PrepareBehind(ctx, behindID, amount); err != nil {
return nil, fmt.Errorf("GetXBehindID: error preparing behind and including ID %s", behindID)
}
- oldestID, err := t.OldestPreparedPostID()
+ oldestID, err := t.OldestPreparedPostID(ctx)
if err != nil {
return nil, err
}
@@ -194,12 +197,12 @@ findMarkLoop:
return statuses, nil
}
l.Trace("trying GetXBehindID again")
- return t.GetXBehindID(amount, behindID, attempts)
+ return t.GetXBehindID(ctx, amount, behindID, attempts)
}
// make sure we have enough posts prepared behind it to return what we're being asked for
if t.preparedPosts.data.Len() < amount+position {
- if err := t.PrepareBehind(behindID, amount); err != nil {
+ if err := t.PrepareBehind(ctx, behindID, amount); err != nil {
return nil, err
}
}
@@ -224,7 +227,7 @@ serveloop:
return statuses, nil
}
-func (t *timeline) GetXBeforeID(amount int, beforeID string, startFromTop bool) ([]*apimodel.Status, error) {
+func (t *timeline) GetXBeforeID(ctx context.Context, amount int, beforeID string, startFromTop bool) ([]*apimodel.Status, error) {
// make a slice of statuses with the length we need to return
statuses := make([]*apimodel.Status, 0, amount)
@@ -295,7 +298,7 @@ findMarkLoop:
return statuses, nil
}
-func (t *timeline) GetXBetweenID(amount int, behindID string, beforeID string) ([]*apimodel.Status, error) {
+func (t *timeline) GetXBetweenID(ctx context.Context, amount int, behindID string, beforeID string) ([]*apimodel.Status, error) {
// make a slice of statuses with the length we need to return
statuses := make([]*apimodel.Status, 0, amount)
@@ -327,7 +330,7 @@ findMarkLoop:
// make sure we have enough posts prepared behind it to return what we're being asked for
if t.preparedPosts.data.Len() < amount+position {
- if err := t.PrepareBehind(behindID, amount); err != nil {
+ if err := t.PrepareBehind(ctx, behindID, amount); err != nil {
return nil, err
}
}
diff --git a/internal/timeline/get_test.go b/internal/timeline/get_test.go
index 0866f3bdd..96c333c5f 100644
--- a/internal/timeline/get_test.go
+++ b/internal/timeline/get_test.go
@@ -19,6 +19,7 @@
package timeline_test
import (
+ "context"
"testing"
"time"
@@ -45,14 +46,14 @@ func (suite *GetTestSuite) SetupTest() {
testrig.StandardDBSetup(suite.db, nil)
// let's take local_account_1 as the timeline owner
- tl, err := timeline.NewTimeline(suite.testAccounts["local_account_1"].ID, suite.db, suite.tc, suite.log)
+ tl, err := timeline.NewTimeline(context.Background(), suite.testAccounts["local_account_1"].ID, suite.db, suite.tc, suite.log)
if err != nil {
suite.FailNow(err.Error())
}
// prepare the timeline by just shoving all test statuses in it -- let's not be fussy about who sees what
for _, s := range suite.testStatuses {
- _, err := tl.IndexAndPrepareOne(s.CreatedAt, s.ID, s.BoostOfID, s.AccountID, s.BoostOfAccountID)
+ _, err := tl.IndexAndPrepareOne(context.Background(), s.CreatedAt, s.ID, s.BoostOfID, s.AccountID, s.BoostOfAccountID)
if err != nil {
suite.FailNow(err.Error())
}
@@ -67,7 +68,7 @@ func (suite *GetTestSuite) TearDownTest() {
func (suite *GetTestSuite) TestGetDefault() {
// get 10 20 the top and don't prepare the next query
- statuses, err := suite.timeline.Get(20, "", "", "", false)
+ statuses, err := suite.timeline.Get(context.Background(), 20, "", "", "", false)
if err != nil {
suite.FailNow(err.Error())
}
@@ -89,7 +90,7 @@ func (suite *GetTestSuite) TestGetDefault() {
func (suite *GetTestSuite) TestGetDefaultPrepareNext() {
// get 10 from the top and prepare the next query
- statuses, err := suite.timeline.Get(10, "", "", "", true)
+ statuses, err := suite.timeline.Get(context.Background(), 10, "", "", "", true)
if err != nil {
suite.FailNow(err.Error())
}
@@ -113,7 +114,7 @@ func (suite *GetTestSuite) TestGetDefaultPrepareNext() {
func (suite *GetTestSuite) TestGetMaxID() {
// ask for 10 with a max ID somewhere in the middle of the stack
- statuses, err := suite.timeline.Get(10, "01F8MHBQCBTDKN6X5VHGMMN4MA", "", "", false)
+ statuses, err := suite.timeline.Get(context.Background(), 10, "01F8MHBQCBTDKN6X5VHGMMN4MA", "", "", false)
if err != nil {
suite.FailNow(err.Error())
}
@@ -135,7 +136,7 @@ func (suite *GetTestSuite) TestGetMaxID() {
func (suite *GetTestSuite) TestGetMaxIDPrepareNext() {
// ask for 10 with a max ID somewhere in the middle of the stack
- statuses, err := suite.timeline.Get(10, "01F8MHBQCBTDKN6X5VHGMMN4MA", "", "", true)
+ statuses, err := suite.timeline.Get(context.Background(), 10, "01F8MHBQCBTDKN6X5VHGMMN4MA", "", "", true)
if err != nil {
suite.FailNow(err.Error())
}
@@ -160,7 +161,7 @@ func (suite *GetTestSuite) TestGetMaxIDPrepareNext() {
func (suite *GetTestSuite) TestGetMinID() {
// ask for 10 with a min ID somewhere in the middle of the stack
- statuses, err := suite.timeline.Get(10, "", "01F8MHBQCBTDKN6X5VHGMMN4MA", "", false)
+ statuses, err := suite.timeline.Get(context.Background(), 10, "", "01F8MHBQCBTDKN6X5VHGMMN4MA", "", false)
if err != nil {
suite.FailNow(err.Error())
}
@@ -182,7 +183,7 @@ func (suite *GetTestSuite) TestGetMinID() {
func (suite *GetTestSuite) TestGetSinceID() {
// ask for 10 with a since ID somewhere in the middle of the stack
- statuses, err := suite.timeline.Get(10, "", "", "01F8MHBQCBTDKN6X5VHGMMN4MA", false)
+ statuses, err := suite.timeline.Get(context.Background(), 10, "", "", "01F8MHBQCBTDKN6X5VHGMMN4MA", false)
if err != nil {
suite.FailNow(err.Error())
}
@@ -204,7 +205,7 @@ func (suite *GetTestSuite) TestGetSinceID() {
func (suite *GetTestSuite) TestGetSinceIDPrepareNext() {
// ask for 10 with a since ID somewhere in the middle of the stack
- statuses, err := suite.timeline.Get(10, "", "", "01F8MHBQCBTDKN6X5VHGMMN4MA", true)
+ statuses, err := suite.timeline.Get(context.Background(), 10, "", "", "01F8MHBQCBTDKN6X5VHGMMN4MA", true)
if err != nil {
suite.FailNow(err.Error())
}
@@ -229,7 +230,7 @@ func (suite *GetTestSuite) TestGetSinceIDPrepareNext() {
func (suite *GetTestSuite) TestGetBetweenID() {
// ask for 10 between these two IDs
- statuses, err := suite.timeline.Get(10, "01F8MHCP5P2NWYQ416SBA0XSEV", "", "01F8MHBQCBTDKN6X5VHGMMN4MA", false)
+ statuses, err := suite.timeline.Get(context.Background(), 10, "01F8MHCP5P2NWYQ416SBA0XSEV", "", "01F8MHBQCBTDKN6X5VHGMMN4MA", false)
if err != nil {
suite.FailNow(err.Error())
}
@@ -251,7 +252,7 @@ func (suite *GetTestSuite) TestGetBetweenID() {
func (suite *GetTestSuite) TestGetBetweenIDPrepareNext() {
// ask for 10 between these two IDs
- statuses, err := suite.timeline.Get(10, "01F8MHCP5P2NWYQ416SBA0XSEV", "", "01F8MHBQCBTDKN6X5VHGMMN4MA", true)
+ statuses, err := suite.timeline.Get(context.Background(), 10, "01F8MHCP5P2NWYQ416SBA0XSEV", "", "01F8MHBQCBTDKN6X5VHGMMN4MA", true)
if err != nil {
suite.FailNow(err.Error())
}
@@ -276,7 +277,7 @@ func (suite *GetTestSuite) TestGetBetweenIDPrepareNext() {
func (suite *GetTestSuite) TestGetXFromTop() {
// get 5 from the top
- statuses, err := suite.timeline.GetXFromTop(5)
+ statuses, err := suite.timeline.GetXFromTop(context.Background(), 5)
if err != nil {
suite.FailNow(err.Error())
}
@@ -300,7 +301,7 @@ func (suite *GetTestSuite) TestGetXBehindID() {
var attempts *int
a := 0
attempts = &a
- statuses, err := suite.timeline.GetXBehindID(3, "01F8MHBQCBTDKN6X5VHGMMN4MA", attempts)
+ statuses, err := suite.timeline.GetXBehindID(context.Background(), 3, "01F8MHBQCBTDKN6X5VHGMMN4MA", attempts)
if err != nil {
suite.FailNow(err.Error())
}
@@ -326,7 +327,7 @@ func (suite *GetTestSuite) TestGetXBehindID0() {
var attempts *int
a := 0
attempts = &a
- statuses, err := suite.timeline.GetXBehindID(3, "0", attempts)
+ statuses, err := suite.timeline.GetXBehindID(context.Background(), 3, "0", attempts)
if err != nil {
suite.FailNow(err.Error())
}
@@ -340,7 +341,7 @@ func (suite *GetTestSuite) TestGetXBehindNonexistentReasonableID() {
var attempts *int
a := 0
attempts = &a
- statuses, err := suite.timeline.GetXBehindID(3, "01F8MHBQCBTDKN6X5VHGMMN4MB", attempts) // change the last A to a B
+ statuses, err := suite.timeline.GetXBehindID(context.Background(), 3, "01F8MHBQCBTDKN6X5VHGMMN4MB", attempts) // change the last A to a B
if err != nil {
suite.FailNow(err.Error())
}
@@ -365,7 +366,7 @@ func (suite *GetTestSuite) TestGetXBehindVeryHighID() {
var attempts *int
a := 0
attempts = &a
- statuses, err := suite.timeline.GetXBehindID(7, "9998MHBQCBTDKN6X5VHGMMN4MA", attempts)
+ statuses, err := suite.timeline.GetXBehindID(context.Background(), 7, "9998MHBQCBTDKN6X5VHGMMN4MA", attempts)
if err != nil {
suite.FailNow(err.Error())
}
@@ -389,7 +390,7 @@ func (suite *GetTestSuite) TestGetXBehindVeryHighID() {
func (suite *GetTestSuite) TestGetXBeforeID() {
// get 3 before the 'middle' id
- statuses, err := suite.timeline.GetXBeforeID(3, "01F8MHBQCBTDKN6X5VHGMMN4MA", true)
+ statuses, err := suite.timeline.GetXBeforeID(context.Background(), 3, "01F8MHBQCBTDKN6X5VHGMMN4MA", true)
if err != nil {
suite.FailNow(err.Error())
}
@@ -412,7 +413,7 @@ func (suite *GetTestSuite) TestGetXBeforeID() {
func (suite *GetTestSuite) TestGetXBeforeIDNoStartFromTop() {
// get 3 before the 'middle' id
- statuses, err := suite.timeline.GetXBeforeID(3, "01F8MHBQCBTDKN6X5VHGMMN4MA", false)
+ statuses, err := suite.timeline.GetXBeforeID(context.Background(), 3, "01F8MHBQCBTDKN6X5VHGMMN4MA", false)
if err != nil {
suite.FailNow(err.Error())
}
diff --git a/internal/timeline/index.go b/internal/timeline/index.go
index 7cffe7ab9..7d7dc8873 100644
--- a/internal/timeline/index.go
+++ b/internal/timeline/index.go
@@ -20,6 +20,7 @@ package timeline
import (
"container/list"
+ "context"
"errors"
"fmt"
"time"
@@ -29,7 +30,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
-func (t *timeline) IndexBefore(statusID string, include bool, amount int) error {
+func (t *timeline) IndexBefore(ctx context.Context, statusID string, include bool, amount int) error {
// lazily initialize index if it hasn't been done already
if t.postIndex.data == nil {
t.postIndex.data = &list.List{}
@@ -42,7 +43,7 @@ func (t *timeline) IndexBefore(statusID string, include bool, amount int) error
if include {
// if we have the status with given statusID in the database, include it in the results set as well
s := &gtsmodel.Status{}
- if err := t.db.GetByID(statusID, s); err == nil {
+ if err := t.db.GetByID(ctx, statusID, s); err == nil {
filtered = append(filtered, s)
}
}
@@ -50,7 +51,7 @@ func (t *timeline) IndexBefore(statusID string, include bool, amount int) error
i := 0
grabloop:
for ; len(filtered) < amount && i < 5; i = i + 1 { // try the grabloop 5 times only
- statuses, err := t.db.GetHomeTimeline(t.accountID, "", "", offsetStatus, amount, false)
+ statuses, err := t.db.GetHomeTimeline(ctx, t.accountID, "", "", offsetStatus, amount, false)
if err != nil {
if err == db.ErrNoEntries {
break grabloop // we just don't have enough statuses left in the db so index what we've got and then bail
@@ -59,7 +60,7 @@ grabloop:
}
for _, s := range statuses {
- timelineable, err := t.filter.StatusHometimelineable(s, t.account)
+ timelineable, err := t.filter.StatusHometimelineable(ctx, s, t.account)
if err != nil {
continue
}
@@ -71,7 +72,7 @@ grabloop:
}
for _, s := range filtered {
- if _, err := t.IndexOne(s.CreatedAt, s.ID, s.BoostOfID, s.AccountID, s.BoostOfAccountID); err != nil {
+ if _, err := t.IndexOne(ctx, s.CreatedAt, s.ID, s.BoostOfID, s.AccountID, s.BoostOfAccountID); err != nil {
return fmt.Errorf("IndexBefore: error indexing status with id %s: %s", s.ID, err)
}
}
@@ -79,7 +80,7 @@ grabloop:
return nil
}
-func (t *timeline) IndexBehind(statusID string, include bool, amount int) error {
+func (t *timeline) IndexBehind(ctx context.Context, statusID string, include bool, amount int) error {
l := t.log.WithFields(logrus.Fields{
"func": "IndexBehind",
"include": include,
@@ -121,7 +122,7 @@ positionLoop:
if include {
// if we have the status with given statusID in the database, include it in the results set as well
s := &gtsmodel.Status{}
- if err := t.db.GetByID(statusID, s); err == nil {
+ if err := t.db.GetByID(ctx, statusID, s); err == nil {
filtered = append(filtered, s)
}
}
@@ -130,7 +131,7 @@ positionLoop:
grabloop:
for ; len(filtered) < amount && i < 5; i = i + 1 { // try the grabloop 5 times only
l.Tracef("entering grabloop; i is %d; len(filtered) is %d", i, len(filtered))
- statuses, err := t.db.GetHomeTimeline(t.accountID, offsetStatus, "", "", amount, false)
+ statuses, err := t.db.GetHomeTimeline(ctx, t.accountID, offsetStatus, "", "", amount, false)
if err != nil {
if err == db.ErrNoEntries {
break grabloop // we just don't have enough statuses left in the db so index what we've got and then bail
@@ -140,7 +141,7 @@ grabloop:
l.Tracef("got %d statuses", len(statuses))
for _, s := range statuses {
- timelineable, err := t.filter.StatusHometimelineable(s, t.account)
+ timelineable, err := t.filter.StatusHometimelineable(ctx, s, t.account)
if err != nil {
l.Tracef("status was not hometimelineable: %s", err)
continue
@@ -154,7 +155,7 @@ grabloop:
l.Trace("left grabloop")
for _, s := range filtered {
- if _, err := t.IndexOne(s.CreatedAt, s.ID, s.BoostOfID, s.AccountID, s.BoostOfAccountID); err != nil {
+ if _, err := t.IndexOne(ctx, s.CreatedAt, s.ID, s.BoostOfID, s.AccountID, s.BoostOfAccountID); err != nil {
return fmt.Errorf("IndexBehind: error indexing status with id %s: %s", s.ID, err)
}
}
@@ -163,7 +164,7 @@ grabloop:
return nil
}
-func (t *timeline) IndexOne(statusCreatedAt time.Time, statusID string, boostOfID string, accountID string, boostOfAccountID string) (bool, error) {
+func (t *timeline) IndexOne(ctx context.Context, statusCreatedAt time.Time, statusID string, boostOfID string, accountID string, boostOfAccountID string) (bool, error) {
t.Lock()
defer t.Unlock()
@@ -177,7 +178,7 @@ func (t *timeline) IndexOne(statusCreatedAt time.Time, statusID string, boostOfI
return t.postIndex.insertIndexed(postIndexEntry)
}
-func (t *timeline) IndexAndPrepareOne(statusCreatedAt time.Time, statusID string, boostOfID string, accountID string, boostOfAccountID string) (bool, error) {
+func (t *timeline) IndexAndPrepareOne(ctx context.Context, statusCreatedAt time.Time, statusID string, boostOfID string, accountID string, boostOfAccountID string) (bool, error) {
t.Lock()
defer t.Unlock()
@@ -194,7 +195,7 @@ func (t *timeline) IndexAndPrepareOne(statusCreatedAt time.Time, statusID string
}
if inserted {
- if err := t.prepare(statusID); err != nil {
+ if err := t.prepare(ctx, statusID); err != nil {
return inserted, fmt.Errorf("IndexAndPrepareOne: error preparing: %s", err)
}
}
@@ -202,7 +203,7 @@ func (t *timeline) IndexAndPrepareOne(statusCreatedAt time.Time, statusID string
return inserted, nil
}
-func (t *timeline) OldestIndexedPostID() (string, error) {
+func (t *timeline) OldestIndexedPostID(ctx context.Context) (string, error) {
var id string
if t.postIndex == nil || t.postIndex.data == nil || t.postIndex.data.Back() == nil {
// return an empty string if postindex hasn't been initialized yet
@@ -217,7 +218,7 @@ func (t *timeline) OldestIndexedPostID() (string, error) {
return entry.statusID, nil
}
-func (t *timeline) NewestIndexedPostID() (string, error) {
+func (t *timeline) NewestIndexedPostID(ctx context.Context) (string, error) {
var id string
if t.postIndex == nil || t.postIndex.data == nil || t.postIndex.data.Front() == nil {
// return an empty string if postindex hasn't been initialized yet
diff --git a/internal/timeline/index_test.go b/internal/timeline/index_test.go
index 4201a27dd..25565a1de 100644
--- a/internal/timeline/index_test.go
+++ b/internal/timeline/index_test.go
@@ -19,6 +19,7 @@
package timeline_test
import (
+ "context"
"testing"
"time"
@@ -46,7 +47,7 @@ func (suite *IndexTestSuite) SetupTest() {
testrig.StandardDBSetup(suite.db, nil)
// let's take local_account_1 as the timeline owner, and start with an empty timeline
- tl, err := timeline.NewTimeline(suite.testAccounts["local_account_1"].ID, suite.db, suite.tc, suite.log)
+ tl, err := timeline.NewTimeline(context.Background(), suite.testAccounts["local_account_1"].ID, suite.db, suite.tc, suite.log)
if err != nil {
suite.FailNow(err.Error())
}
@@ -59,82 +60,82 @@ func (suite *IndexTestSuite) TearDownTest() {
func (suite *IndexTestSuite) TestIndexBeforeLowID() {
// index 10 before the lowest status ID possible
- err := suite.timeline.IndexBefore("00000000000000000000000000", true, 10)
+ err := suite.timeline.IndexBefore(context.Background(), "00000000000000000000000000", true, 10)
suite.NoError(err)
// the oldest indexed post should be the lowest one we have in our testrig
- postID, err := suite.timeline.OldestIndexedPostID()
+ postID, err := suite.timeline.OldestIndexedPostID(context.Background())
suite.NoError(err)
suite.Equal("01F8MHAAY43M6RJ473VQFCVH37", postID)
- indexLength := suite.timeline.PostIndexLength()
+ indexLength := suite.timeline.PostIndexLength(context.Background())
suite.Equal(10, indexLength)
}
func (suite *IndexTestSuite) TestIndexBeforeHighID() {
// index 10 before the highest status ID possible
- err := suite.timeline.IndexBefore("ZZZZZZZZZZZZZZZZZZZZZZZZZZ", true, 10)
+ err := suite.timeline.IndexBefore(context.Background(), "ZZZZZZZZZZZZZZZZZZZZZZZZZZ", true, 10)
suite.NoError(err)
// the oldest indexed post should be empty
- postID, err := suite.timeline.OldestIndexedPostID()
+ postID, err := suite.timeline.OldestIndexedPostID(context.Background())
suite.NoError(err)
suite.Empty(postID)
// indexLength should be 0
- indexLength := suite.timeline.PostIndexLength()
+ indexLength := suite.timeline.PostIndexLength(context.Background())
suite.Equal(0, indexLength)
}
func (suite *IndexTestSuite) TestIndexBehindHighID() {
// index 10 behind the highest status ID possible
- err := suite.timeline.IndexBehind("ZZZZZZZZZZZZZZZZZZZZZZZZZZ", true, 10)
+ err := suite.timeline.IndexBehind(context.Background(), "ZZZZZZZZZZZZZZZZZZZZZZZZZZ", true, 10)
suite.NoError(err)
// the newest indexed post should be the highest one we have in our testrig
- postID, err := suite.timeline.NewestIndexedPostID()
+ postID, err := suite.timeline.NewestIndexedPostID(context.Background())
suite.NoError(err)
suite.Equal("01FCTA44PW9H1TB328S9AQXKDS", postID)
// indexLength should be 10 because that's all this user has hometimelineable
- indexLength := suite.timeline.PostIndexLength()
+ indexLength := suite.timeline.PostIndexLength(context.Background())
suite.Equal(10, indexLength)
}
func (suite *IndexTestSuite) TestIndexBehindLowID() {
// index 10 behind the lowest status ID possible
- err := suite.timeline.IndexBehind("00000000000000000000000000", true, 10)
+ err := suite.timeline.IndexBehind(context.Background(), "00000000000000000000000000", true, 10)
suite.NoError(err)
// the newest indexed post should be empty
- postID, err := suite.timeline.NewestIndexedPostID()
+ postID, err := suite.timeline.NewestIndexedPostID(context.Background())
suite.NoError(err)
suite.Empty(postID)
// indexLength should be 0
- indexLength := suite.timeline.PostIndexLength()
+ indexLength := suite.timeline.PostIndexLength(context.Background())
suite.Equal(0, indexLength)
}
func (suite *IndexTestSuite) TestOldestIndexedPostIDEmpty() {
// the oldest indexed post should be an empty string since there's nothing indexed yet
- postID, err := suite.timeline.OldestIndexedPostID()
+ postID, err := suite.timeline.OldestIndexedPostID(context.Background())
suite.NoError(err)
suite.Empty(postID)
// indexLength should be 0
- indexLength := suite.timeline.PostIndexLength()
+ indexLength := suite.timeline.PostIndexLength(context.Background())
suite.Equal(0, indexLength)
}
func (suite *IndexTestSuite) TestNewestIndexedPostIDEmpty() {
// the newest indexed post should be an empty string since there's nothing indexed yet
- postID, err := suite.timeline.NewestIndexedPostID()
+ postID, err := suite.timeline.NewestIndexedPostID(context.Background())
suite.NoError(err)
suite.Empty(postID)
// indexLength should be 0
- indexLength := suite.timeline.PostIndexLength()
+ indexLength := suite.timeline.PostIndexLength(context.Background())
suite.Equal(0, indexLength)
}
@@ -142,12 +143,12 @@ func (suite *IndexTestSuite) TestIndexAlreadyIndexed() {
testStatus := suite.testStatuses["local_account_1_status_1"]
// index one post -- it should be indexed
- indexed, err := suite.timeline.IndexOne(testStatus.CreatedAt, testStatus.ID, testStatus.BoostOfID, testStatus.AccountID, testStatus.BoostOfAccountID)
+ indexed, err := suite.timeline.IndexOne(context.Background(), testStatus.CreatedAt, testStatus.ID, testStatus.BoostOfID, testStatus.AccountID, testStatus.BoostOfAccountID)
suite.NoError(err)
suite.True(indexed)
// try to index the same post again -- it should not be indexed
- indexed, err = suite.timeline.IndexOne(testStatus.CreatedAt, testStatus.ID, testStatus.BoostOfID, testStatus.AccountID, testStatus.BoostOfAccountID)
+ indexed, err = suite.timeline.IndexOne(context.Background(), testStatus.CreatedAt, testStatus.ID, testStatus.BoostOfID, testStatus.AccountID, testStatus.BoostOfAccountID)
suite.NoError(err)
suite.False(indexed)
}
@@ -156,12 +157,12 @@ func (suite *IndexTestSuite) TestIndexAndPrepareAlreadyIndexedAndPrepared() {
testStatus := suite.testStatuses["local_account_1_status_1"]
// index and prepare one post -- it should be indexed
- indexed, err := suite.timeline.IndexAndPrepareOne(testStatus.CreatedAt, testStatus.ID, testStatus.BoostOfID, testStatus.AccountID, testStatus.BoostOfAccountID)
+ indexed, err := suite.timeline.IndexAndPrepareOne(context.Background(), testStatus.CreatedAt, testStatus.ID, testStatus.BoostOfID, testStatus.AccountID, testStatus.BoostOfAccountID)
suite.NoError(err)
suite.True(indexed)
// try to index and prepare the same post again -- it should not be indexed
- indexed, err = suite.timeline.IndexAndPrepareOne(testStatus.CreatedAt, testStatus.ID, testStatus.BoostOfID, testStatus.AccountID, testStatus.BoostOfAccountID)
+ indexed, err = suite.timeline.IndexAndPrepareOne(context.Background(), testStatus.CreatedAt, testStatus.ID, testStatus.BoostOfID, testStatus.AccountID, testStatus.BoostOfAccountID)
suite.NoError(err)
suite.False(indexed)
}
@@ -177,12 +178,12 @@ func (suite *IndexTestSuite) TestIndexBoostOfAlreadyIndexed() {
}
// index one post -- it should be indexed
- indexed, err := suite.timeline.IndexOne(testStatus.CreatedAt, testStatus.ID, testStatus.BoostOfID, testStatus.AccountID, testStatus.BoostOfAccountID)
+ indexed, err := suite.timeline.IndexOne(context.Background(), testStatus.CreatedAt, testStatus.ID, testStatus.BoostOfID, testStatus.AccountID, testStatus.BoostOfAccountID)
suite.NoError(err)
suite.True(indexed)
// try to index the a boost of that post -- it should not be indexed
- indexed, err = suite.timeline.IndexOne(boostOfTestStatus.CreatedAt, boostOfTestStatus.ID, boostOfTestStatus.BoostOfID, boostOfTestStatus.AccountID, boostOfTestStatus.BoostOfAccountID)
+ indexed, err = suite.timeline.IndexOne(context.Background(), boostOfTestStatus.CreatedAt, boostOfTestStatus.ID, boostOfTestStatus.BoostOfID, boostOfTestStatus.AccountID, boostOfTestStatus.BoostOfAccountID)
suite.NoError(err)
suite.False(indexed)
}
diff --git a/internal/timeline/manager.go b/internal/timeline/manager.go
index a592670a8..7f42e2f51 100644
--- a/internal/timeline/manager.go
+++ b/internal/timeline/manager.go
@@ -19,6 +19,7 @@
package timeline
import (
+ "context"
"fmt"
"strings"
"sync"
@@ -54,7 +55,7 @@ type Manager interface {
//
// The returned bool indicates whether the status was actually put in the timeline. This could be false in cases where
// the status is a boost, but a boost of the original post or the post itself already exists recently in the timeline.
- Ingest(status *gtsmodel.Status, timelineAccountID string) (bool, error)
+ Ingest(ctx context.Context, status *gtsmodel.Status, timelineAccountID string) (bool, error)
// IngestAndPrepare takes one status and indexes it into the timeline for the given account ID, and then immediately prepares it for serving.
// This is useful in cases where we know the status will need to be shown at the top of a user's timeline immediately (eg., a new status is created).
//
@@ -62,24 +63,24 @@ type Manager interface {
//
// The returned bool indicates whether the status was actually put in the timeline. This could be false in cases where
// the status is a boost, but a boost of the original post or the post itself already exists recently in the timeline.
- IngestAndPrepare(status *gtsmodel.Status, timelineAccountID string) (bool, error)
+ IngestAndPrepare(ctx context.Context, status *gtsmodel.Status, timelineAccountID string) (bool, error)
// HomeTimeline returns limit n amount of entries from the home timeline of the given account ID, in descending chronological order.
// If maxID is provided, it will return entries from that maxID onwards, inclusive.
- HomeTimeline(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*apimodel.Status, error)
+ HomeTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*apimodel.Status, error)
// GetIndexedLength returns the amount of posts/statuses that have been *indexed* for the given account ID.
- GetIndexedLength(timelineAccountID string) int
+ GetIndexedLength(ctx context.Context, timelineAccountID string) int
// GetDesiredIndexLength returns the amount of posts that we, ideally, index for each user.
- GetDesiredIndexLength() int
+ GetDesiredIndexLength(ctx context.Context) int
// GetOldestIndexedID returns the status ID for the oldest post that we have indexed for the given account.
- GetOldestIndexedID(timelineAccountID string) (string, error)
+ GetOldestIndexedID(ctx context.Context, timelineAccountID string) (string, error)
// PrepareXFromTop prepares limit n amount of posts, based on their indexed representations, from the top of the index.
- PrepareXFromTop(timelineAccountID string, limit int) error
+ PrepareXFromTop(ctx context.Context, timelineAccountID string, limit int) error
// Remove removes one status from the timeline of the given timelineAccountID
- Remove(timelineAccountID string, statusID string) (int, error)
+ Remove(ctx context.Context, timelineAccountID string, statusID string) (int, error)
// WipeStatusFromAllTimelines removes one status from the index and prepared posts of all timelines
- WipeStatusFromAllTimelines(statusID string) error
+ WipeStatusFromAllTimelines(ctx context.Context, statusID string) error
// WipeStatusesFromAccountID removes all statuses by the given accountID from the timelineAccountID's timelines.
- WipeStatusesFromAccountID(timelineAccountID string, accountID string) error
+ WipeStatusesFromAccountID(ctx context.Context, timelineAccountID string, accountID string) error
}
// NewManager returns a new timeline manager with the given database, typeconverter, config, and log.
@@ -101,104 +102,104 @@ type manager struct {
log *logrus.Logger
}
-func (m *manager) Ingest(status *gtsmodel.Status, timelineAccountID string) (bool, error) {
+func (m *manager) Ingest(ctx context.Context, status *gtsmodel.Status, timelineAccountID string) (bool, error) {
l := m.log.WithFields(logrus.Fields{
"func": "Ingest",
"timelineAccountID": timelineAccountID,
"statusID": status.ID,
})
- t, err := m.getOrCreateTimeline(timelineAccountID)
+ t, err := m.getOrCreateTimeline(ctx, timelineAccountID)
if err != nil {
return false, err
}
l.Trace("ingesting status")
- return t.IndexOne(status.CreatedAt, status.ID, status.BoostOfID, status.AccountID, status.BoostOfAccountID)
+ return t.IndexOne(ctx, status.CreatedAt, status.ID, status.BoostOfID, status.AccountID, status.BoostOfAccountID)
}
-func (m *manager) IngestAndPrepare(status *gtsmodel.Status, timelineAccountID string) (bool, error) {
+func (m *manager) IngestAndPrepare(ctx context.Context, status *gtsmodel.Status, timelineAccountID string) (bool, error) {
l := m.log.WithFields(logrus.Fields{
"func": "IngestAndPrepare",
"timelineAccountID": timelineAccountID,
"statusID": status.ID,
})
- t, err := m.getOrCreateTimeline(timelineAccountID)
+ t, err := m.getOrCreateTimeline(ctx, timelineAccountID)
if err != nil {
return false, err
}
l.Trace("ingesting status")
- return t.IndexAndPrepareOne(status.CreatedAt, status.ID, status.BoostOfID, status.AccountID, status.BoostOfAccountID)
+ return t.IndexAndPrepareOne(ctx, status.CreatedAt, status.ID, status.BoostOfID, status.AccountID, status.BoostOfAccountID)
}
-func (m *manager) Remove(timelineAccountID string, statusID string) (int, error) {
+func (m *manager) Remove(ctx context.Context, timelineAccountID string, statusID string) (int, error) {
l := m.log.WithFields(logrus.Fields{
"func": "Remove",
"timelineAccountID": timelineAccountID,
"statusID": statusID,
})
- t, err := m.getOrCreateTimeline(timelineAccountID)
+ t, err := m.getOrCreateTimeline(ctx, timelineAccountID)
if err != nil {
return 0, err
}
l.Trace("removing status")
- return t.Remove(statusID)
+ return t.Remove(ctx, statusID)
}
-func (m *manager) HomeTimeline(timelineAccountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*apimodel.Status, error) {
+func (m *manager) HomeTimeline(ctx context.Context, timelineAccountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*apimodel.Status, error) {
l := m.log.WithFields(logrus.Fields{
"func": "HomeTimelineGet",
"timelineAccountID": timelineAccountID,
})
- t, err := m.getOrCreateTimeline(timelineAccountID)
+ t, err := m.getOrCreateTimeline(ctx, timelineAccountID)
if err != nil {
return nil, err
}
- statuses, err := t.Get(limit, maxID, sinceID, minID, true)
+ statuses, err := t.Get(ctx, limit, maxID, sinceID, minID, true)
if err != nil {
l.Errorf("error getting statuses: %s", err)
}
return statuses, nil
}
-func (m *manager) GetIndexedLength(timelineAccountID string) int {
- t, err := m.getOrCreateTimeline(timelineAccountID)
+func (m *manager) GetIndexedLength(ctx context.Context, timelineAccountID string) int {
+ t, err := m.getOrCreateTimeline(ctx, timelineAccountID)
if err != nil {
return 0
}
- return t.PostIndexLength()
+ return t.PostIndexLength(ctx)
}
-func (m *manager) GetDesiredIndexLength() int {
+func (m *manager) GetDesiredIndexLength(ctx context.Context) int {
return desiredPostIndexLength
}
-func (m *manager) GetOldestIndexedID(timelineAccountID string) (string, error) {
- t, err := m.getOrCreateTimeline(timelineAccountID)
+func (m *manager) GetOldestIndexedID(ctx context.Context, timelineAccountID string) (string, error) {
+ t, err := m.getOrCreateTimeline(ctx, timelineAccountID)
if err != nil {
return "", err
}
- return t.OldestIndexedPostID()
+ return t.OldestIndexedPostID(ctx)
}
-func (m *manager) PrepareXFromTop(timelineAccountID string, limit int) error {
- t, err := m.getOrCreateTimeline(timelineAccountID)
+func (m *manager) PrepareXFromTop(ctx context.Context, timelineAccountID string, limit int) error {
+ t, err := m.getOrCreateTimeline(ctx, timelineAccountID)
if err != nil {
return err
}
- return t.PrepareFromTop(limit)
+ return t.PrepareFromTop(ctx, limit)
}
-func (m *manager) WipeStatusFromAllTimelines(statusID string) error {
+func (m *manager) WipeStatusFromAllTimelines(ctx context.Context, statusID string) error {
errors := []string{}
m.accountTimelines.Range(func(k interface{}, i interface{}) bool {
t, ok := i.(Timeline)
@@ -206,7 +207,7 @@ func (m *manager) WipeStatusFromAllTimelines(statusID string) error {
panic("couldn't parse entry as Timeline, this should never happen so panic")
}
- if _, err := t.Remove(statusID); err != nil {
+ if _, err := t.Remove(ctx, statusID); err != nil {
errors = append(errors, err.Error())
}
@@ -221,22 +222,22 @@ func (m *manager) WipeStatusFromAllTimelines(statusID string) error {
return err
}
-func (m *manager) WipeStatusesFromAccountID(timelineAccountID string, accountID string) error {
- t, err := m.getOrCreateTimeline(timelineAccountID)
+func (m *manager) WipeStatusesFromAccountID(ctx context.Context, timelineAccountID string, accountID string) error {
+ t, err := m.getOrCreateTimeline(ctx, timelineAccountID)
if err != nil {
return err
}
- _, err = t.RemoveAllBy(accountID)
+ _, err = t.RemoveAllBy(ctx, accountID)
return err
}
-func (m *manager) getOrCreateTimeline(timelineAccountID string) (Timeline, error) {
+func (m *manager) getOrCreateTimeline(ctx context.Context, timelineAccountID string) (Timeline, error) {
var t Timeline
i, ok := m.accountTimelines.Load(timelineAccountID)
if !ok {
var err error
- t, err = NewTimeline(timelineAccountID, m.db, m.tc, m.log)
+ t, err = NewTimeline(ctx, timelineAccountID, m.db, m.tc, m.log)
if err != nil {
return nil, err
}
diff --git a/internal/timeline/manager_test.go b/internal/timeline/manager_test.go
index 00c6dcb4a..ea4dc4c12 100644
--- a/internal/timeline/manager_test.go
+++ b/internal/timeline/manager_test.go
@@ -19,6 +19,7 @@
package timeline_test
import (
+ "context"
"testing"
"github.com/stretchr/testify/suite"
@@ -54,85 +55,85 @@ func (suite *ManagerTestSuite) TestManagerIntegration() {
testAccount := suite.testAccounts["local_account_1"]
// should start at 0
- indexedLen := suite.manager.GetIndexedLength(testAccount.ID)
+ indexedLen := suite.manager.GetIndexedLength(context.Background(), testAccount.ID)
suite.Equal(0, indexedLen)
// oldestIndexed should be empty string since there's nothing indexed
- oldestIndexed, err := suite.manager.GetOldestIndexedID(testAccount.ID)
+ oldestIndexed, err := suite.manager.GetOldestIndexedID(context.Background(), testAccount.ID)
suite.NoError(err)
suite.Empty(oldestIndexed)
// trigger status preparation
- err = suite.manager.PrepareXFromTop(testAccount.ID, 20)
+ err = suite.manager.PrepareXFromTop(context.Background(), testAccount.ID, 20)
suite.NoError(err)
// local_account_1 can see 12 statuses out of the testrig statuses in its home timeline
- indexedLen = suite.manager.GetIndexedLength(testAccount.ID)
+ indexedLen = suite.manager.GetIndexedLength(context.Background(), testAccount.ID)
suite.Equal(12, indexedLen)
// oldest should now be set
- oldestIndexed, err = suite.manager.GetOldestIndexedID(testAccount.ID)
+ oldestIndexed, err = suite.manager.GetOldestIndexedID(context.Background(), testAccount.ID)
suite.NoError(err)
suite.Equal("01F8MH75CBF9JFX4ZAD54N0W0R", oldestIndexed)
// get hometimeline
- statuses, err := suite.manager.HomeTimeline(testAccount.ID, "", "", "", 20, false)
+ statuses, err := suite.manager.HomeTimeline(context.Background(), testAccount.ID, "", "", "", 20, false)
suite.NoError(err)
suite.Len(statuses, 12)
// now wipe the last status from all timelines, as though it had been deleted by the owner
- err = suite.manager.WipeStatusFromAllTimelines("01F8MH75CBF9JFX4ZAD54N0W0R")
+ err = suite.manager.WipeStatusFromAllTimelines(context.Background(), "01F8MH75CBF9JFX4ZAD54N0W0R")
suite.NoError(err)
// timeline should be shorter
- indexedLen = suite.manager.GetIndexedLength(testAccount.ID)
+ indexedLen = suite.manager.GetIndexedLength(context.Background(), testAccount.ID)
suite.Equal(11, indexedLen)
// oldest should now be different
- oldestIndexed, err = suite.manager.GetOldestIndexedID(testAccount.ID)
+ oldestIndexed, err = suite.manager.GetOldestIndexedID(context.Background(), testAccount.ID)
suite.NoError(err)
suite.Equal("01F8MH82FYRXD2RC6108DAJ5HB", oldestIndexed)
// delete the new oldest status specifically from this timeline, as though local_account_1 had muted or blocked it
- removed, err := suite.manager.Remove(testAccount.ID, "01F8MH82FYRXD2RC6108DAJ5HB")
+ removed, err := suite.manager.Remove(context.Background(), testAccount.ID, "01F8MH82FYRXD2RC6108DAJ5HB")
suite.NoError(err)
suite.Equal(2, removed) // 1 status should be removed, but from both indexed and prepared, so 2 removals total
// timeline should be shorter
- indexedLen = suite.manager.GetIndexedLength(testAccount.ID)
+ indexedLen = suite.manager.GetIndexedLength(context.Background(), testAccount.ID)
suite.Equal(10, indexedLen)
// oldest should now be different
- oldestIndexed, err = suite.manager.GetOldestIndexedID(testAccount.ID)
+ oldestIndexed, err = suite.manager.GetOldestIndexedID(context.Background(), testAccount.ID)
suite.NoError(err)
suite.Equal("01F8MHAAY43M6RJ473VQFCVH37", oldestIndexed)
// now remove all entries by local_account_2 from the timeline
- err = suite.manager.WipeStatusesFromAccountID(testAccount.ID, suite.testAccounts["local_account_2"].ID)
+ err = suite.manager.WipeStatusesFromAccountID(context.Background(), testAccount.ID, suite.testAccounts["local_account_2"].ID)
suite.NoError(err)
// timeline should be empty now
- indexedLen = suite.manager.GetIndexedLength(testAccount.ID)
+ indexedLen = suite.manager.GetIndexedLength(context.Background(), testAccount.ID)
suite.Equal(5, indexedLen)
// ingest 1 into the timeline
status1 := suite.testStatuses["admin_account_status_1"]
- ingested, err := suite.manager.Ingest(status1, testAccount.ID)
+ ingested, err := suite.manager.Ingest(context.Background(), status1, testAccount.ID)
suite.NoError(err)
suite.True(ingested)
// ingest and prepare another one into the timeline
status2 := suite.testStatuses["local_account_2_status_1"]
- ingested, err = suite.manager.IngestAndPrepare(status2, testAccount.ID)
+ ingested, err = suite.manager.IngestAndPrepare(context.Background(), status2, testAccount.ID)
suite.NoError(err)
suite.True(ingested)
// timeline should be longer now
- indexedLen = suite.manager.GetIndexedLength(testAccount.ID)
+ indexedLen = suite.manager.GetIndexedLength(context.Background(), testAccount.ID)
suite.Equal(7, indexedLen)
// try to ingest status 2 again
- ingested, err = suite.manager.IngestAndPrepare(status2, testAccount.ID)
+ ingested, err = suite.manager.IngestAndPrepare(context.Background(), status2, testAccount.ID)
suite.NoError(err)
suite.False(ingested) // should be false since it's a duplicate
}
diff --git a/internal/timeline/prepare.go b/internal/timeline/prepare.go
index 20000b4e9..d57222ee8 100644
--- a/internal/timeline/prepare.go
+++ b/internal/timeline/prepare.go
@@ -20,6 +20,7 @@ package timeline
import (
"container/list"
+ "context"
"errors"
"fmt"
@@ -28,7 +29,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
-func (t *timeline) prepareNextQuery(amount int, maxID string, sinceID string, minID string) error {
+func (t *timeline) prepareNextQuery(ctx context.Context, amount int, maxID string, sinceID string, minID string) error {
l := t.log.WithFields(logrus.Fields{
"func": "prepareNextQuery",
"amount": amount,
@@ -42,30 +43,30 @@ func (t *timeline) prepareNextQuery(amount int, maxID string, sinceID string, mi
// maxID is defined but sinceID isn't so take from behind
if maxID != "" && sinceID == "" {
l.Debug("preparing behind maxID")
- err = t.PrepareBehind(maxID, amount)
+ err = t.PrepareBehind(ctx, maxID, amount)
}
// maxID isn't defined, but sinceID || minID are, so take x before
if maxID == "" && sinceID != "" {
l.Debug("preparing before sinceID")
- err = t.PrepareBefore(sinceID, false, amount)
+ err = t.PrepareBefore(ctx, sinceID, false, amount)
}
if maxID == "" && minID != "" {
l.Debug("preparing before minID")
- err = t.PrepareBefore(minID, false, amount)
+ err = t.PrepareBefore(ctx, minID, false, amount)
}
return err
}
-func (t *timeline) PrepareBehind(statusID string, amount int) error {
+func (t *timeline) PrepareBehind(ctx context.Context, statusID string, amount int) error {
// lazily initialize prepared posts if it hasn't been done already
if t.preparedPosts.data == nil {
t.preparedPosts.data = &list.List{}
t.preparedPosts.data.Init()
}
- if err := t.IndexBehind(statusID, true, amount); err != nil {
+ if err := t.IndexBehind(ctx, statusID, true, amount); err != nil {
return fmt.Errorf("PrepareBehind: error indexing behind id %s: %s", statusID, err)
}
@@ -93,7 +94,7 @@ prepareloop:
}
if preparing {
- if err := t.prepare(entry.statusID); err != nil {
+ if err := t.prepare(ctx, entry.statusID); err != nil {
// there's been an error
if err != db.ErrNoEntries {
// it's a real error
@@ -113,7 +114,7 @@ prepareloop:
return nil
}
-func (t *timeline) PrepareBefore(statusID string, include bool, amount int) error {
+func (t *timeline) PrepareBefore(ctx context.Context, statusID string, include bool, amount int) error {
t.Lock()
defer t.Unlock()
@@ -148,7 +149,7 @@ prepareloop:
}
if preparing {
- if err := t.prepare(entry.statusID); err != nil {
+ if err := t.prepare(ctx, entry.statusID); err != nil {
// there's been an error
if err != db.ErrNoEntries {
// it's a real error
@@ -168,7 +169,7 @@ prepareloop:
return nil
}
-func (t *timeline) PrepareFromTop(amount int) error {
+func (t *timeline) PrepareFromTop(ctx context.Context, amount int) error {
l := t.log.WithFields(logrus.Fields{
"func": "PrepareFromTop",
"amount": amount,
@@ -183,7 +184,7 @@ func (t *timeline) PrepareFromTop(amount int) error {
// if the postindex is nil, nothing has been indexed yet so index from the highest ID possible
if t.postIndex.data == nil {
l.Debug("postindex.data was nil, indexing behind highest possible ID")
- if err := t.IndexBehind("ZZZZZZZZZZZZZZZZZZZZZZZZZZ", false, amount); err != nil {
+ if err := t.IndexBehind(ctx, "ZZZZZZZZZZZZZZZZZZZZZZZZZZ", false, amount); err != nil {
return fmt.Errorf("PrepareFromTop: error indexing behind id %s: %s", "ZZZZZZZZZZZZZZZZZZZZZZZZZZ", err)
}
}
@@ -203,7 +204,7 @@ prepareloop:
return errors.New("PrepareFromTop: could not parse e as a postIndexEntry")
}
- if err := t.prepare(entry.statusID); err != nil {
+ if err := t.prepare(ctx, entry.statusID); err != nil {
// there's been an error
if err != db.ErrNoEntries {
// it's a real error
@@ -225,25 +226,25 @@ prepareloop:
return nil
}
-func (t *timeline) prepare(statusID string) error {
+func (t *timeline) prepare(ctx context.Context, statusID string) error {
// start by getting the status out of the database according to its indexed ID
gtsStatus := &gtsmodel.Status{}
- if err := t.db.GetByID(statusID, gtsStatus); err != nil {
+ if err := t.db.GetByID(ctx, statusID, gtsStatus); err != nil {
return err
}
// if the account pointer hasn't been set on this timeline already, set it lazily here
if t.account == nil {
timelineOwnerAccount := &gtsmodel.Account{}
- if err := t.db.GetByID(t.accountID, timelineOwnerAccount); err != nil {
+ if err := t.db.GetByID(ctx, t.accountID, timelineOwnerAccount); err != nil {
return err
}
t.account = timelineOwnerAccount
}
// serialize the status (or, at least, convert it to a form that's ready to be serialized)
- apiModelStatus, err := t.tc.StatusToMasto(gtsStatus, t.account)
+ apiModelStatus, err := t.tc.StatusToMasto(ctx, gtsStatus, t.account)
if err != nil {
return err
}
@@ -260,7 +261,7 @@ func (t *timeline) prepare(statusID string) error {
return t.preparedPosts.insertPrepared(preparedPostsEntry)
}
-func (t *timeline) OldestPreparedPostID() (string, error) {
+func (t *timeline) OldestPreparedPostID(ctx context.Context) (string, error) {
var id string
if t.preparedPosts == nil || t.preparedPosts.data == nil {
// return an empty string if prepared posts hasn't been initialized yet
diff --git a/internal/timeline/remove.go b/internal/timeline/remove.go
index cf0b0b617..031dace1f 100644
--- a/internal/timeline/remove.go
+++ b/internal/timeline/remove.go
@@ -20,12 +20,13 @@ package timeline
import (
"container/list"
+ "context"
"errors"
"github.com/sirupsen/logrus"
)
-func (t *timeline) Remove(statusID string) (int, error) {
+func (t *timeline) Remove(ctx context.Context, statusID string) (int, error) {
l := t.log.WithFields(logrus.Fields{
"func": "Remove",
"accountTimeline": t.accountID,
@@ -77,7 +78,7 @@ func (t *timeline) Remove(statusID string) (int, error) {
return removed, nil
}
-func (t *timeline) RemoveAllBy(accountID string) (int, error) {
+func (t *timeline) RemoveAllBy(ctx context.Context, accountID string) (int, error) {
l := t.log.WithFields(logrus.Fields{
"func": "RemoveAllBy",
"accountTimeline": t.accountID,
diff --git a/internal/timeline/timeline.go b/internal/timeline/timeline.go
index 6274a86ac..5f5fa1b4f 100644
--- a/internal/timeline/timeline.go
+++ b/internal/timeline/timeline.go
@@ -19,6 +19,7 @@
package timeline
import (
+ "context"
"sync"
"time"
@@ -41,24 +42,24 @@ type Timeline interface {
// Get returns an amount of statuses with the given parameters.
// If prepareNext is true, then the next predicted query will be prepared already in a goroutine,
// to make the next call to Get faster.
- Get(amount int, maxID string, sinceID string, minID string, prepareNext bool) ([]*apimodel.Status, error)
+ Get(ctx context.Context, amount int, maxID string, sinceID string, minID string, prepareNext bool) ([]*apimodel.Status, error)
// GetXFromTop returns x amount of posts from the top of the timeline, from newest to oldest.
- GetXFromTop(amount int) ([]*apimodel.Status, error)
+ GetXFromTop(ctx context.Context, amount int) ([]*apimodel.Status, error)
// GetXBehindID returns x amount of posts from the given id onwards, from newest to oldest.
// This will NOT include the status with the given ID.
//
// This corresponds to an api call to /timelines/home?max_id=WHATEVER
- GetXBehindID(amount int, fromID string, attempts *int) ([]*apimodel.Status, error)
+ GetXBehindID(ctx context.Context, amount int, fromID string, attempts *int) ([]*apimodel.Status, error)
// GetXBeforeID returns x amount of posts up to the given id, from newest to oldest.
// This will NOT include the status with the given ID.
//
// This corresponds to an api call to /timelines/home?since_id=WHATEVER
- GetXBeforeID(amount int, sinceID string, startFromTop bool) ([]*apimodel.Status, error)
+ GetXBeforeID(ctx context.Context, amount int, sinceID string, startFromTop bool) ([]*apimodel.Status, error)
// GetXBetweenID returns x amount of posts from the given maxID, up to the given id, from newest to oldest.
// This will NOT include the status with the given IDs.
//
// This corresponds to an api call to /timelines/home?since_id=WHATEVER&max_id=WHATEVER_ELSE
- GetXBetweenID(amount int, maxID string, sinceID string) ([]*apimodel.Status, error)
+ GetXBetweenID(ctx context.Context, amount int, maxID string, sinceID string) ([]*apimodel.Status, error)
/*
INDEXING FUNCTIONS
@@ -68,43 +69,43 @@ type Timeline interface {
//
// The returned bool indicates whether or not the status was actually inserted into the timeline. This will be false
// if the status is a boost and the original post or another boost of it already exists < boostReinsertionDepth back in the timeline.
- IndexOne(statusCreatedAt time.Time, statusID string, boostOfID string, accountID string, boostOfAccountID string) (bool, error)
+ IndexOne(ctx context.Context, statusCreatedAt time.Time, statusID string, boostOfID string, accountID string, boostOfAccountID string) (bool, error)
// OldestIndexedPostID returns the id of the rearmost (ie., the oldest) indexed post, or an error if something goes wrong.
// If nothing goes wrong but there's no oldest post, an empty string will be returned so make sure to check for this.
- OldestIndexedPostID() (string, error)
+ OldestIndexedPostID(ctx context.Context) (string, error)
// NewestIndexedPostID returns the id of the frontmost (ie., the newest) indexed post, or an error if something goes wrong.
// If nothing goes wrong but there's no newest post, an empty string will be returned so make sure to check for this.
- NewestIndexedPostID() (string, error)
+ NewestIndexedPostID(ctx context.Context) (string, error)
- IndexBefore(statusID string, include bool, amount int) error
- IndexBehind(statusID string, include bool, amount int) error
+ IndexBefore(ctx context.Context, statusID string, include bool, amount int) error
+ IndexBehind(ctx context.Context, statusID string, include bool, amount int) error
/*
PREPARATION FUNCTIONS
*/
// PrepareXFromTop instructs the timeline to prepare x amount of posts from the top of the timeline.
- PrepareFromTop(amount int) error
+ PrepareFromTop(ctx context.Context, amount int) error
// PrepareBehind instructs the timeline to prepare the next amount of entries for serialization, from position onwards.
// If include is true, then the given status ID will also be prepared, otherwise only entries behind it will be prepared.
- PrepareBehind(statusID string, amount int) error
+ PrepareBehind(ctx context.Context, statusID string, amount int) error
// IndexOne puts a status into the timeline at the appropriate place according to its 'createdAt' property,
// and then immediately prepares it.
//
// The returned bool indicates whether or not the status was actually inserted into the timeline. This will be false
// if the status is a boost and the original post or another boost of it already exists < boostReinsertionDepth back in the timeline.
- IndexAndPrepareOne(statusCreatedAt time.Time, statusID string, boostOfID string, accountID string, boostOfAccountID string) (bool, error)
+ IndexAndPrepareOne(ctx context.Context, statusCreatedAt time.Time, statusID string, boostOfID string, accountID string, boostOfAccountID string) (bool, error)
// OldestPreparedPostID returns the id of the rearmost (ie., the oldest) prepared post, or an error if something goes wrong.
// If nothing goes wrong but there's no oldest post, an empty string will be returned so make sure to check for this.
- OldestPreparedPostID() (string, error)
+ OldestPreparedPostID(ctx context.Context) (string, error)
/*
INFO FUNCTIONS
*/
// ActualPostIndexLength returns the actual length of the post index at this point in time.
- PostIndexLength() int
+ PostIndexLength(ctx context.Context) int
/*
UTILITY FUNCTIONS
@@ -117,11 +118,11 @@ type Timeline interface {
// If a status has multiple entries in a timeline, they will all be removed.
//
// The returned int indicates the amount of entries that were removed.
- Remove(statusID string) (int, error)
+ Remove(ctx context.Context, statusID string) (int, error)
// RemoveAllBy removes all statuses by the given accountID, from both the index and prepared posts.
//
// The returned int indicates the amount of entries that were removed.
- RemoveAllBy(accountID string) (int, error)
+ RemoveAllBy(ctx context.Context, accountID string) (int, error)
}
// timeline fulfils the Timeline interface
@@ -138,9 +139,9 @@ type timeline struct {
}
// NewTimeline returns a new Timeline for the given account ID
-func NewTimeline(accountID string, db db.DB, typeConverter typeutils.TypeConverter, log *logrus.Logger) (Timeline, error) {
+func NewTimeline(ctx context.Context, accountID string, db db.DB, typeConverter typeutils.TypeConverter, log *logrus.Logger) (Timeline, error) {
timelineOwnerAccount := &gtsmodel.Account{}
- if err := db.GetByID(accountID, timelineOwnerAccount); err != nil {
+ if err := db.GetByID(ctx, accountID, timelineOwnerAccount); err != nil {
return nil, err
}
@@ -160,7 +161,7 @@ func (t *timeline) Reset() error {
return nil
}
-func (t *timeline) PostIndexLength() int {
+func (t *timeline) PostIndexLength(ctx context.Context) int {
if t.postIndex == nil || t.postIndex.data == nil {
return 0
}
diff --git a/internal/transport/controller.go b/internal/transport/controller.go
index 4eb6b5658..c2f5026e0 100644
--- a/internal/transport/controller.go
+++ b/internal/transport/controller.go
@@ -19,6 +19,7 @@
package transport
import (
+ "context"
"crypto"
"fmt"
"sync"
@@ -33,7 +34,7 @@ import (
// Controller generates transports for use in making federation requests to other servers.
type Controller interface {
NewTransport(pubKeyID string, privkey crypto.PrivateKey) (Transport, error)
- NewTransportForUsername(username string) (Transport, error)
+ NewTransportForUsername(ctx context.Context, username string) (Transport, error)
}
type controller struct {
@@ -90,7 +91,7 @@ func (c *controller) NewTransport(pubKeyID string, privkey crypto.PrivateKey) (T
}, nil
}
-func (c *controller) NewTransportForUsername(username string) (Transport, error) {
+func (c *controller) NewTransportForUsername(ctx context.Context, username string) (Transport, error) {
// We need an account to use to create a transport for dereferecing something.
// If a username has been given, we can fetch the account with that username and use it.
// Otherwise, we can take the instance account and use those credentials to make the request.
@@ -101,7 +102,7 @@ func (c *controller) NewTransportForUsername(username string) (Transport, error)
u = username
}
- ourAccount, err := c.db.GetLocalAccountByUsername(u)
+ ourAccount, err := c.db.GetLocalAccountByUsername(ctx, u)
if err != nil {
return nil, fmt.Errorf("error getting account %s from db: %s", username, err)
}
diff --git a/internal/transport/deliver.go b/internal/transport/deliver.go
index 844cb6bea..fd0fb576f 100644
--- a/internal/transport/deliver.go
+++ b/internal/transport/deliver.go
@@ -1,3 +1,21 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
package transport
import (
@@ -5,12 +23,12 @@ import (
"net/url"
)
-func (t *transport) BatchDeliver(c context.Context, b []byte, recipients []*url.URL) error {
- return t.sigTransport.BatchDeliver(c, b, recipients)
+func (t *transport) BatchDeliver(ctx context.Context, b []byte, recipients []*url.URL) error {
+ return t.sigTransport.BatchDeliver(ctx, b, recipients)
}
-func (t *transport) Deliver(c context.Context, b []byte, to *url.URL) error {
+func (t *transport) Deliver(ctx context.Context, b []byte, to *url.URL) error {
l := t.log.WithField("func", "Deliver")
l.Debugf("performing POST to %s", to.String())
- return t.sigTransport.Deliver(c, b, to)
+ return t.sigTransport.Deliver(ctx, b, to)
}
diff --git a/internal/transport/dereference.go b/internal/transport/dereference.go
index d7a28fe17..85fa370ee 100644
--- a/internal/transport/dereference.go
+++ b/internal/transport/dereference.go
@@ -1,3 +1,21 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
package transport
import (
@@ -5,8 +23,8 @@ import (
"net/url"
)
-func (t *transport) Dereference(c context.Context, iri *url.URL) ([]byte, error) {
+func (t *transport) Dereference(ctx context.Context, iri *url.URL) ([]byte, error) {
l := t.log.WithField("func", "Dereference")
l.Debugf("performing GET to %s", iri.String())
- return t.sigTransport.Dereference(c, iri)
+ return t.sigTransport.Dereference(ctx, iri)
}
diff --git a/internal/transport/derefinstance.go b/internal/transport/derefinstance.go
index a8b2ddfc7..3d72d7581 100644
--- a/internal/transport/derefinstance.go
+++ b/internal/transport/derefinstance.go
@@ -1,3 +1,21 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
package transport
import (
@@ -16,7 +34,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/util"
)
-func (t *transport) DereferenceInstance(c context.Context, iri *url.URL) (*gtsmodel.Instance, error) {
+func (t *transport) DereferenceInstance(ctx context.Context, iri *url.URL) (*gtsmodel.Instance, error) {
l := t.log.WithField("func", "DereferenceInstance")
var i *gtsmodel.Instance
@@ -27,7 +45,7 @@ func (t *transport) DereferenceInstance(c context.Context, iri *url.URL) (*gtsmo
//
// This will only work with Mastodon-api compatible instances: Mastodon, some Pleroma instances, GoToSocial.
l.Debugf("trying to dereference instance %s by /api/v1/instance", iri.Host)
- i, err = dereferenceByAPIV1Instance(c, t, iri)
+ i, err = dereferenceByAPIV1Instance(ctx, t, iri)
if err == nil {
l.Debugf("successfully dereferenced instance using /api/v1/instance")
return i, nil
@@ -37,7 +55,7 @@ func (t *transport) DereferenceInstance(c context.Context, iri *url.URL) (*gtsmo
// If that doesn't work, try to dereference using /.well-known/nodeinfo.
// This will involve two API calls and return less info overall, but should be more widely compatible.
l.Debugf("trying to dereference instance %s by /.well-known/nodeinfo", iri.Host)
- i, err = dereferenceByNodeInfo(c, t, iri)
+ i, err = dereferenceByNodeInfo(ctx, t, iri)
if err == nil {
l.Debugf("successfully dereferenced instance using /.well-known/nodeinfo")
return i, nil
@@ -58,7 +76,7 @@ func (t *transport) DereferenceInstance(c context.Context, iri *url.URL) (*gtsmo
}, nil
}
-func dereferenceByAPIV1Instance(c context.Context, t *transport, iri *url.URL) (*gtsmodel.Instance, error) {
+func dereferenceByAPIV1Instance(ctx context.Context, t *transport, iri *url.URL) (*gtsmodel.Instance, error) {
l := t.log.WithField("func", "dereferenceByAPIV1Instance")
cleanIRI := &url.URL{
@@ -68,11 +86,10 @@ func dereferenceByAPIV1Instance(c context.Context, t *transport, iri *url.URL) (
}
l.Debugf("performing GET to %s", cleanIRI.String())
- req, err := http.NewRequest("GET", cleanIRI.String(), nil)
+ req, err := http.NewRequestWithContext(ctx, "GET", cleanIRI.String(), nil)
if err != nil {
return nil, err
}
- req = req.WithContext(c)
req.Header.Add("Accept", "application/json")
req.Header.Add("Date", t.clock.Now().UTC().Format("Mon, 02 Jan 2006 15:04:05")+" GMT")
req.Header.Add("User-Agent", fmt.Sprintf("%s %s", t.appAgent, t.gofedAgent))
@@ -216,7 +233,7 @@ func dereferenceByNodeInfo(c context.Context, t *transport, iri *url.URL) (*gtsm
return i, nil
}
-func callNodeInfoWellKnown(c context.Context, t *transport, iri *url.URL) (*url.URL, error) {
+func callNodeInfoWellKnown(ctx context.Context, t *transport, iri *url.URL) (*url.URL, error) {
l := t.log.WithField("func", "callNodeInfoWellKnown")
cleanIRI := &url.URL{
@@ -226,11 +243,11 @@ func callNodeInfoWellKnown(c context.Context, t *transport, iri *url.URL) (*url.
}
l.Debugf("performing GET to %s", cleanIRI.String())
- req, err := http.NewRequest("GET", cleanIRI.String(), nil)
+ req, err := http.NewRequestWithContext(ctx, "GET", cleanIRI.String(), nil)
if err != nil {
return nil, err
}
- req = req.WithContext(c)
+
req.Header.Add("Accept", "application/json")
req.Header.Add("Date", t.clock.Now().UTC().Format("Mon, 02 Jan 2006 15:04:05")+" GMT")
req.Header.Add("User-Agent", fmt.Sprintf("%s %s", t.appAgent, t.gofedAgent))
@@ -281,15 +298,15 @@ func callNodeInfoWellKnown(c context.Context, t *transport, iri *url.URL) (*url.
return nodeinfoHref, nil
}
-func callNodeInfo(c context.Context, t *transport, iri *url.URL) (*apimodel.Nodeinfo, error) {
+func callNodeInfo(ctx context.Context, t *transport, iri *url.URL) (*apimodel.Nodeinfo, error) {
l := t.log.WithField("func", "callNodeInfo")
l.Debugf("performing GET to %s", iri.String())
- req, err := http.NewRequest("GET", iri.String(), nil)
+ req, err := http.NewRequestWithContext(ctx, "GET", iri.String(), nil)
if err != nil {
return nil, err
}
- req = req.WithContext(c)
+
req.Header.Add("Accept", "application/json")
req.Header.Add("Date", t.clock.Now().UTC().Format("Mon, 02 Jan 2006 15:04:05")+" GMT")
req.Header.Add("User-Agent", fmt.Sprintf("%s %s", t.appAgent, t.gofedAgent))
diff --git a/internal/transport/derefmedia.go b/internal/transport/derefmedia.go
index 5fa901100..e265bfdd4 100644
--- a/internal/transport/derefmedia.go
+++ b/internal/transport/derefmedia.go
@@ -1,3 +1,21 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
package transport
import (
@@ -8,14 +26,13 @@ import (
"net/url"
)
-func (t *transport) DereferenceMedia(c context.Context, iri *url.URL, expectedContentType string) ([]byte, error) {
+func (t *transport) DereferenceMedia(ctx context.Context, iri *url.URL, expectedContentType string) ([]byte, error) {
l := t.log.WithField("func", "DereferenceMedia")
l.Debugf("performing GET to %s", iri.String())
- req, err := http.NewRequest("GET", iri.String(), nil)
+ req, err := http.NewRequestWithContext(ctx, "GET", iri.String(), nil)
if err != nil {
return nil, err
}
- req = req.WithContext(c)
if expectedContentType == "" {
req.Header.Add("Accept", "*/*")
} else {
diff --git a/internal/transport/finger.go b/internal/transport/finger.go
index 12cd2fb64..ce092e83f 100644
--- a/internal/transport/finger.go
+++ b/internal/transport/finger.go
@@ -1,3 +1,21 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
package transport
import (
@@ -8,7 +26,7 @@ import (
"net/url"
)
-func (t *transport) Finger(c context.Context, targetUsername string, targetDomain string) ([]byte, error) {
+func (t *transport) Finger(ctx context.Context, targetUsername string, targetDomain string) ([]byte, error) {
l := t.log.WithField("func", "Finger")
urlString := fmt.Sprintf("https://%s/.well-known/webfinger?resource=acct:%s@%s", targetDomain, targetUsername, targetDomain)
l.Debugf("performing GET to %s", urlString)
@@ -20,11 +38,11 @@ func (t *transport) Finger(c context.Context, targetUsername string, targetDomai
l.Debugf("performing GET to %s", iri.String())
- req, err := http.NewRequest("GET", iri.String(), nil)
+ req, err := http.NewRequestWithContext(ctx, "GET", iri.String(), nil)
if err != nil {
return nil, err
}
- req = req.WithContext(c)
+
req.Header.Add("Accept", "application/json")
req.Header.Add("Accept", "application/jrd+json")
req.Header.Add("Date", t.clock.Now().UTC().Format("Mon, 02 Jan 2006 15:04:05")+" GMT")
diff --git a/internal/transport/transport.go b/internal/transport/transport.go
index 04c72de5c..8d8262834 100644
--- a/internal/transport/transport.go
+++ b/internal/transport/transport.go
@@ -1,3 +1,21 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
package transport
import (
@@ -17,11 +35,11 @@ import (
type Transport interface {
pub.Transport
// DereferenceMedia fetches the bytes of the given media attachment IRI, with the expectedContentType.
- DereferenceMedia(c context.Context, iri *url.URL, expectedContentType string) ([]byte, error)
+ DereferenceMedia(ctx context.Context, iri *url.URL, expectedContentType string) ([]byte, error)
// DereferenceInstance dereferences remote instance information, first by checking /api/v1/instance, and then by checking /.well-known/nodeinfo.
- DereferenceInstance(c context.Context, iri *url.URL) (*gtsmodel.Instance, error)
+ DereferenceInstance(ctx context.Context, iri *url.URL) (*gtsmodel.Instance, error)
// Finger performs a webfinger request with the given username and domain, and returns the bytes from the response body.
- Finger(c context.Context, targetUsername string, targetDomains string) ([]byte, error)
+ Finger(ctx context.Context, targetUsername string, targetDomains string) ([]byte, error)
}
// transport implements the Transport interface
diff --git a/internal/typeutils/astointernal.go b/internal/typeutils/astointernal.go
index 887716a69..46132233b 100644
--- a/internal/typeutils/astointernal.go
+++ b/internal/typeutils/astointernal.go
@@ -19,6 +19,7 @@
package typeutils
import (
+ "context"
"errors"
"fmt"
"net/url"
@@ -29,7 +30,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
-func (c *converter) ASRepresentationToAccount(accountable ap.Accountable, update bool) (*gtsmodel.Account, error) {
+func (c *converter) ASRepresentationToAccount(ctx context.Context, accountable ap.Accountable, update bool) (*gtsmodel.Account, error) {
// first check if we actually already know this account
uriProp := accountable.GetJSONLDId()
if uriProp == nil || !uriProp.IsIRI() {
@@ -38,7 +39,7 @@ func (c *converter) ASRepresentationToAccount(accountable ap.Accountable, update
uri := uriProp.GetIRI()
if !update {
- acct, err := c.db.GetAccountByURI(uri.String())
+ acct, err := c.db.GetAccountByURI(ctx, uri.String())
if err == nil {
// we already know this account so we can skip generating it
return acct, nil
@@ -170,7 +171,7 @@ func (c *converter) ASRepresentationToAccount(accountable ap.Accountable, update
return acct, nil
}
-func (c *converter) ASStatusToStatus(statusable ap.Statusable) (*gtsmodel.Status, error) {
+func (c *converter) ASStatusToStatus(ctx context.Context, statusable ap.Statusable) (*gtsmodel.Status, error) {
status := &gtsmodel.Status{}
// uri at which this status is reachable
@@ -219,6 +220,7 @@ func (c *converter) ASStatusToStatus(statusable ap.Statusable) (*gtsmodel.Status
published, err := ap.ExtractPublished(statusable)
if err == nil {
status.CreatedAt = published
+ status.UpdatedAt = published
}
// which account posted this status?
@@ -229,7 +231,7 @@ func (c *converter) ASStatusToStatus(statusable ap.Statusable) (*gtsmodel.Status
}
status.AccountURI = attributedTo.String()
- statusOwner, err := c.db.GetAccountByURI(attributedTo.String())
+ statusOwner, err := c.db.GetAccountByURI(ctx, attributedTo.String())
if err != nil {
return nil, fmt.Errorf("couldn't get status owner from db: %s", err)
}
@@ -245,14 +247,14 @@ func (c *converter) ASStatusToStatus(statusable ap.Statusable) (*gtsmodel.Status
status.InReplyToURI = inReplyToURI.String()
// now we can check if we have the replied-to status in our db already
- if inReplyToStatus, err := c.db.GetStatusByURI(inReplyToURI.String()); err == nil {
+ if inReplyToStatus, err := c.db.GetStatusByURI(ctx, inReplyToURI.String()); err == nil {
// we have the status in our database already
// so we can set these fields here and now...
status.InReplyToID = inReplyToStatus.ID
status.InReplyToAccountID = inReplyToStatus.AccountID
status.InReplyTo = inReplyToStatus
if status.InReplyToAccount == nil {
- if inReplyToAccount, err := c.db.GetAccountByID(inReplyToStatus.AccountID); err == nil {
+ if inReplyToAccount, err := c.db.GetAccountByID(ctx, inReplyToStatus.AccountID); err == nil {
status.InReplyToAccount = inReplyToAccount
}
}
@@ -318,7 +320,7 @@ func (c *converter) ASStatusToStatus(statusable ap.Statusable) (*gtsmodel.Status
return status, nil
}
-func (c *converter) ASFollowToFollowRequest(followable ap.Followable) (*gtsmodel.FollowRequest, error) {
+func (c *converter) ASFollowToFollowRequest(ctx context.Context, followable ap.Followable) (*gtsmodel.FollowRequest, error) {
idProp := followable.GetJSONLDId()
if idProp == nil || !idProp.IsIRI() {
@@ -330,7 +332,7 @@ func (c *converter) ASFollowToFollowRequest(followable ap.Followable) (*gtsmodel
if err != nil {
return nil, errors.New("error extracting actor property from follow")
}
- originAccount, err := c.db.GetAccountByURI(origin.String())
+ originAccount, err := c.db.GetAccountByURI(ctx, origin.String())
if err != nil {
return nil, fmt.Errorf("error extracting account with uri %s from the database: %s", origin.String(), err)
}
@@ -339,7 +341,7 @@ func (c *converter) ASFollowToFollowRequest(followable ap.Followable) (*gtsmodel
if err != nil {
return nil, errors.New("error extracting object property from follow")
}
- targetAccount, err := c.db.GetAccountByURI(target.String())
+ targetAccount, err := c.db.GetAccountByURI(ctx, target.String())
if err != nil {
return nil, fmt.Errorf("error extracting account with uri %s from the database: %s", origin.String(), err)
}
@@ -353,7 +355,7 @@ func (c *converter) ASFollowToFollowRequest(followable ap.Followable) (*gtsmodel
return followRequest, nil
}
-func (c *converter) ASFollowToFollow(followable ap.Followable) (*gtsmodel.Follow, error) {
+func (c *converter) ASFollowToFollow(ctx context.Context, followable ap.Followable) (*gtsmodel.Follow, error) {
idProp := followable.GetJSONLDId()
if idProp == nil || !idProp.IsIRI() {
return nil, errors.New("no id property set on follow, or was not an iri")
@@ -364,7 +366,7 @@ func (c *converter) ASFollowToFollow(followable ap.Followable) (*gtsmodel.Follow
if err != nil {
return nil, errors.New("error extracting actor property from follow")
}
- originAccount, err := c.db.GetAccountByURI(origin.String())
+ originAccount, err := c.db.GetAccountByURI(ctx, origin.String())
if err != nil {
return nil, fmt.Errorf("error extracting account with uri %s from the database: %s", origin.String(), err)
}
@@ -373,7 +375,7 @@ func (c *converter) ASFollowToFollow(followable ap.Followable) (*gtsmodel.Follow
if err != nil {
return nil, errors.New("error extracting object property from follow")
}
- targetAccount, err := c.db.GetAccountByURI(target.String())
+ targetAccount, err := c.db.GetAccountByURI(ctx, target.String())
if err != nil {
return nil, fmt.Errorf("error extracting account with uri %s from the database: %s", origin.String(), err)
}
@@ -387,7 +389,7 @@ func (c *converter) ASFollowToFollow(followable ap.Followable) (*gtsmodel.Follow
return follow, nil
}
-func (c *converter) ASLikeToFave(likeable ap.Likeable) (*gtsmodel.StatusFave, error) {
+func (c *converter) ASLikeToFave(ctx context.Context, likeable ap.Likeable) (*gtsmodel.StatusFave, error) {
idProp := likeable.GetJSONLDId()
if idProp == nil || !idProp.IsIRI() {
return nil, errors.New("no id property set on like, or was not an iri")
@@ -398,7 +400,7 @@ func (c *converter) ASLikeToFave(likeable ap.Likeable) (*gtsmodel.StatusFave, er
if err != nil {
return nil, errors.New("error extracting actor property from like")
}
- originAccount, err := c.db.GetAccountByURI(origin.String())
+ originAccount, err := c.db.GetAccountByURI(ctx, origin.String())
if err != nil {
return nil, fmt.Errorf("error extracting account with uri %s from the database: %s", origin.String(), err)
}
@@ -408,7 +410,7 @@ func (c *converter) ASLikeToFave(likeable ap.Likeable) (*gtsmodel.StatusFave, er
return nil, errors.New("error extracting object property from like")
}
- targetStatus, err := c.db.GetStatusByURI(target.String())
+ targetStatus, err := c.db.GetStatusByURI(ctx, target.String())
if err != nil {
return nil, fmt.Errorf("error extracting status with uri %s from the database: %s", target.String(), err)
}
@@ -417,7 +419,7 @@ func (c *converter) ASLikeToFave(likeable ap.Likeable) (*gtsmodel.StatusFave, er
if targetStatus.Account != nil {
targetAccount = targetStatus.Account
} else {
- a, err := c.db.GetAccountByID(targetStatus.AccountID)
+ a, err := c.db.GetAccountByID(ctx, targetStatus.AccountID)
if err != nil {
return nil, fmt.Errorf("error extracting account with id %s from the database: %s", targetStatus.AccountID, err)
}
@@ -435,7 +437,7 @@ func (c *converter) ASLikeToFave(likeable ap.Likeable) (*gtsmodel.StatusFave, er
}, nil
}
-func (c *converter) ASBlockToBlock(blockable ap.Blockable) (*gtsmodel.Block, error) {
+func (c *converter) ASBlockToBlock(ctx context.Context, blockable ap.Blockable) (*gtsmodel.Block, error) {
idProp := blockable.GetJSONLDId()
if idProp == nil || !idProp.IsIRI() {
return nil, errors.New("ASBlockToBlock: no id property set on block, or was not an iri")
@@ -446,7 +448,7 @@ func (c *converter) ASBlockToBlock(blockable ap.Blockable) (*gtsmodel.Block, err
if err != nil {
return nil, errors.New("ASBlockToBlock: error extracting actor property from block")
}
- originAccount, err := c.db.GetAccountByURI(origin.String())
+ originAccount, err := c.db.GetAccountByURI(ctx, origin.String())
if err != nil {
return nil, fmt.Errorf("error extracting account with uri %s from the database: %s", origin.String(), err)
}
@@ -456,7 +458,7 @@ func (c *converter) ASBlockToBlock(blockable ap.Blockable) (*gtsmodel.Block, err
return nil, errors.New("ASBlockToBlock: error extracting object property from block")
}
- targetAccount, err := c.db.GetAccountByURI(target.String())
+ targetAccount, err := c.db.GetAccountByURI(ctx, target.String())
if err != nil {
return nil, fmt.Errorf("error extracting account with uri %s from the database: %s", origin.String(), err)
}
@@ -470,7 +472,7 @@ func (c *converter) ASBlockToBlock(blockable ap.Blockable) (*gtsmodel.Block, err
}, nil
}
-func (c *converter) ASAnnounceToStatus(announceable ap.Announceable) (*gtsmodel.Status, bool, error) {
+func (c *converter) ASAnnounceToStatus(ctx context.Context, announceable ap.Announceable) (*gtsmodel.Status, bool, error) {
status := &gtsmodel.Status{}
isNew := true
@@ -481,7 +483,7 @@ func (c *converter) ASAnnounceToStatus(announceable ap.Announceable) (*gtsmodel.
}
uri := idProp.GetIRI().String()
- if status, err := c.db.GetStatusByURI(uri); err == nil {
+ if status, err := c.db.GetStatusByURI(ctx, uri); err == nil {
// we already have it, great, just return it as-is :)
isNew = false
return status, isNew, nil
@@ -515,7 +517,7 @@ func (c *converter) ASAnnounceToStatus(announceable ap.Announceable) (*gtsmodel.
// get the boosting account based on the URI
// this should have been dereferenced already before we hit this point so we can confidently error out if we don't have it
- boostingAccount, err := c.db.GetAccountByURI(actor.String())
+ boostingAccount, err := c.db.GetAccountByURI(ctx, actor.String())
if err != nil {
return nil, isNew, fmt.Errorf("ASAnnounceToStatus: error in db fetching account with uri %s: %s", actor.String(), err)
}
diff --git a/internal/typeutils/astointernal_test.go b/internal/typeutils/astointernal_test.go
index 1d02dec5a..a01e79202 100644
--- a/internal/typeutils/astointernal_test.go
+++ b/internal/typeutils/astointernal_test.go
@@ -348,7 +348,7 @@ func (suite *ASToInternalTestSuite) SetupTest() {
func (suite *ASToInternalTestSuite) TestParsePerson() {
testPerson := suite.people["new_person_1"]
- acct, err := suite.typeconverter.ASRepresentationToAccount(testPerson, false)
+ acct, err := suite.typeconverter.ASRepresentationToAccount(context.Background(), testPerson, false)
assert.NoError(suite.T(), err)
suite.Equal("https://unknown-instance.com/users/brand_new_person", acct.URI)
@@ -379,7 +379,7 @@ func (suite *ASToInternalTestSuite) TestParseGargron() {
rep, ok := t.(ap.Accountable)
assert.True(suite.T(), ok)
- acct, err := suite.typeconverter.ASRepresentationToAccount(rep, false)
+ acct, err := suite.typeconverter.ASRepresentationToAccount(context.Background(), rep, false)
assert.NoError(suite.T(), err)
fmt.Printf("%+v", acct)
diff --git a/internal/typeutils/converter.go b/internal/typeutils/converter.go
index e477a6135..4af9767bc 100644
--- a/internal/typeutils/converter.go
+++ b/internal/typeutils/converter.go
@@ -19,6 +19,7 @@
package typeutils
import (
+ "context"
"net/url"
"github.com/go-fed/activity/streams/vocab"
@@ -47,45 +48,45 @@ type TypeConverter interface {
// AccountToMastoSensitive takes a db model account as a param, and returns a populated mastotype account, or an error
// if something goes wrong. The returned account should be ready to serialize on an API level, and may have sensitive fields,
// so serve it only to an authorized user who should have permission to see it.
- AccountToMastoSensitive(account *gtsmodel.Account) (*model.Account, error)
+ AccountToMastoSensitive(ctx context.Context, account *gtsmodel.Account) (*model.Account, error)
// AccountToMastoPublic takes a db model account as a param, and returns a populated mastotype account, or an error
// if something goes wrong. The returned account should be ready to serialize on an API level, and may NOT have sensitive fields.
// In other words, this is the public record that the server has of an account.
- AccountToMastoPublic(account *gtsmodel.Account) (*model.Account, error)
+ AccountToMastoPublic(ctx context.Context, account *gtsmodel.Account) (*model.Account, error)
// AccountToMastoBlocked takes a db model account as a param, and returns a mastotype account, or an error if
// something goes wrong. The returned account will be a bare minimum representation of the account. This function should be used
// when someone wants to view an account they've blocked.
- AccountToMastoBlocked(account *gtsmodel.Account) (*model.Account, error)
+ AccountToMastoBlocked(ctx context.Context, account *gtsmodel.Account) (*model.Account, error)
// AppToMastoSensitive takes a db model application as a param, and returns a populated mastotype application, or an error
// if something goes wrong. The returned application should be ready to serialize on an API level, and may have sensitive fields
// (such as client id and client secret), so serve it only to an authorized user who should have permission to see it.
- AppToMastoSensitive(application *gtsmodel.Application) (*model.Application, error)
+ AppToMastoSensitive(ctx context.Context, application *gtsmodel.Application) (*model.Application, error)
// AppToMastoPublic takes a db model application as a param, and returns a populated mastotype application, or an error
// if something goes wrong. The returned application should be ready to serialize on an API level, and has sensitive
// fields sanitized so that it can be served to non-authorized accounts without revealing any private information.
- AppToMastoPublic(application *gtsmodel.Application) (*model.Application, error)
+ AppToMastoPublic(ctx context.Context, application *gtsmodel.Application) (*model.Application, error)
// AttachmentToMasto converts a gts model media attacahment into its mastodon representation for serialization on the API.
- AttachmentToMasto(attachment *gtsmodel.MediaAttachment) (model.Attachment, error)
+ AttachmentToMasto(ctx context.Context, attachment *gtsmodel.MediaAttachment) (model.Attachment, error)
// MentionToMasto converts a gts model mention into its mastodon (frontend) representation for serialization on the API.
- MentionToMasto(m *gtsmodel.Mention) (model.Mention, error)
+ MentionToMasto(ctx context.Context, m *gtsmodel.Mention) (model.Mention, error)
// EmojiToMasto converts a gts model emoji into its mastodon (frontend) representation for serialization on the API.
- EmojiToMasto(e *gtsmodel.Emoji) (model.Emoji, error)
+ EmojiToMasto(ctx context.Context, e *gtsmodel.Emoji) (model.Emoji, error)
// TagToMasto converts a gts model tag into its mastodon (frontend) representation for serialization on the API.
- TagToMasto(t *gtsmodel.Tag) (model.Tag, error)
+ TagToMasto(ctx context.Context, t *gtsmodel.Tag) (model.Tag, error)
// StatusToMasto converts a gts model status into its mastodon (frontend) representation for serialization on the API.
//
// Requesting account can be nil.
- StatusToMasto(s *gtsmodel.Status, requestingAccount *gtsmodel.Account) (*model.Status, error)
+ StatusToMasto(ctx context.Context, s *gtsmodel.Status, requestingAccount *gtsmodel.Account) (*model.Status, error)
// VisToMasto converts a gts visibility into its mastodon equivalent
- VisToMasto(m gtsmodel.Visibility) model.Visibility
+ VisToMasto(ctx context.Context, m gtsmodel.Visibility) model.Visibility
// InstanceToMasto converts a gts instance into its mastodon equivalent for serving at /api/v1/instance
- InstanceToMasto(i *gtsmodel.Instance) (*model.Instance, error)
+ InstanceToMasto(ctx context.Context, i *gtsmodel.Instance) (*model.Instance, error)
// RelationshipToMasto converts a gts relationship into its mastodon equivalent for serving in various places
- RelationshipToMasto(r *gtsmodel.Relationship) (*model.Relationship, error)
+ RelationshipToMasto(ctx context.Context, r *gtsmodel.Relationship) (*model.Relationship, error)
// NotificationToMasto converts a gts notification into a mastodon notification
- NotificationToMasto(n *gtsmodel.Notification) (*model.Notification, error)
+ NotificationToMasto(ctx context.Context, n *gtsmodel.Notification) (*model.Notification, error)
// DomainBlockTomasto converts a gts model domin block into a mastodon domain block, for serving at /api/v1/admin/domain_blocks
- DomainBlockToMasto(b *gtsmodel.DomainBlock, export bool) (*model.DomainBlock, error)
+ DomainBlockToMasto(ctx context.Context, b *gtsmodel.DomainBlock, export bool) (*model.DomainBlock, error)
/*
FRONTEND (mastodon) MODEL TO INTERNAL (gts) MODEL
@@ -103,17 +104,17 @@ type TypeConverter interface {
// If update is false, and the account is already known in the database, then the existing account entry will be returned.
// If update is true, then even if the account is already known, all fields in the accountable will be parsed and a new *gtsmodel.Account
// will be generated. This is useful when one needs to force refresh of an account, eg., during an Update of a Profile.
- ASRepresentationToAccount(accountable ap.Accountable, update bool) (*gtsmodel.Account, error)
+ ASRepresentationToAccount(ctx context.Context, accountable ap.Accountable, update bool) (*gtsmodel.Account, error)
// ASStatus converts a remote activitystreams 'status' representation into a gts model status.
- ASStatusToStatus(statusable ap.Statusable) (*gtsmodel.Status, error)
+ ASStatusToStatus(ctx context.Context, statusable ap.Statusable) (*gtsmodel.Status, error)
// ASFollowToFollowRequest converts a remote activitystreams `follow` representation into gts model follow request.
- ASFollowToFollowRequest(followable ap.Followable) (*gtsmodel.FollowRequest, error)
+ ASFollowToFollowRequest(ctx context.Context, followable ap.Followable) (*gtsmodel.FollowRequest, error)
// ASFollowToFollowRequest converts a remote activitystreams `follow` representation into gts model follow.
- ASFollowToFollow(followable ap.Followable) (*gtsmodel.Follow, error)
+ ASFollowToFollow(ctx context.Context, followable ap.Followable) (*gtsmodel.Follow, error)
// ASLikeToFave converts a remote activitystreams 'like' representation into a gts model status fave.
- ASLikeToFave(likeable ap.Likeable) (*gtsmodel.StatusFave, error)
+ ASLikeToFave(ctx context.Context, likeable ap.Likeable) (*gtsmodel.StatusFave, error)
// ASBlockToBlock converts a remote activity streams 'block' representation into a gts model block.
- ASBlockToBlock(blockable ap.Blockable) (*gtsmodel.Block, error)
+ ASBlockToBlock(ctx context.Context, blockable ap.Blockable) (*gtsmodel.Block, error)
// ASAnnounceToStatus converts an activitystreams 'announce' into a status.
//
// The returned bool indicates whether this status is new (true) or not new (false).
@@ -126,46 +127,46 @@ type TypeConverter interface {
// This is useful when multiple users on an instance might receive the same boost, and we only want to process the boost once.
//
// NOTE -- this is different from one status being boosted multiple times! In this case, new boosts should indeed be created.
- ASAnnounceToStatus(announceable ap.Announceable) (status *gtsmodel.Status, new bool, err error)
+ ASAnnounceToStatus(ctx context.Context, announceable ap.Announceable) (status *gtsmodel.Status, new bool, err error)
/*
INTERNAL (gts) MODEL TO ACTIVITYSTREAMS MODEL
*/
// AccountToAS converts a gts model account into an activity streams person, suitable for federation
- AccountToAS(a *gtsmodel.Account) (vocab.ActivityStreamsPerson, error)
+ AccountToAS(ctx context.Context, a *gtsmodel.Account) (vocab.ActivityStreamsPerson, error)
// AccountToASMinimal converts a gts model account into an activity streams person, suitable for federation.
//
// The returned account will just have the Type, Username, PublicKey, and ID properties set. This is
// suitable for serving to requesters to whom we want to give as little information as possible because
// we don't trust them (yet).
- AccountToASMinimal(a *gtsmodel.Account) (vocab.ActivityStreamsPerson, error)
+ AccountToASMinimal(ctx context.Context, a *gtsmodel.Account) (vocab.ActivityStreamsPerson, error)
// StatusToAS converts a gts model status into an activity streams note, suitable for federation
- StatusToAS(s *gtsmodel.Status) (vocab.ActivityStreamsNote, error)
+ StatusToAS(ctx context.Context, s *gtsmodel.Status) (vocab.ActivityStreamsNote, error)
// FollowToASFollow converts a gts model Follow into an activity streams Follow, suitable for federation
- FollowToAS(f *gtsmodel.Follow, originAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (vocab.ActivityStreamsFollow, error)
+ FollowToAS(ctx context.Context, f *gtsmodel.Follow, originAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (vocab.ActivityStreamsFollow, error)
// MentionToAS converts a gts model mention into an activity streams Mention, suitable for federation
- MentionToAS(m *gtsmodel.Mention) (vocab.ActivityStreamsMention, error)
+ MentionToAS(ctx context.Context, m *gtsmodel.Mention) (vocab.ActivityStreamsMention, error)
// AttachmentToAS converts a gts model media attachment into an activity streams Attachment, suitable for federation
- AttachmentToAS(a *gtsmodel.MediaAttachment) (vocab.ActivityStreamsDocument, error)
+ AttachmentToAS(ctx context.Context, a *gtsmodel.MediaAttachment) (vocab.ActivityStreamsDocument, error)
// FaveToAS converts a gts model status fave into an activityStreams LIKE, suitable for federation.
- FaveToAS(f *gtsmodel.StatusFave) (vocab.ActivityStreamsLike, error)
+ FaveToAS(ctx context.Context, f *gtsmodel.StatusFave) (vocab.ActivityStreamsLike, error)
// BoostToAS converts a gts model boost into an activityStreams ANNOUNCE, suitable for federation
- BoostToAS(boostWrapperStatus *gtsmodel.Status, boostingAccount *gtsmodel.Account, boostedAccount *gtsmodel.Account) (vocab.ActivityStreamsAnnounce, error)
+ BoostToAS(ctx context.Context, boostWrapperStatus *gtsmodel.Status, boostingAccount *gtsmodel.Account, boostedAccount *gtsmodel.Account) (vocab.ActivityStreamsAnnounce, error)
// BlockToAS converts a gts model block into an activityStreams BLOCK, suitable for federation.
- BlockToAS(block *gtsmodel.Block) (vocab.ActivityStreamsBlock, error)
+ BlockToAS(ctx context.Context, block *gtsmodel.Block) (vocab.ActivityStreamsBlock, error)
// StatusToASRepliesCollection converts a gts model status into an activityStreams REPLIES collection.
- StatusToASRepliesCollection(status *gtsmodel.Status, onlyOtherAccounts bool) (vocab.ActivityStreamsCollection, error)
+ StatusToASRepliesCollection(ctx context.Context, status *gtsmodel.Status, onlyOtherAccounts bool) (vocab.ActivityStreamsCollection, error)
// StatusURIsToASRepliesPage returns a collection page with appropriate next/part of pagination.
- StatusURIsToASRepliesPage(status *gtsmodel.Status, onlyOtherAccounts bool, minID string, replies map[string]*url.URL) (vocab.ActivityStreamsCollectionPage, error)
+ StatusURIsToASRepliesPage(ctx context.Context, status *gtsmodel.Status, onlyOtherAccounts bool, minID string, replies map[string]*url.URL) (vocab.ActivityStreamsCollectionPage, error)
/*
INTERNAL (gts) MODEL TO INTERNAL MODEL
*/
// FollowRequestToFollow just converts a follow request into a follow, that's it! No bells and whistles.
- FollowRequestToFollow(f *gtsmodel.FollowRequest) *gtsmodel.Follow
+ FollowRequestToFollow(ctx context.Context, f *gtsmodel.FollowRequest) *gtsmodel.Follow
// StatusToBoost wraps the given status into a boosting status.
- StatusToBoost(s *gtsmodel.Status, boostingAccount *gtsmodel.Account) (*gtsmodel.Status, error)
+ StatusToBoost(ctx context.Context, s *gtsmodel.Status, boostingAccount *gtsmodel.Account) (*gtsmodel.Status, error)
/*
WRAPPER CONVENIENCE FUNCTIONS
diff --git a/internal/typeutils/internal.go b/internal/typeutils/internal.go
index ad15ecbee..23839b9a8 100644
--- a/internal/typeutils/internal.go
+++ b/internal/typeutils/internal.go
@@ -1,6 +1,7 @@
package typeutils
import (
+ "context"
"fmt"
"time"
@@ -9,7 +10,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/util"
)
-func (c *converter) FollowRequestToFollow(f *gtsmodel.FollowRequest) *gtsmodel.Follow {
+func (c *converter) FollowRequestToFollow(ctx context.Context, f *gtsmodel.FollowRequest) *gtsmodel.Follow {
return &gtsmodel.Follow{
ID: f.ID,
CreatedAt: f.CreatedAt,
@@ -22,7 +23,7 @@ func (c *converter) FollowRequestToFollow(f *gtsmodel.FollowRequest) *gtsmodel.F
}
}
-func (c *converter) StatusToBoost(s *gtsmodel.Status, boostingAccount *gtsmodel.Account) (*gtsmodel.Status, error) {
+func (c *converter) StatusToBoost(ctx context.Context, s *gtsmodel.Status, boostingAccount *gtsmodel.Account) (*gtsmodel.Status, error) {
// the wrapper won't use the same ID as the boosted status so we generate some new UUIDs
uris := util.GenerateURIsForAccount(boostingAccount.Username, c.config.Protocol, c.config.Host)
boostWrapperStatusID, err := id.NewULID()
diff --git a/internal/typeutils/internaltoas.go b/internal/typeutils/internaltoas.go
index 178567dc6..14ed094c5 100644
--- a/internal/typeutils/internaltoas.go
+++ b/internal/typeutils/internaltoas.go
@@ -19,6 +19,7 @@
package typeutils
import (
+ "context"
"crypto/x509"
"encoding/pem"
"fmt"
@@ -33,7 +34,7 @@ import (
// Converts a gts model account into an Activity Streams person type, following
// the spec laid out for mastodon here: https://docs.joinmastodon.org/spec/activitypub/
-func (c *converter) AccountToAS(a *gtsmodel.Account) (vocab.ActivityStreamsPerson, error) {
+func (c *converter) AccountToAS(ctx context.Context, a *gtsmodel.Account) (vocab.ActivityStreamsPerson, error) {
// first check if we have this person in our asCache already
if personI, err := c.asCache.Fetch(a.ID); err == nil {
if person, ok := personI.(vocab.ActivityStreamsPerson); ok {
@@ -213,9 +214,12 @@ func (c *converter) AccountToAS(a *gtsmodel.Account) (vocab.ActivityStreamsPerso
// icon
// Used as profile avatar.
if a.AvatarMediaAttachmentID != "" {
- avatar := &gtsmodel.MediaAttachment{}
- if err := c.db.GetByID(a.AvatarMediaAttachmentID, avatar); err != nil {
- return nil, err
+ if a.AvatarMediaAttachment == nil {
+ avatar := &gtsmodel.MediaAttachment{}
+ if err := c.db.GetByID(ctx, a.AvatarMediaAttachmentID, avatar); err != nil {
+ return nil, err
+ }
+ a.AvatarMediaAttachment = avatar
}
iconProperty := streams.NewActivityStreamsIconProperty()
@@ -223,11 +227,11 @@ func (c *converter) AccountToAS(a *gtsmodel.Account) (vocab.ActivityStreamsPerso
iconImage := streams.NewActivityStreamsImage()
mediaType := streams.NewActivityStreamsMediaTypeProperty()
- mediaType.Set(avatar.File.ContentType)
+ mediaType.Set(a.AvatarMediaAttachment.File.ContentType)
iconImage.SetActivityStreamsMediaType(mediaType)
avatarURLProperty := streams.NewActivityStreamsUrlProperty()
- avatarURL, err := url.Parse(avatar.URL)
+ avatarURL, err := url.Parse(a.AvatarMediaAttachment.URL)
if err != nil {
return nil, err
}
@@ -241,9 +245,12 @@ func (c *converter) AccountToAS(a *gtsmodel.Account) (vocab.ActivityStreamsPerso
// image
// Used as profile header.
if a.HeaderMediaAttachmentID != "" {
- header := &gtsmodel.MediaAttachment{}
- if err := c.db.GetByID(a.HeaderMediaAttachmentID, header); err != nil {
- return nil, err
+ if a.HeaderMediaAttachment == nil {
+ header := &gtsmodel.MediaAttachment{}
+ if err := c.db.GetByID(ctx, a.HeaderMediaAttachmentID, header); err != nil {
+ return nil, err
+ }
+ a.HeaderMediaAttachment = header
}
headerProperty := streams.NewActivityStreamsImageProperty()
@@ -251,11 +258,11 @@ func (c *converter) AccountToAS(a *gtsmodel.Account) (vocab.ActivityStreamsPerso
headerImage := streams.NewActivityStreamsImage()
mediaType := streams.NewActivityStreamsMediaTypeProperty()
- mediaType.Set(header.File.ContentType)
+ mediaType.Set(a.HeaderMediaAttachment.File.ContentType)
headerImage.SetActivityStreamsMediaType(mediaType)
headerURLProperty := streams.NewActivityStreamsUrlProperty()
- headerURL, err := url.Parse(header.URL)
+ headerURL, err := url.Parse(a.HeaderMediaAttachment.URL)
if err != nil {
return nil, err
}
@@ -278,7 +285,7 @@ func (c *converter) AccountToAS(a *gtsmodel.Account) (vocab.ActivityStreamsPerso
// the spec laid out for mastodon here: https://docs.joinmastodon.org/spec/activitypub/
//
// The returned account will just have the Type, Username, PublicKey, and ID properties set.
-func (c *converter) AccountToASMinimal(a *gtsmodel.Account) (vocab.ActivityStreamsPerson, error) {
+func (c *converter) AccountToASMinimal(ctx context.Context, a *gtsmodel.Account) (vocab.ActivityStreamsPerson, error) {
person := streams.NewActivityStreamsPerson()
// id should be the activitypub URI of this user
@@ -340,7 +347,7 @@ func (c *converter) AccountToASMinimal(a *gtsmodel.Account) (vocab.ActivityStrea
return person, nil
}
-func (c *converter) StatusToAS(s *gtsmodel.Status) (vocab.ActivityStreamsNote, error) {
+func (c *converter) StatusToAS(ctx context.Context, s *gtsmodel.Status) (vocab.ActivityStreamsNote, error) {
// first check if we have this note in our asCache already
if noteI, err := c.asCache.Fetch(s.ID); err == nil {
if note, ok := noteI.(vocab.ActivityStreamsNote); ok {
@@ -354,7 +361,7 @@ func (c *converter) StatusToAS(s *gtsmodel.Status) (vocab.ActivityStreamsNote, e
// check if author account is already attached to status and attach it if not
// if we can't retrieve this, bail here already because we can't attribute the status to anyone
if s.Account == nil {
- a, err := c.db.GetAccountByID(s.AccountID)
+ a, err := c.db.GetAccountByID(ctx, s.AccountID)
if err != nil {
return nil, fmt.Errorf("StatusToAS: error retrieving author account from db: %s", err)
}
@@ -386,7 +393,7 @@ func (c *converter) StatusToAS(s *gtsmodel.Status) (vocab.ActivityStreamsNote, e
// fetch the replied status if we don't have it on hand already
if s.InReplyTo == nil {
rs := &gtsmodel.Status{}
- if err := c.db.GetByID(s.InReplyToID, rs); err != nil {
+ if err := c.db.GetByID(ctx, s.InReplyToID, rs); err != nil {
return nil, fmt.Errorf("StatusToAS: error retrieving replied-to status from db: %s", err)
}
s.InReplyTo = rs
@@ -432,7 +439,7 @@ func (c *converter) StatusToAS(s *gtsmodel.Status) (vocab.ActivityStreamsNote, e
// tag -- mentions
for _, m := range s.Mentions {
- asMention, err := c.MentionToAS(m)
+ asMention, err := c.MentionToAS(ctx, m)
if err != nil {
return nil, fmt.Errorf("StatusToAS: error converting mention to AS mention: %s", err)
}
@@ -520,7 +527,7 @@ func (c *converter) StatusToAS(s *gtsmodel.Status) (vocab.ActivityStreamsNote, e
// attachment
attachmentProp := streams.NewActivityStreamsAttachmentProperty()
for _, a := range s.Attachments {
- doc, err := c.AttachmentToAS(a)
+ doc, err := c.AttachmentToAS(ctx, a)
if err != nil {
return nil, fmt.Errorf("StatusToAS: error converting attachment: %s", err)
}
@@ -529,7 +536,7 @@ func (c *converter) StatusToAS(s *gtsmodel.Status) (vocab.ActivityStreamsNote, e
status.SetActivityStreamsAttachment(attachmentProp)
// replies
- repliesCollection, err := c.StatusToASRepliesCollection(s, false)
+ repliesCollection, err := c.StatusToASRepliesCollection(ctx, s, false)
if err != nil {
return nil, fmt.Errorf("error creating repliesCollection: %s", err)
}
@@ -546,7 +553,7 @@ func (c *converter) StatusToAS(s *gtsmodel.Status) (vocab.ActivityStreamsNote, e
return status, nil
}
-func (c *converter) FollowToAS(f *gtsmodel.Follow, originAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (vocab.ActivityStreamsFollow, error) {
+func (c *converter) FollowToAS(ctx context.Context, f *gtsmodel.Follow, originAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (vocab.ActivityStreamsFollow, error) {
// parse out the various URIs we need for this
// origin account (who's doing the follow)
originAccountURI, err := url.Parse(originAccount.URI)
@@ -592,10 +599,10 @@ func (c *converter) FollowToAS(f *gtsmodel.Follow, originAccount *gtsmodel.Accou
return follow, nil
}
-func (c *converter) MentionToAS(m *gtsmodel.Mention) (vocab.ActivityStreamsMention, error) {
+func (c *converter) MentionToAS(ctx context.Context, m *gtsmodel.Mention) (vocab.ActivityStreamsMention, error) {
if m.OriginAccount == nil {
a := &gtsmodel.Account{}
- if err := c.db.GetWhere([]db.Where{{Key: "target_account_id", Value: m.TargetAccountID}}, a); err != nil {
+ if err := c.db.GetWhere(ctx, []db.Where{{Key: "target_account_id", Value: m.TargetAccountID}}, a); err != nil {
return nil, fmt.Errorf("MentionToAS: error getting target account from db: %s", err)
}
m.OriginAccount = a
@@ -629,7 +636,7 @@ func (c *converter) MentionToAS(m *gtsmodel.Mention) (vocab.ActivityStreamsMenti
return mention, nil
}
-func (c *converter) AttachmentToAS(a *gtsmodel.MediaAttachment) (vocab.ActivityStreamsDocument, error) {
+func (c *converter) AttachmentToAS(ctx context.Context, a *gtsmodel.MediaAttachment) (vocab.ActivityStreamsDocument, error) {
// type -- Document
doc := streams.NewActivityStreamsDocument()
@@ -674,11 +681,11 @@ func (c *converter) AttachmentToAS(a *gtsmodel.MediaAttachment) (vocab.ActivityS
"type": "Like"
}
*/
-func (c *converter) FaveToAS(f *gtsmodel.StatusFave) (vocab.ActivityStreamsLike, error) {
+func (c *converter) FaveToAS(ctx context.Context, f *gtsmodel.StatusFave) (vocab.ActivityStreamsLike, error) {
// check if targetStatus is already pinned to this fave, and fetch it if not
if f.Status == nil {
s := &gtsmodel.Status{}
- if err := c.db.GetByID(f.StatusID, s); err != nil {
+ if err := c.db.GetByID(ctx, f.StatusID, s); err != nil {
return nil, fmt.Errorf("FaveToAS: error fetching target status from database: %s", err)
}
f.Status = s
@@ -687,7 +694,7 @@ func (c *converter) FaveToAS(f *gtsmodel.StatusFave) (vocab.ActivityStreamsLike,
// check if the targetAccount is already pinned to this fave, and fetch it if not
if f.TargetAccount == nil {
a := &gtsmodel.Account{}
- if err := c.db.GetByID(f.TargetAccountID, a); err != nil {
+ if err := c.db.GetByID(ctx, f.TargetAccountID, a); err != nil {
return nil, fmt.Errorf("FaveToAS: error fetching target account from database: %s", err)
}
f.TargetAccount = a
@@ -696,7 +703,7 @@ func (c *converter) FaveToAS(f *gtsmodel.StatusFave) (vocab.ActivityStreamsLike,
// check if the faving account is already pinned to this fave, and fetch it if not
if f.Account == nil {
a := &gtsmodel.Account{}
- if err := c.db.GetByID(f.AccountID, a); err != nil {
+ if err := c.db.GetByID(ctx, f.AccountID, a); err != nil {
return nil, fmt.Errorf("FaveToAS: error fetching faving account from database: %s", err)
}
f.Account = a
@@ -744,11 +751,11 @@ func (c *converter) FaveToAS(f *gtsmodel.StatusFave) (vocab.ActivityStreamsLike,
return like, nil
}
-func (c *converter) BoostToAS(boostWrapperStatus *gtsmodel.Status, boostingAccount *gtsmodel.Account, boostedAccount *gtsmodel.Account) (vocab.ActivityStreamsAnnounce, error) {
+func (c *converter) BoostToAS(ctx context.Context, boostWrapperStatus *gtsmodel.Status, boostingAccount *gtsmodel.Account, boostedAccount *gtsmodel.Account) (vocab.ActivityStreamsAnnounce, error) {
// the boosted status is probably pinned to the boostWrapperStatus but double check to make sure
if boostWrapperStatus.BoostOf == nil {
b := &gtsmodel.Status{}
- if err := c.db.GetByID(boostWrapperStatus.BoostOfID, b); err != nil {
+ if err := c.db.GetByID(ctx, boostWrapperStatus.BoostOfID, b); err != nil {
return nil, fmt.Errorf("BoostToAS: error getting status with ID %s from the db: %s", boostWrapperStatus.BoostOfID, err)
}
boostWrapperStatus.BoostOf = b
@@ -828,10 +835,10 @@ func (c *converter) BoostToAS(boostWrapperStatus *gtsmodel.Status, boostingAccou
"type":"Block"
}
*/
-func (c *converter) BlockToAS(b *gtsmodel.Block) (vocab.ActivityStreamsBlock, error) {
+func (c *converter) BlockToAS(ctx context.Context, b *gtsmodel.Block) (vocab.ActivityStreamsBlock, error) {
if b.Account == nil {
a := &gtsmodel.Account{}
- if err := c.db.GetByID(b.AccountID, a); err != nil {
+ if err := c.db.GetByID(ctx, b.AccountID, a); err != nil {
return nil, fmt.Errorf("BlockToAS: error getting block account from database: %s", err)
}
b.Account = a
@@ -839,7 +846,7 @@ func (c *converter) BlockToAS(b *gtsmodel.Block) (vocab.ActivityStreamsBlock, er
if b.TargetAccount == nil {
a := &gtsmodel.Account{}
- if err := c.db.GetByID(b.TargetAccountID, a); err != nil {
+ if err := c.db.GetByID(ctx, b.TargetAccountID, a); err != nil {
return nil, fmt.Errorf("BlockToAS: error getting block target account from database: %s", err)
}
b.TargetAccount = a
@@ -903,7 +910,7 @@ func (c *converter) BlockToAS(b *gtsmodel.Block) (vocab.ActivityStreamsBlock, er
}
}
*/
-func (c *converter) StatusToASRepliesCollection(status *gtsmodel.Status, onlyOtherAccounts bool) (vocab.ActivityStreamsCollection, error) {
+func (c *converter) StatusToASRepliesCollection(ctx context.Context, status *gtsmodel.Status, onlyOtherAccounts bool) (vocab.ActivityStreamsCollection, error) {
collectionID := fmt.Sprintf("%s/replies", status.URI)
collectionIDURI, err := url.Parse(collectionID)
if err != nil {
@@ -966,7 +973,7 @@ func (c *converter) StatusToASRepliesCollection(status *gtsmodel.Status, onlyOth
]
}
*/
-func (c *converter) StatusURIsToASRepliesPage(status *gtsmodel.Status, onlyOtherAccounts bool, minID string, replies map[string]*url.URL) (vocab.ActivityStreamsCollectionPage, error) {
+func (c *converter) StatusURIsToASRepliesPage(ctx context.Context, status *gtsmodel.Status, onlyOtherAccounts bool, minID string, replies map[string]*url.URL) (vocab.ActivityStreamsCollectionPage, error) {
collectionID := fmt.Sprintf("%s/replies", status.URI)
page := streams.NewActivityStreamsCollectionPage()
diff --git a/internal/typeutils/internaltoas_test.go b/internal/typeutils/internaltoas_test.go
index caa56ce0d..46f04df2f 100644
--- a/internal/typeutils/internaltoas_test.go
+++ b/internal/typeutils/internaltoas_test.go
@@ -19,6 +19,7 @@
package typeutils_test
import (
+ "context"
"encoding/json"
"fmt"
"testing"
@@ -58,7 +59,7 @@ func (suite *InternalToASTestSuite) TearDownTest() {
func (suite *InternalToASTestSuite) TestAccountToAS() {
testAccount := suite.accounts["local_account_1"] // take zork for this test
- asPerson, err := suite.typeconverter.AccountToAS(testAccount)
+ asPerson, err := suite.typeconverter.AccountToAS(context.Background(), testAccount)
assert.NoError(suite.T(), err)
ser, err := streams.Serialize(asPerson)
diff --git a/internal/typeutils/internaltofrontend.go b/internal/typeutils/internaltofrontend.go
index caa14e211..89da9eb01 100644
--- a/internal/typeutils/internaltofrontend.go
+++ b/internal/typeutils/internaltofrontend.go
@@ -19,6 +19,7 @@
package typeutils
import (
+ "context"
"fmt"
"strings"
"time"
@@ -28,9 +29,9 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
-func (c *converter) AccountToMastoSensitive(a *gtsmodel.Account) (*model.Account, error) {
+func (c *converter) AccountToMastoSensitive(ctx context.Context, a *gtsmodel.Account) (*model.Account, error) {
// we can build this sensitive account easily by first getting the public account....
- mastoAccount, err := c.AccountToMastoPublic(a)
+ mastoAccount, err := c.AccountToMastoPublic(ctx, a)
if err != nil {
return nil, err
}
@@ -38,7 +39,7 @@ func (c *converter) AccountToMastoSensitive(a *gtsmodel.Account) (*model.Account
// then adding the Source object to it...
// check pending follow requests aimed at this account
- frs, err := c.db.GetAccountFollowRequests(a.ID)
+ frs, err := c.db.GetAccountFollowRequests(ctx, a.ID)
if err != nil {
if err != db.ErrNoEntries {
return nil, fmt.Errorf("error getting follow requests: %s", err)
@@ -50,7 +51,7 @@ func (c *converter) AccountToMastoSensitive(a *gtsmodel.Account) (*model.Account
}
mastoAccount.Source = &model.Source{
- Privacy: c.VisToMasto(a.Privacy),
+ Privacy: c.VisToMasto(ctx, a.Privacy),
Sensitive: a.Sensitive,
Language: a.Language,
Note: a.Note,
@@ -61,7 +62,11 @@ func (c *converter) AccountToMastoSensitive(a *gtsmodel.Account) (*model.Account
return mastoAccount, nil
}
-func (c *converter) AccountToMastoPublic(a *gtsmodel.Account) (*model.Account, error) {
+func (c *converter) AccountToMastoPublic(ctx context.Context, a *gtsmodel.Account) (*model.Account, error) {
+ if a == nil {
+ return nil, fmt.Errorf("given account was nil")
+ }
+
// first check if we have this account in our frontEnd cache
if accountI, err := c.frontendCache.Fetch(a.ID); err == nil {
if account, ok := accountI.(*model.Account); ok {
@@ -71,26 +76,26 @@ func (c *converter) AccountToMastoPublic(a *gtsmodel.Account) (*model.Account, e
}
// count followers
- followersCount, err := c.db.CountAccountFollowedBy(a.ID, false)
+ followersCount, err := c.db.CountAccountFollowedBy(ctx, a.ID, false)
if err != nil {
return nil, fmt.Errorf("error counting followers: %s", err)
}
// count following
- followingCount, err := c.db.CountAccountFollows(a.ID, false)
+ followingCount, err := c.db.CountAccountFollows(ctx, a.ID, false)
if err != nil {
return nil, fmt.Errorf("error counting following: %s", err)
}
// count statuses
- statusesCount, err := c.db.CountAccountStatuses(a.ID)
+ statusesCount, err := c.db.CountAccountStatuses(ctx, a.ID)
if err != nil {
return nil, fmt.Errorf("error counting statuses: %s", err)
}
// check when the last status was
var lastStatusAt string
- lastPosted, err := c.db.GetAccountLastPosted(a.ID)
+ lastPosted, err := c.db.GetAccountLastPosted(ctx, a.ID)
if err == nil && !lastPosted.IsZero() {
lastStatusAt = lastPosted.Format(time.RFC3339)
}
@@ -101,7 +106,7 @@ func (c *converter) AccountToMastoPublic(a *gtsmodel.Account) (*model.Account, e
if a.AvatarMediaAttachmentID != "" {
// make sure avi is pinned to this account
if a.AvatarMediaAttachment == nil {
- avi, err := c.db.GetAttachmentByID(a.AvatarMediaAttachmentID)
+ avi, err := c.db.GetAttachmentByID(ctx, a.AvatarMediaAttachmentID)
if err != nil {
return nil, fmt.Errorf("error retrieving avatar: %s", err)
}
@@ -116,7 +121,7 @@ func (c *converter) AccountToMastoPublic(a *gtsmodel.Account) (*model.Account, e
if a.HeaderMediaAttachmentID != "" {
// make sure header is pinned to this account
if a.HeaderMediaAttachment == nil {
- avi, err := c.db.GetAttachmentByID(a.HeaderMediaAttachmentID)
+ avi, err := c.db.GetAttachmentByID(ctx, a.HeaderMediaAttachmentID)
if err != nil {
return nil, fmt.Errorf("error retrieving avatar: %s", err)
}
@@ -187,7 +192,7 @@ func (c *converter) AccountToMastoPublic(a *gtsmodel.Account) (*model.Account, e
return accountFrontend, nil
}
-func (c *converter) AccountToMastoBlocked(a *gtsmodel.Account) (*model.Account, error) {
+func (c *converter) AccountToMastoBlocked(ctx context.Context, a *gtsmodel.Account) (*model.Account, error) {
var acct string
if a.Domain != "" {
// this is a remote user
@@ -214,7 +219,7 @@ func (c *converter) AccountToMastoBlocked(a *gtsmodel.Account) (*model.Account,
}, nil
}
-func (c *converter) AppToMastoSensitive(a *gtsmodel.Application) (*model.Application, error) {
+func (c *converter) AppToMastoSensitive(ctx context.Context, a *gtsmodel.Application) (*model.Application, error) {
return &model.Application{
ID: a.ID,
Name: a.Name,
@@ -226,14 +231,14 @@ func (c *converter) AppToMastoSensitive(a *gtsmodel.Application) (*model.Applica
}, nil
}
-func (c *converter) AppToMastoPublic(a *gtsmodel.Application) (*model.Application, error) {
+func (c *converter) AppToMastoPublic(ctx context.Context, a *gtsmodel.Application) (*model.Application, error) {
return &model.Application{
Name: a.Name,
Website: a.Website,
}, nil
}
-func (c *converter) AttachmentToMasto(a *gtsmodel.MediaAttachment) (model.Attachment, error) {
+func (c *converter) AttachmentToMasto(ctx context.Context, a *gtsmodel.MediaAttachment) (model.Attachment, error) {
return model.Attachment{
ID: a.ID,
Type: strings.ToLower(string(a.Type)),
@@ -264,33 +269,36 @@ func (c *converter) AttachmentToMasto(a *gtsmodel.MediaAttachment) (model.Attach
}, nil
}
-func (c *converter) MentionToMasto(m *gtsmodel.Mention) (model.Mention, error) {
- target := &gtsmodel.Account{}
- if err := c.db.GetByID(m.TargetAccountID, target); err != nil {
- return model.Mention{}, err
+func (c *converter) MentionToMasto(ctx context.Context, m *gtsmodel.Mention) (model.Mention, error) {
+ if m.TargetAccount == nil {
+ targetAccount, err := c.db.GetAccountByID(ctx, m.TargetAccountID)
+ if err != nil {
+ return model.Mention{}, err
+ }
+ m.TargetAccount = targetAccount
}
var local bool
- if target.Domain == "" {
+ if m.TargetAccount.Domain == "" {
local = true
}
var acct string
if local {
- acct = target.Username
+ acct = m.TargetAccount.Username
} else {
- acct = fmt.Sprintf("%s@%s", target.Username, target.Domain)
+ acct = fmt.Sprintf("%s@%s", m.TargetAccount.Username, m.TargetAccount.Domain)
}
return model.Mention{
- ID: target.ID,
- Username: target.Username,
- URL: target.URL,
+ ID: m.TargetAccount.ID,
+ Username: m.TargetAccount.Username,
+ URL: m.TargetAccount.URL,
Acct: acct,
}, nil
}
-func (c *converter) EmojiToMasto(e *gtsmodel.Emoji) (model.Emoji, error) {
+func (c *converter) EmojiToMasto(ctx context.Context, e *gtsmodel.Emoji) (model.Emoji, error) {
return model.Emoji{
Shortcode: e.Shortcode,
URL: e.ImageURL,
@@ -300,27 +308,25 @@ func (c *converter) EmojiToMasto(e *gtsmodel.Emoji) (model.Emoji, error) {
}, nil
}
-func (c *converter) TagToMasto(t *gtsmodel.Tag) (model.Tag, error) {
- tagURL := fmt.Sprintf("%s://%s/tags/%s", c.config.Protocol, c.config.Host, t.Name)
-
+func (c *converter) TagToMasto(ctx context.Context, t *gtsmodel.Tag) (model.Tag, error) {
return model.Tag{
Name: t.Name,
- URL: tagURL, // we don't serve URLs with collections of tagged statuses (FOR NOW) so this is purely for mastodon compatibility ¯\_(ツ)_/¯
+ URL: t.URL,
}, nil
}
-func (c *converter) StatusToMasto(s *gtsmodel.Status, requestingAccount *gtsmodel.Account) (*model.Status, error) {
- repliesCount, err := c.db.CountStatusReplies(s)
+func (c *converter) StatusToMasto(ctx context.Context, s *gtsmodel.Status, requestingAccount *gtsmodel.Account) (*model.Status, error) {
+ repliesCount, err := c.db.CountStatusReplies(ctx, s)
if err != nil {
return nil, fmt.Errorf("error counting replies: %s", err)
}
- reblogsCount, err := c.db.CountStatusReblogs(s)
+ reblogsCount, err := c.db.CountStatusReblogs(ctx, s)
if err != nil {
return nil, fmt.Errorf("error counting reblogs: %s", err)
}
- favesCount, err := c.db.CountStatusFaves(s)
+ favesCount, err := c.db.CountStatusFaves(ctx, s)
if err != nil {
return nil, fmt.Errorf("error counting faves: %s", err)
}
@@ -330,8 +336,8 @@ func (c *converter) StatusToMasto(s *gtsmodel.Status, requestingAccount *gtsmode
// the boosted status might have been set on this struct already so check first before doing db calls
if s.BoostOf == nil {
// it's not set so fetch it from the db
- bs := &gtsmodel.Status{}
- if err := c.db.GetByID(s.BoostOfID, bs); err != nil {
+ bs, err := c.db.GetStatusByID(ctx, s.BoostOfID)
+ if err != nil {
return nil, fmt.Errorf("error getting boosted status with id %s: %s", s.BoostOfID, err)
}
s.BoostOf = bs
@@ -340,15 +346,15 @@ func (c *converter) StatusToMasto(s *gtsmodel.Status, requestingAccount *gtsmode
// the boosted account might have been set on this struct already or passed as a param so check first before doing db calls
if s.BoostOfAccount == nil {
// it's not set so fetch it from the db
- ba := &gtsmodel.Account{}
- if err := c.db.GetByID(s.BoostOf.AccountID, ba); err != nil {
+ ba, err := c.db.GetAccountByID(ctx, s.BoostOf.AccountID)
+ if err != nil {
return nil, fmt.Errorf("error getting boosted account %s from status with id %s: %s", s.BoostOf.AccountID, s.BoostOfID, err)
}
s.BoostOfAccount = ba
s.BoostOf.Account = ba
}
- mastoRebloggedStatus, err = c.StatusToMasto(s.BoostOf, requestingAccount)
+ mastoRebloggedStatus, err = c.StatusToMasto(ctx, s.BoostOf, requestingAccount)
if err != nil {
return nil, fmt.Errorf("error converting boosted status to mastotype: %s", err)
}
@@ -357,24 +363,24 @@ func (c *converter) StatusToMasto(s *gtsmodel.Status, requestingAccount *gtsmode
var mastoApplication *model.Application
if s.CreatedWithApplicationID != "" {
gtsApplication := &gtsmodel.Application{}
- if err := c.db.GetByID(s.CreatedWithApplicationID, gtsApplication); err != nil {
+ if err := c.db.GetByID(ctx, s.CreatedWithApplicationID, gtsApplication); err != nil {
return nil, fmt.Errorf("error fetching application used to create status: %s", err)
}
- mastoApplication, err = c.AppToMastoPublic(gtsApplication)
+ mastoApplication, err = c.AppToMastoPublic(ctx, gtsApplication)
if err != nil {
return nil, fmt.Errorf("error parsing application used to create status: %s", err)
}
}
if s.Account == nil {
- a := &gtsmodel.Account{}
- if err := c.db.GetByID(s.AccountID, a); err != nil {
+ a, err := c.db.GetAccountByID(ctx, s.AccountID)
+ if err != nil {
return nil, fmt.Errorf("error getting status author: %s", err)
}
s.Account = a
}
- mastoAuthorAccount, err := c.AccountToMastoPublic(s.Account)
+ mastoAuthorAccount, err := c.AccountToMastoPublic(ctx, s.Account)
if err != nil {
return nil, fmt.Errorf("error parsing account of status author: %s", err)
}
@@ -384,7 +390,7 @@ func (c *converter) StatusToMasto(s *gtsmodel.Status, requestingAccount *gtsmode
// if so, we can directly convert the gts attachments into masto ones
if s.Attachments != nil {
for _, gtsAttachment := range s.Attachments {
- mastoAttachment, err := c.AttachmentToMasto(gtsAttachment)
+ mastoAttachment, err := c.AttachmentToMasto(ctx, gtsAttachment)
if err != nil {
return nil, fmt.Errorf("error converting attachment with id %s: %s", gtsAttachment.ID, err)
}
@@ -393,14 +399,14 @@ func (c *converter) StatusToMasto(s *gtsmodel.Status, requestingAccount *gtsmode
// the status doesn't have gts attachments on it, but it does have attachment IDs
// in this case, we need to pull the gts attachments from the db to convert them into masto ones
} else {
- for _, a := range s.AttachmentIDs {
- gtsAttachment := &gtsmodel.MediaAttachment{}
- if err := c.db.GetByID(a, gtsAttachment); err != nil {
- return nil, fmt.Errorf("error getting attachment with id %s: %s", a, err)
+ for _, aID := range s.AttachmentIDs {
+ gtsAttachment, err := c.db.GetAttachmentByID(ctx, aID)
+ if err != nil {
+ return nil, fmt.Errorf("error getting attachment with id %s: %s", aID, err)
}
- mastoAttachment, err := c.AttachmentToMasto(gtsAttachment)
+ mastoAttachment, err := c.AttachmentToMasto(ctx, gtsAttachment)
if err != nil {
- return nil, fmt.Errorf("error converting attachment with id %s: %s", a, err)
+ return nil, fmt.Errorf("error converting attachment with id %s: %s", aID, err)
}
mastoAttachments = append(mastoAttachments, mastoAttachment)
}
@@ -411,7 +417,7 @@ func (c *converter) StatusToMasto(s *gtsmodel.Status, requestingAccount *gtsmode
// if so, we can directly convert the gts mentions into masto ones
if s.Mentions != nil {
for _, gtsMention := range s.Mentions {
- mastoMention, err := c.MentionToMasto(gtsMention)
+ mastoMention, err := c.MentionToMasto(ctx, gtsMention)
if err != nil {
return nil, fmt.Errorf("error converting mention with id %s: %s", gtsMention.ID, err)
}
@@ -420,12 +426,12 @@ func (c *converter) StatusToMasto(s *gtsmodel.Status, requestingAccount *gtsmode
// the status doesn't have gts mentions on it, but it does have mention IDs
// in this case, we need to pull the gts mentions from the db to convert them into masto ones
} else {
- for _, m := range s.MentionIDs {
- gtsMention := &gtsmodel.Mention{}
- if err := c.db.GetByID(m, gtsMention); err != nil {
- return nil, fmt.Errorf("error getting mention with id %s: %s", m, err)
+ for _, mID := range s.MentionIDs {
+ gtsMention, err := c.db.GetMention(ctx, mID)
+ if err != nil {
+ return nil, fmt.Errorf("error getting mention with id %s: %s", mID, err)
}
- mastoMention, err := c.MentionToMasto(gtsMention)
+ mastoMention, err := c.MentionToMasto(ctx, gtsMention)
if err != nil {
return nil, fmt.Errorf("error converting mention with id %s: %s", gtsMention.ID, err)
}
@@ -438,7 +444,7 @@ func (c *converter) StatusToMasto(s *gtsmodel.Status, requestingAccount *gtsmode
// if so, we can directly convert the gts tags into masto ones
if s.Tags != nil {
for _, gtsTag := range s.Tags {
- mastoTag, err := c.TagToMasto(gtsTag)
+ mastoTag, err := c.TagToMasto(ctx, gtsTag)
if err != nil {
return nil, fmt.Errorf("error converting tag with id %s: %s", gtsTag.ID, err)
}
@@ -449,10 +455,10 @@ func (c *converter) StatusToMasto(s *gtsmodel.Status, requestingAccount *gtsmode
} else {
for _, t := range s.TagIDs {
gtsTag := &gtsmodel.Tag{}
- if err := c.db.GetByID(t, gtsTag); err != nil {
+ if err := c.db.GetByID(ctx, t, gtsTag); err != nil {
return nil, fmt.Errorf("error getting tag with id %s: %s", t, err)
}
- mastoTag, err := c.TagToMasto(gtsTag)
+ mastoTag, err := c.TagToMasto(ctx, gtsTag)
if err != nil {
return nil, fmt.Errorf("error converting tag with id %s: %s", gtsTag.ID, err)
}
@@ -465,7 +471,7 @@ func (c *converter) StatusToMasto(s *gtsmodel.Status, requestingAccount *gtsmode
// if so, we can directly convert the gts emojis into masto ones
if s.Emojis != nil {
for _, gtsEmoji := range s.Emojis {
- mastoEmoji, err := c.EmojiToMasto(gtsEmoji)
+ mastoEmoji, err := c.EmojiToMasto(ctx, gtsEmoji)
if err != nil {
return nil, fmt.Errorf("error converting emoji with id %s: %s", gtsEmoji.ID, err)
}
@@ -476,10 +482,10 @@ func (c *converter) StatusToMasto(s *gtsmodel.Status, requestingAccount *gtsmode
} else {
for _, e := range s.EmojiIDs {
gtsEmoji := &gtsmodel.Emoji{}
- if err := c.db.GetByID(e, gtsEmoji); err != nil {
+ if err := c.db.GetByID(ctx, e, gtsEmoji); err != nil {
return nil, fmt.Errorf("error getting emoji with id %s: %s", e, err)
}
- mastoEmoji, err := c.EmojiToMasto(gtsEmoji)
+ mastoEmoji, err := c.EmojiToMasto(ctx, gtsEmoji)
if err != nil {
return nil, fmt.Errorf("error converting emoji with id %s: %s", gtsEmoji.ID, err)
}
@@ -491,7 +497,7 @@ func (c *converter) StatusToMasto(s *gtsmodel.Status, requestingAccount *gtsmode
var mastoPoll *model.Poll
statusInteractions := &statusInteractions{}
- si, err := c.interactionsWithStatusForAccount(s, requestingAccount)
+ si, err := c.interactionsWithStatusForAccount(ctx, s, requestingAccount)
if err == nil {
statusInteractions = si
}
@@ -503,7 +509,7 @@ func (c *converter) StatusToMasto(s *gtsmodel.Status, requestingAccount *gtsmode
InReplyToAccountID: s.InReplyToAccountID,
Sensitive: s.Sensitive,
SpoilerText: s.ContentWarning,
- Visibility: c.VisToMasto(s.Visibility),
+ Visibility: c.VisToMasto(ctx, s.Visibility),
Language: s.Language,
URI: s.URI,
URL: s.URL,
@@ -535,7 +541,7 @@ func (c *converter) StatusToMasto(s *gtsmodel.Status, requestingAccount *gtsmode
}
// VisToMasto converts a gts visibility into its mastodon equivalent
-func (c *converter) VisToMasto(m gtsmodel.Visibility) model.Visibility {
+func (c *converter) VisToMasto(ctx context.Context, m gtsmodel.Visibility) model.Visibility {
switch m {
case gtsmodel.VisibilityPublic:
return model.VisibilityPublic
@@ -549,7 +555,7 @@ func (c *converter) VisToMasto(m gtsmodel.Visibility) model.Visibility {
return ""
}
-func (c *converter) InstanceToMasto(i *gtsmodel.Instance) (*model.Instance, error) {
+func (c *converter) InstanceToMasto(ctx context.Context, i *gtsmodel.Instance) (*model.Instance, error) {
mi := &model.Instance{
URI: i.URI,
Title: i.Title,
@@ -567,17 +573,17 @@ func (c *converter) InstanceToMasto(i *gtsmodel.Instance) (*model.Instance, erro
statusCountKey := "status_count"
domainCountKey := "domain_count"
- userCount, err := c.db.CountInstanceUsers(c.config.Host)
+ userCount, err := c.db.CountInstanceUsers(ctx, c.config.Host)
if err == nil {
mi.Stats[userCountKey] = userCount
}
- statusCount, err := c.db.CountInstanceStatuses(c.config.Host)
+ statusCount, err := c.db.CountInstanceStatuses(ctx, c.config.Host)
if err == nil {
mi.Stats[statusCountKey] = statusCount
}
- domainCount, err := c.db.CountInstanceDomains(c.config.Host)
+ domainCount, err := c.db.CountInstanceDomains(ctx, c.config.Host)
if err == nil {
mi.Stats[domainCountKey] = domainCount
}
@@ -593,7 +599,7 @@ func (c *converter) InstanceToMasto(i *gtsmodel.Instance) (*model.Instance, erro
}
// get the instance account if it exists and just skip if it doesn't
- ia, err := c.db.GetInstanceAccount("")
+ ia, err := c.db.GetInstanceAccount(ctx, "")
if err == nil {
if ia.HeaderMediaAttachment != nil {
mi.Thumbnail = ia.HeaderMediaAttachment.URL
@@ -602,19 +608,22 @@ func (c *converter) InstanceToMasto(i *gtsmodel.Instance) (*model.Instance, erro
// contact account is optional but let's try to get it
if i.ContactAccountID != "" {
- ia := &gtsmodel.Account{}
- if err := c.db.GetByID(i.ContactAccountID, ia); err == nil {
- ma, err := c.AccountToMastoPublic(ia)
+ if i.ContactAccount == nil {
+ contactAccount, err := c.db.GetAccountByID(ctx, i.ContactAccountID)
if err == nil {
- mi.ContactAccount = ma
+ i.ContactAccount = contactAccount
}
}
+ ma, err := c.AccountToMastoPublic(ctx, i.ContactAccount)
+ if err == nil {
+ mi.ContactAccount = ma
+ }
}
return mi, nil
}
-func (c *converter) RelationshipToMasto(r *gtsmodel.Relationship) (*model.Relationship, error) {
+func (c *converter) RelationshipToMasto(ctx context.Context, r *gtsmodel.Relationship) (*model.Relationship, error) {
return &model.Relationship{
ID: r.ID,
Following: r.Following,
@@ -632,9 +641,9 @@ func (c *converter) RelationshipToMasto(r *gtsmodel.Relationship) (*model.Relati
}, nil
}
-func (c *converter) NotificationToMasto(n *gtsmodel.Notification) (*model.Notification, error) {
+func (c *converter) NotificationToMasto(ctx context.Context, n *gtsmodel.Notification) (*model.Notification, error) {
if n.TargetAccount == nil {
- tAccount, err := c.db.GetAccountByID(n.TargetAccountID)
+ tAccount, err := c.db.GetAccountByID(ctx, n.TargetAccountID)
if err != nil {
return nil, fmt.Errorf("NotificationToMasto: error getting target account with id %s from the db: %s", n.TargetAccountID, err)
}
@@ -642,14 +651,14 @@ func (c *converter) NotificationToMasto(n *gtsmodel.Notification) (*model.Notifi
}
if n.OriginAccount == nil {
- ogAccount, err := c.db.GetAccountByID(n.OriginAccountID)
+ ogAccount, err := c.db.GetAccountByID(ctx, n.OriginAccountID)
if err != nil {
return nil, fmt.Errorf("NotificationToMasto: error getting origin account with id %s from the db: %s", n.OriginAccountID, err)
}
n.OriginAccount = ogAccount
}
- mastoAccount, err := c.AccountToMastoPublic(n.OriginAccount)
+ mastoAccount, err := c.AccountToMastoPublic(ctx, n.OriginAccount)
if err != nil {
return nil, fmt.Errorf("NotificationToMasto: error converting account to masto: %s", err)
}
@@ -657,7 +666,7 @@ func (c *converter) NotificationToMasto(n *gtsmodel.Notification) (*model.Notifi
var mastoStatus *model.Status
if n.StatusID != "" {
if n.Status == nil {
- status, err := c.db.GetStatusByID(n.StatusID)
+ status, err := c.db.GetStatusByID(ctx, n.StatusID)
if err != nil {
return nil, fmt.Errorf("NotificationToMasto: error getting status with id %s from the db: %s", n.StatusID, err)
}
@@ -673,7 +682,7 @@ func (c *converter) NotificationToMasto(n *gtsmodel.Notification) (*model.Notifi
}
var err error
- mastoStatus, err = c.StatusToMasto(n.Status, nil)
+ mastoStatus, err = c.StatusToMasto(ctx, n.Status, nil)
if err != nil {
return nil, fmt.Errorf("NotificationToMasto: error converting status to masto: %s", err)
}
@@ -688,7 +697,7 @@ func (c *converter) NotificationToMasto(n *gtsmodel.Notification) (*model.Notifi
}, nil
}
-func (c *converter) DomainBlockToMasto(b *gtsmodel.DomainBlock, export bool) (*model.DomainBlock, error) {
+func (c *converter) DomainBlockToMasto(ctx context.Context, b *gtsmodel.DomainBlock, export bool) (*model.DomainBlock, error) {
domainBlock := &model.DomainBlock{
Domain: b.Domain,
diff --git a/internal/typeutils/util.go b/internal/typeutils/util.go
index 5751fbc84..1d1903afc 100644
--- a/internal/typeutils/util.go
+++ b/internal/typeutils/util.go
@@ -1,34 +1,35 @@
package typeutils
import (
+ "context"
"fmt"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
-func (c *converter) interactionsWithStatusForAccount(s *gtsmodel.Status, requestingAccount *gtsmodel.Account) (*statusInteractions, error) {
+func (c *converter) interactionsWithStatusForAccount(ctx context.Context, s *gtsmodel.Status, requestingAccount *gtsmodel.Account) (*statusInteractions, error) {
si := &statusInteractions{}
if requestingAccount != nil {
- faved, err := c.db.IsStatusFavedBy(s, requestingAccount.ID)
+ faved, err := c.db.IsStatusFavedBy(ctx, s, requestingAccount.ID)
if err != nil {
return nil, fmt.Errorf("error checking if requesting account has faved status: %s", err)
}
si.Faved = faved
- reblogged, err := c.db.IsStatusRebloggedBy(s, requestingAccount.ID)
+ reblogged, err := c.db.IsStatusRebloggedBy(ctx, s, requestingAccount.ID)
if err != nil {
return nil, fmt.Errorf("error checking if requesting account has reblogged status: %s", err)
}
si.Reblogged = reblogged
- muted, err := c.db.IsStatusMutedBy(s, requestingAccount.ID)
+ muted, err := c.db.IsStatusMutedBy(ctx, s, requestingAccount.ID)
if err != nil {
return nil, fmt.Errorf("error checking if requesting account has muted status: %s", err)
}
si.Muted = muted
- bookmarked, err := c.db.IsStatusBookmarkedBy(s, requestingAccount.ID)
+ bookmarked, err := c.db.IsStatusBookmarkedBy(ctx, s, requestingAccount.ID)
if err != nil {
return nil, fmt.Errorf("error checking if requesting account has bookmarked status: %s", err)
}
diff --git a/internal/visibility/filter.go b/internal/visibility/filter.go
index 2c43fa4ee..644e85b35 100644
--- a/internal/visibility/filter.go
+++ b/internal/visibility/filter.go
@@ -19,6 +19,8 @@
package visibility
import (
+ "context"
+
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
@@ -29,17 +31,17 @@ type Filter interface {
// StatusVisible returns true if targetStatus is visible to requestingAccount, based on the
// privacy settings of the status, and any blocks/mutes that might exist between the two accounts
// or account domains, and other relevant accounts mentioned in or replied to by the status.
- StatusVisible(targetStatus *gtsmodel.Status, requestingAccount *gtsmodel.Account) (bool, error)
+ StatusVisible(ctx context.Context, targetStatus *gtsmodel.Status, requestingAccount *gtsmodel.Account) (bool, error)
// StatusHometimelineable returns true if targetStatus should be in the home timeline of the requesting account.
//
// This function will call StatusVisible internally, so it's not necessary to call it beforehand.
- StatusHometimelineable(targetStatus *gtsmodel.Status, requestingAccount *gtsmodel.Account) (bool, error)
+ StatusHometimelineable(ctx context.Context, targetStatus *gtsmodel.Status, requestingAccount *gtsmodel.Account) (bool, error)
// StatusPublictimelineable returns true if targetStatus should be in the public timeline of the requesting account.
//
// This function will call StatusVisible internally, so it's not necessary to call it beforehand.
- StatusPublictimelineable(targetStatus *gtsmodel.Status, timelineOwnerAccount *gtsmodel.Account) (bool, error)
+ StatusPublictimelineable(ctx context.Context, targetStatus *gtsmodel.Status, timelineOwnerAccount *gtsmodel.Account) (bool, error)
}
type filter struct {
diff --git a/internal/visibility/relevantaccounts.go b/internal/visibility/relevantaccounts.go
index 5957d3111..d19d26ff4 100644
--- a/internal/visibility/relevantaccounts.go
+++ b/internal/visibility/relevantaccounts.go
@@ -19,6 +19,7 @@
package visibility
import (
+ "context"
"errors"
"fmt"
@@ -41,7 +42,7 @@ type relevantAccounts struct {
BoostedMentionedAccounts []*gtsmodel.Account
}
-func (f *filter) relevantAccounts(status *gtsmodel.Status, getBoosted bool) (*relevantAccounts, error) {
+func (f *filter) relevantAccounts(ctx context.Context, status *gtsmodel.Status, getBoosted bool) (*relevantAccounts, error) {
relAccts := &relevantAccounts{
MentionedAccounts: []*gtsmodel.Account{},
BoostedMentionedAccounts: []*gtsmodel.Account{},
@@ -77,7 +78,7 @@ func (f *filter) relevantAccounts(status *gtsmodel.Status, getBoosted bool) (*re
relAccts.Account = status.Account
} else {
// it wasn't set, so get it from the db
- account, err := f.db.GetAccountByID(status.AccountID)
+ account, err := f.db.GetAccountByID(ctx, status.AccountID)
if err != nil {
return nil, fmt.Errorf("relevantAccounts: error getting account with id %s: %s", status.AccountID, err)
}
@@ -96,7 +97,7 @@ func (f *filter) relevantAccounts(status *gtsmodel.Status, getBoosted bool) (*re
relAccts.InReplyToAccount = status.InReplyToAccount
} else {
// it wasn't set, so get it from the db
- inReplyToAccount, err := f.db.GetAccountByID(status.InReplyToAccountID)
+ inReplyToAccount, err := f.db.GetAccountByID(ctx, status.InReplyToAccountID)
if err != nil {
return nil, fmt.Errorf("relevantAccounts: error getting inReplyToAccount with id %s: %s", status.InReplyToAccountID, err)
}
@@ -115,7 +116,7 @@ func (f *filter) relevantAccounts(status *gtsmodel.Status, getBoosted bool) (*re
}
if !idIn(mID, status.Mentions) {
// mention with ID isn't in status.Mentions
- mention, err := f.db.GetMention(mID)
+ mention, err := f.db.GetMention(ctx, mID)
if err != nil {
return nil, fmt.Errorf("relevantAccounts: error getting mention with id %s: %s", mID, err)
}
@@ -146,7 +147,7 @@ func (f *filter) relevantAccounts(status *gtsmodel.Status, getBoosted bool) (*re
// 4, 5, 6. Boosted status items
// get the boosted status if it's not set on the status already
if status.BoostOfID != "" && status.BoostOf == nil {
- boostedStatus, err := f.db.GetStatusByID(status.BoostOfID)
+ boostedStatus, err := f.db.GetStatusByID(ctx, status.BoostOfID)
if err != nil {
return nil, fmt.Errorf("relevantAccounts: error getting boosted status with id %s: %s", status.BoostOfID, err)
}
@@ -155,7 +156,7 @@ func (f *filter) relevantAccounts(status *gtsmodel.Status, getBoosted bool) (*re
if status.BoostOf != nil {
// return relevant accounts for the boosted status
- boostedRelAccts, err := f.relevantAccounts(status.BoostOf, false) // false because we don't want to recurse
+ boostedRelAccts, err := f.relevantAccounts(ctx, status.BoostOf, false) // false because we don't want to recurse
if err != nil {
return nil, fmt.Errorf("relevantAccounts: error getting relevant accounts of boosted status %s: %s", status.BoostOf.ID, err)
}
@@ -170,7 +171,7 @@ func (f *filter) relevantAccounts(status *gtsmodel.Status, getBoosted bool) (*re
// domainBlockedRelevant checks through all relevant accounts attached to a status
// to make sure none of them are domain blocked by this instance.
-func (f *filter) domainBlockedRelevant(r *relevantAccounts) (bool, error) {
+func (f *filter) domainBlockedRelevant(ctx context.Context, r *relevantAccounts) (bool, error) {
domains := []string{}
if r.Account != nil {
@@ -201,7 +202,7 @@ func (f *filter) domainBlockedRelevant(r *relevantAccounts) (bool, error) {
}
}
- return f.db.AreDomainsBlocked(domains)
+ return f.db.AreDomainsBlocked(ctx, domains)
}
func idIn(id string, mentions []*gtsmodel.Mention) bool {
diff --git a/internal/visibility/statushometimelineable.go b/internal/visibility/statushometimelineable.go
index a3ca62fb3..dd0ca079b 100644
--- a/internal/visibility/statushometimelineable.go
+++ b/internal/visibility/statushometimelineable.go
@@ -19,13 +19,14 @@
package visibility
import (
+ "context"
"fmt"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
-func (f *filter) StatusHometimelineable(targetStatus *gtsmodel.Status, timelineOwnerAccount *gtsmodel.Account) (bool, error) {
+func (f *filter) StatusHometimelineable(ctx context.Context, targetStatus *gtsmodel.Status, timelineOwnerAccount *gtsmodel.Account) (bool, error) {
l := f.log.WithFields(logrus.Fields{
"func": "StatusHometimelineable",
"statusID": targetStatus.ID,
@@ -36,7 +37,7 @@ func (f *filter) StatusHometimelineable(targetStatus *gtsmodel.Status, timelineO
return true, nil
}
- v, err := f.StatusVisible(targetStatus, timelineOwnerAccount)
+ v, err := f.StatusVisible(ctx, targetStatus, timelineOwnerAccount)
if err != nil {
return false, fmt.Errorf("StatusHometimelineable: error checking visibility of status with id %s: %s", targetStatus.ID, err)
}
@@ -63,7 +64,7 @@ func (f *filter) StatusHometimelineable(targetStatus *gtsmodel.Status, timelineO
if targetStatus.InReplyToID != "" {
// pin the reply to status on to this status if it hasn't been done already
if targetStatus.InReplyTo == nil {
- rs, err := f.db.GetStatusByID(targetStatus.InReplyToID)
+ rs, err := f.db.GetStatusByID(ctx, targetStatus.InReplyToID)
if err != nil {
return false, fmt.Errorf("StatusHometimelineable: error getting replied to status with id %s: %s", targetStatus.InReplyToID, err)
}
@@ -72,7 +73,7 @@ func (f *filter) StatusHometimelineable(targetStatus *gtsmodel.Status, timelineO
// pin the reply to account on to this status if it hasn't been done already
if targetStatus.InReplyToAccount == nil {
- ra, err := f.db.GetAccountByID(targetStatus.InReplyToAccountID)
+ ra, err := f.db.GetAccountByID(ctx, targetStatus.InReplyToAccountID)
if err != nil {
return false, fmt.Errorf("StatusHometimelineable: error getting replied to account with id %s: %s", targetStatus.InReplyToAccountID, err)
}
@@ -85,7 +86,7 @@ func (f *filter) StatusHometimelineable(targetStatus *gtsmodel.Status, timelineO
}
// the replied-to account != timelineOwnerAccount, so make sure the timelineOwnerAccount follows the replied-to account
- follows, err := f.db.IsFollowing(timelineOwnerAccount, targetStatus.InReplyToAccount)
+ follows, err := f.db.IsFollowing(ctx, timelineOwnerAccount, targetStatus.InReplyToAccount)
if err != nil {
return false, fmt.Errorf("StatusHometimelineable: error checking follow from account %s to account %s: %s", timelineOwnerAccount.ID, targetStatus.InReplyToAccountID, err)
}
diff --git a/internal/visibility/statuspublictimelineable.go b/internal/visibility/statuspublictimelineable.go
index f07e06aae..8d0a7aa28 100644
--- a/internal/visibility/statuspublictimelineable.go
+++ b/internal/visibility/statuspublictimelineable.go
@@ -19,13 +19,14 @@
package visibility
import (
+ "context"
"fmt"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
-func (f *filter) StatusPublictimelineable(targetStatus *gtsmodel.Status, timelineOwnerAccount *gtsmodel.Account) (bool, error) {
+func (f *filter) StatusPublictimelineable(ctx context.Context, targetStatus *gtsmodel.Status, timelineOwnerAccount *gtsmodel.Account) (bool, error) {
l := f.log.WithFields(logrus.Fields{
"func": "StatusPublictimelineable",
"statusID": targetStatus.ID,
@@ -41,7 +42,7 @@ func (f *filter) StatusPublictimelineable(targetStatus *gtsmodel.Status, timelin
return true, nil
}
- v, err := f.StatusVisible(targetStatus, timelineOwnerAccount)
+ v, err := f.StatusVisible(ctx, targetStatus, timelineOwnerAccount)
if err != nil {
return false, fmt.Errorf("StatusPublictimelineable: error checking visibility of status with id %s: %s", targetStatus.ID, err)
}
diff --git a/internal/visibility/statusvisible.go b/internal/visibility/statusvisible.go
index 15e545881..5b6fe0c1e 100644
--- a/internal/visibility/statusvisible.go
+++ b/internal/visibility/statusvisible.go
@@ -19,6 +19,7 @@
package visibility
import (
+ "context"
"errors"
"fmt"
@@ -28,20 +29,20 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
-func (f *filter) StatusVisible(targetStatus *gtsmodel.Status, requestingAccount *gtsmodel.Account) (bool, error) {
+func (f *filter) StatusVisible(ctx context.Context, targetStatus *gtsmodel.Status, requestingAccount *gtsmodel.Account) (bool, error) {
l := f.log.WithFields(logrus.Fields{
"func": "StatusVisible",
"statusID": targetStatus.ID,
})
getBoosted := true
- relevantAccounts, err := f.relevantAccounts(targetStatus, getBoosted)
+ relevantAccounts, err := f.relevantAccounts(ctx, targetStatus, getBoosted)
if err != nil {
l.Debugf("error pulling relevant accounts for status %s: %s", targetStatus.ID, err)
return false, fmt.Errorf("StatusVisible: error pulling relevant accounts for status %s: %s", targetStatus.ID, err)
}
- domainBlocked, err := f.domainBlockedRelevant(relevantAccounts)
+ domainBlocked, err := f.domainBlockedRelevant(ctx, relevantAccounts)
if err != nil {
l.Debugf("error checking domain block: %s", err)
return false, fmt.Errorf("error checking domain block: %s", err)
@@ -67,7 +68,7 @@ func (f *filter) StatusVisible(targetStatus *gtsmodel.Status, requestingAccount
// note: we only do this for local users
if targetAccount.Domain == "" {
targetUser := &gtsmodel.User{}
- if err := f.db.GetWhere([]db.Where{{Key: "account_id", Value: targetAccount.ID}}, targetUser); err != nil {
+ if err := f.db.GetWhere(ctx, []db.Where{{Key: "account_id", Value: targetAccount.ID}}, targetUser); err != nil {
l.Debug("target user could not be selected")
if err == db.ErrNoEntries {
return false, nil
@@ -97,7 +98,7 @@ func (f *filter) StatusVisible(targetStatus *gtsmodel.Status, requestingAccount
// note: we only do this for local users
if requestingAccount.Domain == "" {
requestingUser := &gtsmodel.User{}
- if err := f.db.GetWhere([]db.Where{{Key: "account_id", Value: requestingAccount.ID}}, requestingUser); err != nil {
+ if err := f.db.GetWhere(ctx, []db.Where{{Key: "account_id", Value: requestingAccount.ID}}, requestingUser); err != nil {
// if the requesting account is local but doesn't have a corresponding user in the db this is a problem
l.Debug("requesting user could not be selected")
if err == db.ErrNoEntries {
@@ -126,7 +127,7 @@ func (f *filter) StatusVisible(targetStatus *gtsmodel.Status, requestingAccount
// At this point we have a populated targetAccount, targetStatus, and requestingAccount, so we can check for blocks and whathaveyou
// First check if a block exists directly between the target account (which authored the status) and the requesting account.
- if blocked, err := f.db.IsBlocked(targetAccount.ID, requestingAccount.ID, true); err != nil {
+ if blocked, err := f.db.IsBlocked(ctx, targetAccount.ID, requestingAccount.ID, true); err != nil {
l.Debugf("something went wrong figuring out if the accounts have a block: %s", err)
return false, err
} else if blocked {
@@ -137,7 +138,7 @@ func (f *filter) StatusVisible(targetStatus *gtsmodel.Status, requestingAccount
// status replies to account id
if relevantAccounts.InReplyToAccount != nil && relevantAccounts.InReplyToAccount.ID != requestingAccount.ID {
- if blocked, err := f.db.IsBlocked(relevantAccounts.InReplyToAccount.ID, requestingAccount.ID, true); err != nil {
+ if blocked, err := f.db.IsBlocked(ctx, relevantAccounts.InReplyToAccount.ID, requestingAccount.ID, true); err != nil {
return false, err
} else if blocked {
l.Trace("a block exists between requesting account and reply to account")
@@ -146,7 +147,7 @@ func (f *filter) StatusVisible(targetStatus *gtsmodel.Status, requestingAccount
// check reply to ID
if targetStatus.InReplyToID != "" && (targetStatus.Visibility == gtsmodel.VisibilityFollowersOnly || targetStatus.Visibility == gtsmodel.VisibilityDirect) {
- followsRepliedAccount, err := f.db.IsFollowing(requestingAccount, relevantAccounts.InReplyToAccount)
+ followsRepliedAccount, err := f.db.IsFollowing(ctx, requestingAccount, relevantAccounts.InReplyToAccount)
if err != nil {
return false, err
}
@@ -159,7 +160,7 @@ func (f *filter) StatusVisible(targetStatus *gtsmodel.Status, requestingAccount
// status boosts accounts id
if relevantAccounts.BoostedAccount != nil {
- if blocked, err := f.db.IsBlocked(relevantAccounts.BoostedAccount.ID, requestingAccount.ID, true); err != nil {
+ if blocked, err := f.db.IsBlocked(ctx, relevantAccounts.BoostedAccount.ID, requestingAccount.ID, true); err != nil {
return false, err
} else if blocked {
l.Trace("a block exists between requesting account and boosted account")
@@ -169,7 +170,7 @@ func (f *filter) StatusVisible(targetStatus *gtsmodel.Status, requestingAccount
// status boosts a reply to account id
if relevantAccounts.BoostedInReplyToAccount != nil {
- if blocked, err := f.db.IsBlocked(relevantAccounts.BoostedInReplyToAccount.ID, requestingAccount.ID, true); err != nil {
+ if blocked, err := f.db.IsBlocked(ctx, relevantAccounts.BoostedInReplyToAccount.ID, requestingAccount.ID, true); err != nil {
return false, err
} else if blocked {
l.Trace("a block exists between requesting account and boosted reply to account")
@@ -182,7 +183,7 @@ func (f *filter) StatusVisible(targetStatus *gtsmodel.Status, requestingAccount
if a == nil {
continue
}
- if blocked, err := f.db.IsBlocked(a.ID, requestingAccount.ID, true); err != nil {
+ if blocked, err := f.db.IsBlocked(ctx, a.ID, requestingAccount.ID, true); err != nil {
return false, err
} else if blocked {
l.Trace("a block exists between requesting account and a mentioned account")
@@ -195,7 +196,7 @@ func (f *filter) StatusVisible(targetStatus *gtsmodel.Status, requestingAccount
if a == nil {
continue
}
- if blocked, err := f.db.IsBlocked(a.ID, requestingAccount.ID, true); err != nil {
+ if blocked, err := f.db.IsBlocked(ctx, a.ID, requestingAccount.ID, true); err != nil {
return false, err
} else if blocked {
l.Trace("a block exists between requesting account and a boosted mentioned account")
@@ -221,7 +222,7 @@ func (f *filter) StatusVisible(targetStatus *gtsmodel.Status, requestingAccount
return true, nil
case gtsmodel.VisibilityFollowersOnly:
// check one-way follow
- follows, err := f.db.IsFollowing(requestingAccount, targetAccount)
+ follows, err := f.db.IsFollowing(ctx, requestingAccount, targetAccount)
if err != nil {
return false, err
}
@@ -232,7 +233,7 @@ func (f *filter) StatusVisible(targetStatus *gtsmodel.Status, requestingAccount
return true, nil
case gtsmodel.VisibilityMutualsOnly:
// check mutual follow
- mutuals, err := f.db.IsMutualFollowing(requestingAccount, targetAccount)
+ mutuals, err := f.db.IsMutualFollowing(ctx, requestingAccount, targetAccount)
if err != nil {
return false, err
}
diff --git a/internal/web/base.go b/internal/web/base.go
index c0b85b613..eabde676c 100644
--- a/internal/web/base.go
+++ b/internal/web/base.go
@@ -50,7 +50,7 @@ func (m *Module) baseHandler(c *gin.Context) {
l := m.log.WithField("func", "BaseGETHandler")
l.Trace("serving index html")
- instance, err := m.processor.InstanceGet(m.config.Host)
+ instance, err := m.processor.InstanceGet(c.Request.Context(), m.config.Host)
if err != nil {
l.Debugf("error getting instance from processor: %s", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "internal server error"})
@@ -71,7 +71,7 @@ func (m *Module) NotFoundHandler(c *gin.Context) {
l := m.log.WithField("func", "404")
l.Trace("serving 404 html")
- instance, err := m.processor.InstanceGet(m.config.Host)
+ instance, err := m.processor.InstanceGet(c.Request.Context(), m.config.Host)
if err != nil {
l.Debugf("error getting instance from processor: %s", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "internal server error"})