summaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
authorLibravatar kim <89579420+NyaaaWhatsUpDoc@users.noreply.github.com>2022-05-15 10:16:43 +0100
committerLibravatar GitHub <noreply@github.com>2022-05-15 11:16:43 +0200
commit223025fc27ef636206027b360201877848d426a4 (patch)
treed2f5f293caabdd82fbb87fed3730eb8f6f2e1c1f /internal
parent[chore] Update LE server to use copy of main http.Server{} to maintain server... (diff)
downloadgotosocial-223025fc27ef636206027b360201877848d426a4.tar.xz
[security] transport.Controller{} and transport.Transport{} security and performance improvements (#564)
* cache transports in controller by privkey-generated pubkey, add retry logic to transport requests Signed-off-by: kim <grufwub@gmail.com> * update code comments, defer mutex unlocks Signed-off-by: kim <grufwub@gmail.com> * add count to 'performing request' log message Signed-off-by: kim <grufwub@gmail.com> * reduce repeated conversions of same url.URL object Signed-off-by: kim <grufwub@gmail.com> * move worker.Worker to concurrency subpackage, add WorkQueue type, limit transport http client use by WorkQueue Signed-off-by: kim <grufwub@gmail.com> * fix security advisories regarding max outgoing conns, max rsp body size - implemented by a new httpclient.Client{} that wraps an underlying client with a queue to limit connections, and limit reader wrapping a response body with a configured maximum size - update pub.HttpClient args passed around to be this new httpclient.Client{} Signed-off-by: kim <grufwub@gmail.com> * add httpclient tests, move ip validation to separate package + change mechanism Signed-off-by: kim <grufwub@gmail.com> * fix merge conflicts Signed-off-by: kim <grufwub@gmail.com> * use singular mutex in transport rather than separate signer mus Signed-off-by: kim <grufwub@gmail.com> * improved useragent string Signed-off-by: kim <grufwub@gmail.com> * add note regarding missing test Signed-off-by: kim <grufwub@gmail.com> * remove useragent field from transport (instead store in controller) Signed-off-by: kim <grufwub@gmail.com> * shutup linter Signed-off-by: kim <grufwub@gmail.com> * reset other signing headers on each loop iteration Signed-off-by: kim <grufwub@gmail.com> * respect request ctx during retry-backoff sleep period Signed-off-by: kim <grufwub@gmail.com> * use external pkg with docs explaining performance "hack" Signed-off-by: kim <grufwub@gmail.com> * use http package constants instead of string method literals Signed-off-by: kim <grufwub@gmail.com> * add license file headers Signed-off-by: kim <grufwub@gmail.com> * update code comment to match new func names Signed-off-by: kim <grufwub@gmail.com> * updates to user-agent string Signed-off-by: kim <grufwub@gmail.com> * update signed testrig models to fit with new transport logic (instead uses separate signer now) Signed-off-by: kim <grufwub@gmail.com> * fuck you linter Signed-off-by: kim <grufwub@gmail.com>
Diffstat (limited to 'internal')
-rw-r--r--internal/api/client/account/account_test.go6
-rw-r--r--internal/api/client/admin/admin_test.go6
-rw-r--r--internal/api/client/fileserver/servefile_test.go6
-rw-r--r--internal/api/client/followrequest/followrequest_test.go6
-rw-r--r--internal/api/client/media/mediacreate_test.go6
-rw-r--r--internal/api/client/media/mediaupdate_test.go6
-rw-r--r--internal/api/client/status/status_test.go6
-rw-r--r--internal/api/client/user/user_test.go6
-rw-r--r--internal/api/s2s/user/inboxpost_test.go18
-rw-r--r--internal/api/s2s/user/outboxget_test.go14
-rw-r--r--internal/api/s2s/user/repliesget_test.go14
-rw-r--r--internal/api/s2s/user/statusget_test.go10
-rw-r--r--internal/api/s2s/user/user_test.go6
-rw-r--r--internal/api/s2s/user/userget_test.go10
-rw-r--r--internal/api/s2s/webfinger/webfinger_test.go6
-rw-r--r--internal/api/s2s/webfinger/webfingerget_test.go10
-rw-r--r--internal/concurrency/workers.go (renamed from internal/worker/workers.go)20
-rw-r--r--internal/federation/dereferencing/dereferencer_test.go4
-rw-r--r--internal/federation/federatingactor_test.go6
-rw-r--r--internal/federation/federatingdb/db.go6
-rw-r--r--internal/federation/federatingdb/federatingdb_test.go6
-rw-r--r--internal/federation/federatingprotocol_test.go6
-rw-r--r--internal/httpclient/client.go199
-rw-r--r--internal/httpclient/client_test.go154
-rw-r--r--internal/httpclient/sanitizer.go64
-rw-r--r--internal/media/manager.go12
-rw-r--r--internal/netutil/validate.go78
-rw-r--r--internal/processing/account/account.go6
-rw-r--r--internal/processing/account/account_test.go6
-rw-r--r--internal/processing/admin/admin.go6
-rw-r--r--internal/processing/media/media_test.go4
-rw-r--r--internal/processing/processor.go10
-rw-r--r--internal/processing/processor_test.go6
-rw-r--r--internal/processing/status/status.go6
-rw-r--r--internal/processing/status/status_test.go8
-rw-r--r--internal/transport/controller.go196
-rw-r--r--internal/transport/deliver.go29
-rw-r--r--internal/transport/dereference.go39
-rw-r--r--internal/transport/derefinstance.go85
-rw-r--r--internal/transport/derefmedia.go33
-rw-r--r--internal/transport/finger.go50
-rw-r--r--internal/transport/signing.go43
-rw-r--r--internal/transport/transport.go163
43 files changed, 1041 insertions, 340 deletions
diff --git a/internal/api/client/account/account_test.go b/internal/api/client/account/account_test.go
index d65b49550..d6bb5a5c0 100644
--- a/internal/api/client/account/account_test.go
+++ b/internal/api/client/account/account_test.go
@@ -11,6 +11,7 @@ import (
"github.com/spf13/viper"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/client/account"
+ "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
@@ -20,7 +21,6 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/processing"
- "github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@@ -62,8 +62,8 @@ func (suite *AccountStandardTestSuite) SetupTest() {
testrig.InitTestConfig()
testrig.InitTestLog()
- fedWorker := worker.New[messages.FromFederator](-1, -1)
- clientWorker := worker.New[messages.FromClientAPI](-1, -1)
+ fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
+ clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
suite.db = testrig.NewTestDB()
suite.storage = testrig.NewTestStorage()
diff --git a/internal/api/client/admin/admin_test.go b/internal/api/client/admin/admin_test.go
index 578ab167c..11e2f8354 100644
--- a/internal/api/client/admin/admin_test.go
+++ b/internal/api/client/admin/admin_test.go
@@ -29,6 +29,7 @@ import (
"github.com/spf13/viper"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/client/admin"
+ "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
@@ -38,7 +39,6 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/processing"
- "github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@@ -80,8 +80,8 @@ func (suite *AdminStandardTestSuite) SetupTest() {
testrig.InitTestConfig()
testrig.InitTestLog()
- fedWorker := worker.New[messages.FromFederator](-1, -1)
- clientWorker := worker.New[messages.FromClientAPI](-1, -1)
+ fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
+ clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
suite.db = testrig.NewTestDB()
suite.storage = testrig.NewTestStorage()
diff --git a/internal/api/client/fileserver/servefile_test.go b/internal/api/client/fileserver/servefile_test.go
index 49d813981..d7de2f4f9 100644
--- a/internal/api/client/fileserver/servefile_test.go
+++ b/internal/api/client/fileserver/servefile_test.go
@@ -31,6 +31,7 @@ import (
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/client/fileserver"
+ "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
"github.com/superseriousbusiness/gotosocial/internal/federation"
@@ -40,7 +41,6 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
- "github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@@ -77,8 +77,8 @@ func (suite *ServeFileTestSuite) SetupSuite() {
testrig.InitTestConfig()
testrig.InitTestLog()
- fedWorker := worker.New[messages.FromFederator](-1, -1)
- clientWorker := worker.New[messages.FromClientAPI](-1, -1)
+ fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
+ clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
suite.db = testrig.NewTestDB()
suite.storage = testrig.NewTestStorage()
diff --git a/internal/api/client/followrequest/followrequest_test.go b/internal/api/client/followrequest/followrequest_test.go
index 072025931..14b5656b6 100644
--- a/internal/api/client/followrequest/followrequest_test.go
+++ b/internal/api/client/followrequest/followrequest_test.go
@@ -28,6 +28,7 @@ import (
"github.com/spf13/viper"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/client/followrequest"
+ "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
@@ -37,7 +38,6 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/processing"
- "github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@@ -77,8 +77,8 @@ func (suite *FollowRequestStandardTestSuite) SetupTest() {
testrig.InitTestConfig()
testrig.InitTestLog()
- fedWorker := worker.New[messages.FromFederator](-1, -1)
- clientWorker := worker.New[messages.FromClientAPI](-1, -1)
+ fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
+ clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
suite.db = testrig.NewTestDB()
suite.storage = testrig.NewTestStorage()
diff --git a/internal/api/client/media/mediacreate_test.go b/internal/api/client/media/mediacreate_test.go
index 4d08697ef..e16b9f5eb 100644
--- a/internal/api/client/media/mediacreate_test.go
+++ b/internal/api/client/media/mediacreate_test.go
@@ -37,6 +37,7 @@ import (
"github.com/stretchr/testify/suite"
mediamodule "github.com/superseriousbusiness/gotosocial/internal/api/client/media"
"github.com/superseriousbusiness/gotosocial/internal/api/model"
+ "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
@@ -47,7 +48,6 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
- "github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@@ -84,8 +84,8 @@ func (suite *MediaCreateTestSuite) SetupSuite() {
testrig.InitTestConfig()
testrig.InitTestLog()
- fedWorker := worker.New[messages.FromFederator](-1, -1)
- clientWorker := worker.New[messages.FromClientAPI](-1, -1)
+ fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
+ clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
suite.db = testrig.NewTestDB()
suite.storage = testrig.NewTestStorage()
diff --git a/internal/api/client/media/mediaupdate_test.go b/internal/api/client/media/mediaupdate_test.go
index b87e6ec8d..a87718438 100644
--- a/internal/api/client/media/mediaupdate_test.go
+++ b/internal/api/client/media/mediaupdate_test.go
@@ -35,6 +35,7 @@ import (
"github.com/stretchr/testify/suite"
mediamodule "github.com/superseriousbusiness/gotosocial/internal/api/client/media"
"github.com/superseriousbusiness/gotosocial/internal/api/model"
+ "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
@@ -45,7 +46,6 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
- "github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@@ -82,8 +82,8 @@ func (suite *MediaUpdateTestSuite) SetupSuite() {
testrig.InitTestConfig()
testrig.InitTestLog()
- fedWorker := worker.New[messages.FromFederator](-1, -1)
- clientWorker := worker.New[messages.FromClientAPI](-1, -1)
+ fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
+ clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
suite.db = testrig.NewTestDB()
suite.storage = testrig.NewTestStorage()
diff --git a/internal/api/client/status/status_test.go b/internal/api/client/status/status_test.go
index a4a56aa0b..e2e2819b5 100644
--- a/internal/api/client/status/status_test.go
+++ b/internal/api/client/status/status_test.go
@@ -32,6 +32,7 @@ import (
"github.com/superseriousbusiness/activity/pub"
"github.com/superseriousbusiness/activity/streams"
"github.com/superseriousbusiness/gotosocial/internal/api/client/status"
+ "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
"github.com/superseriousbusiness/gotosocial/internal/federation"
@@ -40,7 +41,6 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
- "github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@@ -90,8 +90,8 @@ func (suite *StatusStandardTestSuite) SetupTest() {
testrig.StandardDBSetup(suite.db, nil)
testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media")
- fedWorker := worker.New[messages.FromFederator](-1, -1)
- clientWorker := worker.New[messages.FromClientAPI](-1, -1)
+ fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
+ clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage)
suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(suite.testHttpClient(), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker)
diff --git a/internal/api/client/user/user_test.go b/internal/api/client/user/user_test.go
index b0fd2b2e9..6e9c46525 100644
--- a/internal/api/client/user/user_test.go
+++ b/internal/api/client/user/user_test.go
@@ -22,6 +22,7 @@ import (
"codeberg.org/gruf/go-store/kv"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/client/user"
+ "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
"github.com/superseriousbusiness/gotosocial/internal/federation"
@@ -30,7 +31,6 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
- "github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@@ -58,8 +58,8 @@ type UserStandardTestSuite struct {
func (suite *UserStandardTestSuite) SetupTest() {
testrig.InitTestLog()
testrig.InitTestConfig()
- fedWorker := worker.New[messages.FromFederator](-1, -1)
- clientWorker := worker.New[messages.FromClientAPI](-1, -1)
+ fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
+ clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
suite.testTokens = testrig.NewTestTokens()
suite.testClients = testrig.NewTestClients()
suite.testApplications = testrig.NewTestApplications()
diff --git a/internal/api/s2s/user/inboxpost_test.go b/internal/api/s2s/user/inboxpost_test.go
index 6f2909430..388a9fbbb 100644
--- a/internal/api/s2s/user/inboxpost_test.go
+++ b/internal/api/s2s/user/inboxpost_test.go
@@ -33,11 +33,11 @@ import (
"github.com/superseriousbusiness/activity/pub"
"github.com/superseriousbusiness/activity/streams"
"github.com/superseriousbusiness/gotosocial/internal/api/s2s/user"
+ "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/superseriousbusiness/gotosocial/internal/messages"
- "github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@@ -85,8 +85,8 @@ func (suite *InboxPostTestSuite) TestPostBlock() {
suite.NoError(err)
body := bytes.NewReader(bodyJson)
- clientWorker := worker.New[messages.FromClientAPI](-1, -1)
- fedWorker := worker.New[messages.FromFederator](-1, -1)
+ clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
+ fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
@@ -188,8 +188,8 @@ func (suite *InboxPostTestSuite) TestPostUnblock() {
suite.NoError(err)
body := bytes.NewReader(bodyJson)
- clientWorker := worker.New[messages.FromClientAPI](-1, -1)
- fedWorker := worker.New[messages.FromFederator](-1, -1)
+ clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
+ fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
@@ -281,8 +281,8 @@ func (suite *InboxPostTestSuite) TestPostUpdate() {
suite.NoError(err)
body := bytes.NewReader(bodyJson)
- clientWorker := worker.New[messages.FromClientAPI](-1, -1)
- fedWorker := worker.New[messages.FromFederator](-1, -1)
+ clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
+ fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
@@ -403,8 +403,8 @@ func (suite *InboxPostTestSuite) TestPostDelete() {
suite.NoError(err)
body := bytes.NewReader(bodyJson)
- clientWorker := worker.New[messages.FromClientAPI](-1, -1)
- fedWorker := worker.New[messages.FromFederator](-1, -1)
+ clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
+ fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
diff --git a/internal/api/s2s/user/outboxget_test.go b/internal/api/s2s/user/outboxget_test.go
index ea9259b0f..79122731f 100644
--- a/internal/api/s2s/user/outboxget_test.go
+++ b/internal/api/s2s/user/outboxget_test.go
@@ -31,8 +31,8 @@ import (
"github.com/superseriousbusiness/activity/streams"
"github.com/superseriousbusiness/activity/streams/vocab"
"github.com/superseriousbusiness/gotosocial/internal/api/s2s/user"
+ "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/messages"
- "github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@@ -46,8 +46,8 @@ func (suite *OutboxGetTestSuite) TestGetOutbox() {
signedRequest := derefRequests["foss_satan_dereference_zork_outbox"]
targetAccount := suite.testAccounts["local_account_1"]
- clientWorker := worker.New[messages.FromClientAPI](-1, -1)
- fedWorker := worker.New[messages.FromFederator](-1, -1)
+ clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
+ fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
@@ -104,8 +104,8 @@ func (suite *OutboxGetTestSuite) TestGetOutboxFirstPage() {
signedRequest := derefRequests["foss_satan_dereference_zork_outbox_first"]
targetAccount := suite.testAccounts["local_account_1"]
- clientWorker := worker.New[messages.FromClientAPI](-1, -1)
- fedWorker := worker.New[messages.FromFederator](-1, -1)
+ clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
+ fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
@@ -162,8 +162,8 @@ func (suite *OutboxGetTestSuite) TestGetOutboxNextPage() {
signedRequest := derefRequests["foss_satan_dereference_zork_outbox_next"]
targetAccount := suite.testAccounts["local_account_1"]
- clientWorker := worker.New[messages.FromClientAPI](-1, -1)
- fedWorker := worker.New[messages.FromFederator](-1, -1)
+ clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
+ fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
diff --git a/internal/api/s2s/user/repliesget_test.go b/internal/api/s2s/user/repliesget_test.go
index 4b8364318..845c07bdb 100644
--- a/internal/api/s2s/user/repliesget_test.go
+++ b/internal/api/s2s/user/repliesget_test.go
@@ -33,8 +33,8 @@ import (
"github.com/superseriousbusiness/activity/streams"
"github.com/superseriousbusiness/activity/streams/vocab"
"github.com/superseriousbusiness/gotosocial/internal/api/s2s/user"
+ "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/messages"
- "github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@@ -49,8 +49,8 @@ func (suite *RepliesGetTestSuite) TestGetReplies() {
targetAccount := suite.testAccounts["local_account_1"]
targetStatus := suite.testStatuses["local_account_1_status_1"]
- clientWorker := worker.New[messages.FromClientAPI](-1, -1)
- fedWorker := worker.New[messages.FromFederator](-1, -1)
+ clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
+ fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
@@ -113,8 +113,8 @@ func (suite *RepliesGetTestSuite) TestGetRepliesNext() {
targetAccount := suite.testAccounts["local_account_1"]
targetStatus := suite.testStatuses["local_account_1_status_1"]
- clientWorker := worker.New[messages.FromClientAPI](-1, -1)
- fedWorker := worker.New[messages.FromFederator](-1, -1)
+ clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
+ fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
@@ -180,8 +180,8 @@ func (suite *RepliesGetTestSuite) TestGetRepliesLast() {
targetAccount := suite.testAccounts["local_account_1"]
targetStatus := suite.testStatuses["local_account_1_status_1"]
- clientWorker := worker.New[messages.FromClientAPI](-1, -1)
- fedWorker := worker.New[messages.FromFederator](-1, -1)
+ clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
+ fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
diff --git a/internal/api/s2s/user/statusget_test.go b/internal/api/s2s/user/statusget_test.go
index c28e4e567..6696bd7e9 100644
--- a/internal/api/s2s/user/statusget_test.go
+++ b/internal/api/s2s/user/statusget_test.go
@@ -32,8 +32,8 @@ import (
"github.com/superseriousbusiness/activity/streams"
"github.com/superseriousbusiness/activity/streams/vocab"
"github.com/superseriousbusiness/gotosocial/internal/api/s2s/user"
+ "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/messages"
- "github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@@ -48,8 +48,8 @@ func (suite *StatusGetTestSuite) TestGetStatus() {
targetAccount := suite.testAccounts["local_account_1"]
targetStatus := suite.testStatuses["local_account_1_status_1"]
- clientWorker := worker.New[messages.FromClientAPI](-1, -1)
- fedWorker := worker.New[messages.FromFederator](-1, -1)
+ clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
+ fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
@@ -116,8 +116,8 @@ func (suite *StatusGetTestSuite) TestGetStatusLowercase() {
targetAccount := suite.testAccounts["local_account_1"]
targetStatus := suite.testStatuses["local_account_1_status_1"]
- clientWorker := worker.New[messages.FromClientAPI](-1, -1)
- fedWorker := worker.New[messages.FromFederator](-1, -1)
+ clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
+ fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
diff --git a/internal/api/s2s/user/user_test.go b/internal/api/s2s/user/user_test.go
index 1ed960544..e8d305d06 100644
--- a/internal/api/s2s/user/user_test.go
+++ b/internal/api/s2s/user/user_test.go
@@ -23,6 +23,7 @@ import (
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/s2s/user"
"github.com/superseriousbusiness/gotosocial/internal/api/security"
+ "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
"github.com/superseriousbusiness/gotosocial/internal/federation"
@@ -32,7 +33,6 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
- "github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@@ -78,8 +78,8 @@ func (suite *UserStandardTestSuite) SetupTest() {
testrig.InitTestConfig()
testrig.InitTestLog()
- clientWorker := worker.New[messages.FromClientAPI](-1, -1)
- fedWorker := worker.New[messages.FromFederator](-1, -1)
+ clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
+ fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
suite.db = testrig.NewTestDB()
suite.tc = testrig.NewTestTypeConverter(suite.db)
diff --git a/internal/api/s2s/user/userget_test.go b/internal/api/s2s/user/userget_test.go
index 5c9e4f0d8..5ac2197ff 100644
--- a/internal/api/s2s/user/userget_test.go
+++ b/internal/api/s2s/user/userget_test.go
@@ -33,9 +33,9 @@ import (
"github.com/superseriousbusiness/activity/streams/vocab"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/api/s2s/user"
+ "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
- "github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@@ -49,8 +49,8 @@ func (suite *UserGetTestSuite) TestGetUser() {
signedRequest := derefRequests["foss_satan_dereference_zork"]
targetAccount := suite.testAccounts["local_account_1"]
- clientWorker := worker.New[messages.FromClientAPI](-1, -1)
- fedWorker := worker.New[messages.FromFederator](-1, -1)
+ clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
+ fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
@@ -130,8 +130,8 @@ func (suite *UserGetTestSuite) TestGetUserPublicKeyDeleted() {
derefRequests := testrig.NewTestDereferenceRequests(suite.testAccounts)
signedRequest := derefRequests["foss_satan_dereference_zork_public_key"]
- clientWorker := worker.New[messages.FromClientAPI](-1, -1)
- fedWorker := worker.New[messages.FromFederator](-1, -1)
+ clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
+ fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
diff --git a/internal/api/s2s/webfinger/webfinger_test.go b/internal/api/s2s/webfinger/webfinger_test.go
index 1f597d3f9..0df50c503 100644
--- a/internal/api/s2s/webfinger/webfinger_test.go
+++ b/internal/api/s2s/webfinger/webfinger_test.go
@@ -28,6 +28,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/ap"
"github.com/superseriousbusiness/gotosocial/internal/api/s2s/webfinger"
"github.com/superseriousbusiness/gotosocial/internal/api/security"
+ "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
"github.com/superseriousbusiness/gotosocial/internal/federation"
@@ -37,7 +38,6 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
- "github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@@ -81,8 +81,8 @@ func (suite *WebfingerStandardTestSuite) SetupTest() {
testrig.InitTestLog()
testrig.InitTestConfig()
- clientWorker := worker.New[messages.FromClientAPI](-1, -1)
- fedWorker := worker.New[messages.FromFederator](-1, -1)
+ clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
+ fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
suite.db = testrig.NewTestDB()
suite.tc = testrig.NewTestTypeConverter(suite.db)
diff --git a/internal/api/s2s/webfinger/webfingerget_test.go b/internal/api/s2s/webfinger/webfingerget_test.go
index 55de30f34..7871b6a3f 100644
--- a/internal/api/s2s/webfinger/webfingerget_test.go
+++ b/internal/api/s2s/webfinger/webfingerget_test.go
@@ -31,10 +31,10 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/s2s/webfinger"
+ "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/processing"
- "github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@@ -71,8 +71,8 @@ func (suite *WebfingerGetTestSuite) TestFingerUser() {
func (suite *WebfingerGetTestSuite) TestFingerUserWithDifferentAccountDomainByHost() {
viper.Set(config.Keys.Host, "gts.example.org")
viper.Set(config.Keys.AccountDomain, "example.org")
- clientWorker := worker.New[messages.FromClientAPI](-1, -1)
- fedWorker := worker.New[messages.FromFederator](-1, -1)
+ clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
+ fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
suite.processor = processing.NewProcessor(suite.tc, suite.federator, testrig.NewTestOauthServer(suite.db), testrig.NewTestMediaManager(suite.db, suite.storage), suite.storage, suite.db, suite.emailSender, clientWorker, fedWorker)
suite.webfingerModule = webfinger.New(suite.processor).(*webfinger.Module)
@@ -107,8 +107,8 @@ func (suite *WebfingerGetTestSuite) TestFingerUserWithDifferentAccountDomainByHo
func (suite *WebfingerGetTestSuite) TestFingerUserWithDifferentAccountDomainByAccountDomain() {
viper.Set(config.Keys.Host, "gts.example.org")
viper.Set(config.Keys.AccountDomain, "example.org")
- clientWorker := worker.New[messages.FromClientAPI](-1, -1)
- fedWorker := worker.New[messages.FromFederator](-1, -1)
+ clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
+ fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
suite.processor = processing.NewProcessor(suite.tc, suite.federator, testrig.NewTestOauthServer(suite.db), testrig.NewTestMediaManager(suite.db, suite.storage), suite.storage, suite.db, suite.emailSender, clientWorker, fedWorker)
suite.webfingerModule = webfinger.New(suite.processor).(*webfinger.Module)
diff --git a/internal/worker/workers.go b/internal/concurrency/workers.go
index 6adf9ad30..2e344aece 100644
--- a/internal/worker/workers.go
+++ b/internal/concurrency/workers.go
@@ -1,4 +1,4 @@
-package worker
+package concurrency
import (
"context"
@@ -12,17 +12,17 @@ import (
"github.com/sirupsen/logrus"
)
-// Worker represents a proccessor for MsgType objects, using a worker pool to allocate resources.
-type Worker[MsgType any] struct {
+// WorkerPool represents a proccessor for MsgType objects, using a worker pool to allocate resources.
+type WorkerPool[MsgType any] struct {
workers runners.WorkerPool
process func(context.Context, MsgType) error
prefix string // contains type prefix for logging
}
-// New returns a new Worker[MsgType] with given number of workers and queue ratio,
+// New returns a new WorkerPool[MsgType] with given number of workers and queue ratio,
// where the queue ratio is multiplied by no. workers to get queue size. If args < 1
// then suitable defaults are determined from the runtime's GOMAXPROCS variable.
-func New[MsgType any](workers int, queueRatio int) *Worker[MsgType] {
+func NewWorkerPool[MsgType any](workers int, queueRatio int) *WorkerPool[MsgType] {
var zero MsgType
if workers < 1 {
@@ -38,7 +38,7 @@ func New[MsgType any](workers int, queueRatio int) *Worker[MsgType] {
msgType := reflect.TypeOf(zero).String()
_, msgType = path.Split(msgType)
- w := &Worker[MsgType]{
+ w := &WorkerPool[MsgType]{
workers: runners.NewWorkerPool(workers, workers*queueRatio),
process: nil,
prefix: fmt.Sprintf("worker.Worker[%s]", msgType),
@@ -55,7 +55,7 @@ func New[MsgType any](workers int, queueRatio int) *Worker[MsgType] {
}
// Start will attempt to start the underlying worker pool, or return error.
-func (w *Worker[MsgType]) Start() error {
+func (w *WorkerPool[MsgType]) Start() error {
logrus.Infof("%s starting", w.prefix)
// Check processor was set
@@ -72,7 +72,7 @@ func (w *Worker[MsgType]) Start() error {
}
// Stop will attempt to stop the underlying worker pool, or return error.
-func (w *Worker[MsgType]) Stop() error {
+func (w *WorkerPool[MsgType]) Stop() error {
logrus.Infof("%s stopping", w.prefix)
// Attempt to stop pool
@@ -84,7 +84,7 @@ func (w *Worker[MsgType]) Stop() error {
}
// SetProcessor will set the Worker's processor function, which is called for each queued message.
-func (w *Worker[MsgType]) SetProcessor(fn func(context.Context, MsgType) error) {
+func (w *WorkerPool[MsgType]) SetProcessor(fn func(context.Context, MsgType) error) {
if w.process != nil {
logrus.Panicf("%s Worker.process is already set", w.prefix)
}
@@ -92,7 +92,7 @@ func (w *Worker[MsgType]) SetProcessor(fn func(context.Context, MsgType) error)
}
// Queue will queue provided message to be processed with there's a free worker.
-func (w *Worker[MsgType]) Queue(msg MsgType) {
+func (w *WorkerPool[MsgType]) Queue(msg MsgType) {
logrus.Tracef("%s queueing message (workers=%d queue=%d): %+v",
w.prefix, w.workers.Workers(), w.workers.Queue(), msg,
)
diff --git a/internal/federation/dereferencing/dereferencer_test.go b/internal/federation/dereferencing/dereferencer_test.go
index 441019866..339490e5d 100644
--- a/internal/federation/dereferencing/dereferencer_test.go
+++ b/internal/federation/dereferencing/dereferencer_test.go
@@ -29,12 +29,12 @@ import (
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/activity/streams"
"github.com/superseriousbusiness/activity/streams/vocab"
+ "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/federation/dereferencing"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/transport"
- "github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@@ -150,7 +150,7 @@ func (suite *DereferencerStandardTestSuite) mockTransportController() transport.
return response, nil
}
- fedWorker := worker.New[messages.FromFederator](-1, -1)
+ fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
mockClient := testrig.NewMockHTTPClient(do)
return testrig.NewTestTransportController(mockClient, suite.db, fedWorker)
}
diff --git a/internal/federation/federatingactor_test.go b/internal/federation/federatingactor_test.go
index 4039783a4..fdf907030 100644
--- a/internal/federation/federatingactor_test.go
+++ b/internal/federation/federatingactor_test.go
@@ -28,10 +28,10 @@ import (
"time"
"github.com/stretchr/testify/suite"
+ "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/messages"
- "github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@@ -57,7 +57,7 @@ func (suite *FederatingActorTestSuite) TestSendNoRemoteFollowers() {
)
testActivity := testrig.WrapAPNoteInCreate(testrig.URLMustParse("http://localhost:8080/whatever_some_create"), testrig.URLMustParse(testAccount.URI), time.Now(), testNote)
- fedWorker := worker.New[messages.FromFederator](-1, -1)
+ fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
// setup transport controller with a no-op client so we don't make external calls
sentMessages := []*url.URL{}
@@ -112,7 +112,7 @@ func (suite *FederatingActorTestSuite) TestSendRemoteFollower() {
)
testActivity := testrig.WrapAPNoteInCreate(testrig.URLMustParse("http://localhost:8080/whatever_some_create"), testrig.URLMustParse(testAccount.URI), time.Now(), testNote)
- fedWorker := worker.New[messages.FromFederator](-1, -1)
+ fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
// setup transport controller with a no-op client so we don't make external calls
sentMessages := []*url.URL{}
diff --git a/internal/federation/federatingdb/db.go b/internal/federation/federatingdb/db.go
index 60f09b909..cbe65e922 100644
--- a/internal/federation/federatingdb/db.go
+++ b/internal/federation/federatingdb/db.go
@@ -24,10 +24,10 @@ import (
"codeberg.org/gruf/go-mutexes"
"github.com/superseriousbusiness/activity/pub"
"github.com/superseriousbusiness/activity/streams/vocab"
+ "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
- "github.com/superseriousbusiness/gotosocial/internal/worker"
)
// DB wraps the pub.Database interface with a couple of custom functions for GoToSocial.
@@ -44,12 +44,12 @@ type DB interface {
type federatingDB struct {
locks mutexes.MutexMap
db db.DB
- fedWorker *worker.Worker[messages.FromFederator]
+ fedWorker *concurrency.WorkerPool[messages.FromFederator]
typeConverter typeutils.TypeConverter
}
// New returns a DB interface using the given database and config
-func New(db db.DB, fedWorker *worker.Worker[messages.FromFederator]) DB {
+func New(db db.DB, fedWorker *concurrency.WorkerPool[messages.FromFederator]) DB {
fdb := federatingDB{
locks: mutexes.NewMap(-1, -1), // use defaults
db: db,
diff --git a/internal/federation/federatingdb/federatingdb_test.go b/internal/federation/federatingdb/federatingdb_test.go
index d53294c1c..8e6c1802d 100644
--- a/internal/federation/federatingdb/federatingdb_test.go
+++ b/internal/federation/federatingdb/federatingdb_test.go
@@ -23,12 +23,12 @@ import (
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/ap"
+ "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/federation/federatingdb"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
- "github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@@ -36,7 +36,7 @@ type FederatingDBTestSuite struct {
suite.Suite
db db.DB
tc typeutils.TypeConverter
- fedWorker *worker.Worker[messages.FromFederator]
+ fedWorker *concurrency.WorkerPool[messages.FromFederator]
fromFederator chan messages.FromFederator
federatingDB federatingdb.DB
@@ -65,7 +65,7 @@ func (suite *FederatingDBTestSuite) SetupSuite() {
func (suite *FederatingDBTestSuite) SetupTest() {
testrig.InitTestLog()
testrig.InitTestConfig()
- suite.fedWorker = worker.New[messages.FromFederator](-1, -1)
+ suite.fedWorker = concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
suite.fromFederator = make(chan messages.FromFederator, 10)
suite.fedWorker.SetProcessor(func(ctx context.Context, msg messages.FromFederator) error {
suite.fromFederator <- msg
diff --git a/internal/federation/federatingprotocol_test.go b/internal/federation/federatingprotocol_test.go
index 09817cff3..b4769a70f 100644
--- a/internal/federation/federatingprotocol_test.go
+++ b/internal/federation/federatingprotocol_test.go
@@ -28,10 +28,10 @@ import (
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/activity/pub"
"github.com/superseriousbusiness/gotosocial/internal/ap"
+ "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/messages"
- "github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@@ -44,7 +44,7 @@ func (suite *FederatingProtocolTestSuite) TestPostInboxRequestBodyHook() {
// the activity we're gonna use
activity := suite.testActivities["dm_for_zork"]
- fedWorker := worker.New[messages.FromFederator](-1, -1)
+ fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
// setup transport controller with a no-op client so we don't make external calls
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(func(req *http.Request) (*http.Response, error) {
@@ -78,7 +78,7 @@ func (suite *FederatingProtocolTestSuite) TestAuthenticatePostInbox() {
sendingAccount := suite.testAccounts["remote_account_1"]
inboxAccount := suite.testAccounts["local_account_1"]
- fedWorker := worker.New[messages.FromFederator](-1, -1)
+ fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
// now setup module being tested, with the mock transport controller
diff --git a/internal/httpclient/client.go b/internal/httpclient/client.go
new file mode 100644
index 000000000..1a1f5e53b
--- /dev/null
+++ b/internal/httpclient/client.go
@@ -0,0 +1,199 @@
+/*
+ GoToSocial
+ Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package httpclient
+
+import (
+ "errors"
+ "io"
+ "net"
+ "net/http"
+ "net/netip"
+ "runtime"
+ "time"
+)
+
+// ErrReservedAddr is returned if a dialed address resolves to an IP within a blocked or reserved net.
+var ErrReservedAddr = errors.New("dial within blocked / reserved IP range")
+
+// ErrBodyTooLarge is returned when a received response body is above predefined limit (default 40MB).
+var ErrBodyTooLarge = errors.New("body size too large")
+
+// dialer is the base net.Dialer used by all package-created http.Transports.
+var dialer = &net.Dialer{
+ Timeout: 30 * time.Second,
+ KeepAlive: 30 * time.Second,
+ Resolver: &net.Resolver{Dial: nil},
+}
+
+// Config provides configuration details for setting up a new
+// instance of httpclient.Client{}. Within are a subset of the
+// configuration values passed to initialized http.Transport{}
+// and http.Client{}, along with httpclient.Client{} specific.
+type Config struct {
+ // MaxOpenConns limits the max number of concurrent open connections.
+ MaxOpenConns int
+
+ // MaxIdleConns: see http.Transport{}.MaxIdleConns.
+ MaxIdleConns int
+
+ // ReadBufferSize: see http.Transport{}.ReadBufferSize.
+ ReadBufferSize int
+
+ // WriteBufferSize: see http.Transport{}.WriteBufferSize.
+ WriteBufferSize int
+
+ // MaxBodySize determines the maximum fetchable body size.
+ MaxBodySize int64
+
+ // Timeout: see http.Client{}.Timeout.
+ Timeout time.Duration
+
+ // DisableCompression: see http.Transport{}.DisableCompression.
+ DisableCompression bool
+
+ // AllowRanges allows outgoing communications to given IP nets.
+ AllowRanges []netip.Prefix
+
+ // BlockRanges blocks outgoing communiciations to given IP nets.
+ BlockRanges []netip.Prefix
+}
+
+// Client wraps an underlying http.Client{} to provide the following:
+// - setting a maximum received request body size, returning error on
+// large content lengths, and using a limited reader in all other
+// cases to protect against forged / unknown content-lengths
+// - protection from server side request forgery (SSRF) by only dialing
+// out to known public IP prefixes, configurable with allows/blocks
+// - limit number of concurrent requests, else blocking until a slot
+// is available (context channels still respected)
+type Client struct {
+ client http.Client
+ queue chan struct{}
+ bmax int64
+}
+
+// New returns a new instance of Client initialized using configuration.
+func New(cfg Config) *Client {
+ var c Client
+
+ // Copy global
+ d := dialer
+
+ if cfg.MaxOpenConns <= 0 {
+ // By default base this value on GOMAXPROCS.
+ maxprocs := runtime.GOMAXPROCS(0)
+ cfg.MaxOpenConns = maxprocs * 10
+ }
+
+ if cfg.MaxIdleConns <= 0 {
+ // By default base this value on MaxOpenConns
+ cfg.MaxIdleConns = cfg.MaxOpenConns * 10
+ }
+
+ if cfg.MaxBodySize <= 0 {
+ // By default set this to a reasonable 40MB
+ cfg.MaxBodySize = 40 * 1024 * 1024
+ }
+
+ // Protect dialer with IP range sanitizer
+ d.Control = (&sanitizer{
+ allow: cfg.AllowRanges,
+ block: cfg.BlockRanges,
+ }).Sanitize
+
+ // Prepare client fields
+ c.bmax = cfg.MaxBodySize
+ c.queue = make(chan struct{}, cfg.MaxOpenConns)
+ c.client.Timeout = cfg.Timeout
+
+ // Set underlying HTTP client roundtripper
+ c.client.Transport = &http.Transport{
+ Proxy: http.ProxyFromEnvironment,
+ ForceAttemptHTTP2: true,
+ DialContext: d.DialContext,
+ MaxIdleConns: cfg.MaxIdleConns,
+ IdleConnTimeout: 90 * time.Second,
+ TLSHandshakeTimeout: 10 * time.Second,
+ ExpectContinueTimeout: 1 * time.Second,
+ ReadBufferSize: cfg.ReadBufferSize,
+ WriteBufferSize: cfg.WriteBufferSize,
+ DisableCompression: cfg.DisableCompression,
+ }
+
+ return &c
+}
+
+// Do will perform given request when an available slot in the queue is available,
+// and block until this time. For returned values, this follows the same semantics
+// as the standard http.Client{}.Do() implementation except that response body will
+// be wrapped by an io.LimitReader() to limit response body sizes.
+func (c *Client) Do(req *http.Request) (*http.Response, error) {
+ select {
+ // Request context cancelled
+ case <-req.Context().Done():
+ return nil, req.Context().Err()
+
+ // Slot in queue acquired
+ case c.queue <- struct{}{}:
+ // NOTE:
+ // Ideally here we would set the slot release to happen either
+ // on error return, or via callback from the response body closer.
+ // However when implementing this, there appear deadlocks between
+ // the channel queue here and the media manager worker pool. So
+ // currently we only place a limit on connections dialing out, but
+ // there may still be more connections open than len(c.queue) given
+ // that connections may not be closed until response body is closed.
+ // The current implementation will reduce the viability of denial of
+ // service attacks, but if there are future issues heed this advice :]
+ defer func() { <-c.queue }()
+ }
+
+ // Perform the HTTP request
+ rsp, err := c.client.Do(req)
+ if err != nil {
+ return nil, err
+ }
+
+ // Check response body not too large
+ if rsp.ContentLength > c.bmax {
+ return nil, ErrBodyTooLarge
+ }
+
+ // Seperate the body implementers
+ rbody := (io.Reader)(rsp.Body)
+ cbody := (io.Closer)(rsp.Body)
+
+ var limit int64
+
+ if limit = rsp.ContentLength; limit < 0 {
+ // If unknown, use max as reader limit
+ limit = c.bmax
+ }
+
+ // Don't trust them, limit body reads
+ rbody = io.LimitReader(rbody, limit)
+
+ // Wrap body with limit
+ rsp.Body = &struct {
+ io.Reader
+ io.Closer
+ }{rbody, cbody}
+
+ return rsp, nil
+}
diff --git a/internal/httpclient/client_test.go b/internal/httpclient/client_test.go
new file mode 100644
index 000000000..dc190d430
--- /dev/null
+++ b/internal/httpclient/client_test.go
@@ -0,0 +1,154 @@
+/*
+ GoToSocial
+ Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package httpclient_test
+
+import (
+ "bytes"
+ "errors"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "net/netip"
+ "testing"
+
+ "github.com/superseriousbusiness/gotosocial/internal/httpclient"
+)
+
+var privateIPs = []string{
+ "http://127.0.0.1:80",
+ "http://0.0.0.0:80",
+ "http://192.168.0.1:80",
+ "http://192.168.1.0:80",
+ "http://10.0.0.0:80",
+ "http://172.16.0.0:80",
+ "http://10.255.255.255:80",
+ "http://172.31.255.255:80",
+ "http://255.255.255.255:80",
+}
+
+var bodies = []string{
+ "hello world!",
+ "{}",
+ `{"key": "value", "some": "kinda bullshit"}`,
+ "body with\r\nnewlines",
+}
+
+// Note:
+// There is no test for the .MaxOpenConns implementation
+// in the httpclient.Client{}, due to the difficult to test
+// this. The block is only held for the actual dial out to
+// the connection, so the usual test of blocking and holding
+// open this queue slot to check we can't open another isn't
+// an easy test here.
+
+func TestHTTPClientSmallBody(t *testing.T) {
+ for _, body := range bodies {
+ _TestHTTPClientWithBody(t, []byte(body), int(^uint16(0)))
+ }
+}
+
+func TestHTTPClientExactBody(t *testing.T) {
+ for _, body := range bodies {
+ _TestHTTPClientWithBody(t, []byte(body), len(body))
+ }
+}
+
+func TestHTTPClientLargeBody(t *testing.T) {
+ for _, body := range bodies {
+ _TestHTTPClientWithBody(t, []byte(body), len(body)-1)
+ }
+}
+
+func _TestHTTPClientWithBody(t *testing.T, body []byte, max int) {
+ var (
+ handler http.HandlerFunc
+
+ expect []byte
+
+ expectErr error
+ )
+
+ // If this is a larger body, reslice and
+ // set error so we know what to expect
+ expect = body
+ if max < len(body) {
+ expect = expect[:max]
+ expectErr = httpclient.ErrBodyTooLarge
+ }
+
+ // Create new HTTP client with maximum body size
+ client := httpclient.New(httpclient.Config{
+ MaxBodySize: int64(max),
+ DisableCompression: true,
+ AllowRanges: []netip.Prefix{
+ // Loopback (used by server)
+ netip.MustParsePrefix("127.0.0.1/8"),
+ },
+ })
+
+ // Set simple body-writing test handler
+ handler = func(rw http.ResponseWriter, r *http.Request) {
+ _, _ = rw.Write(body)
+ }
+
+ // Start the test server
+ srv := httptest.NewServer(handler)
+ defer srv.Close()
+
+ // Wrap body to provide reader iface
+ rbody := bytes.NewReader(body)
+
+ // Create the test HTTP request
+ req, _ := http.NewRequest("POST", srv.URL, rbody)
+
+ // Perform the test request
+ rsp, err := client.Do(req)
+ if !errors.Is(err, expectErr) {
+ t.Fatalf("error performing client request: %v", err)
+ } else if err != nil {
+ return // expected error
+ }
+ defer rsp.Body.Close()
+
+ // Read response body into memory
+ check, err := io.ReadAll(rsp.Body)
+ if err != nil {
+ t.Fatalf("error reading response body: %v", err)
+ }
+
+ // Check actual response body matches expected
+ if !bytes.Equal(expect, check) {
+ t.Errorf("response body did not match expected: expect=%q actual=%q", string(expect), string(check))
+ }
+}
+
+func TestHTTPClientPrivateIP(t *testing.T) {
+ client := httpclient.New(httpclient.Config{})
+
+ for _, addr := range privateIPs {
+ // Prepare request to private IP
+ req, _ := http.NewRequest("GET", addr, nil)
+
+ // Perform the HTTP request
+ _, err := client.Do(req)
+ if !errors.Is(err, httpclient.ErrReservedAddr) {
+ t.Errorf("dialing private address did not return expected error: %v", err)
+ }
+ }
+}
diff --git a/internal/httpclient/sanitizer.go b/internal/httpclient/sanitizer.go
new file mode 100644
index 000000000..6eef6898a
--- /dev/null
+++ b/internal/httpclient/sanitizer.go
@@ -0,0 +1,64 @@
+/*
+ GoToSocial
+ Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package httpclient
+
+import (
+ "net/netip"
+ "syscall"
+
+ "github.com/superseriousbusiness/gotosocial/internal/netutil"
+)
+
+type sanitizer struct {
+ allow []netip.Prefix
+ block []netip.Prefix
+}
+
+// Sanitize implements the required net.Dialer.Control function signature.
+func (s *sanitizer) Sanitize(ntwrk, addr string, _ syscall.RawConn) error {
+ // Parse IP+port from addr
+ ipport, err := netip.ParseAddrPort(addr)
+ if err != nil {
+ return err
+ }
+
+ // Seperate the IP
+ ip := ipport.Addr()
+
+ // Check if this is explicitly allowed
+ for i := 0; i < len(s.allow); i++ {
+ if s.allow[i].Contains(ip) {
+ return nil
+ }
+ }
+
+ // Now check if explicity blocked
+ for i := 0; i < len(s.block); i++ {
+ if s.block[i].Contains(ip) {
+ return ErrReservedAddr
+ }
+ }
+
+ // Validate this is a safe IP
+ if !netutil.ValidateIP(ip) {
+ return ErrReservedAddr
+ }
+
+ return nil
+}
diff --git a/internal/media/manager.go b/internal/media/manager.go
index 174fca8e2..5b4a01021 100644
--- a/internal/media/manager.go
+++ b/internal/media/manager.go
@@ -27,9 +27,9 @@ import (
"github.com/robfig/cron/v3"
"github.com/sirupsen/logrus"
"github.com/spf13/viper"
+ "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
- "github.com/superseriousbusiness/gotosocial/internal/worker"
)
// Manager provides an interface for managing media: parsing, storing, and retrieving media objects like photos, videos, and gifs.
@@ -79,8 +79,8 @@ type Manager interface {
type manager struct {
db db.DB
storage *kv.KVStore
- emojiWorker *worker.Worker[*ProcessingEmoji]
- mediaWorker *worker.Worker[*ProcessingMedia]
+ emojiWorker *concurrency.WorkerPool[*ProcessingEmoji]
+ mediaWorker *concurrency.WorkerPool[*ProcessingMedia]
stopCronJobs func() error
}
@@ -89,7 +89,7 @@ type manager struct {
// A worker pool will also be initialized for the manager, to ensure that only
// a limited number of media will be processed in parallel. The numbers of workers
// is determined from the $GOMAXPROCS environment variable (usually no. CPU cores).
-// See internal/worker.New() documentation for further information.
+// See internal/concurrency.NewWorkerPool() documentation for further information.
func NewManager(database db.DB, storage *kv.KVStore) (Manager, error) {
m := &manager{
db: database,
@@ -97,7 +97,7 @@ func NewManager(database db.DB, storage *kv.KVStore) (Manager, error) {
}
// Prepare the media worker pool
- m.mediaWorker = worker.New[*ProcessingMedia](-1, 10)
+ m.mediaWorker = concurrency.NewWorkerPool[*ProcessingMedia](-1, 10)
m.mediaWorker.SetProcessor(func(ctx context.Context, media *ProcessingMedia) error {
if err := ctx.Err(); err != nil {
return err
@@ -109,7 +109,7 @@ func NewManager(database db.DB, storage *kv.KVStore) (Manager, error) {
})
// Prepare the emoji worker pool
- m.emojiWorker = worker.New[*ProcessingEmoji](-1, 10)
+ m.emojiWorker = concurrency.NewWorkerPool[*ProcessingEmoji](-1, 10)
m.emojiWorker.SetProcessor(func(ctx context.Context, emoji *ProcessingEmoji) error {
if err := ctx.Err(); err != nil {
return err
diff --git a/internal/netutil/validate.go b/internal/netutil/validate.go
new file mode 100644
index 000000000..27cc9ba4a
--- /dev/null
+++ b/internal/netutil/validate.go
@@ -0,0 +1,78 @@
+/*
+ GoToSocial
+ Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package netutil
+
+import (
+ "net/netip"
+)
+
+var (
+ // IPv6GlobalUnicast is the global IPv6 unicast IP prefix.
+ IPv6GlobalUnicast = netip.MustParsePrefix("ff00::/8")
+
+ // IPvReserved contains IPv4 reserved IP prefixes.
+ IPv4Reserved = [...]netip.Prefix{
+ netip.MustParsePrefix("0.0.0.0/8"), // Current network
+ netip.MustParsePrefix("10.0.0.0/8"), // Private
+ netip.MustParsePrefix("100.64.0.0/10"), // RFC6598
+ netip.MustParsePrefix("127.0.0.0/8"), // Loopback
+ netip.MustParsePrefix("169.254.0.0/16"), // Link-local
+ netip.MustParsePrefix("172.16.0.0/12"), // Private
+ netip.MustParsePrefix("192.0.0.0/24"), // RFC6890
+ netip.MustParsePrefix("192.0.2.0/24"), // Test, doc, examples
+ netip.MustParsePrefix("192.88.99.0/24"), // IPv6 to IPv4 relay
+ netip.MustParsePrefix("192.168.0.0/16"), // Private
+ netip.MustParsePrefix("198.18.0.0/15"), // Benchmarking tests
+ netip.MustParsePrefix("198.51.100.0/24"), // Test, doc, examples
+ netip.MustParsePrefix("203.0.113.0/24"), // Test, doc, examples
+ netip.MustParsePrefix("224.0.0.0/4"), // Multicast
+ netip.MustParsePrefix("240.0.0.0/4"), // Reserved (includes broadcast / 255.255.255.255)
+ }
+)
+
+// ValidateAddr will parse a netip.AddrPort from string, and return the result of ValidateIP() on addr.
+func ValidateAddr(s string) bool {
+ ipport, err := netip.ParseAddrPort(s)
+ if err != nil {
+ return false
+ }
+ return ValidateIP(ipport.Addr())
+}
+
+// ValidateIP returns whether IP is an IPv4/6 address in non-reserved, public ranges.
+func ValidateIP(ip netip.Addr) bool {
+ switch {
+ // IPv4: check if IPv4 in reserved nets
+ case ip.Is4():
+ for _, reserved := range IPv4Reserved {
+ if reserved.Contains(ip) {
+ return false
+ }
+ }
+ return true
+
+ // IPv6: check if in global unicast (public internet)
+ case ip.Is6():
+ return IPv6GlobalUnicast.Contains(ip)
+
+ // Assume malicious by default
+ default:
+ return false
+ }
+}
diff --git a/internal/processing/account/account.go b/internal/processing/account/account.go
index c49df1a1a..7668da02c 100644
--- a/internal/processing/account/account.go
+++ b/internal/processing/account/account.go
@@ -23,6 +23,7 @@ import (
"mime/multipart"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
+ "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
@@ -33,7 +34,6 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/text"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
- "github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/oauth2/v4"
)
@@ -84,7 +84,7 @@ type Processor interface {
type processor struct {
tc typeutils.TypeConverter
mediaManager media.Manager
- clientWorker *worker.Worker[messages.FromClientAPI]
+ clientWorker *concurrency.WorkerPool[messages.FromClientAPI]
oauthServer oauth.Server
filter visibility.Filter
formatter text.Formatter
@@ -94,7 +94,7 @@ type processor struct {
}
// New returns a new account processor.
-func New(db db.DB, tc typeutils.TypeConverter, mediaManager media.Manager, oauthServer oauth.Server, clientWorker *worker.Worker[messages.FromClientAPI], federator federation.Federator, parseMention gtsmodel.ParseMentionFunc) Processor {
+func New(db db.DB, tc typeutils.TypeConverter, mediaManager media.Manager, oauthServer oauth.Server, clientWorker *concurrency.WorkerPool[messages.FromClientAPI], federator federation.Federator, parseMention gtsmodel.ParseMentionFunc) Processor {
return &processor{
tc: tc,
mediaManager: mediaManager,
diff --git a/internal/processing/account/account_test.go b/internal/processing/account/account_test.go
index 33b744250..d9ce68cc0 100644
--- a/internal/processing/account/account_test.go
+++ b/internal/processing/account/account_test.go
@@ -24,6 +24,7 @@ import (
"codeberg.org/gruf/go-store/kv"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/activity/pub"
+ "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
"github.com/superseriousbusiness/gotosocial/internal/federation"
@@ -35,7 +36,6 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/processing/account"
"github.com/superseriousbusiness/gotosocial/internal/transport"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
- "github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@@ -81,8 +81,8 @@ func (suite *AccountStandardTestSuite) SetupTest() {
testrig.InitTestLog()
testrig.InitTestConfig()
- fedWorker := worker.New[messages.FromFederator](-1, -1)
- clientWorker := worker.New[messages.FromClientAPI](-1, -1)
+ fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
+ clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
clientWorker.SetProcessor(func(_ context.Context, msg messages.FromClientAPI) error {
suite.fromClientAPIChan <- msg
return nil
diff --git a/internal/processing/admin/admin.go b/internal/processing/admin/admin.go
index 4b466a2d7..6779f59b7 100644
--- a/internal/processing/admin/admin.go
+++ b/internal/processing/admin/admin.go
@@ -23,13 +23,13 @@ import (
"mime/multipart"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
+ "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/media"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
- "github.com/superseriousbusiness/gotosocial/internal/worker"
)
// Processor wraps a bunch of functions for processing admin actions.
@@ -47,12 +47,12 @@ type Processor interface {
type processor struct {
tc typeutils.TypeConverter
mediaManager media.Manager
- clientWorker *worker.Worker[messages.FromClientAPI]
+ clientWorker *concurrency.WorkerPool[messages.FromClientAPI]
db db.DB
}
// New returns a new admin processor.
-func New(db db.DB, tc typeutils.TypeConverter, mediaManager media.Manager, clientWorker *worker.Worker[messages.FromClientAPI]) Processor {
+func New(db db.DB, tc typeutils.TypeConverter, mediaManager media.Manager, clientWorker *concurrency.WorkerPool[messages.FromClientAPI]) Processor {
return &processor{
tc: tc,
mediaManager: mediaManager,
diff --git a/internal/processing/media/media_test.go b/internal/processing/media/media_test.go
index af67b36b1..1149f2646 100644
--- a/internal/processing/media/media_test.go
+++ b/internal/processing/media/media_test.go
@@ -26,6 +26,7 @@ import (
"codeberg.org/gruf/go-store/kv"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/suite"
+ "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/media"
@@ -33,7 +34,6 @@ import (
mediaprocessing "github.com/superseriousbusiness/gotosocial/internal/processing/media"
"github.com/superseriousbusiness/gotosocial/internal/transport"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
- "github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@@ -122,7 +122,7 @@ func (suite *MediaStandardTestSuite) mockTransportController() transport.Control
return response, nil
}
- fedWorker := worker.New[messages.FromFederator](-1, -1)
+ fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
mockClient := testrig.NewMockHTTPClient(do)
return testrig.NewTestTransportController(mockClient, suite.db, fedWorker)
}
diff --git a/internal/processing/processor.go b/internal/processing/processor.go
index 69f3100f9..d30f2f37e 100644
--- a/internal/processing/processor.go
+++ b/internal/processing/processor.go
@@ -25,6 +25,7 @@ import (
"codeberg.org/gruf/go-store/kv"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
+ "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
"github.com/superseriousbusiness/gotosocial/internal/federation"
@@ -44,7 +45,6 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/timeline"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
- "github.com/superseriousbusiness/gotosocial/internal/worker"
)
// Processor should be passed to api modules (see internal/apimodule/...). It is used for
@@ -237,8 +237,8 @@ type Processor interface {
// processor just implements the Processor interface
type processor struct {
- clientWorker *worker.Worker[messages.FromClientAPI]
- fedWorker *worker.Worker[messages.FromFederator]
+ clientWorker *concurrency.WorkerPool[messages.FromClientAPI]
+ fedWorker *concurrency.WorkerPool[messages.FromFederator]
federator federation.Federator
tc typeutils.TypeConverter
@@ -271,8 +271,8 @@ func NewProcessor(
storage *kv.KVStore,
db db.DB,
emailSender email.Sender,
- clientWorker *worker.Worker[messages.FromClientAPI],
- fedWorker *worker.Worker[messages.FromFederator],
+ clientWorker *concurrency.WorkerPool[messages.FromClientAPI],
+ fedWorker *concurrency.WorkerPool[messages.FromFederator],
) Processor {
parseMentionFunc := GetParseMentionFunc(db, federator)
diff --git a/internal/processing/processor_test.go b/internal/processing/processor_test.go
index 7e1972366..5946e6718 100644
--- a/internal/processing/processor_test.go
+++ b/internal/processing/processor_test.go
@@ -29,6 +29,7 @@ import (
"codeberg.org/gruf/go-store/kv"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/activity/streams"
+ "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
"github.com/superseriousbusiness/gotosocial/internal/federation"
@@ -40,7 +41,6 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/timeline"
"github.com/superseriousbusiness/gotosocial/internal/transport"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
- "github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@@ -217,8 +217,8 @@ func (suite *ProcessingStandardTestSuite) SetupTest() {
}, nil
})
- clientWorker := worker.New[messages.FromClientAPI](-1, -1)
- fedWorker := worker.New[messages.FromFederator](-1, -1)
+ clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
+ fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
suite.transportController = testrig.NewTestTransportController(httpClient, suite.db, fedWorker)
suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage)
diff --git a/internal/processing/status/status.go b/internal/processing/status/status.go
index 207bffb30..e8b4a8268 100644
--- a/internal/processing/status/status.go
+++ b/internal/processing/status/status.go
@@ -22,6 +22,7 @@ import (
"context"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
+ "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
@@ -29,7 +30,6 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/text"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
- "github.com/superseriousbusiness/gotosocial/internal/worker"
)
// Processor wraps a bunch of functions for processing statuses.
@@ -74,12 +74,12 @@ type processor struct {
db db.DB
filter visibility.Filter
formatter text.Formatter
- clientWorker *worker.Worker[messages.FromClientAPI]
+ clientWorker *concurrency.WorkerPool[messages.FromClientAPI]
parseMention gtsmodel.ParseMentionFunc
}
// New returns a new status processor.
-func New(db db.DB, tc typeutils.TypeConverter, clientWorker *worker.Worker[messages.FromClientAPI], parseMention gtsmodel.ParseMentionFunc) Processor {
+func New(db db.DB, tc typeutils.TypeConverter, clientWorker *concurrency.WorkerPool[messages.FromClientAPI], parseMention gtsmodel.ParseMentionFunc) Processor {
return &processor{
tc: tc,
db: db,
diff --git a/internal/processing/status/status_test.go b/internal/processing/status/status_test.go
index d2126f03d..17c68c0b6 100644
--- a/internal/processing/status/status_test.go
+++ b/internal/processing/status/status_test.go
@@ -21,6 +21,7 @@ package status_test
import (
"codeberg.org/gruf/go-store/kv"
"github.com/stretchr/testify/suite"
+ "github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
@@ -30,7 +31,6 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/processing/status"
"github.com/superseriousbusiness/gotosocial/internal/transport"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
- "github.com/superseriousbusiness/gotosocial/internal/worker"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@@ -42,7 +42,7 @@ type StatusStandardTestSuite struct {
storage *kv.KVStore
mediaManager media.Manager
federator federation.Federator
- clientWorker *worker.Worker[messages.FromClientAPI]
+ clientWorker *concurrency.WorkerPool[messages.FromClientAPI]
// standard suite models
testTokens map[string]*gtsmodel.Token
@@ -75,11 +75,11 @@ func (suite *StatusStandardTestSuite) SetupTest() {
testrig.InitTestConfig()
testrig.InitTestLog()
- fedWorker := worker.New[messages.FromFederator](-1, -1)
+ fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
suite.db = testrig.NewTestDB()
suite.typeConverter = testrig.NewTestTypeConverter(suite.db)
- suite.clientWorker = worker.New[messages.FromClientAPI](-1, -1)
+ suite.clientWorker = concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
suite.tc = testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker)
suite.storage = testrig.NewTestStorage()
suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage)
diff --git a/internal/transport/controller.go b/internal/transport/controller.go
index 56a922a8b..280d4bc0b 100644
--- a/internal/transport/controller.go
+++ b/internal/transport/controller.go
@@ -20,13 +20,17 @@ package transport
import (
"context"
- "crypto"
+ "crypto/rsa"
+ "crypto/x509"
"encoding/json"
"fmt"
"net/url"
- "sync"
+ "runtime/debug"
+ "time"
- "github.com/go-fed/httpsig"
+ "codeberg.org/gruf/go-byteutil"
+ "codeberg.org/gruf/go-cache/v2"
+ "github.com/sirupsen/logrus"
"github.com/spf13/viper"
"github.com/superseriousbusiness/activity/pub"
"github.com/superseriousbusiness/activity/streams"
@@ -37,109 +41,85 @@ import (
// Controller generates transports for use in making federation requests to other servers.
type Controller interface {
- NewTransport(pubKeyID string, privkey crypto.PrivateKey) (Transport, error)
+ // NewTransport returns an http signature transport with the given public key ID (URL location of pubkey), and the given private key.
+ NewTransport(pubKeyID string, privkey *rsa.PrivateKey) (Transport, error)
+
+ // NewTransportForUsername searches for account with username, and returns result of .NewTransport().
NewTransportForUsername(ctx context.Context, username string) (Transport, error)
}
type controller struct {
- db db.DB
- clock pub.Clock
- client pub.HttpClient
- appAgent string
-
- // dereferenceFollowersShortcut is a shortcut to dereference followers of an
- // account on this instance, without making any external api/http calls.
- //
- // It is passed to new transports, and should only be invoked when the iri.Host == this host.
- dereferenceFollowersShortcut func(ctx context.Context, iri *url.URL) ([]byte, error)
-
- // dereferenceUserShortcut is a shortcut to dereference followers an account on
- // this instance, without making any external api/http calls.
- //
- // It is passed to new transports, and should only be invoked when the iri.Host == this host.
- dereferenceUserShortcut func(ctx context.Context, iri *url.URL) ([]byte, error)
+ db db.DB
+ fedDB federatingdb.DB
+ clock pub.Clock
+ client pub.HttpClient
+ cache cache.Cache[string, *transport]
+ userAgent string
}
-func dereferenceFollowersShortcut(federatingDB federatingdb.DB) func(context.Context, *url.URL) ([]byte, error) {
- return func(ctx context.Context, iri *url.URL) ([]byte, error) {
- followers, err := federatingDB.Followers(ctx, iri)
- if err != nil {
- return nil, err
- }
+// NewController returns an implementation of the Controller interface for creating new transports
+func NewController(db db.DB, federatingDB federatingdb.DB, clock pub.Clock, client pub.HttpClient) Controller {
+ applicationName := viper.GetString(config.Keys.ApplicationName)
+ host := viper.GetString(config.Keys.Host)
- i, err := streams.Serialize(followers)
- if err != nil {
- return nil, err
- }
+ // Determine build information
+ build, _ := debug.ReadBuildInfo()
- return json.Marshal(i)
+ c := &controller{
+ db: db,
+ fedDB: federatingDB,
+ clock: clock,
+ client: client,
+ cache: cache.New[string, *transport](),
+ userAgent: fmt.Sprintf("%s; %s (gofed/activity gotosocial-%s)", applicationName, host, build.Main.Version),
}
-}
-func dereferenceUserShortcut(federatingDB federatingdb.DB) func(context.Context, *url.URL) ([]byte, error) {
- return func(ctx context.Context, iri *url.URL) ([]byte, error) {
- user, err := federatingDB.Get(ctx, iri)
- if err != nil {
- return nil, err
- }
-
- i, err := streams.Serialize(user)
- if err != nil {
- return nil, err
- }
-
- return json.Marshal(i)
+ // Transport cache has TTL=1hr freq=1m
+ c.cache.SetTTL(time.Hour, false)
+ if !c.cache.Start(time.Minute) {
+ logrus.Panic("failed to start transport controller cache")
}
-}
-// NewController returns an implementation of the Controller interface for creating new transports
-func NewController(db db.DB, federatingDB federatingdb.DB, clock pub.Clock, client pub.HttpClient) Controller {
- applicationName := viper.GetString(config.Keys.ApplicationName)
- host := viper.GetString(config.Keys.Host)
- appAgent := fmt.Sprintf("%s %s", applicationName, host)
-
- return &controller{
- db: db,
- clock: clock,
- client: client,
- appAgent: appAgent,
- dereferenceFollowersShortcut: dereferenceFollowersShortcut(federatingDB),
- dereferenceUserShortcut: dereferenceUserShortcut(federatingDB),
- }
+ return c
}
-// NewTransport returns a new http signature transport with the given public key id (a URL), and the given private key.
-func (c *controller) NewTransport(pubKeyID string, privkey crypto.PrivateKey) (Transport, error) {
- prefs := []httpsig.Algorithm{httpsig.RSA_SHA256}
- digestAlgo := httpsig.DigestSha256
- getHeaders := []string{httpsig.RequestTarget, "host", "date"}
- postHeaders := []string{httpsig.RequestTarget, "host", "date", "digest"}
+func (c *controller) NewTransport(pubKeyID string, privkey *rsa.PrivateKey) (Transport, error) {
+ // Generate public key string for cache key
+ //
+ // NOTE: it is safe to use the public key as the cache
+ // key here as we are generating it ourselves from the
+ // private key. If we were simply using a public key
+ // provided as argument that would absolutely NOT be safe.
+ pubStr := privkeyToPublicStr(privkey)
+
+ // First check for cached transport
+ transp, ok := c.cache.Get(pubStr)
+ if ok {
+ return transp, nil
+ }
- getSigner, _, err := httpsig.NewSigner(prefs, digestAlgo, getHeaders, httpsig.Signature, 120)
- if err != nil {
- return nil, fmt.Errorf("error creating get signer: %s", err)
+ // Create the transport
+ transp = &transport{
+ controller: c,
+ pubKeyID: pubKeyID,
+ privkey: privkey,
}
- postSigner, _, err := httpsig.NewSigner(prefs, digestAlgo, postHeaders, httpsig.Signature, 120)
- if err != nil {
- return nil, fmt.Errorf("error creating post signer: %s", err)
+ // Cache this transport under pubkey
+ if !c.cache.Put(pubStr, transp) {
+ var cached *transport
+
+ cached, ok = c.cache.Get(pubStr)
+ if !ok {
+ // Some ridiculous race cond.
+ c.cache.Set(pubStr, transp)
+ } else {
+ // Use already cached
+ transp = cached
+ }
}
- sigTransport := pub.NewHttpSigTransport(c.client, c.appAgent, c.clock, getSigner, postSigner, pubKeyID, privkey)
-
- return &transport{
- client: c.client,
- appAgent: c.appAgent,
- gofedAgent: "(go-fed/activity v1.0.0)",
- clock: c.clock,
- pubKeyID: pubKeyID,
- privkey: privkey,
- sigTransport: sigTransport,
- getSigner: getSigner,
- getSignerMu: &sync.Mutex{},
- dereferenceFollowersShortcut: c.dereferenceFollowersShortcut,
- dereferenceUserShortcut: c.dereferenceUserShortcut,
- }, nil
+ return transp, nil
}
func (c *controller) NewTransportForUsername(ctx context.Context, username string) (Transport, error) {
@@ -164,3 +144,45 @@ func (c *controller) NewTransportForUsername(ctx context.Context, username strin
}
return transport, nil
}
+
+// dereferenceLocalFollowers is a shortcut to dereference followers of an
+// account on this instance, without making any external api/http calls.
+//
+// It is passed to new transports, and should only be invoked when the iri.Host == this host.
+func (c *controller) dereferenceLocalFollowers(ctx context.Context, iri *url.URL) ([]byte, error) {
+ followers, err := c.fedDB.Followers(ctx, iri)
+ if err != nil {
+ return nil, err
+ }
+
+ i, err := streams.Serialize(followers)
+ if err != nil {
+ return nil, err
+ }
+
+ return json.Marshal(i)
+}
+
+// dereferenceLocalUser is a shortcut to dereference followers an account on
+// this instance, without making any external api/http calls.
+//
+// It is passed to new transports, and should only be invoked when the iri.Host == this host.
+func (c *controller) dereferenceLocalUser(ctx context.Context, iri *url.URL) ([]byte, error) {
+ user, err := c.fedDB.Get(ctx, iri)
+ if err != nil {
+ return nil, err
+ }
+
+ i, err := streams.Serialize(user)
+ if err != nil {
+ return nil, err
+ }
+
+ return json.Marshal(i)
+}
+
+// privkeyToPublicStr will create a string representation of RSA public key from private.
+func privkeyToPublicStr(privkey *rsa.PrivateKey) string {
+ b := x509.MarshalPKCS1PublicKey(&privkey.PublicKey)
+ return byteutil.B2S(b)
+}
diff --git a/internal/transport/deliver.go b/internal/transport/deliver.go
index fe17f7761..bacaa9b3a 100644
--- a/internal/transport/deliver.go
+++ b/internal/transport/deliver.go
@@ -19,13 +19,14 @@
package transport
import (
+ "bytes"
"context"
"fmt"
+ "net/http"
"net/url"
"strings"
"sync"
- "github.com/sirupsen/logrus"
"github.com/spf13/viper"
"github.com/superseriousbusiness/gotosocial/internal/config"
)
@@ -72,6 +73,28 @@ func (t *transport) Deliver(ctx context.Context, b []byte, to *url.URL) error {
return nil
}
- logrus.Debugf("Deliver: posting as %s to %s", t.pubKeyID, to.String())
- return t.sigTransport.Deliver(ctx, b, to)
+ urlStr := to.String()
+
+ req, err := http.NewRequestWithContext(ctx, "POST", urlStr, bytes.NewReader(b))
+ if err != nil {
+ return err
+ }
+
+ req.Header.Add("Content-Type", "application/ld+json; profile=\"https://www.w3.org/ns/activitystreams\"")
+ req.Header.Add("Accept-Charset", "utf-8")
+ req.Header.Add("User-Agent", t.controller.userAgent)
+ req.Header.Set("Host", to.Host)
+
+ resp, err := t.POST(req, b)
+ if err != nil {
+ return err
+ }
+ defer resp.Body.Close()
+
+ if code := resp.StatusCode; code != http.StatusOK &&
+ code != http.StatusCreated && code != http.StatusAccepted {
+ return fmt.Errorf("POST request to %s failed (%d): %s", urlStr, resp.StatusCode, resp.Status)
+ }
+
+ return nil
}
diff --git a/internal/transport/dereference.go b/internal/transport/dereference.go
index 61d99c5c5..36157b673 100644
--- a/internal/transport/dereference.go
+++ b/internal/transport/dereference.go
@@ -20,32 +20,55 @@ package transport
import (
"context"
+ "fmt"
+ "io/ioutil"
+ "net/http"
"net/url"
- "github.com/sirupsen/logrus"
"github.com/spf13/viper"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/uris"
)
func (t *transport) Dereference(ctx context.Context, iri *url.URL) ([]byte, error) {
- l := logrus.WithField("func", "Dereference")
-
// if the request is to us, we can shortcut for certain URIs rather than going through
// the normal request flow, thereby saving time and energy
if iri.Host == viper.GetString(config.Keys.Host) {
if uris.IsFollowersPath(iri) {
// the request is for followers of one of our accounts, which we can shortcut
- return t.dereferenceFollowersShortcut(ctx, iri)
+ return t.controller.dereferenceLocalFollowers(ctx, iri)
}
if uris.IsUserPath(iri) {
// the request is for one of our accounts, which we can shortcut
- return t.dereferenceUserShortcut(ctx, iri)
+ return t.controller.dereferenceLocalUser(ctx, iri)
}
}
- // the request is either for a remote host or for us but we don't have a shortcut, so continue as normal
- l.Debugf("performing GET to %s", iri.String())
- return t.sigTransport.Dereference(ctx, iri)
+ // Build IRI just once
+ iriStr := iri.String()
+
+ // Prepare new HTTP request to endpoint
+ req, err := http.NewRequestWithContext(ctx, "GET", iriStr, nil)
+ if err != nil {
+ return nil, err
+ }
+ req.Header.Add("Accept", "application/ld+json; profile=\"https://www.w3.org/ns/activitystreams\"")
+ req.Header.Add("Accept-Charset", "utf-8")
+ req.Header.Add("User-Agent", t.controller.userAgent)
+ req.Header.Set("Host", iri.Host)
+
+ // Perform the HTTP request
+ rsp, err := t.GET(req)
+ if err != nil {
+ return nil, err
+ }
+ defer rsp.Body.Close()
+
+ // Check for an expected status code
+ if rsp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("GET request to %s failed (%d): %s", iriStr, rsp.StatusCode, rsp.Status)
+ }
+
+ return ioutil.ReadAll(rsp.Body)
}
diff --git a/internal/transport/derefinstance.go b/internal/transport/derefinstance.go
index c64dced0f..1acbcc364 100644
--- a/internal/transport/derefinstance.go
+++ b/internal/transport/derefinstance.go
@@ -80,43 +80,38 @@ func (t *transport) DereferenceInstance(ctx context.Context, iri *url.URL) (*gts
}
func dereferenceByAPIV1Instance(ctx context.Context, t *transport, iri *url.URL) (*gtsmodel.Instance, error) {
- l := logrus.WithField("func", "dereferenceByAPIV1Instance")
-
cleanIRI := &url.URL{
Scheme: iri.Scheme,
Host: iri.Host,
Path: "api/v1/instance",
}
- l.Debugf("performing GET to %s", cleanIRI.String())
- req, err := http.NewRequestWithContext(ctx, "GET", cleanIRI.String(), nil)
+ // Build IRI just once
+ iriStr := cleanIRI.String()
+
+ req, err := http.NewRequestWithContext(ctx, "GET", iriStr, nil)
if err != nil {
return nil, err
}
+
req.Header.Add("Accept", "application/json")
- req.Header.Add("Date", t.clock.Now().UTC().Format("Mon, 02 Jan 2006 15:04:05")+" GMT")
- req.Header.Add("User-Agent", fmt.Sprintf("%s %s", t.appAgent, t.gofedAgent))
+ req.Header.Add("User-Agent", t.controller.userAgent)
req.Header.Set("Host", cleanIRI.Host)
- t.getSignerMu.Lock()
- err = t.getSigner.SignRequest(t.privkey, t.pubKeyID, req, nil)
- t.getSignerMu.Unlock()
- if err != nil {
- return nil, err
- }
- resp, err := t.client.Do(req)
+
+ resp, err := t.GET(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
+
if resp.StatusCode != http.StatusOK {
- return nil, fmt.Errorf("GET request to %s failed (%d): %s", cleanIRI.String(), resp.StatusCode, resp.Status)
+ return nil, fmt.Errorf("GET request to %s failed (%d): %s", iriStr, resp.StatusCode, resp.Status)
}
+
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, err
- }
-
- if len(b) == 0 {
+ } else if len(b) == 0 {
return nil, errors.New("response bytes was len 0")
}
@@ -237,44 +232,37 @@ func dereferenceByNodeInfo(c context.Context, t *transport, iri *url.URL) (*gtsm
}
func callNodeInfoWellKnown(ctx context.Context, t *transport, iri *url.URL) (*url.URL, error) {
- l := logrus.WithField("func", "callNodeInfoWellKnown")
-
cleanIRI := &url.URL{
Scheme: iri.Scheme,
Host: iri.Host,
Path: ".well-known/nodeinfo",
}
- l.Debugf("performing GET to %s", cleanIRI.String())
- req, err := http.NewRequestWithContext(ctx, "GET", cleanIRI.String(), nil)
+ // Build IRI just once
+ iriStr := cleanIRI.String()
+
+ req, err := http.NewRequestWithContext(ctx, "GET", iriStr, nil)
if err != nil {
return nil, err
}
-
req.Header.Add("Accept", "application/json")
- req.Header.Add("Date", t.clock.Now().UTC().Format("Mon, 02 Jan 2006 15:04:05")+" GMT")
- req.Header.Add("User-Agent", fmt.Sprintf("%s %s", t.appAgent, t.gofedAgent))
+ req.Header.Add("User-Agent", t.controller.userAgent)
req.Header.Set("Host", cleanIRI.Host)
- t.getSignerMu.Lock()
- err = t.getSigner.SignRequest(t.privkey, t.pubKeyID, req, nil)
- t.getSignerMu.Unlock()
- if err != nil {
- return nil, err
- }
- resp, err := t.client.Do(req)
+
+ resp, err := t.GET(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
+
if resp.StatusCode != http.StatusOK {
- return nil, fmt.Errorf("callNodeInfoWellKnown: GET request to %s failed (%d): %s", cleanIRI.String(), resp.StatusCode, resp.Status)
+ return nil, fmt.Errorf("callNodeInfoWellKnown: GET request to %s failed (%d): %s", iriStr, resp.StatusCode, resp.Status)
}
+
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, err
- }
-
- if len(b) == 0 {
+ } else if len(b) == 0 {
return nil, errors.New("callNodeInfoWellKnown: response bytes was len 0")
}
@@ -302,38 +290,31 @@ func callNodeInfoWellKnown(ctx context.Context, t *transport, iri *url.URL) (*ur
}
func callNodeInfo(ctx context.Context, t *transport, iri *url.URL) (*apimodel.Nodeinfo, error) {
- l := logrus.WithField("func", "callNodeInfo")
+ // Build IRI just once
+ iriStr := iri.String()
- l.Debugf("performing GET to %s", iri.String())
- req, err := http.NewRequestWithContext(ctx, "GET", iri.String(), nil)
+ req, err := http.NewRequestWithContext(ctx, "GET", iriStr, nil)
if err != nil {
return nil, err
}
-
req.Header.Add("Accept", "application/json")
- req.Header.Add("Date", t.clock.Now().UTC().Format("Mon, 02 Jan 2006 15:04:05")+" GMT")
- req.Header.Add("User-Agent", fmt.Sprintf("%s %s", t.appAgent, t.gofedAgent))
+ req.Header.Add("User-Agent", t.controller.userAgent)
req.Header.Set("Host", iri.Host)
- t.getSignerMu.Lock()
- err = t.getSigner.SignRequest(t.privkey, t.pubKeyID, req, nil)
- t.getSignerMu.Unlock()
- if err != nil {
- return nil, err
- }
- resp, err := t.client.Do(req)
+
+ resp, err := t.GET(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
+
if resp.StatusCode != http.StatusOK {
- return nil, fmt.Errorf("callNodeInfo: GET request to %s failed (%d): %s", iri.String(), resp.StatusCode, resp.Status)
+ return nil, fmt.Errorf("callNodeInfo: GET request to %s failed (%d): %s", iriStr, resp.StatusCode, resp.Status)
}
+
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, err
- }
-
- if len(b) == 0 {
+ } else if len(b) == 0 {
return nil, errors.New("callNodeInfo: response bytes was len 0")
}
diff --git a/internal/transport/derefmedia.go b/internal/transport/derefmedia.go
index e3c86ce1e..8feb7ed20 100644
--- a/internal/transport/derefmedia.go
+++ b/internal/transport/derefmedia.go
@@ -24,34 +24,31 @@ import (
"io"
"net/http"
"net/url"
-
- "github.com/sirupsen/logrus"
)
func (t *transport) DereferenceMedia(ctx context.Context, iri *url.URL) (io.ReadCloser, int, error) {
- l := logrus.WithField("func", "DereferenceMedia")
- l.Debugf("performing GET to %s", iri.String())
- req, err := http.NewRequestWithContext(ctx, "GET", iri.String(), nil)
+ // Build IRI just once
+ iriStr := iri.String()
+
+ // Prepare HTTP request to this media's IRI
+ req, err := http.NewRequestWithContext(ctx, "GET", iriStr, nil)
if err != nil {
return nil, 0, err
}
-
req.Header.Add("Accept", "*/*") // we don't know what kind of media we're going to get here
- req.Header.Add("Date", t.clock.Now().UTC().Format("Mon, 02 Jan 2006 15:04:05")+" GMT")
- req.Header.Add("User-Agent", fmt.Sprintf("%s %s", t.appAgent, t.gofedAgent))
+ req.Header.Add("User-Agent", t.controller.userAgent)
req.Header.Set("Host", iri.Host)
- t.getSignerMu.Lock()
- err = t.getSigner.SignRequest(t.privkey, t.pubKeyID, req, nil)
- t.getSignerMu.Unlock()
- if err != nil {
- return nil, 0, err
- }
- resp, err := t.client.Do(req)
+
+ // Perform the HTTP request
+ rsp, err := t.GET(req)
if err != nil {
return nil, 0, err
}
- if resp.StatusCode != http.StatusOK {
- return nil, 0, fmt.Errorf("GET request to %s failed (%d): %s", iri.String(), resp.StatusCode, resp.Status)
+
+ // Check for an expected status code
+ if rsp.StatusCode != http.StatusOK {
+ return nil, 0, fmt.Errorf("GET request to %s failed (%d): %s", iriStr, rsp.StatusCode, rsp.Status)
}
- return resp.Body, int(resp.ContentLength), nil
+
+ return rsp.Body, int(rsp.ContentLength), nil
}
diff --git a/internal/transport/finger.go b/internal/transport/finger.go
index a71bbb51e..7554a242f 100644
--- a/internal/transport/finger.go
+++ b/internal/transport/finger.go
@@ -23,46 +23,36 @@ import (
"fmt"
"io/ioutil"
"net/http"
- "net/url"
-
- "github.com/sirupsen/logrus"
)
func (t *transport) Finger(ctx context.Context, targetUsername string, targetDomain string) ([]byte, error) {
- l := logrus.WithField("func", "Finger")
- urlString := fmt.Sprintf("https://%s/.well-known/webfinger?resource=acct:%s@%s", targetDomain, targetUsername, targetDomain)
- l.Debugf("performing GET to %s", urlString)
-
- iri, err := url.Parse(urlString)
- if err != nil {
- return nil, fmt.Errorf("Finger: error parsing url %s: %s", urlString, err)
- }
-
- l.Debugf("performing GET to %s", iri.String())
-
- req, err := http.NewRequestWithContext(ctx, "GET", iri.String(), nil)
+ // Prepare URL string
+ urlStr := "https://" +
+ targetDomain +
+ "/.well-known/webfinger?resource=acct:" +
+ targetUsername + "@" + targetDomain
+
+ // Generate new GET request from URL string
+ req, err := http.NewRequestWithContext(ctx, "GET", urlStr, nil)
if err != nil {
return nil, err
}
-
req.Header.Add("Accept", "application/json")
req.Header.Add("Accept", "application/jrd+json")
- req.Header.Add("Date", t.clock.Now().UTC().Format("Mon, 02 Jan 2006 15:04:05")+" GMT")
- req.Header.Add("User-Agent", fmt.Sprintf("%s %s", t.appAgent, t.gofedAgent))
- req.Header.Set("Host", iri.Host)
- t.getSignerMu.Lock()
- err = t.getSigner.SignRequest(t.privkey, t.pubKeyID, req, nil)
- t.getSignerMu.Unlock()
- if err != nil {
- return nil, err
- }
- resp, err := t.client.Do(req)
+ req.Header.Add("User-Agent", t.controller.userAgent)
+ req.Header.Set("Host", req.URL.Host)
+
+ // Perform the HTTP request
+ rsp, err := t.GET(req)
if err != nil {
return nil, err
}
- defer resp.Body.Close()
- if resp.StatusCode != http.StatusOK {
- return nil, fmt.Errorf("GET request to %s failed (%d): %s", iri.String(), resp.StatusCode, resp.Status)
+ defer rsp.Body.Close()
+
+ // Check for an expected status code
+ if rsp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("GET request to %s failed (%d): %s", urlStr, rsp.StatusCode, rsp.Status)
}
- return ioutil.ReadAll(resp.Body)
+
+ return ioutil.ReadAll(rsp.Body)
}
diff --git a/internal/transport/signing.go b/internal/transport/signing.go
new file mode 100644
index 000000000..39896a2a8
--- /dev/null
+++ b/internal/transport/signing.go
@@ -0,0 +1,43 @@
+/*
+ GoToSocial
+ Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package transport
+
+import (
+ "github.com/go-fed/httpsig"
+)
+
+var (
+ // http signer preferences
+ prefs = []httpsig.Algorithm{httpsig.RSA_SHA256}
+ digestAlgo = httpsig.DigestSha256
+ getHeaders = []string{httpsig.RequestTarget, "host", "date"}
+ postHeaders = []string{httpsig.RequestTarget, "host", "date", "digest"}
+)
+
+// NewGETSigner returns a new httpsig.Signer instance initialized with GTS GET preferences.
+func NewGETSigner(expiresIn int64) (httpsig.Signer, error) {
+ sig, _, err := httpsig.NewSigner(prefs, digestAlgo, getHeaders, httpsig.Signature, expiresIn)
+ return sig, err
+}
+
+// NewPOSTSigner returns a new httpsig.Signer instance initialized with GTS POST preferences.
+func NewPOSTSigner(expiresIn int64) (httpsig.Signer, error) {
+ sig, _, err := httpsig.NewSigner(prefs, digestAlgo, postHeaders, httpsig.Signature, expiresIn)
+ return sig, err
+}
diff --git a/internal/transport/transport.go b/internal/transport/transport.go
index 40c11ca17..c52686c43 100644
--- a/internal/transport/transport.go
+++ b/internal/transport/transport.go
@@ -21,11 +21,18 @@ package transport
import (
"context"
"crypto"
+ "crypto/x509"
+ "errors"
"io"
+ "net/http"
"net/url"
+ "strings"
"sync"
+ "time"
+ errorsv2 "codeberg.org/gruf/go-errors/v2"
"github.com/go-fed/httpsig"
+ "github.com/sirupsen/logrus"
"github.com/superseriousbusiness/activity/pub"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
@@ -43,28 +50,148 @@ type Transport interface {
DereferenceInstance(ctx context.Context, iri *url.URL) (*gtsmodel.Instance, error)
// Finger performs a webfinger request with the given username and domain, and returns the bytes from the response body.
Finger(ctx context.Context, targetUsername string, targetDomains string) ([]byte, error)
- // SigTransport returns the underlying http signature transport wrapped by the GoToSocial transport.
- SigTransport() pub.Transport
}
// transport implements the Transport interface
type transport struct {
- client pub.HttpClient
- appAgent string
- gofedAgent string
- clock pub.Clock
- pubKeyID string
- privkey crypto.PrivateKey
- sigTransport *pub.HttpSigTransport
- getSigner httpsig.Signer
- getSignerMu *sync.Mutex
-
- // shortcuts for dereferencing things that exist on our instance without making an http call to ourself
-
- dereferenceFollowersShortcut func(ctx context.Context, iri *url.URL) ([]byte, error)
- dereferenceUserShortcut func(ctx context.Context, iri *url.URL) ([]byte, error)
+ controller *controller
+ pubKeyID string
+ privkey crypto.PrivateKey
+
+ signerExp time.Time
+ getSigner httpsig.Signer
+ postSigner httpsig.Signer
+ signerMu sync.Mutex
+}
+
+// GET will perform given http request using transport client, retrying on certain preset errors, or if status code is among retryOn.
+func (t *transport) GET(r *http.Request, retryOn ...int) (*http.Response, error) {
+ if r.Method != http.MethodGet {
+ return nil, errors.New("must be GET request")
+ }
+ return t.do(r, func(r *http.Request) error {
+ return t.signGET(r)
+ }, retryOn...)
+}
+
+// POST will perform given http request using transport client, retrying on certain preset errors, or if status code is among retryOn.
+func (t *transport) POST(r *http.Request, body []byte, retryOn ...int) (*http.Response, error) {
+ if r.Method != http.MethodPost {
+ return nil, errors.New("must be POST request")
+ }
+ return t.do(r, func(r *http.Request) error {
+ return t.signPOST(r, body)
+ }, retryOn...)
+}
+
+func (t *transport) do(r *http.Request, signer func(*http.Request) error, retryOn ...int) (*http.Response, error) {
+ const maxRetries = 5
+ backoff := time.Second * 2
+
+ // Start a log entry for this request
+ l := logrus.WithFields(logrus.Fields{
+ "pubKeyID": t.pubKeyID,
+ "method": r.Method,
+ "url": r.URL.String(),
+ })
+
+ for i := 0; i < maxRetries; i++ {
+ // Reset signing header fields
+ now := t.controller.clock.Now().UTC()
+ r.Header.Set("Date", now.Format("Mon, 02 Jan 2006 15:04:05")+" GMT")
+ r.Header.Del("Signature")
+ r.Header.Del("Digest")
+
+ // Perform request signing
+ if err := signer(r); err != nil {
+ return nil, err
+ }
+
+ l.Infof("performing request")
+
+ // Attempt to perform request
+ rsp, err := t.controller.client.Do(r)
+ if err == nil { //nolint shutup linter
+ // TooManyRequest means we need to slow
+ // down and retry our request. Codes over
+ // 500 generally indicate temp. outages.
+ if code := rsp.StatusCode; code < 500 &&
+ code != http.StatusTooManyRequests &&
+ !containsInt(retryOn, rsp.StatusCode) {
+ return rsp, nil
+ }
+
+ // Generate error from status code for logging
+ err = errors.New(`http response "` + rsp.Status + `"`)
+ } else if errorsv2.Is(err, context.DeadlineExceeded, context.Canceled) {
+ // Return early if context has cancelled
+ return nil, err
+ } else if strings.Contains(err.Error(), "stopped after 10 redirects") {
+ // Don't bother if net/http returned after too many redirects
+ return nil, err
+ } else if errors.As(err, &x509.UnknownAuthorityError{}) {
+ // Unknown authority errors we do NOT recover from
+ return nil, err
+ }
+
+ l.Errorf("backing off for %s after http request error: %v", backoff.String(), err)
+
+ select {
+ // Request ctx cancelled
+ case <-r.Context().Done():
+ return nil, r.Context().Err()
+
+ // Backoff for some time
+ case <-time.After(backoff):
+ backoff *= 2
+ }
+ }
+
+ return nil, errors.New("transport reached max retries")
+}
+
+// signGET will safely sign an HTTP GET request.
+func (t *transport) signGET(r *http.Request) (err error) {
+ t.safesign(func() {
+ err = t.getSigner.SignRequest(t.privkey, t.pubKeyID, r, nil)
+ })
+ return
+}
+
+// signPOST will safely sign an HTTP POST request for given body.
+func (t *transport) signPOST(r *http.Request, body []byte) (err error) {
+ t.safesign(func() {
+ err = t.postSigner.SignRequest(t.privkey, t.pubKeyID, r, body)
+ })
+ return
+}
+
+// safesign will perform sign function within mutex protection,
+// and ensured that httpsig.Signers are up-to-date.
+func (t *transport) safesign(sign func()) {
+ // Perform within mu safety
+ t.signerMu.Lock()
+ defer t.signerMu.Unlock()
+
+ if now := time.Now(); now.After(t.signerExp) {
+ const expiry = 120
+
+ // Signers have expired and require renewal
+ t.getSigner, _ = NewGETSigner(expiry)
+ t.postSigner, _ = NewPOSTSigner(expiry)
+ t.signerExp = now.Add(time.Second * expiry)
+ }
+
+ // Perform signing
+ sign()
}
-func (t *transport) SigTransport() pub.Transport {
- return t.sigTransport
+// containsInt checks if slice contains check.
+func containsInt(slice []int, check int) bool {
+ for _, i := range slice {
+ if i == check {
+ return true
+ }
+ }
+ return false
}