diff options
author | 2022-05-15 10:16:43 +0100 | |
---|---|---|
committer | 2022-05-15 11:16:43 +0200 | |
commit | 223025fc27ef636206027b360201877848d426a4 (patch) | |
tree | d2f5f293caabdd82fbb87fed3730eb8f6f2e1c1f /internal | |
parent | [chore] Update LE server to use copy of main http.Server{} to maintain server... (diff) | |
download | gotosocial-223025fc27ef636206027b360201877848d426a4.tar.xz |
[security] transport.Controller{} and transport.Transport{} security and performance improvements (#564)
* cache transports in controller by privkey-generated pubkey, add retry logic to transport requests
Signed-off-by: kim <grufwub@gmail.com>
* update code comments, defer mutex unlocks
Signed-off-by: kim <grufwub@gmail.com>
* add count to 'performing request' log message
Signed-off-by: kim <grufwub@gmail.com>
* reduce repeated conversions of same url.URL object
Signed-off-by: kim <grufwub@gmail.com>
* move worker.Worker to concurrency subpackage, add WorkQueue type, limit transport http client use by WorkQueue
Signed-off-by: kim <grufwub@gmail.com>
* fix security advisories regarding max outgoing conns, max rsp body size
- implemented by a new httpclient.Client{} that wraps an underlying
client with a queue to limit connections, and limit reader wrapping
a response body with a configured maximum size
- update pub.HttpClient args passed around to be this new httpclient.Client{}
Signed-off-by: kim <grufwub@gmail.com>
* add httpclient tests, move ip validation to separate package + change mechanism
Signed-off-by: kim <grufwub@gmail.com>
* fix merge conflicts
Signed-off-by: kim <grufwub@gmail.com>
* use singular mutex in transport rather than separate signer mus
Signed-off-by: kim <grufwub@gmail.com>
* improved useragent string
Signed-off-by: kim <grufwub@gmail.com>
* add note regarding missing test
Signed-off-by: kim <grufwub@gmail.com>
* remove useragent field from transport (instead store in controller)
Signed-off-by: kim <grufwub@gmail.com>
* shutup linter
Signed-off-by: kim <grufwub@gmail.com>
* reset other signing headers on each loop iteration
Signed-off-by: kim <grufwub@gmail.com>
* respect request ctx during retry-backoff sleep period
Signed-off-by: kim <grufwub@gmail.com>
* use external pkg with docs explaining performance "hack"
Signed-off-by: kim <grufwub@gmail.com>
* use http package constants instead of string method literals
Signed-off-by: kim <grufwub@gmail.com>
* add license file headers
Signed-off-by: kim <grufwub@gmail.com>
* update code comment to match new func names
Signed-off-by: kim <grufwub@gmail.com>
* updates to user-agent string
Signed-off-by: kim <grufwub@gmail.com>
* update signed testrig models to fit with new transport logic (instead uses separate signer now)
Signed-off-by: kim <grufwub@gmail.com>
* fuck you linter
Signed-off-by: kim <grufwub@gmail.com>
Diffstat (limited to 'internal')
43 files changed, 1041 insertions, 340 deletions
diff --git a/internal/api/client/account/account_test.go b/internal/api/client/account/account_test.go index d65b49550..d6bb5a5c0 100644 --- a/internal/api/client/account/account_test.go +++ b/internal/api/client/account/account_test.go @@ -11,6 +11,7 @@ import ( "github.com/spf13/viper" "github.com/stretchr/testify/suite" "github.com/superseriousbusiness/gotosocial/internal/api/client/account" + "github.com/superseriousbusiness/gotosocial/internal/concurrency" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/email" @@ -20,7 +21,6 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/messages" "github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/processing" - "github.com/superseriousbusiness/gotosocial/internal/worker" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -62,8 +62,8 @@ func (suite *AccountStandardTestSuite) SetupTest() { testrig.InitTestConfig() testrig.InitTestLog() - fedWorker := worker.New[messages.FromFederator](-1, -1) - clientWorker := worker.New[messages.FromClientAPI](-1, -1) + fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) + clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) suite.db = testrig.NewTestDB() suite.storage = testrig.NewTestStorage() diff --git a/internal/api/client/admin/admin_test.go b/internal/api/client/admin/admin_test.go index 578ab167c..11e2f8354 100644 --- a/internal/api/client/admin/admin_test.go +++ b/internal/api/client/admin/admin_test.go @@ -29,6 +29,7 @@ import ( "github.com/spf13/viper" "github.com/stretchr/testify/suite" "github.com/superseriousbusiness/gotosocial/internal/api/client/admin" + "github.com/superseriousbusiness/gotosocial/internal/concurrency" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/email" @@ -38,7 +39,6 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/messages" "github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/processing" - "github.com/superseriousbusiness/gotosocial/internal/worker" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -80,8 +80,8 @@ func (suite *AdminStandardTestSuite) SetupTest() { testrig.InitTestConfig() testrig.InitTestLog() - fedWorker := worker.New[messages.FromFederator](-1, -1) - clientWorker := worker.New[messages.FromClientAPI](-1, -1) + fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) + clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) suite.db = testrig.NewTestDB() suite.storage = testrig.NewTestStorage() diff --git a/internal/api/client/fileserver/servefile_test.go b/internal/api/client/fileserver/servefile_test.go index 49d813981..d7de2f4f9 100644 --- a/internal/api/client/fileserver/servefile_test.go +++ b/internal/api/client/fileserver/servefile_test.go @@ -31,6 +31,7 @@ import ( "github.com/sirupsen/logrus" "github.com/stretchr/testify/suite" "github.com/superseriousbusiness/gotosocial/internal/api/client/fileserver" + "github.com/superseriousbusiness/gotosocial/internal/concurrency" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/email" "github.com/superseriousbusiness/gotosocial/internal/federation" @@ -40,7 +41,6 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/processing" "github.com/superseriousbusiness/gotosocial/internal/typeutils" - "github.com/superseriousbusiness/gotosocial/internal/worker" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -77,8 +77,8 @@ func (suite *ServeFileTestSuite) SetupSuite() { testrig.InitTestConfig() testrig.InitTestLog() - fedWorker := worker.New[messages.FromFederator](-1, -1) - clientWorker := worker.New[messages.FromClientAPI](-1, -1) + fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) + clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) suite.db = testrig.NewTestDB() suite.storage = testrig.NewTestStorage() diff --git a/internal/api/client/followrequest/followrequest_test.go b/internal/api/client/followrequest/followrequest_test.go index 072025931..14b5656b6 100644 --- a/internal/api/client/followrequest/followrequest_test.go +++ b/internal/api/client/followrequest/followrequest_test.go @@ -28,6 +28,7 @@ import ( "github.com/spf13/viper" "github.com/stretchr/testify/suite" "github.com/superseriousbusiness/gotosocial/internal/api/client/followrequest" + "github.com/superseriousbusiness/gotosocial/internal/concurrency" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/email" @@ -37,7 +38,6 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/messages" "github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/processing" - "github.com/superseriousbusiness/gotosocial/internal/worker" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -77,8 +77,8 @@ func (suite *FollowRequestStandardTestSuite) SetupTest() { testrig.InitTestConfig() testrig.InitTestLog() - fedWorker := worker.New[messages.FromFederator](-1, -1) - clientWorker := worker.New[messages.FromClientAPI](-1, -1) + fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) + clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) suite.db = testrig.NewTestDB() suite.storage = testrig.NewTestStorage() diff --git a/internal/api/client/media/mediacreate_test.go b/internal/api/client/media/mediacreate_test.go index 4d08697ef..e16b9f5eb 100644 --- a/internal/api/client/media/mediacreate_test.go +++ b/internal/api/client/media/mediacreate_test.go @@ -37,6 +37,7 @@ import ( "github.com/stretchr/testify/suite" mediamodule "github.com/superseriousbusiness/gotosocial/internal/api/client/media" "github.com/superseriousbusiness/gotosocial/internal/api/model" + "github.com/superseriousbusiness/gotosocial/internal/concurrency" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/email" @@ -47,7 +48,6 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/processing" "github.com/superseriousbusiness/gotosocial/internal/typeutils" - "github.com/superseriousbusiness/gotosocial/internal/worker" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -84,8 +84,8 @@ func (suite *MediaCreateTestSuite) SetupSuite() { testrig.InitTestConfig() testrig.InitTestLog() - fedWorker := worker.New[messages.FromFederator](-1, -1) - clientWorker := worker.New[messages.FromClientAPI](-1, -1) + fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) + clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) suite.db = testrig.NewTestDB() suite.storage = testrig.NewTestStorage() diff --git a/internal/api/client/media/mediaupdate_test.go b/internal/api/client/media/mediaupdate_test.go index b87e6ec8d..a87718438 100644 --- a/internal/api/client/media/mediaupdate_test.go +++ b/internal/api/client/media/mediaupdate_test.go @@ -35,6 +35,7 @@ import ( "github.com/stretchr/testify/suite" mediamodule "github.com/superseriousbusiness/gotosocial/internal/api/client/media" "github.com/superseriousbusiness/gotosocial/internal/api/model" + "github.com/superseriousbusiness/gotosocial/internal/concurrency" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/email" @@ -45,7 +46,6 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/processing" "github.com/superseriousbusiness/gotosocial/internal/typeutils" - "github.com/superseriousbusiness/gotosocial/internal/worker" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -82,8 +82,8 @@ func (suite *MediaUpdateTestSuite) SetupSuite() { testrig.InitTestConfig() testrig.InitTestLog() - fedWorker := worker.New[messages.FromFederator](-1, -1) - clientWorker := worker.New[messages.FromClientAPI](-1, -1) + fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) + clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) suite.db = testrig.NewTestDB() suite.storage = testrig.NewTestStorage() diff --git a/internal/api/client/status/status_test.go b/internal/api/client/status/status_test.go index a4a56aa0b..e2e2819b5 100644 --- a/internal/api/client/status/status_test.go +++ b/internal/api/client/status/status_test.go @@ -32,6 +32,7 @@ import ( "github.com/superseriousbusiness/activity/pub" "github.com/superseriousbusiness/activity/streams" "github.com/superseriousbusiness/gotosocial/internal/api/client/status" + "github.com/superseriousbusiness/gotosocial/internal/concurrency" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/email" "github.com/superseriousbusiness/gotosocial/internal/federation" @@ -40,7 +41,6 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/messages" "github.com/superseriousbusiness/gotosocial/internal/processing" "github.com/superseriousbusiness/gotosocial/internal/typeutils" - "github.com/superseriousbusiness/gotosocial/internal/worker" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -90,8 +90,8 @@ func (suite *StatusStandardTestSuite) SetupTest() { testrig.StandardDBSetup(suite.db, nil) testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media") - fedWorker := worker.New[messages.FromFederator](-1, -1) - clientWorker := worker.New[messages.FromClientAPI](-1, -1) + fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) + clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage) suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(suite.testHttpClient(), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker) diff --git a/internal/api/client/user/user_test.go b/internal/api/client/user/user_test.go index b0fd2b2e9..6e9c46525 100644 --- a/internal/api/client/user/user_test.go +++ b/internal/api/client/user/user_test.go @@ -22,6 +22,7 @@ import ( "codeberg.org/gruf/go-store/kv" "github.com/stretchr/testify/suite" "github.com/superseriousbusiness/gotosocial/internal/api/client/user" + "github.com/superseriousbusiness/gotosocial/internal/concurrency" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/email" "github.com/superseriousbusiness/gotosocial/internal/federation" @@ -30,7 +31,6 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/messages" "github.com/superseriousbusiness/gotosocial/internal/processing" "github.com/superseriousbusiness/gotosocial/internal/typeutils" - "github.com/superseriousbusiness/gotosocial/internal/worker" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -58,8 +58,8 @@ type UserStandardTestSuite struct { func (suite *UserStandardTestSuite) SetupTest() { testrig.InitTestLog() testrig.InitTestConfig() - fedWorker := worker.New[messages.FromFederator](-1, -1) - clientWorker := worker.New[messages.FromClientAPI](-1, -1) + fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) + clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) suite.testTokens = testrig.NewTestTokens() suite.testClients = testrig.NewTestClients() suite.testApplications = testrig.NewTestApplications() diff --git a/internal/api/s2s/user/inboxpost_test.go b/internal/api/s2s/user/inboxpost_test.go index 6f2909430..388a9fbbb 100644 --- a/internal/api/s2s/user/inboxpost_test.go +++ b/internal/api/s2s/user/inboxpost_test.go @@ -33,11 +33,11 @@ import ( "github.com/superseriousbusiness/activity/pub" "github.com/superseriousbusiness/activity/streams" "github.com/superseriousbusiness/gotosocial/internal/api/s2s/user" + "github.com/superseriousbusiness/gotosocial/internal/concurrency" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/id" "github.com/superseriousbusiness/gotosocial/internal/messages" - "github.com/superseriousbusiness/gotosocial/internal/worker" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -85,8 +85,8 @@ func (suite *InboxPostTestSuite) TestPostBlock() { suite.NoError(err) body := bytes.NewReader(bodyJson) - clientWorker := worker.New[messages.FromClientAPI](-1, -1) - fedWorker := worker.New[messages.FromFederator](-1, -1) + clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) + fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker) federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker) @@ -188,8 +188,8 @@ func (suite *InboxPostTestSuite) TestPostUnblock() { suite.NoError(err) body := bytes.NewReader(bodyJson) - clientWorker := worker.New[messages.FromClientAPI](-1, -1) - fedWorker := worker.New[messages.FromFederator](-1, -1) + clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) + fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker) federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker) @@ -281,8 +281,8 @@ func (suite *InboxPostTestSuite) TestPostUpdate() { suite.NoError(err) body := bytes.NewReader(bodyJson) - clientWorker := worker.New[messages.FromClientAPI](-1, -1) - fedWorker := worker.New[messages.FromFederator](-1, -1) + clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) + fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker) federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker) @@ -403,8 +403,8 @@ func (suite *InboxPostTestSuite) TestPostDelete() { suite.NoError(err) body := bytes.NewReader(bodyJson) - clientWorker := worker.New[messages.FromClientAPI](-1, -1) - fedWorker := worker.New[messages.FromFederator](-1, -1) + clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) + fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker) federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker) diff --git a/internal/api/s2s/user/outboxget_test.go b/internal/api/s2s/user/outboxget_test.go index ea9259b0f..79122731f 100644 --- a/internal/api/s2s/user/outboxget_test.go +++ b/internal/api/s2s/user/outboxget_test.go @@ -31,8 +31,8 @@ import ( "github.com/superseriousbusiness/activity/streams" "github.com/superseriousbusiness/activity/streams/vocab" "github.com/superseriousbusiness/gotosocial/internal/api/s2s/user" + "github.com/superseriousbusiness/gotosocial/internal/concurrency" "github.com/superseriousbusiness/gotosocial/internal/messages" - "github.com/superseriousbusiness/gotosocial/internal/worker" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -46,8 +46,8 @@ func (suite *OutboxGetTestSuite) TestGetOutbox() { signedRequest := derefRequests["foss_satan_dereference_zork_outbox"] targetAccount := suite.testAccounts["local_account_1"] - clientWorker := worker.New[messages.FromClientAPI](-1, -1) - fedWorker := worker.New[messages.FromFederator](-1, -1) + clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) + fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker) federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker) @@ -104,8 +104,8 @@ func (suite *OutboxGetTestSuite) TestGetOutboxFirstPage() { signedRequest := derefRequests["foss_satan_dereference_zork_outbox_first"] targetAccount := suite.testAccounts["local_account_1"] - clientWorker := worker.New[messages.FromClientAPI](-1, -1) - fedWorker := worker.New[messages.FromFederator](-1, -1) + clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) + fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker) federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker) @@ -162,8 +162,8 @@ func (suite *OutboxGetTestSuite) TestGetOutboxNextPage() { signedRequest := derefRequests["foss_satan_dereference_zork_outbox_next"] targetAccount := suite.testAccounts["local_account_1"] - clientWorker := worker.New[messages.FromClientAPI](-1, -1) - fedWorker := worker.New[messages.FromFederator](-1, -1) + clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) + fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker) federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker) diff --git a/internal/api/s2s/user/repliesget_test.go b/internal/api/s2s/user/repliesget_test.go index 4b8364318..845c07bdb 100644 --- a/internal/api/s2s/user/repliesget_test.go +++ b/internal/api/s2s/user/repliesget_test.go @@ -33,8 +33,8 @@ import ( "github.com/superseriousbusiness/activity/streams" "github.com/superseriousbusiness/activity/streams/vocab" "github.com/superseriousbusiness/gotosocial/internal/api/s2s/user" + "github.com/superseriousbusiness/gotosocial/internal/concurrency" "github.com/superseriousbusiness/gotosocial/internal/messages" - "github.com/superseriousbusiness/gotosocial/internal/worker" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -49,8 +49,8 @@ func (suite *RepliesGetTestSuite) TestGetReplies() { targetAccount := suite.testAccounts["local_account_1"] targetStatus := suite.testStatuses["local_account_1_status_1"] - clientWorker := worker.New[messages.FromClientAPI](-1, -1) - fedWorker := worker.New[messages.FromFederator](-1, -1) + clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) + fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker) federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker) @@ -113,8 +113,8 @@ func (suite *RepliesGetTestSuite) TestGetRepliesNext() { targetAccount := suite.testAccounts["local_account_1"] targetStatus := suite.testStatuses["local_account_1_status_1"] - clientWorker := worker.New[messages.FromClientAPI](-1, -1) - fedWorker := worker.New[messages.FromFederator](-1, -1) + clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) + fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker) federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker) @@ -180,8 +180,8 @@ func (suite *RepliesGetTestSuite) TestGetRepliesLast() { targetAccount := suite.testAccounts["local_account_1"] targetStatus := suite.testStatuses["local_account_1_status_1"] - clientWorker := worker.New[messages.FromClientAPI](-1, -1) - fedWorker := worker.New[messages.FromFederator](-1, -1) + clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) + fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker) federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker) diff --git a/internal/api/s2s/user/statusget_test.go b/internal/api/s2s/user/statusget_test.go index c28e4e567..6696bd7e9 100644 --- a/internal/api/s2s/user/statusget_test.go +++ b/internal/api/s2s/user/statusget_test.go @@ -32,8 +32,8 @@ import ( "github.com/superseriousbusiness/activity/streams" "github.com/superseriousbusiness/activity/streams/vocab" "github.com/superseriousbusiness/gotosocial/internal/api/s2s/user" + "github.com/superseriousbusiness/gotosocial/internal/concurrency" "github.com/superseriousbusiness/gotosocial/internal/messages" - "github.com/superseriousbusiness/gotosocial/internal/worker" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -48,8 +48,8 @@ func (suite *StatusGetTestSuite) TestGetStatus() { targetAccount := suite.testAccounts["local_account_1"] targetStatus := suite.testStatuses["local_account_1_status_1"] - clientWorker := worker.New[messages.FromClientAPI](-1, -1) - fedWorker := worker.New[messages.FromFederator](-1, -1) + clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) + fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker) federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker) @@ -116,8 +116,8 @@ func (suite *StatusGetTestSuite) TestGetStatusLowercase() { targetAccount := suite.testAccounts["local_account_1"] targetStatus := suite.testStatuses["local_account_1_status_1"] - clientWorker := worker.New[messages.FromClientAPI](-1, -1) - fedWorker := worker.New[messages.FromFederator](-1, -1) + clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) + fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker) federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker) diff --git a/internal/api/s2s/user/user_test.go b/internal/api/s2s/user/user_test.go index 1ed960544..e8d305d06 100644 --- a/internal/api/s2s/user/user_test.go +++ b/internal/api/s2s/user/user_test.go @@ -23,6 +23,7 @@ import ( "github.com/stretchr/testify/suite" "github.com/superseriousbusiness/gotosocial/internal/api/s2s/user" "github.com/superseriousbusiness/gotosocial/internal/api/security" + "github.com/superseriousbusiness/gotosocial/internal/concurrency" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/email" "github.com/superseriousbusiness/gotosocial/internal/federation" @@ -32,7 +33,6 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/processing" "github.com/superseriousbusiness/gotosocial/internal/typeutils" - "github.com/superseriousbusiness/gotosocial/internal/worker" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -78,8 +78,8 @@ func (suite *UserStandardTestSuite) SetupTest() { testrig.InitTestConfig() testrig.InitTestLog() - clientWorker := worker.New[messages.FromClientAPI](-1, -1) - fedWorker := worker.New[messages.FromFederator](-1, -1) + clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) + fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) suite.db = testrig.NewTestDB() suite.tc = testrig.NewTestTypeConverter(suite.db) diff --git a/internal/api/s2s/user/userget_test.go b/internal/api/s2s/user/userget_test.go index 5c9e4f0d8..5ac2197ff 100644 --- a/internal/api/s2s/user/userget_test.go +++ b/internal/api/s2s/user/userget_test.go @@ -33,9 +33,9 @@ import ( "github.com/superseriousbusiness/activity/streams/vocab" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" "github.com/superseriousbusiness/gotosocial/internal/api/s2s/user" + "github.com/superseriousbusiness/gotosocial/internal/concurrency" "github.com/superseriousbusiness/gotosocial/internal/messages" "github.com/superseriousbusiness/gotosocial/internal/oauth" - "github.com/superseriousbusiness/gotosocial/internal/worker" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -49,8 +49,8 @@ func (suite *UserGetTestSuite) TestGetUser() { signedRequest := derefRequests["foss_satan_dereference_zork"] targetAccount := suite.testAccounts["local_account_1"] - clientWorker := worker.New[messages.FromClientAPI](-1, -1) - fedWorker := worker.New[messages.FromFederator](-1, -1) + clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) + fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker) federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker) @@ -130,8 +130,8 @@ func (suite *UserGetTestSuite) TestGetUserPublicKeyDeleted() { derefRequests := testrig.NewTestDereferenceRequests(suite.testAccounts) signedRequest := derefRequests["foss_satan_dereference_zork_public_key"] - clientWorker := worker.New[messages.FromClientAPI](-1, -1) - fedWorker := worker.New[messages.FromFederator](-1, -1) + clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) + fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker) federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker) diff --git a/internal/api/s2s/webfinger/webfinger_test.go b/internal/api/s2s/webfinger/webfinger_test.go index 1f597d3f9..0df50c503 100644 --- a/internal/api/s2s/webfinger/webfinger_test.go +++ b/internal/api/s2s/webfinger/webfinger_test.go @@ -28,6 +28,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/ap" "github.com/superseriousbusiness/gotosocial/internal/api/s2s/webfinger" "github.com/superseriousbusiness/gotosocial/internal/api/security" + "github.com/superseriousbusiness/gotosocial/internal/concurrency" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/email" "github.com/superseriousbusiness/gotosocial/internal/federation" @@ -37,7 +38,6 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/processing" "github.com/superseriousbusiness/gotosocial/internal/typeutils" - "github.com/superseriousbusiness/gotosocial/internal/worker" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -81,8 +81,8 @@ func (suite *WebfingerStandardTestSuite) SetupTest() { testrig.InitTestLog() testrig.InitTestConfig() - clientWorker := worker.New[messages.FromClientAPI](-1, -1) - fedWorker := worker.New[messages.FromFederator](-1, -1) + clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) + fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) suite.db = testrig.NewTestDB() suite.tc = testrig.NewTestTypeConverter(suite.db) diff --git a/internal/api/s2s/webfinger/webfingerget_test.go b/internal/api/s2s/webfinger/webfingerget_test.go index 55de30f34..7871b6a3f 100644 --- a/internal/api/s2s/webfinger/webfingerget_test.go +++ b/internal/api/s2s/webfinger/webfingerget_test.go @@ -31,10 +31,10 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" "github.com/superseriousbusiness/gotosocial/internal/api/s2s/webfinger" + "github.com/superseriousbusiness/gotosocial/internal/concurrency" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/messages" "github.com/superseriousbusiness/gotosocial/internal/processing" - "github.com/superseriousbusiness/gotosocial/internal/worker" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -71,8 +71,8 @@ func (suite *WebfingerGetTestSuite) TestFingerUser() { func (suite *WebfingerGetTestSuite) TestFingerUserWithDifferentAccountDomainByHost() { viper.Set(config.Keys.Host, "gts.example.org") viper.Set(config.Keys.AccountDomain, "example.org") - clientWorker := worker.New[messages.FromClientAPI](-1, -1) - fedWorker := worker.New[messages.FromFederator](-1, -1) + clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) + fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) suite.processor = processing.NewProcessor(suite.tc, suite.federator, testrig.NewTestOauthServer(suite.db), testrig.NewTestMediaManager(suite.db, suite.storage), suite.storage, suite.db, suite.emailSender, clientWorker, fedWorker) suite.webfingerModule = webfinger.New(suite.processor).(*webfinger.Module) @@ -107,8 +107,8 @@ func (suite *WebfingerGetTestSuite) TestFingerUserWithDifferentAccountDomainByHo func (suite *WebfingerGetTestSuite) TestFingerUserWithDifferentAccountDomainByAccountDomain() { viper.Set(config.Keys.Host, "gts.example.org") viper.Set(config.Keys.AccountDomain, "example.org") - clientWorker := worker.New[messages.FromClientAPI](-1, -1) - fedWorker := worker.New[messages.FromFederator](-1, -1) + clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) + fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) suite.processor = processing.NewProcessor(suite.tc, suite.federator, testrig.NewTestOauthServer(suite.db), testrig.NewTestMediaManager(suite.db, suite.storage), suite.storage, suite.db, suite.emailSender, clientWorker, fedWorker) suite.webfingerModule = webfinger.New(suite.processor).(*webfinger.Module) diff --git a/internal/worker/workers.go b/internal/concurrency/workers.go index 6adf9ad30..2e344aece 100644 --- a/internal/worker/workers.go +++ b/internal/concurrency/workers.go @@ -1,4 +1,4 @@ -package worker +package concurrency import ( "context" @@ -12,17 +12,17 @@ import ( "github.com/sirupsen/logrus" ) -// Worker represents a proccessor for MsgType objects, using a worker pool to allocate resources. -type Worker[MsgType any] struct { +// WorkerPool represents a proccessor for MsgType objects, using a worker pool to allocate resources. +type WorkerPool[MsgType any] struct { workers runners.WorkerPool process func(context.Context, MsgType) error prefix string // contains type prefix for logging } -// New returns a new Worker[MsgType] with given number of workers and queue ratio, +// New returns a new WorkerPool[MsgType] with given number of workers and queue ratio, // where the queue ratio is multiplied by no. workers to get queue size. If args < 1 // then suitable defaults are determined from the runtime's GOMAXPROCS variable. -func New[MsgType any](workers int, queueRatio int) *Worker[MsgType] { +func NewWorkerPool[MsgType any](workers int, queueRatio int) *WorkerPool[MsgType] { var zero MsgType if workers < 1 { @@ -38,7 +38,7 @@ func New[MsgType any](workers int, queueRatio int) *Worker[MsgType] { msgType := reflect.TypeOf(zero).String() _, msgType = path.Split(msgType) - w := &Worker[MsgType]{ + w := &WorkerPool[MsgType]{ workers: runners.NewWorkerPool(workers, workers*queueRatio), process: nil, prefix: fmt.Sprintf("worker.Worker[%s]", msgType), @@ -55,7 +55,7 @@ func New[MsgType any](workers int, queueRatio int) *Worker[MsgType] { } // Start will attempt to start the underlying worker pool, or return error. -func (w *Worker[MsgType]) Start() error { +func (w *WorkerPool[MsgType]) Start() error { logrus.Infof("%s starting", w.prefix) // Check processor was set @@ -72,7 +72,7 @@ func (w *Worker[MsgType]) Start() error { } // Stop will attempt to stop the underlying worker pool, or return error. -func (w *Worker[MsgType]) Stop() error { +func (w *WorkerPool[MsgType]) Stop() error { logrus.Infof("%s stopping", w.prefix) // Attempt to stop pool @@ -84,7 +84,7 @@ func (w *Worker[MsgType]) Stop() error { } // SetProcessor will set the Worker's processor function, which is called for each queued message. -func (w *Worker[MsgType]) SetProcessor(fn func(context.Context, MsgType) error) { +func (w *WorkerPool[MsgType]) SetProcessor(fn func(context.Context, MsgType) error) { if w.process != nil { logrus.Panicf("%s Worker.process is already set", w.prefix) } @@ -92,7 +92,7 @@ func (w *Worker[MsgType]) SetProcessor(fn func(context.Context, MsgType) error) } // Queue will queue provided message to be processed with there's a free worker. -func (w *Worker[MsgType]) Queue(msg MsgType) { +func (w *WorkerPool[MsgType]) Queue(msg MsgType) { logrus.Tracef("%s queueing message (workers=%d queue=%d): %+v", w.prefix, w.workers.Workers(), w.workers.Queue(), msg, ) diff --git a/internal/federation/dereferencing/dereferencer_test.go b/internal/federation/dereferencing/dereferencer_test.go index 441019866..339490e5d 100644 --- a/internal/federation/dereferencing/dereferencer_test.go +++ b/internal/federation/dereferencing/dereferencer_test.go @@ -29,12 +29,12 @@ import ( "github.com/stretchr/testify/suite" "github.com/superseriousbusiness/activity/streams" "github.com/superseriousbusiness/activity/streams/vocab" + "github.com/superseriousbusiness/gotosocial/internal/concurrency" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/federation/dereferencing" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/messages" "github.com/superseriousbusiness/gotosocial/internal/transport" - "github.com/superseriousbusiness/gotosocial/internal/worker" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -150,7 +150,7 @@ func (suite *DereferencerStandardTestSuite) mockTransportController() transport. return response, nil } - fedWorker := worker.New[messages.FromFederator](-1, -1) + fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) mockClient := testrig.NewMockHTTPClient(do) return testrig.NewTestTransportController(mockClient, suite.db, fedWorker) } diff --git a/internal/federation/federatingactor_test.go b/internal/federation/federatingactor_test.go index 4039783a4..fdf907030 100644 --- a/internal/federation/federatingactor_test.go +++ b/internal/federation/federatingactor_test.go @@ -28,10 +28,10 @@ import ( "time" "github.com/stretchr/testify/suite" + "github.com/superseriousbusiness/gotosocial/internal/concurrency" "github.com/superseriousbusiness/gotosocial/internal/federation" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/messages" - "github.com/superseriousbusiness/gotosocial/internal/worker" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -57,7 +57,7 @@ func (suite *FederatingActorTestSuite) TestSendNoRemoteFollowers() { ) testActivity := testrig.WrapAPNoteInCreate(testrig.URLMustParse("http://localhost:8080/whatever_some_create"), testrig.URLMustParse(testAccount.URI), time.Now(), testNote) - fedWorker := worker.New[messages.FromFederator](-1, -1) + fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) // setup transport controller with a no-op client so we don't make external calls sentMessages := []*url.URL{} @@ -112,7 +112,7 @@ func (suite *FederatingActorTestSuite) TestSendRemoteFollower() { ) testActivity := testrig.WrapAPNoteInCreate(testrig.URLMustParse("http://localhost:8080/whatever_some_create"), testrig.URLMustParse(testAccount.URI), time.Now(), testNote) - fedWorker := worker.New[messages.FromFederator](-1, -1) + fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) // setup transport controller with a no-op client so we don't make external calls sentMessages := []*url.URL{} diff --git a/internal/federation/federatingdb/db.go b/internal/federation/federatingdb/db.go index 60f09b909..cbe65e922 100644 --- a/internal/federation/federatingdb/db.go +++ b/internal/federation/federatingdb/db.go @@ -24,10 +24,10 @@ import ( "codeberg.org/gruf/go-mutexes" "github.com/superseriousbusiness/activity/pub" "github.com/superseriousbusiness/activity/streams/vocab" + "github.com/superseriousbusiness/gotosocial/internal/concurrency" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/messages" "github.com/superseriousbusiness/gotosocial/internal/typeutils" - "github.com/superseriousbusiness/gotosocial/internal/worker" ) // DB wraps the pub.Database interface with a couple of custom functions for GoToSocial. @@ -44,12 +44,12 @@ type DB interface { type federatingDB struct { locks mutexes.MutexMap db db.DB - fedWorker *worker.Worker[messages.FromFederator] + fedWorker *concurrency.WorkerPool[messages.FromFederator] typeConverter typeutils.TypeConverter } // New returns a DB interface using the given database and config -func New(db db.DB, fedWorker *worker.Worker[messages.FromFederator]) DB { +func New(db db.DB, fedWorker *concurrency.WorkerPool[messages.FromFederator]) DB { fdb := federatingDB{ locks: mutexes.NewMap(-1, -1), // use defaults db: db, diff --git a/internal/federation/federatingdb/federatingdb_test.go b/internal/federation/federatingdb/federatingdb_test.go index d53294c1c..8e6c1802d 100644 --- a/internal/federation/federatingdb/federatingdb_test.go +++ b/internal/federation/federatingdb/federatingdb_test.go @@ -23,12 +23,12 @@ import ( "github.com/stretchr/testify/suite" "github.com/superseriousbusiness/gotosocial/internal/ap" + "github.com/superseriousbusiness/gotosocial/internal/concurrency" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/federation/federatingdb" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/messages" "github.com/superseriousbusiness/gotosocial/internal/typeutils" - "github.com/superseriousbusiness/gotosocial/internal/worker" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -36,7 +36,7 @@ type FederatingDBTestSuite struct { suite.Suite db db.DB tc typeutils.TypeConverter - fedWorker *worker.Worker[messages.FromFederator] + fedWorker *concurrency.WorkerPool[messages.FromFederator] fromFederator chan messages.FromFederator federatingDB federatingdb.DB @@ -65,7 +65,7 @@ func (suite *FederatingDBTestSuite) SetupSuite() { func (suite *FederatingDBTestSuite) SetupTest() { testrig.InitTestLog() testrig.InitTestConfig() - suite.fedWorker = worker.New[messages.FromFederator](-1, -1) + suite.fedWorker = concurrency.NewWorkerPool[messages.FromFederator](-1, -1) suite.fromFederator = make(chan messages.FromFederator, 10) suite.fedWorker.SetProcessor(func(ctx context.Context, msg messages.FromFederator) error { suite.fromFederator <- msg diff --git a/internal/federation/federatingprotocol_test.go b/internal/federation/federatingprotocol_test.go index 09817cff3..b4769a70f 100644 --- a/internal/federation/federatingprotocol_test.go +++ b/internal/federation/federatingprotocol_test.go @@ -28,10 +28,10 @@ import ( "github.com/stretchr/testify/suite" "github.com/superseriousbusiness/activity/pub" "github.com/superseriousbusiness/gotosocial/internal/ap" + "github.com/superseriousbusiness/gotosocial/internal/concurrency" "github.com/superseriousbusiness/gotosocial/internal/federation" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/messages" - "github.com/superseriousbusiness/gotosocial/internal/worker" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -44,7 +44,7 @@ func (suite *FederatingProtocolTestSuite) TestPostInboxRequestBodyHook() { // the activity we're gonna use activity := suite.testActivities["dm_for_zork"] - fedWorker := worker.New[messages.FromFederator](-1, -1) + fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) // setup transport controller with a no-op client so we don't make external calls tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(func(req *http.Request) (*http.Response, error) { @@ -78,7 +78,7 @@ func (suite *FederatingProtocolTestSuite) TestAuthenticatePostInbox() { sendingAccount := suite.testAccounts["remote_account_1"] inboxAccount := suite.testAccounts["local_account_1"] - fedWorker := worker.New[messages.FromFederator](-1, -1) + fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker) // now setup module being tested, with the mock transport controller diff --git a/internal/httpclient/client.go b/internal/httpclient/client.go new file mode 100644 index 000000000..1a1f5e53b --- /dev/null +++ b/internal/httpclient/client.go @@ -0,0 +1,199 @@ +/* + GoToSocial + Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package httpclient + +import ( + "errors" + "io" + "net" + "net/http" + "net/netip" + "runtime" + "time" +) + +// ErrReservedAddr is returned if a dialed address resolves to an IP within a blocked or reserved net. +var ErrReservedAddr = errors.New("dial within blocked / reserved IP range") + +// ErrBodyTooLarge is returned when a received response body is above predefined limit (default 40MB). +var ErrBodyTooLarge = errors.New("body size too large") + +// dialer is the base net.Dialer used by all package-created http.Transports. +var dialer = &net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + Resolver: &net.Resolver{Dial: nil}, +} + +// Config provides configuration details for setting up a new +// instance of httpclient.Client{}. Within are a subset of the +// configuration values passed to initialized http.Transport{} +// and http.Client{}, along with httpclient.Client{} specific. +type Config struct { + // MaxOpenConns limits the max number of concurrent open connections. + MaxOpenConns int + + // MaxIdleConns: see http.Transport{}.MaxIdleConns. + MaxIdleConns int + + // ReadBufferSize: see http.Transport{}.ReadBufferSize. + ReadBufferSize int + + // WriteBufferSize: see http.Transport{}.WriteBufferSize. + WriteBufferSize int + + // MaxBodySize determines the maximum fetchable body size. + MaxBodySize int64 + + // Timeout: see http.Client{}.Timeout. + Timeout time.Duration + + // DisableCompression: see http.Transport{}.DisableCompression. + DisableCompression bool + + // AllowRanges allows outgoing communications to given IP nets. + AllowRanges []netip.Prefix + + // BlockRanges blocks outgoing communiciations to given IP nets. + BlockRanges []netip.Prefix +} + +// Client wraps an underlying http.Client{} to provide the following: +// - setting a maximum received request body size, returning error on +// large content lengths, and using a limited reader in all other +// cases to protect against forged / unknown content-lengths +// - protection from server side request forgery (SSRF) by only dialing +// out to known public IP prefixes, configurable with allows/blocks +// - limit number of concurrent requests, else blocking until a slot +// is available (context channels still respected) +type Client struct { + client http.Client + queue chan struct{} + bmax int64 +} + +// New returns a new instance of Client initialized using configuration. +func New(cfg Config) *Client { + var c Client + + // Copy global + d := dialer + + if cfg.MaxOpenConns <= 0 { + // By default base this value on GOMAXPROCS. + maxprocs := runtime.GOMAXPROCS(0) + cfg.MaxOpenConns = maxprocs * 10 + } + + if cfg.MaxIdleConns <= 0 { + // By default base this value on MaxOpenConns + cfg.MaxIdleConns = cfg.MaxOpenConns * 10 + } + + if cfg.MaxBodySize <= 0 { + // By default set this to a reasonable 40MB + cfg.MaxBodySize = 40 * 1024 * 1024 + } + + // Protect dialer with IP range sanitizer + d.Control = (&sanitizer{ + allow: cfg.AllowRanges, + block: cfg.BlockRanges, + }).Sanitize + + // Prepare client fields + c.bmax = cfg.MaxBodySize + c.queue = make(chan struct{}, cfg.MaxOpenConns) + c.client.Timeout = cfg.Timeout + + // Set underlying HTTP client roundtripper + c.client.Transport = &http.Transport{ + Proxy: http.ProxyFromEnvironment, + ForceAttemptHTTP2: true, + DialContext: d.DialContext, + MaxIdleConns: cfg.MaxIdleConns, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + ReadBufferSize: cfg.ReadBufferSize, + WriteBufferSize: cfg.WriteBufferSize, + DisableCompression: cfg.DisableCompression, + } + + return &c +} + +// Do will perform given request when an available slot in the queue is available, +// and block until this time. For returned values, this follows the same semantics +// as the standard http.Client{}.Do() implementation except that response body will +// be wrapped by an io.LimitReader() to limit response body sizes. +func (c *Client) Do(req *http.Request) (*http.Response, error) { + select { + // Request context cancelled + case <-req.Context().Done(): + return nil, req.Context().Err() + + // Slot in queue acquired + case c.queue <- struct{}{}: + // NOTE: + // Ideally here we would set the slot release to happen either + // on error return, or via callback from the response body closer. + // However when implementing this, there appear deadlocks between + // the channel queue here and the media manager worker pool. So + // currently we only place a limit on connections dialing out, but + // there may still be more connections open than len(c.queue) given + // that connections may not be closed until response body is closed. + // The current implementation will reduce the viability of denial of + // service attacks, but if there are future issues heed this advice :] + defer func() { <-c.queue }() + } + + // Perform the HTTP request + rsp, err := c.client.Do(req) + if err != nil { + return nil, err + } + + // Check response body not too large + if rsp.ContentLength > c.bmax { + return nil, ErrBodyTooLarge + } + + // Seperate the body implementers + rbody := (io.Reader)(rsp.Body) + cbody := (io.Closer)(rsp.Body) + + var limit int64 + + if limit = rsp.ContentLength; limit < 0 { + // If unknown, use max as reader limit + limit = c.bmax + } + + // Don't trust them, limit body reads + rbody = io.LimitReader(rbody, limit) + + // Wrap body with limit + rsp.Body = &struct { + io.Reader + io.Closer + }{rbody, cbody} + + return rsp, nil +} diff --git a/internal/httpclient/client_test.go b/internal/httpclient/client_test.go new file mode 100644 index 000000000..dc190d430 --- /dev/null +++ b/internal/httpclient/client_test.go @@ -0,0 +1,154 @@ +/* + GoToSocial + Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package httpclient_test + +import ( + "bytes" + "errors" + "io" + "net/http" + "net/http/httptest" + "net/netip" + "testing" + + "github.com/superseriousbusiness/gotosocial/internal/httpclient" +) + +var privateIPs = []string{ + "http://127.0.0.1:80", + "http://0.0.0.0:80", + "http://192.168.0.1:80", + "http://192.168.1.0:80", + "http://10.0.0.0:80", + "http://172.16.0.0:80", + "http://10.255.255.255:80", + "http://172.31.255.255:80", + "http://255.255.255.255:80", +} + +var bodies = []string{ + "hello world!", + "{}", + `{"key": "value", "some": "kinda bullshit"}`, + "body with\r\nnewlines", +} + +// Note: +// There is no test for the .MaxOpenConns implementation +// in the httpclient.Client{}, due to the difficult to test +// this. The block is only held for the actual dial out to +// the connection, so the usual test of blocking and holding +// open this queue slot to check we can't open another isn't +// an easy test here. + +func TestHTTPClientSmallBody(t *testing.T) { + for _, body := range bodies { + _TestHTTPClientWithBody(t, []byte(body), int(^uint16(0))) + } +} + +func TestHTTPClientExactBody(t *testing.T) { + for _, body := range bodies { + _TestHTTPClientWithBody(t, []byte(body), len(body)) + } +} + +func TestHTTPClientLargeBody(t *testing.T) { + for _, body := range bodies { + _TestHTTPClientWithBody(t, []byte(body), len(body)-1) + } +} + +func _TestHTTPClientWithBody(t *testing.T, body []byte, max int) { + var ( + handler http.HandlerFunc + + expect []byte + + expectErr error + ) + + // If this is a larger body, reslice and + // set error so we know what to expect + expect = body + if max < len(body) { + expect = expect[:max] + expectErr = httpclient.ErrBodyTooLarge + } + + // Create new HTTP client with maximum body size + client := httpclient.New(httpclient.Config{ + MaxBodySize: int64(max), + DisableCompression: true, + AllowRanges: []netip.Prefix{ + // Loopback (used by server) + netip.MustParsePrefix("127.0.0.1/8"), + }, + }) + + // Set simple body-writing test handler + handler = func(rw http.ResponseWriter, r *http.Request) { + _, _ = rw.Write(body) + } + + // Start the test server + srv := httptest.NewServer(handler) + defer srv.Close() + + // Wrap body to provide reader iface + rbody := bytes.NewReader(body) + + // Create the test HTTP request + req, _ := http.NewRequest("POST", srv.URL, rbody) + + // Perform the test request + rsp, err := client.Do(req) + if !errors.Is(err, expectErr) { + t.Fatalf("error performing client request: %v", err) + } else if err != nil { + return // expected error + } + defer rsp.Body.Close() + + // Read response body into memory + check, err := io.ReadAll(rsp.Body) + if err != nil { + t.Fatalf("error reading response body: %v", err) + } + + // Check actual response body matches expected + if !bytes.Equal(expect, check) { + t.Errorf("response body did not match expected: expect=%q actual=%q", string(expect), string(check)) + } +} + +func TestHTTPClientPrivateIP(t *testing.T) { + client := httpclient.New(httpclient.Config{}) + + for _, addr := range privateIPs { + // Prepare request to private IP + req, _ := http.NewRequest("GET", addr, nil) + + // Perform the HTTP request + _, err := client.Do(req) + if !errors.Is(err, httpclient.ErrReservedAddr) { + t.Errorf("dialing private address did not return expected error: %v", err) + } + } +} diff --git a/internal/httpclient/sanitizer.go b/internal/httpclient/sanitizer.go new file mode 100644 index 000000000..6eef6898a --- /dev/null +++ b/internal/httpclient/sanitizer.go @@ -0,0 +1,64 @@ +/* + GoToSocial + Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package httpclient + +import ( + "net/netip" + "syscall" + + "github.com/superseriousbusiness/gotosocial/internal/netutil" +) + +type sanitizer struct { + allow []netip.Prefix + block []netip.Prefix +} + +// Sanitize implements the required net.Dialer.Control function signature. +func (s *sanitizer) Sanitize(ntwrk, addr string, _ syscall.RawConn) error { + // Parse IP+port from addr + ipport, err := netip.ParseAddrPort(addr) + if err != nil { + return err + } + + // Seperate the IP + ip := ipport.Addr() + + // Check if this is explicitly allowed + for i := 0; i < len(s.allow); i++ { + if s.allow[i].Contains(ip) { + return nil + } + } + + // Now check if explicity blocked + for i := 0; i < len(s.block); i++ { + if s.block[i].Contains(ip) { + return ErrReservedAddr + } + } + + // Validate this is a safe IP + if !netutil.ValidateIP(ip) { + return ErrReservedAddr + } + + return nil +} diff --git a/internal/media/manager.go b/internal/media/manager.go index 174fca8e2..5b4a01021 100644 --- a/internal/media/manager.go +++ b/internal/media/manager.go @@ -27,9 +27,9 @@ import ( "github.com/robfig/cron/v3" "github.com/sirupsen/logrus" "github.com/spf13/viper" + "github.com/superseriousbusiness/gotosocial/internal/concurrency" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" - "github.com/superseriousbusiness/gotosocial/internal/worker" ) // Manager provides an interface for managing media: parsing, storing, and retrieving media objects like photos, videos, and gifs. @@ -79,8 +79,8 @@ type Manager interface { type manager struct { db db.DB storage *kv.KVStore - emojiWorker *worker.Worker[*ProcessingEmoji] - mediaWorker *worker.Worker[*ProcessingMedia] + emojiWorker *concurrency.WorkerPool[*ProcessingEmoji] + mediaWorker *concurrency.WorkerPool[*ProcessingMedia] stopCronJobs func() error } @@ -89,7 +89,7 @@ type manager struct { // A worker pool will also be initialized for the manager, to ensure that only // a limited number of media will be processed in parallel. The numbers of workers // is determined from the $GOMAXPROCS environment variable (usually no. CPU cores). -// See internal/worker.New() documentation for further information. +// See internal/concurrency.NewWorkerPool() documentation for further information. func NewManager(database db.DB, storage *kv.KVStore) (Manager, error) { m := &manager{ db: database, @@ -97,7 +97,7 @@ func NewManager(database db.DB, storage *kv.KVStore) (Manager, error) { } // Prepare the media worker pool - m.mediaWorker = worker.New[*ProcessingMedia](-1, 10) + m.mediaWorker = concurrency.NewWorkerPool[*ProcessingMedia](-1, 10) m.mediaWorker.SetProcessor(func(ctx context.Context, media *ProcessingMedia) error { if err := ctx.Err(); err != nil { return err @@ -109,7 +109,7 @@ func NewManager(database db.DB, storage *kv.KVStore) (Manager, error) { }) // Prepare the emoji worker pool - m.emojiWorker = worker.New[*ProcessingEmoji](-1, 10) + m.emojiWorker = concurrency.NewWorkerPool[*ProcessingEmoji](-1, 10) m.emojiWorker.SetProcessor(func(ctx context.Context, emoji *ProcessingEmoji) error { if err := ctx.Err(); err != nil { return err diff --git a/internal/netutil/validate.go b/internal/netutil/validate.go new file mode 100644 index 000000000..27cc9ba4a --- /dev/null +++ b/internal/netutil/validate.go @@ -0,0 +1,78 @@ +/* + GoToSocial + Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package netutil + +import ( + "net/netip" +) + +var ( + // IPv6GlobalUnicast is the global IPv6 unicast IP prefix. + IPv6GlobalUnicast = netip.MustParsePrefix("ff00::/8") + + // IPvReserved contains IPv4 reserved IP prefixes. + IPv4Reserved = [...]netip.Prefix{ + netip.MustParsePrefix("0.0.0.0/8"), // Current network + netip.MustParsePrefix("10.0.0.0/8"), // Private + netip.MustParsePrefix("100.64.0.0/10"), // RFC6598 + netip.MustParsePrefix("127.0.0.0/8"), // Loopback + netip.MustParsePrefix("169.254.0.0/16"), // Link-local + netip.MustParsePrefix("172.16.0.0/12"), // Private + netip.MustParsePrefix("192.0.0.0/24"), // RFC6890 + netip.MustParsePrefix("192.0.2.0/24"), // Test, doc, examples + netip.MustParsePrefix("192.88.99.0/24"), // IPv6 to IPv4 relay + netip.MustParsePrefix("192.168.0.0/16"), // Private + netip.MustParsePrefix("198.18.0.0/15"), // Benchmarking tests + netip.MustParsePrefix("198.51.100.0/24"), // Test, doc, examples + netip.MustParsePrefix("203.0.113.0/24"), // Test, doc, examples + netip.MustParsePrefix("224.0.0.0/4"), // Multicast + netip.MustParsePrefix("240.0.0.0/4"), // Reserved (includes broadcast / 255.255.255.255) + } +) + +// ValidateAddr will parse a netip.AddrPort from string, and return the result of ValidateIP() on addr. +func ValidateAddr(s string) bool { + ipport, err := netip.ParseAddrPort(s) + if err != nil { + return false + } + return ValidateIP(ipport.Addr()) +} + +// ValidateIP returns whether IP is an IPv4/6 address in non-reserved, public ranges. +func ValidateIP(ip netip.Addr) bool { + switch { + // IPv4: check if IPv4 in reserved nets + case ip.Is4(): + for _, reserved := range IPv4Reserved { + if reserved.Contains(ip) { + return false + } + } + return true + + // IPv6: check if in global unicast (public internet) + case ip.Is6(): + return IPv6GlobalUnicast.Contains(ip) + + // Assume malicious by default + default: + return false + } +} diff --git a/internal/processing/account/account.go b/internal/processing/account/account.go index c49df1a1a..7668da02c 100644 --- a/internal/processing/account/account.go +++ b/internal/processing/account/account.go @@ -23,6 +23,7 @@ import ( "mime/multipart" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" + "github.com/superseriousbusiness/gotosocial/internal/concurrency" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/federation" "github.com/superseriousbusiness/gotosocial/internal/gtserror" @@ -33,7 +34,6 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/text" "github.com/superseriousbusiness/gotosocial/internal/typeutils" "github.com/superseriousbusiness/gotosocial/internal/visibility" - "github.com/superseriousbusiness/gotosocial/internal/worker" "github.com/superseriousbusiness/oauth2/v4" ) @@ -84,7 +84,7 @@ type Processor interface { type processor struct { tc typeutils.TypeConverter mediaManager media.Manager - clientWorker *worker.Worker[messages.FromClientAPI] + clientWorker *concurrency.WorkerPool[messages.FromClientAPI] oauthServer oauth.Server filter visibility.Filter formatter text.Formatter @@ -94,7 +94,7 @@ type processor struct { } // New returns a new account processor. -func New(db db.DB, tc typeutils.TypeConverter, mediaManager media.Manager, oauthServer oauth.Server, clientWorker *worker.Worker[messages.FromClientAPI], federator federation.Federator, parseMention gtsmodel.ParseMentionFunc) Processor { +func New(db db.DB, tc typeutils.TypeConverter, mediaManager media.Manager, oauthServer oauth.Server, clientWorker *concurrency.WorkerPool[messages.FromClientAPI], federator federation.Federator, parseMention gtsmodel.ParseMentionFunc) Processor { return &processor{ tc: tc, mediaManager: mediaManager, diff --git a/internal/processing/account/account_test.go b/internal/processing/account/account_test.go index 33b744250..d9ce68cc0 100644 --- a/internal/processing/account/account_test.go +++ b/internal/processing/account/account_test.go @@ -24,6 +24,7 @@ import ( "codeberg.org/gruf/go-store/kv" "github.com/stretchr/testify/suite" "github.com/superseriousbusiness/activity/pub" + "github.com/superseriousbusiness/gotosocial/internal/concurrency" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/email" "github.com/superseriousbusiness/gotosocial/internal/federation" @@ -35,7 +36,6 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/processing/account" "github.com/superseriousbusiness/gotosocial/internal/transport" "github.com/superseriousbusiness/gotosocial/internal/typeutils" - "github.com/superseriousbusiness/gotosocial/internal/worker" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -81,8 +81,8 @@ func (suite *AccountStandardTestSuite) SetupTest() { testrig.InitTestLog() testrig.InitTestConfig() - fedWorker := worker.New[messages.FromFederator](-1, -1) - clientWorker := worker.New[messages.FromClientAPI](-1, -1) + fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) + clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) clientWorker.SetProcessor(func(_ context.Context, msg messages.FromClientAPI) error { suite.fromClientAPIChan <- msg return nil diff --git a/internal/processing/admin/admin.go b/internal/processing/admin/admin.go index 4b466a2d7..6779f59b7 100644 --- a/internal/processing/admin/admin.go +++ b/internal/processing/admin/admin.go @@ -23,13 +23,13 @@ import ( "mime/multipart" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" + "github.com/superseriousbusiness/gotosocial/internal/concurrency" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/media" "github.com/superseriousbusiness/gotosocial/internal/messages" "github.com/superseriousbusiness/gotosocial/internal/typeutils" - "github.com/superseriousbusiness/gotosocial/internal/worker" ) // Processor wraps a bunch of functions for processing admin actions. @@ -47,12 +47,12 @@ type Processor interface { type processor struct { tc typeutils.TypeConverter mediaManager media.Manager - clientWorker *worker.Worker[messages.FromClientAPI] + clientWorker *concurrency.WorkerPool[messages.FromClientAPI] db db.DB } // New returns a new admin processor. -func New(db db.DB, tc typeutils.TypeConverter, mediaManager media.Manager, clientWorker *worker.Worker[messages.FromClientAPI]) Processor { +func New(db db.DB, tc typeutils.TypeConverter, mediaManager media.Manager, clientWorker *concurrency.WorkerPool[messages.FromClientAPI]) Processor { return &processor{ tc: tc, mediaManager: mediaManager, diff --git a/internal/processing/media/media_test.go b/internal/processing/media/media_test.go index af67b36b1..1149f2646 100644 --- a/internal/processing/media/media_test.go +++ b/internal/processing/media/media_test.go @@ -26,6 +26,7 @@ import ( "codeberg.org/gruf/go-store/kv" "github.com/sirupsen/logrus" "github.com/stretchr/testify/suite" + "github.com/superseriousbusiness/gotosocial/internal/concurrency" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/media" @@ -33,7 +34,6 @@ import ( mediaprocessing "github.com/superseriousbusiness/gotosocial/internal/processing/media" "github.com/superseriousbusiness/gotosocial/internal/transport" "github.com/superseriousbusiness/gotosocial/internal/typeutils" - "github.com/superseriousbusiness/gotosocial/internal/worker" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -122,7 +122,7 @@ func (suite *MediaStandardTestSuite) mockTransportController() transport.Control return response, nil } - fedWorker := worker.New[messages.FromFederator](-1, -1) + fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) mockClient := testrig.NewMockHTTPClient(do) return testrig.NewTestTransportController(mockClient, suite.db, fedWorker) } diff --git a/internal/processing/processor.go b/internal/processing/processor.go index 69f3100f9..d30f2f37e 100644 --- a/internal/processing/processor.go +++ b/internal/processing/processor.go @@ -25,6 +25,7 @@ import ( "codeberg.org/gruf/go-store/kv" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" + "github.com/superseriousbusiness/gotosocial/internal/concurrency" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/email" "github.com/superseriousbusiness/gotosocial/internal/federation" @@ -44,7 +45,6 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/timeline" "github.com/superseriousbusiness/gotosocial/internal/typeutils" "github.com/superseriousbusiness/gotosocial/internal/visibility" - "github.com/superseriousbusiness/gotosocial/internal/worker" ) // Processor should be passed to api modules (see internal/apimodule/...). It is used for @@ -237,8 +237,8 @@ type Processor interface { // processor just implements the Processor interface type processor struct { - clientWorker *worker.Worker[messages.FromClientAPI] - fedWorker *worker.Worker[messages.FromFederator] + clientWorker *concurrency.WorkerPool[messages.FromClientAPI] + fedWorker *concurrency.WorkerPool[messages.FromFederator] federator federation.Federator tc typeutils.TypeConverter @@ -271,8 +271,8 @@ func NewProcessor( storage *kv.KVStore, db db.DB, emailSender email.Sender, - clientWorker *worker.Worker[messages.FromClientAPI], - fedWorker *worker.Worker[messages.FromFederator], + clientWorker *concurrency.WorkerPool[messages.FromClientAPI], + fedWorker *concurrency.WorkerPool[messages.FromFederator], ) Processor { parseMentionFunc := GetParseMentionFunc(db, federator) diff --git a/internal/processing/processor_test.go b/internal/processing/processor_test.go index 7e1972366..5946e6718 100644 --- a/internal/processing/processor_test.go +++ b/internal/processing/processor_test.go @@ -29,6 +29,7 @@ import ( "codeberg.org/gruf/go-store/kv" "github.com/stretchr/testify/suite" "github.com/superseriousbusiness/activity/streams" + "github.com/superseriousbusiness/gotosocial/internal/concurrency" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/email" "github.com/superseriousbusiness/gotosocial/internal/federation" @@ -40,7 +41,6 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/timeline" "github.com/superseriousbusiness/gotosocial/internal/transport" "github.com/superseriousbusiness/gotosocial/internal/typeutils" - "github.com/superseriousbusiness/gotosocial/internal/worker" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -217,8 +217,8 @@ func (suite *ProcessingStandardTestSuite) SetupTest() { }, nil }) - clientWorker := worker.New[messages.FromClientAPI](-1, -1) - fedWorker := worker.New[messages.FromFederator](-1, -1) + clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) + fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) suite.transportController = testrig.NewTestTransportController(httpClient, suite.db, fedWorker) suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage) diff --git a/internal/processing/status/status.go b/internal/processing/status/status.go index 207bffb30..e8b4a8268 100644 --- a/internal/processing/status/status.go +++ b/internal/processing/status/status.go @@ -22,6 +22,7 @@ import ( "context" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" + "github.com/superseriousbusiness/gotosocial/internal/concurrency" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" @@ -29,7 +30,6 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/text" "github.com/superseriousbusiness/gotosocial/internal/typeutils" "github.com/superseriousbusiness/gotosocial/internal/visibility" - "github.com/superseriousbusiness/gotosocial/internal/worker" ) // Processor wraps a bunch of functions for processing statuses. @@ -74,12 +74,12 @@ type processor struct { db db.DB filter visibility.Filter formatter text.Formatter - clientWorker *worker.Worker[messages.FromClientAPI] + clientWorker *concurrency.WorkerPool[messages.FromClientAPI] parseMention gtsmodel.ParseMentionFunc } // New returns a new status processor. -func New(db db.DB, tc typeutils.TypeConverter, clientWorker *worker.Worker[messages.FromClientAPI], parseMention gtsmodel.ParseMentionFunc) Processor { +func New(db db.DB, tc typeutils.TypeConverter, clientWorker *concurrency.WorkerPool[messages.FromClientAPI], parseMention gtsmodel.ParseMentionFunc) Processor { return &processor{ tc: tc, db: db, diff --git a/internal/processing/status/status_test.go b/internal/processing/status/status_test.go index d2126f03d..17c68c0b6 100644 --- a/internal/processing/status/status_test.go +++ b/internal/processing/status/status_test.go @@ -21,6 +21,7 @@ package status_test import ( "codeberg.org/gruf/go-store/kv" "github.com/stretchr/testify/suite" + "github.com/superseriousbusiness/gotosocial/internal/concurrency" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/federation" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" @@ -30,7 +31,6 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/processing/status" "github.com/superseriousbusiness/gotosocial/internal/transport" "github.com/superseriousbusiness/gotosocial/internal/typeutils" - "github.com/superseriousbusiness/gotosocial/internal/worker" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -42,7 +42,7 @@ type StatusStandardTestSuite struct { storage *kv.KVStore mediaManager media.Manager federator federation.Federator - clientWorker *worker.Worker[messages.FromClientAPI] + clientWorker *concurrency.WorkerPool[messages.FromClientAPI] // standard suite models testTokens map[string]*gtsmodel.Token @@ -75,11 +75,11 @@ func (suite *StatusStandardTestSuite) SetupTest() { testrig.InitTestConfig() testrig.InitTestLog() - fedWorker := worker.New[messages.FromFederator](-1, -1) + fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) suite.db = testrig.NewTestDB() suite.typeConverter = testrig.NewTestTypeConverter(suite.db) - suite.clientWorker = worker.New[messages.FromClientAPI](-1, -1) + suite.clientWorker = concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) suite.tc = testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil), suite.db, fedWorker) suite.storage = testrig.NewTestStorage() suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage) diff --git a/internal/transport/controller.go b/internal/transport/controller.go index 56a922a8b..280d4bc0b 100644 --- a/internal/transport/controller.go +++ b/internal/transport/controller.go @@ -20,13 +20,17 @@ package transport import ( "context" - "crypto" + "crypto/rsa" + "crypto/x509" "encoding/json" "fmt" "net/url" - "sync" + "runtime/debug" + "time" - "github.com/go-fed/httpsig" + "codeberg.org/gruf/go-byteutil" + "codeberg.org/gruf/go-cache/v2" + "github.com/sirupsen/logrus" "github.com/spf13/viper" "github.com/superseriousbusiness/activity/pub" "github.com/superseriousbusiness/activity/streams" @@ -37,109 +41,85 @@ import ( // Controller generates transports for use in making federation requests to other servers. type Controller interface { - NewTransport(pubKeyID string, privkey crypto.PrivateKey) (Transport, error) + // NewTransport returns an http signature transport with the given public key ID (URL location of pubkey), and the given private key. + NewTransport(pubKeyID string, privkey *rsa.PrivateKey) (Transport, error) + + // NewTransportForUsername searches for account with username, and returns result of .NewTransport(). NewTransportForUsername(ctx context.Context, username string) (Transport, error) } type controller struct { - db db.DB - clock pub.Clock - client pub.HttpClient - appAgent string - - // dereferenceFollowersShortcut is a shortcut to dereference followers of an - // account on this instance, without making any external api/http calls. - // - // It is passed to new transports, and should only be invoked when the iri.Host == this host. - dereferenceFollowersShortcut func(ctx context.Context, iri *url.URL) ([]byte, error) - - // dereferenceUserShortcut is a shortcut to dereference followers an account on - // this instance, without making any external api/http calls. - // - // It is passed to new transports, and should only be invoked when the iri.Host == this host. - dereferenceUserShortcut func(ctx context.Context, iri *url.URL) ([]byte, error) + db db.DB + fedDB federatingdb.DB + clock pub.Clock + client pub.HttpClient + cache cache.Cache[string, *transport] + userAgent string } -func dereferenceFollowersShortcut(federatingDB federatingdb.DB) func(context.Context, *url.URL) ([]byte, error) { - return func(ctx context.Context, iri *url.URL) ([]byte, error) { - followers, err := federatingDB.Followers(ctx, iri) - if err != nil { - return nil, err - } +// NewController returns an implementation of the Controller interface for creating new transports +func NewController(db db.DB, federatingDB federatingdb.DB, clock pub.Clock, client pub.HttpClient) Controller { + applicationName := viper.GetString(config.Keys.ApplicationName) + host := viper.GetString(config.Keys.Host) - i, err := streams.Serialize(followers) - if err != nil { - return nil, err - } + // Determine build information + build, _ := debug.ReadBuildInfo() - return json.Marshal(i) + c := &controller{ + db: db, + fedDB: federatingDB, + clock: clock, + client: client, + cache: cache.New[string, *transport](), + userAgent: fmt.Sprintf("%s; %s (gofed/activity gotosocial-%s)", applicationName, host, build.Main.Version), } -} -func dereferenceUserShortcut(federatingDB federatingdb.DB) func(context.Context, *url.URL) ([]byte, error) { - return func(ctx context.Context, iri *url.URL) ([]byte, error) { - user, err := federatingDB.Get(ctx, iri) - if err != nil { - return nil, err - } - - i, err := streams.Serialize(user) - if err != nil { - return nil, err - } - - return json.Marshal(i) + // Transport cache has TTL=1hr freq=1m + c.cache.SetTTL(time.Hour, false) + if !c.cache.Start(time.Minute) { + logrus.Panic("failed to start transport controller cache") } -} -// NewController returns an implementation of the Controller interface for creating new transports -func NewController(db db.DB, federatingDB federatingdb.DB, clock pub.Clock, client pub.HttpClient) Controller { - applicationName := viper.GetString(config.Keys.ApplicationName) - host := viper.GetString(config.Keys.Host) - appAgent := fmt.Sprintf("%s %s", applicationName, host) - - return &controller{ - db: db, - clock: clock, - client: client, - appAgent: appAgent, - dereferenceFollowersShortcut: dereferenceFollowersShortcut(federatingDB), - dereferenceUserShortcut: dereferenceUserShortcut(federatingDB), - } + return c } -// NewTransport returns a new http signature transport with the given public key id (a URL), and the given private key. -func (c *controller) NewTransport(pubKeyID string, privkey crypto.PrivateKey) (Transport, error) { - prefs := []httpsig.Algorithm{httpsig.RSA_SHA256} - digestAlgo := httpsig.DigestSha256 - getHeaders := []string{httpsig.RequestTarget, "host", "date"} - postHeaders := []string{httpsig.RequestTarget, "host", "date", "digest"} +func (c *controller) NewTransport(pubKeyID string, privkey *rsa.PrivateKey) (Transport, error) { + // Generate public key string for cache key + // + // NOTE: it is safe to use the public key as the cache + // key here as we are generating it ourselves from the + // private key. If we were simply using a public key + // provided as argument that would absolutely NOT be safe. + pubStr := privkeyToPublicStr(privkey) + + // First check for cached transport + transp, ok := c.cache.Get(pubStr) + if ok { + return transp, nil + } - getSigner, _, err := httpsig.NewSigner(prefs, digestAlgo, getHeaders, httpsig.Signature, 120) - if err != nil { - return nil, fmt.Errorf("error creating get signer: %s", err) + // Create the transport + transp = &transport{ + controller: c, + pubKeyID: pubKeyID, + privkey: privkey, } - postSigner, _, err := httpsig.NewSigner(prefs, digestAlgo, postHeaders, httpsig.Signature, 120) - if err != nil { - return nil, fmt.Errorf("error creating post signer: %s", err) + // Cache this transport under pubkey + if !c.cache.Put(pubStr, transp) { + var cached *transport + + cached, ok = c.cache.Get(pubStr) + if !ok { + // Some ridiculous race cond. + c.cache.Set(pubStr, transp) + } else { + // Use already cached + transp = cached + } } - sigTransport := pub.NewHttpSigTransport(c.client, c.appAgent, c.clock, getSigner, postSigner, pubKeyID, privkey) - - return &transport{ - client: c.client, - appAgent: c.appAgent, - gofedAgent: "(go-fed/activity v1.0.0)", - clock: c.clock, - pubKeyID: pubKeyID, - privkey: privkey, - sigTransport: sigTransport, - getSigner: getSigner, - getSignerMu: &sync.Mutex{}, - dereferenceFollowersShortcut: c.dereferenceFollowersShortcut, - dereferenceUserShortcut: c.dereferenceUserShortcut, - }, nil + return transp, nil } func (c *controller) NewTransportForUsername(ctx context.Context, username string) (Transport, error) { @@ -164,3 +144,45 @@ func (c *controller) NewTransportForUsername(ctx context.Context, username strin } return transport, nil } + +// dereferenceLocalFollowers is a shortcut to dereference followers of an +// account on this instance, without making any external api/http calls. +// +// It is passed to new transports, and should only be invoked when the iri.Host == this host. +func (c *controller) dereferenceLocalFollowers(ctx context.Context, iri *url.URL) ([]byte, error) { + followers, err := c.fedDB.Followers(ctx, iri) + if err != nil { + return nil, err + } + + i, err := streams.Serialize(followers) + if err != nil { + return nil, err + } + + return json.Marshal(i) +} + +// dereferenceLocalUser is a shortcut to dereference followers an account on +// this instance, without making any external api/http calls. +// +// It is passed to new transports, and should only be invoked when the iri.Host == this host. +func (c *controller) dereferenceLocalUser(ctx context.Context, iri *url.URL) ([]byte, error) { + user, err := c.fedDB.Get(ctx, iri) + if err != nil { + return nil, err + } + + i, err := streams.Serialize(user) + if err != nil { + return nil, err + } + + return json.Marshal(i) +} + +// privkeyToPublicStr will create a string representation of RSA public key from private. +func privkeyToPublicStr(privkey *rsa.PrivateKey) string { + b := x509.MarshalPKCS1PublicKey(&privkey.PublicKey) + return byteutil.B2S(b) +} diff --git a/internal/transport/deliver.go b/internal/transport/deliver.go index fe17f7761..bacaa9b3a 100644 --- a/internal/transport/deliver.go +++ b/internal/transport/deliver.go @@ -19,13 +19,14 @@ package transport import ( + "bytes" "context" "fmt" + "net/http" "net/url" "strings" "sync" - "github.com/sirupsen/logrus" "github.com/spf13/viper" "github.com/superseriousbusiness/gotosocial/internal/config" ) @@ -72,6 +73,28 @@ func (t *transport) Deliver(ctx context.Context, b []byte, to *url.URL) error { return nil } - logrus.Debugf("Deliver: posting as %s to %s", t.pubKeyID, to.String()) - return t.sigTransport.Deliver(ctx, b, to) + urlStr := to.String() + + req, err := http.NewRequestWithContext(ctx, "POST", urlStr, bytes.NewReader(b)) + if err != nil { + return err + } + + req.Header.Add("Content-Type", "application/ld+json; profile=\"https://www.w3.org/ns/activitystreams\"") + req.Header.Add("Accept-Charset", "utf-8") + req.Header.Add("User-Agent", t.controller.userAgent) + req.Header.Set("Host", to.Host) + + resp, err := t.POST(req, b) + if err != nil { + return err + } + defer resp.Body.Close() + + if code := resp.StatusCode; code != http.StatusOK && + code != http.StatusCreated && code != http.StatusAccepted { + return fmt.Errorf("POST request to %s failed (%d): %s", urlStr, resp.StatusCode, resp.Status) + } + + return nil } diff --git a/internal/transport/dereference.go b/internal/transport/dereference.go index 61d99c5c5..36157b673 100644 --- a/internal/transport/dereference.go +++ b/internal/transport/dereference.go @@ -20,32 +20,55 @@ package transport import ( "context" + "fmt" + "io/ioutil" + "net/http" "net/url" - "github.com/sirupsen/logrus" "github.com/spf13/viper" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/uris" ) func (t *transport) Dereference(ctx context.Context, iri *url.URL) ([]byte, error) { - l := logrus.WithField("func", "Dereference") - // if the request is to us, we can shortcut for certain URIs rather than going through // the normal request flow, thereby saving time and energy if iri.Host == viper.GetString(config.Keys.Host) { if uris.IsFollowersPath(iri) { // the request is for followers of one of our accounts, which we can shortcut - return t.dereferenceFollowersShortcut(ctx, iri) + return t.controller.dereferenceLocalFollowers(ctx, iri) } if uris.IsUserPath(iri) { // the request is for one of our accounts, which we can shortcut - return t.dereferenceUserShortcut(ctx, iri) + return t.controller.dereferenceLocalUser(ctx, iri) } } - // the request is either for a remote host or for us but we don't have a shortcut, so continue as normal - l.Debugf("performing GET to %s", iri.String()) - return t.sigTransport.Dereference(ctx, iri) + // Build IRI just once + iriStr := iri.String() + + // Prepare new HTTP request to endpoint + req, err := http.NewRequestWithContext(ctx, "GET", iriStr, nil) + if err != nil { + return nil, err + } + req.Header.Add("Accept", "application/ld+json; profile=\"https://www.w3.org/ns/activitystreams\"") + req.Header.Add("Accept-Charset", "utf-8") + req.Header.Add("User-Agent", t.controller.userAgent) + req.Header.Set("Host", iri.Host) + + // Perform the HTTP request + rsp, err := t.GET(req) + if err != nil { + return nil, err + } + defer rsp.Body.Close() + + // Check for an expected status code + if rsp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("GET request to %s failed (%d): %s", iriStr, rsp.StatusCode, rsp.Status) + } + + return ioutil.ReadAll(rsp.Body) } diff --git a/internal/transport/derefinstance.go b/internal/transport/derefinstance.go index c64dced0f..1acbcc364 100644 --- a/internal/transport/derefinstance.go +++ b/internal/transport/derefinstance.go @@ -80,43 +80,38 @@ func (t *transport) DereferenceInstance(ctx context.Context, iri *url.URL) (*gts } func dereferenceByAPIV1Instance(ctx context.Context, t *transport, iri *url.URL) (*gtsmodel.Instance, error) { - l := logrus.WithField("func", "dereferenceByAPIV1Instance") - cleanIRI := &url.URL{ Scheme: iri.Scheme, Host: iri.Host, Path: "api/v1/instance", } - l.Debugf("performing GET to %s", cleanIRI.String()) - req, err := http.NewRequestWithContext(ctx, "GET", cleanIRI.String(), nil) + // Build IRI just once + iriStr := cleanIRI.String() + + req, err := http.NewRequestWithContext(ctx, "GET", iriStr, nil) if err != nil { return nil, err } + req.Header.Add("Accept", "application/json") - req.Header.Add("Date", t.clock.Now().UTC().Format("Mon, 02 Jan 2006 15:04:05")+" GMT") - req.Header.Add("User-Agent", fmt.Sprintf("%s %s", t.appAgent, t.gofedAgent)) + req.Header.Add("User-Agent", t.controller.userAgent) req.Header.Set("Host", cleanIRI.Host) - t.getSignerMu.Lock() - err = t.getSigner.SignRequest(t.privkey, t.pubKeyID, req, nil) - t.getSignerMu.Unlock() - if err != nil { - return nil, err - } - resp, err := t.client.Do(req) + + resp, err := t.GET(req) if err != nil { return nil, err } defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("GET request to %s failed (%d): %s", cleanIRI.String(), resp.StatusCode, resp.Status) + return nil, fmt.Errorf("GET request to %s failed (%d): %s", iriStr, resp.StatusCode, resp.Status) } + b, err := ioutil.ReadAll(resp.Body) if err != nil { return nil, err - } - - if len(b) == 0 { + } else if len(b) == 0 { return nil, errors.New("response bytes was len 0") } @@ -237,44 +232,37 @@ func dereferenceByNodeInfo(c context.Context, t *transport, iri *url.URL) (*gtsm } func callNodeInfoWellKnown(ctx context.Context, t *transport, iri *url.URL) (*url.URL, error) { - l := logrus.WithField("func", "callNodeInfoWellKnown") - cleanIRI := &url.URL{ Scheme: iri.Scheme, Host: iri.Host, Path: ".well-known/nodeinfo", } - l.Debugf("performing GET to %s", cleanIRI.String()) - req, err := http.NewRequestWithContext(ctx, "GET", cleanIRI.String(), nil) + // Build IRI just once + iriStr := cleanIRI.String() + + req, err := http.NewRequestWithContext(ctx, "GET", iriStr, nil) if err != nil { return nil, err } - req.Header.Add("Accept", "application/json") - req.Header.Add("Date", t.clock.Now().UTC().Format("Mon, 02 Jan 2006 15:04:05")+" GMT") - req.Header.Add("User-Agent", fmt.Sprintf("%s %s", t.appAgent, t.gofedAgent)) + req.Header.Add("User-Agent", t.controller.userAgent) req.Header.Set("Host", cleanIRI.Host) - t.getSignerMu.Lock() - err = t.getSigner.SignRequest(t.privkey, t.pubKeyID, req, nil) - t.getSignerMu.Unlock() - if err != nil { - return nil, err - } - resp, err := t.client.Do(req) + + resp, err := t.GET(req) if err != nil { return nil, err } defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("callNodeInfoWellKnown: GET request to %s failed (%d): %s", cleanIRI.String(), resp.StatusCode, resp.Status) + return nil, fmt.Errorf("callNodeInfoWellKnown: GET request to %s failed (%d): %s", iriStr, resp.StatusCode, resp.Status) } + b, err := ioutil.ReadAll(resp.Body) if err != nil { return nil, err - } - - if len(b) == 0 { + } else if len(b) == 0 { return nil, errors.New("callNodeInfoWellKnown: response bytes was len 0") } @@ -302,38 +290,31 @@ func callNodeInfoWellKnown(ctx context.Context, t *transport, iri *url.URL) (*ur } func callNodeInfo(ctx context.Context, t *transport, iri *url.URL) (*apimodel.Nodeinfo, error) { - l := logrus.WithField("func", "callNodeInfo") + // Build IRI just once + iriStr := iri.String() - l.Debugf("performing GET to %s", iri.String()) - req, err := http.NewRequestWithContext(ctx, "GET", iri.String(), nil) + req, err := http.NewRequestWithContext(ctx, "GET", iriStr, nil) if err != nil { return nil, err } - req.Header.Add("Accept", "application/json") - req.Header.Add("Date", t.clock.Now().UTC().Format("Mon, 02 Jan 2006 15:04:05")+" GMT") - req.Header.Add("User-Agent", fmt.Sprintf("%s %s", t.appAgent, t.gofedAgent)) + req.Header.Add("User-Agent", t.controller.userAgent) req.Header.Set("Host", iri.Host) - t.getSignerMu.Lock() - err = t.getSigner.SignRequest(t.privkey, t.pubKeyID, req, nil) - t.getSignerMu.Unlock() - if err != nil { - return nil, err - } - resp, err := t.client.Do(req) + + resp, err := t.GET(req) if err != nil { return nil, err } defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("callNodeInfo: GET request to %s failed (%d): %s", iri.String(), resp.StatusCode, resp.Status) + return nil, fmt.Errorf("callNodeInfo: GET request to %s failed (%d): %s", iriStr, resp.StatusCode, resp.Status) } + b, err := ioutil.ReadAll(resp.Body) if err != nil { return nil, err - } - - if len(b) == 0 { + } else if len(b) == 0 { return nil, errors.New("callNodeInfo: response bytes was len 0") } diff --git a/internal/transport/derefmedia.go b/internal/transport/derefmedia.go index e3c86ce1e..8feb7ed20 100644 --- a/internal/transport/derefmedia.go +++ b/internal/transport/derefmedia.go @@ -24,34 +24,31 @@ import ( "io" "net/http" "net/url" - - "github.com/sirupsen/logrus" ) func (t *transport) DereferenceMedia(ctx context.Context, iri *url.URL) (io.ReadCloser, int, error) { - l := logrus.WithField("func", "DereferenceMedia") - l.Debugf("performing GET to %s", iri.String()) - req, err := http.NewRequestWithContext(ctx, "GET", iri.String(), nil) + // Build IRI just once + iriStr := iri.String() + + // Prepare HTTP request to this media's IRI + req, err := http.NewRequestWithContext(ctx, "GET", iriStr, nil) if err != nil { return nil, 0, err } - req.Header.Add("Accept", "*/*") // we don't know what kind of media we're going to get here - req.Header.Add("Date", t.clock.Now().UTC().Format("Mon, 02 Jan 2006 15:04:05")+" GMT") - req.Header.Add("User-Agent", fmt.Sprintf("%s %s", t.appAgent, t.gofedAgent)) + req.Header.Add("User-Agent", t.controller.userAgent) req.Header.Set("Host", iri.Host) - t.getSignerMu.Lock() - err = t.getSigner.SignRequest(t.privkey, t.pubKeyID, req, nil) - t.getSignerMu.Unlock() - if err != nil { - return nil, 0, err - } - resp, err := t.client.Do(req) + + // Perform the HTTP request + rsp, err := t.GET(req) if err != nil { return nil, 0, err } - if resp.StatusCode != http.StatusOK { - return nil, 0, fmt.Errorf("GET request to %s failed (%d): %s", iri.String(), resp.StatusCode, resp.Status) + + // Check for an expected status code + if rsp.StatusCode != http.StatusOK { + return nil, 0, fmt.Errorf("GET request to %s failed (%d): %s", iriStr, rsp.StatusCode, rsp.Status) } - return resp.Body, int(resp.ContentLength), nil + + return rsp.Body, int(rsp.ContentLength), nil } diff --git a/internal/transport/finger.go b/internal/transport/finger.go index a71bbb51e..7554a242f 100644 --- a/internal/transport/finger.go +++ b/internal/transport/finger.go @@ -23,46 +23,36 @@ import ( "fmt" "io/ioutil" "net/http" - "net/url" - - "github.com/sirupsen/logrus" ) func (t *transport) Finger(ctx context.Context, targetUsername string, targetDomain string) ([]byte, error) { - l := logrus.WithField("func", "Finger") - urlString := fmt.Sprintf("https://%s/.well-known/webfinger?resource=acct:%s@%s", targetDomain, targetUsername, targetDomain) - l.Debugf("performing GET to %s", urlString) - - iri, err := url.Parse(urlString) - if err != nil { - return nil, fmt.Errorf("Finger: error parsing url %s: %s", urlString, err) - } - - l.Debugf("performing GET to %s", iri.String()) - - req, err := http.NewRequestWithContext(ctx, "GET", iri.String(), nil) + // Prepare URL string + urlStr := "https://" + + targetDomain + + "/.well-known/webfinger?resource=acct:" + + targetUsername + "@" + targetDomain + + // Generate new GET request from URL string + req, err := http.NewRequestWithContext(ctx, "GET", urlStr, nil) if err != nil { return nil, err } - req.Header.Add("Accept", "application/json") req.Header.Add("Accept", "application/jrd+json") - req.Header.Add("Date", t.clock.Now().UTC().Format("Mon, 02 Jan 2006 15:04:05")+" GMT") - req.Header.Add("User-Agent", fmt.Sprintf("%s %s", t.appAgent, t.gofedAgent)) - req.Header.Set("Host", iri.Host) - t.getSignerMu.Lock() - err = t.getSigner.SignRequest(t.privkey, t.pubKeyID, req, nil) - t.getSignerMu.Unlock() - if err != nil { - return nil, err - } - resp, err := t.client.Do(req) + req.Header.Add("User-Agent", t.controller.userAgent) + req.Header.Set("Host", req.URL.Host) + + // Perform the HTTP request + rsp, err := t.GET(req) if err != nil { return nil, err } - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("GET request to %s failed (%d): %s", iri.String(), resp.StatusCode, resp.Status) + defer rsp.Body.Close() + + // Check for an expected status code + if rsp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("GET request to %s failed (%d): %s", urlStr, rsp.StatusCode, rsp.Status) } - return ioutil.ReadAll(resp.Body) + + return ioutil.ReadAll(rsp.Body) } diff --git a/internal/transport/signing.go b/internal/transport/signing.go new file mode 100644 index 000000000..39896a2a8 --- /dev/null +++ b/internal/transport/signing.go @@ -0,0 +1,43 @@ +/* + GoToSocial + Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package transport + +import ( + "github.com/go-fed/httpsig" +) + +var ( + // http signer preferences + prefs = []httpsig.Algorithm{httpsig.RSA_SHA256} + digestAlgo = httpsig.DigestSha256 + getHeaders = []string{httpsig.RequestTarget, "host", "date"} + postHeaders = []string{httpsig.RequestTarget, "host", "date", "digest"} +) + +// NewGETSigner returns a new httpsig.Signer instance initialized with GTS GET preferences. +func NewGETSigner(expiresIn int64) (httpsig.Signer, error) { + sig, _, err := httpsig.NewSigner(prefs, digestAlgo, getHeaders, httpsig.Signature, expiresIn) + return sig, err +} + +// NewPOSTSigner returns a new httpsig.Signer instance initialized with GTS POST preferences. +func NewPOSTSigner(expiresIn int64) (httpsig.Signer, error) { + sig, _, err := httpsig.NewSigner(prefs, digestAlgo, postHeaders, httpsig.Signature, expiresIn) + return sig, err +} diff --git a/internal/transport/transport.go b/internal/transport/transport.go index 40c11ca17..c52686c43 100644 --- a/internal/transport/transport.go +++ b/internal/transport/transport.go @@ -21,11 +21,18 @@ package transport import ( "context" "crypto" + "crypto/x509" + "errors" "io" + "net/http" "net/url" + "strings" "sync" + "time" + errorsv2 "codeberg.org/gruf/go-errors/v2" "github.com/go-fed/httpsig" + "github.com/sirupsen/logrus" "github.com/superseriousbusiness/activity/pub" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) @@ -43,28 +50,148 @@ type Transport interface { DereferenceInstance(ctx context.Context, iri *url.URL) (*gtsmodel.Instance, error) // Finger performs a webfinger request with the given username and domain, and returns the bytes from the response body. Finger(ctx context.Context, targetUsername string, targetDomains string) ([]byte, error) - // SigTransport returns the underlying http signature transport wrapped by the GoToSocial transport. - SigTransport() pub.Transport } // transport implements the Transport interface type transport struct { - client pub.HttpClient - appAgent string - gofedAgent string - clock pub.Clock - pubKeyID string - privkey crypto.PrivateKey - sigTransport *pub.HttpSigTransport - getSigner httpsig.Signer - getSignerMu *sync.Mutex - - // shortcuts for dereferencing things that exist on our instance without making an http call to ourself - - dereferenceFollowersShortcut func(ctx context.Context, iri *url.URL) ([]byte, error) - dereferenceUserShortcut func(ctx context.Context, iri *url.URL) ([]byte, error) + controller *controller + pubKeyID string + privkey crypto.PrivateKey + + signerExp time.Time + getSigner httpsig.Signer + postSigner httpsig.Signer + signerMu sync.Mutex +} + +// GET will perform given http request using transport client, retrying on certain preset errors, or if status code is among retryOn. +func (t *transport) GET(r *http.Request, retryOn ...int) (*http.Response, error) { + if r.Method != http.MethodGet { + return nil, errors.New("must be GET request") + } + return t.do(r, func(r *http.Request) error { + return t.signGET(r) + }, retryOn...) +} + +// POST will perform given http request using transport client, retrying on certain preset errors, or if status code is among retryOn. +func (t *transport) POST(r *http.Request, body []byte, retryOn ...int) (*http.Response, error) { + if r.Method != http.MethodPost { + return nil, errors.New("must be POST request") + } + return t.do(r, func(r *http.Request) error { + return t.signPOST(r, body) + }, retryOn...) +} + +func (t *transport) do(r *http.Request, signer func(*http.Request) error, retryOn ...int) (*http.Response, error) { + const maxRetries = 5 + backoff := time.Second * 2 + + // Start a log entry for this request + l := logrus.WithFields(logrus.Fields{ + "pubKeyID": t.pubKeyID, + "method": r.Method, + "url": r.URL.String(), + }) + + for i := 0; i < maxRetries; i++ { + // Reset signing header fields + now := t.controller.clock.Now().UTC() + r.Header.Set("Date", now.Format("Mon, 02 Jan 2006 15:04:05")+" GMT") + r.Header.Del("Signature") + r.Header.Del("Digest") + + // Perform request signing + if err := signer(r); err != nil { + return nil, err + } + + l.Infof("performing request") + + // Attempt to perform request + rsp, err := t.controller.client.Do(r) + if err == nil { //nolint shutup linter + // TooManyRequest means we need to slow + // down and retry our request. Codes over + // 500 generally indicate temp. outages. + if code := rsp.StatusCode; code < 500 && + code != http.StatusTooManyRequests && + !containsInt(retryOn, rsp.StatusCode) { + return rsp, nil + } + + // Generate error from status code for logging + err = errors.New(`http response "` + rsp.Status + `"`) + } else if errorsv2.Is(err, context.DeadlineExceeded, context.Canceled) { + // Return early if context has cancelled + return nil, err + } else if strings.Contains(err.Error(), "stopped after 10 redirects") { + // Don't bother if net/http returned after too many redirects + return nil, err + } else if errors.As(err, &x509.UnknownAuthorityError{}) { + // Unknown authority errors we do NOT recover from + return nil, err + } + + l.Errorf("backing off for %s after http request error: %v", backoff.String(), err) + + select { + // Request ctx cancelled + case <-r.Context().Done(): + return nil, r.Context().Err() + + // Backoff for some time + case <-time.After(backoff): + backoff *= 2 + } + } + + return nil, errors.New("transport reached max retries") +} + +// signGET will safely sign an HTTP GET request. +func (t *transport) signGET(r *http.Request) (err error) { + t.safesign(func() { + err = t.getSigner.SignRequest(t.privkey, t.pubKeyID, r, nil) + }) + return +} + +// signPOST will safely sign an HTTP POST request for given body. +func (t *transport) signPOST(r *http.Request, body []byte) (err error) { + t.safesign(func() { + err = t.postSigner.SignRequest(t.privkey, t.pubKeyID, r, body) + }) + return +} + +// safesign will perform sign function within mutex protection, +// and ensured that httpsig.Signers are up-to-date. +func (t *transport) safesign(sign func()) { + // Perform within mu safety + t.signerMu.Lock() + defer t.signerMu.Unlock() + + if now := time.Now(); now.After(t.signerExp) { + const expiry = 120 + + // Signers have expired and require renewal + t.getSigner, _ = NewGETSigner(expiry) + t.postSigner, _ = NewPOSTSigner(expiry) + t.signerExp = now.Add(time.Second * expiry) + } + + // Perform signing + sign() } -func (t *transport) SigTransport() pub.Transport { - return t.sigTransport +// containsInt checks if slice contains check. +func containsInt(slice []int, check int) bool { + for _, i := range slice { + if i == check { + return true + } + } + return false } |