diff options
130 files changed, 1034 insertions, 1080 deletions
| diff --git a/cmd/gotosocial/action/server/server.go b/cmd/gotosocial/action/server/server.go index d43599b05..bc56e21f0 100644 --- a/cmd/gotosocial/action/server/server.go +++ b/cmd/gotosocial/action/server/server.go @@ -35,7 +35,6 @@ import (  	"github.com/superseriousbusiness/gotosocial/internal/middleware"  	"go.uber.org/automaxprocs/maxprocs" -	"github.com/superseriousbusiness/gotosocial/internal/concurrency"  	"github.com/superseriousbusiness/gotosocial/internal/config"  	"github.com/superseriousbusiness/gotosocial/internal/db/bundb"  	"github.com/superseriousbusiness/gotosocial/internal/email" @@ -45,7 +44,6 @@ import (  	"github.com/superseriousbusiness/gotosocial/internal/httpclient"  	"github.com/superseriousbusiness/gotosocial/internal/log"  	"github.com/superseriousbusiness/gotosocial/internal/media" -	"github.com/superseriousbusiness/gotosocial/internal/messages"  	"github.com/superseriousbusiness/gotosocial/internal/oauth"  	"github.com/superseriousbusiness/gotosocial/internal/oidc"  	"github.com/superseriousbusiness/gotosocial/internal/processing" @@ -107,19 +105,11 @@ var Start action.GTSAction = func(ctx context.Context) error {  	state.Workers.Start()  	defer state.Workers.Stop() -	// Create the client API and federator worker pools -	// NOTE: these MUST NOT be used until they are passed to the -	// processor and it is started. The reason being that the processor -	// sets the Worker process functions and start the underlying pools -	// TODO: move these into state.Workers (and maybe reformat worker pools). -	clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) -	fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) -  	// build backend handlers  	mediaManager := media.NewManager(&state)  	oauthServer := oauth.New(ctx, dbService)  	typeConverter := typeutils.NewConverter(dbService) -	federatingDB := federatingdb.New(dbService, fedWorker, typeConverter) +	federatingDB := federatingdb.New(&state, typeConverter)  	transportController := transport.NewController(dbService, federatingDB, &federation.Clock{}, client)  	federator := federation.NewFederator(dbService, federatingDB, transportController, typeConverter, mediaManager) @@ -140,11 +130,15 @@ var Start action.GTSAction = func(ctx context.Context) error {  	}  	// create the message processor using the other services we've created so far -	processor := processing.NewProcessor(typeConverter, federator, oauthServer, mediaManager, storage, dbService, emailSender, clientWorker, fedWorker) +	processor := processing.NewProcessor(typeConverter, federator, oauthServer, mediaManager, &state, emailSender)  	if err := processor.Start(); err != nil {  		return fmt.Errorf("error creating processor: %s", err)  	} +	// Set state client / federator worker enqueue functions +	state.Workers.EnqueueClientAPI = processor.EnqueueClientAPI +	state.Workers.EnqueueFederator = processor.EnqueueFederator +  	/*  		HTTP router initialization  	*/ diff --git a/cmd/gotosocial/action/testrig/testrig.go b/cmd/gotosocial/action/testrig/testrig.go index 3be7907fe..68bb94ec3 100644 --- a/cmd/gotosocial/action/testrig/testrig.go +++ b/cmd/gotosocial/action/testrig/testrig.go @@ -33,14 +33,13 @@ import (  	"github.com/superseriousbusiness/gotosocial/cmd/gotosocial/action"  	"github.com/superseriousbusiness/gotosocial/internal/api"  	apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" -	"github.com/superseriousbusiness/gotosocial/internal/concurrency"  	"github.com/superseriousbusiness/gotosocial/internal/config"  	"github.com/superseriousbusiness/gotosocial/internal/gotosocial"  	"github.com/superseriousbusiness/gotosocial/internal/gtserror"  	"github.com/superseriousbusiness/gotosocial/internal/log" -	"github.com/superseriousbusiness/gotosocial/internal/messages"  	"github.com/superseriousbusiness/gotosocial/internal/middleware"  	"github.com/superseriousbusiness/gotosocial/internal/oidc" +	"github.com/superseriousbusiness/gotosocial/internal/state"  	"github.com/superseriousbusiness/gotosocial/internal/storage"  	"github.com/superseriousbusiness/gotosocial/internal/web"  	"github.com/superseriousbusiness/gotosocial/testrig" @@ -48,37 +47,44 @@ import (  // Start creates and starts a gotosocial testrig server  var Start action.GTSAction = func(ctx context.Context) error { +	var state state.State +  	testrig.InitTestConfig()  	testrig.InitTestLog() -	dbService := testrig.NewTestDB() -	testrig.StandardDBSetup(dbService, nil) -	var storageBackend *storage.Driver +	// Initialize caches +	state.Caches.Init() +	state.Caches.Start() +	defer state.Caches.Stop() + +	state.DB = testrig.NewTestDB(&state) +	testrig.StandardDBSetup(state.DB, nil) +  	if os.Getenv("GTS_STORAGE_BACKEND") == "s3" { -		storageBackend, _ = storage.NewS3Storage() +		state.Storage, _ = storage.NewS3Storage()  	} else { -		storageBackend = testrig.NewInMemoryStorage() +		state.Storage = testrig.NewInMemoryStorage()  	} -	testrig.StandardStorageSetup(storageBackend, "./testrig/media") +	testrig.StandardStorageSetup(state.Storage, "./testrig/media") -	// Create client API and federator worker pools -	clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) -	fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) +	// Initialize workers. +	state.Workers.Start() +	defer state.Workers.Stop()  	// build backend handlers -	transportController := testrig.NewTestTransportController(testrig.NewMockHTTPClient(func(req *http.Request) (*http.Response, error) { +	transportController := testrig.NewTestTransportController(&state, testrig.NewMockHTTPClient(func(req *http.Request) (*http.Response, error) {  		r := io.NopCloser(bytes.NewReader([]byte{}))  		return &http.Response{  			StatusCode: 200,  			Body:       r,  		}, nil -	}, ""), dbService, fedWorker) -	mediaManager := testrig.NewTestMediaManager(dbService, storageBackend) -	federator := testrig.NewTestFederator(dbService, transportController, storageBackend, mediaManager, fedWorker) +	}, "")) +	mediaManager := testrig.NewTestMediaManager(&state) +	federator := testrig.NewTestFederator(&state, transportController, mediaManager)  	emailSender := testrig.NewEmailSender("./web/template/", nil) -	processor := testrig.NewTestProcessor(dbService, storageBackend, federator, emailSender, mediaManager, clientWorker, fedWorker) +	processor := testrig.NewTestProcessor(&state, federator, emailSender, mediaManager)  	if err := processor.Start(); err != nil {  		return fmt.Errorf("error starting processor: %s", err)  	} @@ -87,7 +93,7 @@ var Start action.GTSAction = func(ctx context.Context) error {  		HTTP router initialization  	*/ -	router := testrig.NewTestRouter(dbService) +	router := testrig.NewTestRouter(state.DB)  	// attach global middlewares which are used for every request  	router.AttachGlobalMiddleware( @@ -112,7 +118,7 @@ var Start action.GTSAction = func(ctx context.Context) error {  		}  	} -	routerSession, err := dbService.GetSession(ctx) +	routerSession, err := state.DB.GetSession(ctx)  	if err != nil {  		return fmt.Errorf("error retrieving router session for session middleware: %w", err)  	} @@ -123,13 +129,13 @@ var Start action.GTSAction = func(ctx context.Context) error {  	}  	var ( -		authModule        = api.NewAuth(dbService, processor, idp, routerSession, sessionName) // auth/oauth paths -		clientModule      = api.NewClient(dbService, processor)                                // api client endpoints -		fileserverModule  = api.NewFileserver(processor)                                       // fileserver endpoints -		wellKnownModule   = api.NewWellKnown(processor)                                        // .well-known endpoints -		nodeInfoModule    = api.NewNodeInfo(processor)                                         // nodeinfo endpoint -		activityPubModule = api.NewActivityPub(dbService, processor)                           // ActivityPub endpoints -		webModule         = web.New(dbService, processor)                                      // web pages + user profiles + settings panels etc +		authModule        = api.NewAuth(state.DB, processor, idp, routerSession, sessionName) // auth/oauth paths +		clientModule      = api.NewClient(state.DB, processor)                                // api client endpoints +		fileserverModule  = api.NewFileserver(processor)                                      // fileserver endpoints +		wellKnownModule   = api.NewWellKnown(processor)                                       // .well-known endpoints +		nodeInfoModule    = api.NewNodeInfo(processor)                                        // nodeinfo endpoint +		activityPubModule = api.NewActivityPub(state.DB, processor)                           // ActivityPub endpoints +		webModule         = web.New(state.DB, processor)                                      // web pages + user profiles + settings panels etc  	)  	// these should be routed in order @@ -142,7 +148,7 @@ var Start action.GTSAction = func(ctx context.Context) error {  	activityPubModule.RoutePublicKey(router)  	webModule.Route(router) -	gts, err := gotosocial.NewServer(dbService, router, federator, mediaManager) +	gts, err := gotosocial.NewServer(state.DB, router, federator, mediaManager)  	if err != nil {  		return fmt.Errorf("error creating gotosocial service: %s", err)  	} @@ -157,8 +163,8 @@ var Start action.GTSAction = func(ctx context.Context) error {  	sig := <-sigs  	log.Infof(ctx, "received signal %s, shutting down", sig) -	testrig.StandardDBTeardown(dbService) -	testrig.StandardStorageTeardown(storageBackend) +	testrig.StandardDBTeardown(state.DB) +	testrig.StandardStorageTeardown(state.Storage)  	// close down all running services in order  	if err := gts.Stop(ctx); err != nil { diff --git a/internal/api/activitypub/emoji/emojiget_test.go b/internal/api/activitypub/emoji/emojiget_test.go index cd7333955..8f99efdfc 100644 --- a/internal/api/activitypub/emoji/emojiget_test.go +++ b/internal/api/activitypub/emoji/emojiget_test.go @@ -27,15 +27,14 @@ import (  	"github.com/gin-gonic/gin"  	"github.com/stretchr/testify/suite"  	"github.com/superseriousbusiness/gotosocial/internal/api/activitypub/emoji" -	"github.com/superseriousbusiness/gotosocial/internal/concurrency"  	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/email"  	"github.com/superseriousbusiness/gotosocial/internal/federation"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"  	"github.com/superseriousbusiness/gotosocial/internal/media" -	"github.com/superseriousbusiness/gotosocial/internal/messages"  	"github.com/superseriousbusiness/gotosocial/internal/middleware"  	"github.com/superseriousbusiness/gotosocial/internal/processing" +	"github.com/superseriousbusiness/gotosocial/internal/state"  	"github.com/superseriousbusiness/gotosocial/internal/storage"  	"github.com/superseriousbusiness/gotosocial/internal/typeutils"  	"github.com/superseriousbusiness/gotosocial/testrig" @@ -50,6 +49,7 @@ type EmojiGetTestSuite struct {  	emailSender  email.Sender  	processor    *processing.Processor  	storage      *storage.Driver +	state        state.State  	testEmojis   map[string]*gtsmodel.Emoji  	testAccounts map[string]*gtsmodel.Account @@ -65,19 +65,23 @@ func (suite *EmojiGetTestSuite) SetupSuite() {  }  func (suite *EmojiGetTestSuite) SetupTest() { +	suite.state.Caches.Init() +	testrig.StartWorkers(&suite.state) +  	testrig.InitTestConfig()  	testrig.InitTestLog() -	clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) -	fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) +	suite.db = testrig.NewTestDB(&suite.state) +	suite.state.DB = suite.db +	suite.storage = testrig.NewInMemoryStorage() +	suite.state.Storage = suite.storage -	suite.db = testrig.NewTestDB()  	suite.tc = testrig.NewTestTypeConverter(suite.db) -	suite.storage = testrig.NewInMemoryStorage() -	suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage) -	suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker) + +	suite.mediaManager = testrig.NewTestMediaManager(&suite.state) +	suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager)  	suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil) -	suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker) +	suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager)  	suite.emojiModule = emoji.New(suite.processor)  	testrig.StandardDBSetup(suite.db, suite.testAccounts)  	testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media") @@ -90,6 +94,7 @@ func (suite *EmojiGetTestSuite) SetupTest() {  func (suite *EmojiGetTestSuite) TearDownTest() {  	testrig.StandardDBTeardown(suite.db)  	testrig.StandardStorageTeardown(suite.storage) +	testrig.StopWorkers(&suite.state)  }  func (suite *EmojiGetTestSuite) TestGetEmoji() { diff --git a/internal/api/activitypub/users/inboxpost_test.go b/internal/api/activitypub/users/inboxpost_test.go index 0ad63abf7..fa23204c9 100644 --- a/internal/api/activitypub/users/inboxpost_test.go +++ b/internal/api/activitypub/users/inboxpost_test.go @@ -34,11 +34,9 @@ import (  	"github.com/superseriousbusiness/activity/streams"  	"github.com/superseriousbusiness/activity/streams/vocab"  	"github.com/superseriousbusiness/gotosocial/internal/api/activitypub/users" -	"github.com/superseriousbusiness/gotosocial/internal/concurrency"  	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"  	"github.com/superseriousbusiness/gotosocial/internal/id" -	"github.com/superseriousbusiness/gotosocial/internal/messages"  	"github.com/superseriousbusiness/gotosocial/testrig"  ) @@ -86,13 +84,10 @@ func (suite *InboxPostTestSuite) TestPostBlock() {  	suite.NoError(err)  	body := bytes.NewReader(bodyJson) -	clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) -	fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) - -	tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker) -	federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker) +	tc := testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")) +	federator := testrig.NewTestFederator(&suite.state, tc, suite.mediaManager)  	emailSender := testrig.NewEmailSender("../../../../web/template/", nil) -	processor := testrig.NewTestProcessor(suite.db, suite.storage, federator, emailSender, suite.mediaManager, clientWorker, fedWorker) +	processor := testrig.NewTestProcessor(&suite.state, federator, emailSender, suite.mediaManager)  	userModule := users.New(processor)  	suite.NoError(processor.Start()) @@ -190,13 +185,10 @@ func (suite *InboxPostTestSuite) TestPostUnblock() {  	suite.NoError(err)  	body := bytes.NewReader(bodyJson) -	clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) -	fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) - -	tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker) -	federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker) +	tc := testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")) +	federator := testrig.NewTestFederator(&suite.state, tc, suite.mediaManager)  	emailSender := testrig.NewEmailSender("../../../../web/template/", nil) -	processor := testrig.NewTestProcessor(suite.db, suite.storage, federator, emailSender, suite.mediaManager, clientWorker, fedWorker) +	processor := testrig.NewTestProcessor(&suite.state, federator, emailSender, suite.mediaManager)  	userModule := users.New(processor)  	suite.NoError(processor.Start()) @@ -291,9 +283,6 @@ func (suite *InboxPostTestSuite) TestPostUpdate() {  	suite.NoError(err)  	body := bytes.NewReader(bodyJson) -	clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) -	fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) -  	// use a different version of the mock http client which serves the updated  	// version of the remote account, as though it had been updated there too;  	// this is needed so it can be dereferenced + updated properly @@ -301,10 +290,11 @@ func (suite *InboxPostTestSuite) TestPostUpdate() {  	mockHTTPClient.TestRemotePeople = map[string]vocab.ActivityStreamsPerson{  		updatedAccount.URI: asAccount,  	} -	tc := testrig.NewTestTransportController(mockHTTPClient, suite.db, fedWorker) -	federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker) + +	tc := testrig.NewTestTransportController(&suite.state, mockHTTPClient) +	federator := testrig.NewTestFederator(&suite.state, tc, suite.mediaManager)  	emailSender := testrig.NewEmailSender("../../../../web/template/", nil) -	processor := testrig.NewTestProcessor(suite.db, suite.storage, federator, emailSender, suite.mediaManager, clientWorker, fedWorker) +	processor := testrig.NewTestProcessor(&suite.state, federator, emailSender, suite.mediaManager)  	userModule := users.New(processor)  	suite.NoError(processor.Start()) @@ -430,15 +420,12 @@ func (suite *InboxPostTestSuite) TestPostDelete() {  	suite.NoError(err)  	body := bytes.NewReader(bodyJson) -	clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) -	fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) - -	tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker) -	federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker) +	tc := testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")) +	federator := testrig.NewTestFederator(&suite.state, tc, suite.mediaManager)  	emailSender := testrig.NewEmailSender("../../../../web/template/", nil) -	processor := testrig.NewTestProcessor(suite.db, suite.storage, federator, emailSender, suite.mediaManager, clientWorker, fedWorker) -	suite.NoError(processor.Start()) +	processor := testrig.NewTestProcessor(&suite.state, federator, emailSender, suite.mediaManager)  	userModule := users.New(processor) +	suite.NoError(processor.Start())  	// setup request  	recorder := httptest.NewRecorder() diff --git a/internal/api/activitypub/users/outboxget_test.go b/internal/api/activitypub/users/outboxget_test.go index 6e5c4e1e0..8f3306a25 100644 --- a/internal/api/activitypub/users/outboxget_test.go +++ b/internal/api/activitypub/users/outboxget_test.go @@ -32,8 +32,6 @@ import (  	"github.com/superseriousbusiness/activity/streams"  	"github.com/superseriousbusiness/activity/streams/vocab"  	"github.com/superseriousbusiness/gotosocial/internal/api/activitypub/users" -	"github.com/superseriousbusiness/gotosocial/internal/concurrency" -	"github.com/superseriousbusiness/gotosocial/internal/messages"  	"github.com/superseriousbusiness/gotosocial/testrig"  ) @@ -104,13 +102,10 @@ func (suite *OutboxGetTestSuite) TestGetOutboxFirstPage() {  	signedRequest := derefRequests["foss_satan_dereference_zork_outbox_first"]  	targetAccount := suite.testAccounts["local_account_1"] -	clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) -	fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) - -	tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker) -	federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker) +	tc := testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")) +	federator := testrig.NewTestFederator(&suite.state, tc, suite.mediaManager)  	emailSender := testrig.NewEmailSender("../../../../web/template/", nil) -	processor := testrig.NewTestProcessor(suite.db, suite.storage, federator, emailSender, suite.mediaManager, clientWorker, fedWorker) +	processor := testrig.NewTestProcessor(&suite.state, federator, emailSender, suite.mediaManager)  	userModule := users.New(processor)  	suite.NoError(processor.Start()) @@ -182,13 +177,10 @@ func (suite *OutboxGetTestSuite) TestGetOutboxNextPage() {  	signedRequest := derefRequests["foss_satan_dereference_zork_outbox_next"]  	targetAccount := suite.testAccounts["local_account_1"] -	clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) -	fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) - -	tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker) -	federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker) +	tc := testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")) +	federator := testrig.NewTestFederator(&suite.state, tc, suite.mediaManager)  	emailSender := testrig.NewEmailSender("../../../../web/template/", nil) -	processor := testrig.NewTestProcessor(suite.db, suite.storage, federator, emailSender, suite.mediaManager, clientWorker, fedWorker) +	processor := testrig.NewTestProcessor(&suite.state, federator, emailSender, suite.mediaManager)  	userModule := users.New(processor)  	suite.NoError(processor.Start()) diff --git a/internal/api/activitypub/users/repliesget_test.go b/internal/api/activitypub/users/repliesget_test.go index 4e985a0a1..92e5cddfa 100644 --- a/internal/api/activitypub/users/repliesget_test.go +++ b/internal/api/activitypub/users/repliesget_test.go @@ -33,8 +33,6 @@ import (  	"github.com/superseriousbusiness/activity/streams"  	"github.com/superseriousbusiness/activity/streams/vocab"  	"github.com/superseriousbusiness/gotosocial/internal/api/activitypub/users" -	"github.com/superseriousbusiness/gotosocial/internal/concurrency" -	"github.com/superseriousbusiness/gotosocial/internal/messages"  	"github.com/superseriousbusiness/gotosocial/testrig"  ) @@ -104,13 +102,10 @@ func (suite *RepliesGetTestSuite) TestGetRepliesNext() {  	targetAccount := suite.testAccounts["local_account_1"]  	targetStatus := suite.testStatuses["local_account_1_status_1"] -	clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) -	fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) - -	tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker) -	federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker) +	tc := testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")) +	federator := testrig.NewTestFederator(&suite.state, tc, suite.mediaManager)  	emailSender := testrig.NewEmailSender("../../../../web/template/", nil) -	processor := testrig.NewTestProcessor(suite.db, suite.storage, federator, emailSender, suite.mediaManager, clientWorker, fedWorker) +	processor := testrig.NewTestProcessor(&suite.state, federator, emailSender, suite.mediaManager)  	userModule := users.New(processor)  	suite.NoError(processor.Start()) @@ -172,13 +167,10 @@ func (suite *RepliesGetTestSuite) TestGetRepliesLast() {  	targetAccount := suite.testAccounts["local_account_1"]  	targetStatus := suite.testStatuses["local_account_1_status_1"] -	clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) -	fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) - -	tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker) -	federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker) +	tc := testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")) +	federator := testrig.NewTestFederator(&suite.state, tc, suite.mediaManager)  	emailSender := testrig.NewEmailSender("../../../../web/template/", nil) -	processor := testrig.NewTestProcessor(suite.db, suite.storage, federator, emailSender, suite.mediaManager, clientWorker, fedWorker) +	processor := testrig.NewTestProcessor(&suite.state, federator, emailSender, suite.mediaManager)  	userModule := users.New(processor)  	suite.NoError(processor.Start()) diff --git a/internal/api/activitypub/users/user_test.go b/internal/api/activitypub/users/user_test.go index 0124925b9..d025eada0 100644 --- a/internal/api/activitypub/users/user_test.go +++ b/internal/api/activitypub/users/user_test.go @@ -22,15 +22,14 @@ import (  	"github.com/gin-gonic/gin"  	"github.com/stretchr/testify/suite"  	"github.com/superseriousbusiness/gotosocial/internal/api/activitypub/users" -	"github.com/superseriousbusiness/gotosocial/internal/concurrency"  	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/email"  	"github.com/superseriousbusiness/gotosocial/internal/federation"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"  	"github.com/superseriousbusiness/gotosocial/internal/media" -	"github.com/superseriousbusiness/gotosocial/internal/messages"  	"github.com/superseriousbusiness/gotosocial/internal/middleware"  	"github.com/superseriousbusiness/gotosocial/internal/processing" +	"github.com/superseriousbusiness/gotosocial/internal/state"  	"github.com/superseriousbusiness/gotosocial/internal/storage"  	"github.com/superseriousbusiness/gotosocial/internal/typeutils"  	"github.com/superseriousbusiness/gotosocial/testrig" @@ -46,6 +45,7 @@ type UserStandardTestSuite struct {  	emailSender  email.Sender  	processor    *processing.Processor  	storage      *storage.Driver +	state        state.State  	// standard suite models  	testTokens       map[string]*gtsmodel.Token @@ -75,19 +75,21 @@ func (suite *UserStandardTestSuite) SetupSuite() {  }  func (suite *UserStandardTestSuite) SetupTest() { +	suite.state.Caches.Init() +	testrig.StartWorkers(&suite.state) +  	testrig.InitTestConfig()  	testrig.InitTestLog() -	clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) -	fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) - -	suite.db = testrig.NewTestDB() +	suite.db = testrig.NewTestDB(&suite.state) +	suite.state.DB = suite.db  	suite.tc = testrig.NewTestTypeConverter(suite.db)  	suite.storage = testrig.NewInMemoryStorage() -	suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage) -	suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker) +	suite.state.Storage = suite.storage +	suite.mediaManager = testrig.NewTestMediaManager(&suite.state) +	suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager)  	suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil) -	suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker) +	suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager)  	suite.userModule = users.New(suite.processor)  	testrig.StandardDBSetup(suite.db, suite.testAccounts)  	testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media") @@ -100,4 +102,5 @@ func (suite *UserStandardTestSuite) SetupTest() {  func (suite *UserStandardTestSuite) TearDownTest() {  	testrig.StandardDBTeardown(suite.db)  	testrig.StandardStorageTeardown(suite.storage) +	testrig.StopWorkers(&suite.state)  } diff --git a/internal/api/auth/auth_test.go b/internal/api/auth/auth_test.go index a5e518cda..1a15155bd 100644 --- a/internal/api/auth/auth_test.go +++ b/internal/api/auth/auth_test.go @@ -28,17 +28,16 @@ import (  	"github.com/gin-gonic/gin"  	"github.com/stretchr/testify/suite"  	"github.com/superseriousbusiness/gotosocial/internal/api/auth" -	"github.com/superseriousbusiness/gotosocial/internal/concurrency"  	"github.com/superseriousbusiness/gotosocial/internal/config"  	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/email"  	"github.com/superseriousbusiness/gotosocial/internal/federation"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"  	"github.com/superseriousbusiness/gotosocial/internal/media" -	"github.com/superseriousbusiness/gotosocial/internal/messages"  	"github.com/superseriousbusiness/gotosocial/internal/middleware"  	"github.com/superseriousbusiness/gotosocial/internal/oidc"  	"github.com/superseriousbusiness/gotosocial/internal/processing" +	"github.com/superseriousbusiness/gotosocial/internal/state"  	"github.com/superseriousbusiness/gotosocial/internal/storage"  	"github.com/superseriousbusiness/gotosocial/testrig"  ) @@ -47,6 +46,7 @@ type AuthStandardTestSuite struct {  	suite.Suite  	db           db.DB  	storage      *storage.Driver +	state        state.State  	mediaManager media.Manager  	federator    federation.Federator  	processor    *processing.Processor @@ -78,18 +78,19 @@ func (suite *AuthStandardTestSuite) SetupSuite() {  }  func (suite *AuthStandardTestSuite) SetupTest() { +	suite.state.Caches.Init() +  	testrig.InitTestConfig()  	testrig.InitTestLog() -	fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) -	clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) - -	suite.db = testrig.NewTestDB() +	suite.db = testrig.NewTestDB(&suite.state) +	suite.state.DB = suite.db  	suite.storage = testrig.NewInMemoryStorage() -	suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage) -	suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker) +	suite.state.Storage = suite.storage +	suite.mediaManager = testrig.NewTestMediaManager(&suite.state) +	suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../testrig/media")), suite.mediaManager)  	suite.emailSender = testrig.NewEmailSender("../../../web/template/", nil) -	suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker) +	suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager)  	suite.authModule = auth.New(suite.db, suite.processor, suite.idp)  	testrig.StandardDBSetup(suite.db, suite.testAccounts)  } diff --git a/internal/api/client/accounts/account_test.go b/internal/api/client/accounts/account_test.go index 5a25c12f1..ab3f4cd1f 100644 --- a/internal/api/client/accounts/account_test.go +++ b/internal/api/client/accounts/account_test.go @@ -27,16 +27,15 @@ import (  	"github.com/gin-gonic/gin"  	"github.com/stretchr/testify/suite"  	"github.com/superseriousbusiness/gotosocial/internal/api/client/accounts" -	"github.com/superseriousbusiness/gotosocial/internal/concurrency"  	"github.com/superseriousbusiness/gotosocial/internal/config"  	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/email"  	"github.com/superseriousbusiness/gotosocial/internal/federation"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"  	"github.com/superseriousbusiness/gotosocial/internal/media" -	"github.com/superseriousbusiness/gotosocial/internal/messages"  	"github.com/superseriousbusiness/gotosocial/internal/oauth"  	"github.com/superseriousbusiness/gotosocial/internal/processing" +	"github.com/superseriousbusiness/gotosocial/internal/state"  	"github.com/superseriousbusiness/gotosocial/internal/storage"  	"github.com/superseriousbusiness/gotosocial/testrig"  ) @@ -51,6 +50,7 @@ type AccountStandardTestSuite struct {  	processor    *processing.Processor  	emailSender  email.Sender  	sentEmails   map[string]string +	state        state.State  	// standard suite models  	testTokens       map[string]*gtsmodel.Token @@ -76,19 +76,22 @@ func (suite *AccountStandardTestSuite) SetupSuite() {  }  func (suite *AccountStandardTestSuite) SetupTest() { +	suite.state.Caches.Init() +	testrig.StartWorkers(&suite.state) +  	testrig.InitTestConfig()  	testrig.InitTestLog() -	fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) -	clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) - -	suite.db = testrig.NewTestDB() +	suite.db = testrig.NewTestDB(&suite.state) +	suite.state.DB = suite.db  	suite.storage = testrig.NewInMemoryStorage() -	suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage) -	suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker) +	suite.state.Storage = suite.storage + +	suite.mediaManager = testrig.NewTestMediaManager(&suite.state) +	suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager)  	suite.sentEmails = make(map[string]string)  	suite.emailSender = testrig.NewEmailSender("../../../../web/template/", suite.sentEmails) -	suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker) +	suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager)  	suite.accountsModule = accounts.New(suite.processor)  	testrig.StandardDBSetup(suite.db, nil)  	testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media") @@ -99,6 +102,7 @@ func (suite *AccountStandardTestSuite) SetupTest() {  func (suite *AccountStandardTestSuite) TearDownTest() {  	testrig.StandardDBTeardown(suite.db)  	testrig.StandardStorageTeardown(suite.storage) +	testrig.StopWorkers(&suite.state)  }  func (suite *AccountStandardTestSuite) newContext(recorder *httptest.ResponseRecorder, requestMethod string, requestBody []byte, requestPath string, bodyContentType string) *gin.Context { diff --git a/internal/api/client/admin/admin_test.go b/internal/api/client/admin/admin_test.go index 4f3f48904..1d19635f0 100644 --- a/internal/api/client/admin/admin_test.go +++ b/internal/api/client/admin/admin_test.go @@ -27,16 +27,15 @@ import (  	"github.com/gin-gonic/gin"  	"github.com/stretchr/testify/suite"  	"github.com/superseriousbusiness/gotosocial/internal/api/client/admin" -	"github.com/superseriousbusiness/gotosocial/internal/concurrency"  	"github.com/superseriousbusiness/gotosocial/internal/config"  	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/email"  	"github.com/superseriousbusiness/gotosocial/internal/federation"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"  	"github.com/superseriousbusiness/gotosocial/internal/media" -	"github.com/superseriousbusiness/gotosocial/internal/messages"  	"github.com/superseriousbusiness/gotosocial/internal/oauth"  	"github.com/superseriousbusiness/gotosocial/internal/processing" +	"github.com/superseriousbusiness/gotosocial/internal/state"  	"github.com/superseriousbusiness/gotosocial/internal/storage"  	"github.com/superseriousbusiness/gotosocial/testrig"  ) @@ -51,6 +50,7 @@ type AdminStandardTestSuite struct {  	processor    *processing.Processor  	emailSender  email.Sender  	sentEmails   map[string]string +	state        state.State  	// standard suite models  	testTokens          map[string]*gtsmodel.Token @@ -82,19 +82,22 @@ func (suite *AdminStandardTestSuite) SetupSuite() {  }  func (suite *AdminStandardTestSuite) SetupTest() { +	suite.state.Caches.Init() +	testrig.StartWorkers(&suite.state) +  	testrig.InitTestConfig()  	testrig.InitTestLog() -	fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) -	clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) - -	suite.db = testrig.NewTestDB() +	suite.db = testrig.NewTestDB(&suite.state) +	suite.state.DB = suite.db  	suite.storage = testrig.NewInMemoryStorage() -	suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage) -	suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker) +	suite.state.Storage = suite.storage + +	suite.mediaManager = testrig.NewTestMediaManager(&suite.state) +	suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager)  	suite.sentEmails = make(map[string]string)  	suite.emailSender = testrig.NewEmailSender("../../../../web/template/", suite.sentEmails) -	suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker) +	suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager)  	suite.adminModule = admin.New(suite.processor)  	testrig.StandardDBSetup(suite.db, nil)  	testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media") @@ -103,6 +106,7 @@ func (suite *AdminStandardTestSuite) SetupTest() {  func (suite *AdminStandardTestSuite) TearDownTest() {  	testrig.StandardDBTeardown(suite.db)  	testrig.StandardStorageTeardown(suite.storage) +	testrig.StopWorkers(&suite.state)  }  func (suite *AdminStandardTestSuite) newContext(recorder *httptest.ResponseRecorder, requestMethod string, requestBody []byte, requestPath string, bodyContentType string) *gin.Context { diff --git a/internal/api/client/bookmarks/bookmarks_test.go b/internal/api/client/bookmarks/bookmarks_test.go index c39ad49f3..931d504f7 100644 --- a/internal/api/client/bookmarks/bookmarks_test.go +++ b/internal/api/client/bookmarks/bookmarks_test.go @@ -32,16 +32,15 @@ import (  	"github.com/superseriousbusiness/gotosocial/internal/api/client/bookmarks"  	"github.com/superseriousbusiness/gotosocial/internal/api/client/statuses"  	apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" -	"github.com/superseriousbusiness/gotosocial/internal/concurrency"  	"github.com/superseriousbusiness/gotosocial/internal/config"  	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/email"  	"github.com/superseriousbusiness/gotosocial/internal/federation"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"  	"github.com/superseriousbusiness/gotosocial/internal/media" -	"github.com/superseriousbusiness/gotosocial/internal/messages"  	"github.com/superseriousbusiness/gotosocial/internal/oauth"  	"github.com/superseriousbusiness/gotosocial/internal/processing" +	"github.com/superseriousbusiness/gotosocial/internal/state"  	"github.com/superseriousbusiness/gotosocial/internal/storage"  	"github.com/superseriousbusiness/gotosocial/internal/typeutils"  	"github.com/superseriousbusiness/gotosocial/testrig" @@ -57,6 +56,7 @@ type BookmarkTestSuite struct {  	emailSender  email.Sender  	processor    *processing.Processor  	storage      *storage.Driver +	state        state.State  	// standard suite models  	testTokens       map[string]*gtsmodel.Token @@ -87,22 +87,25 @@ func (suite *BookmarkTestSuite) SetupSuite() {  }  func (suite *BookmarkTestSuite) SetupTest() { +	suite.state.Caches.Init() +	testrig.StartWorkers(&suite.state) +  	testrig.InitTestConfig()  	testrig.InitTestLog() -	suite.db = testrig.NewTestDB() -	suite.tc = testrig.NewTestTypeConverter(suite.db) +	suite.db = testrig.NewTestDB(&suite.state) +	suite.state.DB = suite.db  	suite.storage = testrig.NewInMemoryStorage() +	suite.state.Storage = suite.storage + +	suite.tc = testrig.NewTestTypeConverter(suite.db)  	testrig.StandardDBSetup(suite.db, nil)  	testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media") -	fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) -	clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) - -	suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage) -	suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker) +	suite.mediaManager = testrig.NewTestMediaManager(&suite.state) +	suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager)  	suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil) -	suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker) +	suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager)  	suite.statusModule = statuses.New(suite.processor)  	suite.bookmarkModule = bookmarks.New(suite.processor) @@ -112,6 +115,7 @@ func (suite *BookmarkTestSuite) SetupTest() {  func (suite *BookmarkTestSuite) TearDownTest() {  	testrig.StandardDBTeardown(suite.db)  	testrig.StandardStorageTeardown(suite.storage) +	testrig.StopWorkers(&suite.state)  }  func (suite *BookmarkTestSuite) getBookmarks( diff --git a/internal/api/client/favourites/favourites_test.go b/internal/api/client/favourites/favourites_test.go index 7949aa38c..71c7097cc 100644 --- a/internal/api/client/favourites/favourites_test.go +++ b/internal/api/client/favourites/favourites_test.go @@ -21,14 +21,13 @@ package favourites_test  import (  	"github.com/stretchr/testify/suite"  	"github.com/superseriousbusiness/gotosocial/internal/api/client/favourites" -	"github.com/superseriousbusiness/gotosocial/internal/concurrency"  	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/email"  	"github.com/superseriousbusiness/gotosocial/internal/federation"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"  	"github.com/superseriousbusiness/gotosocial/internal/media" -	"github.com/superseriousbusiness/gotosocial/internal/messages"  	"github.com/superseriousbusiness/gotosocial/internal/processing" +	"github.com/superseriousbusiness/gotosocial/internal/state"  	"github.com/superseriousbusiness/gotosocial/internal/storage"  	"github.com/superseriousbusiness/gotosocial/internal/typeutils"  	"github.com/superseriousbusiness/gotosocial/testrig" @@ -44,6 +43,7 @@ type FavouritesStandardTestSuite struct {  	emailSender  email.Sender  	processor    *processing.Processor  	storage      *storage.Driver +	state        state.State  	// standard suite models  	testTokens       map[string]*gtsmodel.Token @@ -71,22 +71,25 @@ func (suite *FavouritesStandardTestSuite) SetupSuite() {  }  func (suite *FavouritesStandardTestSuite) SetupTest() { +	suite.state.Caches.Init() +	testrig.StartWorkers(&suite.state) +  	testrig.InitTestConfig()  	testrig.InitTestLog() -	suite.db = testrig.NewTestDB() -	suite.tc = testrig.NewTestTypeConverter(suite.db) +	suite.db = testrig.NewTestDB(&suite.state) +	suite.state.DB = suite.db  	suite.storage = testrig.NewInMemoryStorage() +	suite.state.Storage = suite.storage + +	suite.tc = testrig.NewTestTypeConverter(suite.db)  	testrig.StandardDBSetup(suite.db, nil)  	testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media") -	fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) -	clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) - -	suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage) -	suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker) +	suite.mediaManager = testrig.NewTestMediaManager(&suite.state) +	suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager)  	suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil) -	suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker) +	suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager)  	suite.favModule = favourites.New(suite.processor)  	suite.NoError(suite.processor.Start()) @@ -95,6 +98,7 @@ func (suite *FavouritesStandardTestSuite) SetupTest() {  func (suite *FavouritesStandardTestSuite) TearDownTest() {  	testrig.StandardDBTeardown(suite.db)  	testrig.StandardStorageTeardown(suite.storage) +	testrig.StopWorkers(&suite.state)  }  func (suite *FavouritesStandardTestSuite) TestProcessFave() {} diff --git a/internal/api/client/followrequests/followrequest_test.go b/internal/api/client/followrequests/followrequest_test.go index 7a08479ab..294dbc7ed 100644 --- a/internal/api/client/followrequests/followrequest_test.go +++ b/internal/api/client/followrequests/followrequest_test.go @@ -26,16 +26,15 @@ import (  	"github.com/gin-gonic/gin"  	"github.com/stretchr/testify/suite"  	"github.com/superseriousbusiness/gotosocial/internal/api/client/followrequests" -	"github.com/superseriousbusiness/gotosocial/internal/concurrency"  	"github.com/superseriousbusiness/gotosocial/internal/config"  	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/email"  	"github.com/superseriousbusiness/gotosocial/internal/federation"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"  	"github.com/superseriousbusiness/gotosocial/internal/media" -	"github.com/superseriousbusiness/gotosocial/internal/messages"  	"github.com/superseriousbusiness/gotosocial/internal/oauth"  	"github.com/superseriousbusiness/gotosocial/internal/processing" +	"github.com/superseriousbusiness/gotosocial/internal/state"  	"github.com/superseriousbusiness/gotosocial/internal/storage"  	"github.com/superseriousbusiness/gotosocial/testrig"  ) @@ -48,6 +47,7 @@ type FollowRequestStandardTestSuite struct {  	federator    federation.Federator  	processor    *processing.Processor  	emailSender  email.Sender +	state        state.State  	// standard suite models  	testTokens       map[string]*gtsmodel.Token @@ -73,18 +73,21 @@ func (suite *FollowRequestStandardTestSuite) SetupSuite() {  }  func (suite *FollowRequestStandardTestSuite) SetupTest() { +	suite.state.Caches.Init() +	testrig.StartWorkers(&suite.state) +  	testrig.InitTestConfig()  	testrig.InitTestLog() -	fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) -	clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) - -	suite.db = testrig.NewTestDB() +	suite.db = testrig.NewTestDB(&suite.state) +	suite.state.DB = suite.db  	suite.storage = testrig.NewInMemoryStorage() -	suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage) -	suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker) +	suite.state.Storage = suite.storage + +	suite.mediaManager = testrig.NewTestMediaManager(&suite.state) +	suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager)  	suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil) -	suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker) +	suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager)  	suite.followRequestModule = followrequests.New(suite.processor)  	testrig.StandardDBSetup(suite.db, nil)  	testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media") @@ -95,6 +98,7 @@ func (suite *FollowRequestStandardTestSuite) SetupTest() {  func (suite *FollowRequestStandardTestSuite) TearDownTest() {  	testrig.StandardDBTeardown(suite.db)  	testrig.StandardStorageTeardown(suite.storage) +	testrig.StopWorkers(&suite.state)  }  func (suite *FollowRequestStandardTestSuite) newContext(recorder *httptest.ResponseRecorder, requestMethod string, requestBody []byte, requestPath string, bodyContentType string) *gin.Context { diff --git a/internal/api/client/instance/instance_test.go b/internal/api/client/instance/instance_test.go index ff622febe..6870d2a44 100644 --- a/internal/api/client/instance/instance_test.go +++ b/internal/api/client/instance/instance_test.go @@ -26,16 +26,15 @@ import (  	"github.com/gin-gonic/gin"  	"github.com/stretchr/testify/suite"  	"github.com/superseriousbusiness/gotosocial/internal/api/client/instance" -	"github.com/superseriousbusiness/gotosocial/internal/concurrency"  	"github.com/superseriousbusiness/gotosocial/internal/config"  	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/email"  	"github.com/superseriousbusiness/gotosocial/internal/federation"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"  	"github.com/superseriousbusiness/gotosocial/internal/media" -	"github.com/superseriousbusiness/gotosocial/internal/messages"  	"github.com/superseriousbusiness/gotosocial/internal/oauth"  	"github.com/superseriousbusiness/gotosocial/internal/processing" +	"github.com/superseriousbusiness/gotosocial/internal/state"  	"github.com/superseriousbusiness/gotosocial/internal/storage"  	"github.com/superseriousbusiness/gotosocial/testrig"  ) @@ -50,6 +49,7 @@ type InstanceStandardTestSuite struct {  	processor    *processing.Processor  	emailSender  email.Sender  	sentEmails   map[string]string +	state        state.State  	// standard suite models  	testTokens       map[string]*gtsmodel.Token @@ -75,19 +75,22 @@ func (suite *InstanceStandardTestSuite) SetupSuite() {  }  func (suite *InstanceStandardTestSuite) SetupTest() { +	suite.state.Caches.Init() +	testrig.StartWorkers(&suite.state) +  	testrig.InitTestConfig()  	testrig.InitTestLog() -	fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) -	clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) - -	suite.db = testrig.NewTestDB() +	suite.db = testrig.NewTestDB(&suite.state) +	suite.state.DB = suite.db  	suite.storage = testrig.NewInMemoryStorage() -	suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage) -	suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker) +	suite.state.Storage = suite.storage + +	suite.mediaManager = testrig.NewTestMediaManager(&suite.state) +	suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager)  	suite.sentEmails = make(map[string]string)  	suite.emailSender = testrig.NewEmailSender("../../../../web/template/", suite.sentEmails) -	suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker) +	suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager)  	suite.instanceModule = instance.New(suite.processor)  	testrig.StandardDBSetup(suite.db, nil)  	testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media") @@ -96,6 +99,7 @@ func (suite *InstanceStandardTestSuite) SetupTest() {  func (suite *InstanceStandardTestSuite) TearDownTest() {  	testrig.StandardDBTeardown(suite.db)  	testrig.StandardStorageTeardown(suite.storage) +	testrig.StopWorkers(&suite.state)  }  func (suite *InstanceStandardTestSuite) newContext(recorder *httptest.ResponseRecorder, method string, path string, body []byte, contentType string, auth bool) *gin.Context { diff --git a/internal/api/client/media/mediacreate_test.go b/internal/api/client/media/mediacreate_test.go index caa40b061..6439895f3 100644 --- a/internal/api/client/media/mediacreate_test.go +++ b/internal/api/client/media/mediacreate_test.go @@ -33,7 +33,6 @@ import (  	"github.com/stretchr/testify/suite"  	mediamodule "github.com/superseriousbusiness/gotosocial/internal/api/client/media"  	apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" -	"github.com/superseriousbusiness/gotosocial/internal/concurrency"  	"github.com/superseriousbusiness/gotosocial/internal/config"  	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/email" @@ -41,9 +40,9 @@ import (  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"  	"github.com/superseriousbusiness/gotosocial/internal/log"  	"github.com/superseriousbusiness/gotosocial/internal/media" -	"github.com/superseriousbusiness/gotosocial/internal/messages"  	"github.com/superseriousbusiness/gotosocial/internal/oauth"  	"github.com/superseriousbusiness/gotosocial/internal/processing" +	"github.com/superseriousbusiness/gotosocial/internal/state"  	"github.com/superseriousbusiness/gotosocial/internal/storage"  	"github.com/superseriousbusiness/gotosocial/internal/typeutils"  	"github.com/superseriousbusiness/gotosocial/testrig" @@ -60,6 +59,7 @@ type MediaCreateTestSuite struct {  	oauthServer  oauth.Server  	emailSender  email.Sender  	processor    *processing.Processor +	state        state.State  	// standard suite models  	testTokens       map[string]*gtsmodel.Token @@ -78,21 +78,24 @@ type MediaCreateTestSuite struct {  */  func (suite *MediaCreateTestSuite) SetupSuite() { +	suite.state.Caches.Init() +	testrig.StartWorkers(&suite.state) +  	// setup standard items  	testrig.InitTestConfig()  	testrig.InitTestLog() -	fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) -	clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) - -	suite.db = testrig.NewTestDB() +	suite.db = testrig.NewTestDB(&suite.state) +	suite.state.DB = suite.db  	suite.storage = testrig.NewInMemoryStorage() +	suite.state.Storage = suite.storage +  	suite.tc = testrig.NewTestTypeConverter(suite.db) -	suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage) +	suite.mediaManager = testrig.NewTestMediaManager(&suite.state)  	suite.oauthServer = testrig.NewTestOauthServer(suite.db) -	suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker) +	suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager)  	suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil) -	suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker) +	suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager)  	// setup module being tested  	suite.mediaModule = mediamodule.New(suite.processor) @@ -102,11 +105,15 @@ func (suite *MediaCreateTestSuite) TearDownSuite() {  	if err := suite.db.Stop(context.Background()); err != nil {  		log.Panicf(nil, "error closing db connection: %s", err)  	} +	testrig.StopWorkers(&suite.state)  }  func (suite *MediaCreateTestSuite) SetupTest() { +	suite.state.Caches.Init() +  	testrig.StandardDBSetup(suite.db, nil)  	testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media") +  	suite.testTokens = testrig.NewTestTokens()  	suite.testClients = testrig.NewTestClients()  	suite.testApplications = testrig.NewTestApplications() diff --git a/internal/api/client/media/mediaupdate_test.go b/internal/api/client/media/mediaupdate_test.go index cb96e8aa1..75657e1b5 100644 --- a/internal/api/client/media/mediaupdate_test.go +++ b/internal/api/client/media/mediaupdate_test.go @@ -31,7 +31,6 @@ import (  	"github.com/stretchr/testify/suite"  	mediamodule "github.com/superseriousbusiness/gotosocial/internal/api/client/media"  	apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" -	"github.com/superseriousbusiness/gotosocial/internal/concurrency"  	"github.com/superseriousbusiness/gotosocial/internal/config"  	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/email" @@ -39,9 +38,9 @@ import (  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"  	"github.com/superseriousbusiness/gotosocial/internal/log"  	"github.com/superseriousbusiness/gotosocial/internal/media" -	"github.com/superseriousbusiness/gotosocial/internal/messages"  	"github.com/superseriousbusiness/gotosocial/internal/oauth"  	"github.com/superseriousbusiness/gotosocial/internal/processing" +	"github.com/superseriousbusiness/gotosocial/internal/state"  	"github.com/superseriousbusiness/gotosocial/internal/storage"  	"github.com/superseriousbusiness/gotosocial/internal/typeutils"  	"github.com/superseriousbusiness/gotosocial/testrig" @@ -58,6 +57,7 @@ type MediaUpdateTestSuite struct {  	oauthServer  oauth.Server  	emailSender  email.Sender  	processor    *processing.Processor +	state        state.State  	// standard suite models  	testTokens       map[string]*gtsmodel.Token @@ -76,21 +76,23 @@ type MediaUpdateTestSuite struct {  */  func (suite *MediaUpdateTestSuite) SetupSuite() { +	testrig.StartWorkers(&suite.state) +  	// setup standard items  	testrig.InitTestConfig()  	testrig.InitTestLog() -	fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) -	clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) - -	suite.db = testrig.NewTestDB() +	suite.db = testrig.NewTestDB(&suite.state) +	suite.state.DB = suite.db  	suite.storage = testrig.NewInMemoryStorage() +	suite.state.Storage = suite.storage +  	suite.tc = testrig.NewTestTypeConverter(suite.db) -	suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage) +	suite.mediaManager = testrig.NewTestMediaManager(&suite.state)  	suite.oauthServer = testrig.NewTestOauthServer(suite.db) -	suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker) +	suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager)  	suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil) -	suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker) +	suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager)  	// setup module being tested  	suite.mediaModule = mediamodule.New(suite.processor) @@ -100,11 +102,15 @@ func (suite *MediaUpdateTestSuite) TearDownSuite() {  	if err := suite.db.Stop(context.Background()); err != nil {  		log.Panicf(nil, "error closing db connection: %s", err)  	} +	testrig.StopWorkers(&suite.state)  }  func (suite *MediaUpdateTestSuite) SetupTest() { +	suite.state.Caches.Init() +  	testrig.StandardDBSetup(suite.db, nil)  	testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media") +  	suite.testTokens = testrig.NewTestTokens()  	suite.testClients = testrig.NewTestClients()  	suite.testApplications = testrig.NewTestApplications() diff --git a/internal/api/client/reports/reports_test.go b/internal/api/client/reports/reports_test.go index 1c5a532b9..cdab0b77b 100644 --- a/internal/api/client/reports/reports_test.go +++ b/internal/api/client/reports/reports_test.go @@ -21,14 +21,13 @@ package reports_test  import (  	"github.com/stretchr/testify/suite"  	"github.com/superseriousbusiness/gotosocial/internal/api/client/reports" -	"github.com/superseriousbusiness/gotosocial/internal/concurrency"  	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/email"  	"github.com/superseriousbusiness/gotosocial/internal/federation"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"  	"github.com/superseriousbusiness/gotosocial/internal/media" -	"github.com/superseriousbusiness/gotosocial/internal/messages"  	"github.com/superseriousbusiness/gotosocial/internal/processing" +	"github.com/superseriousbusiness/gotosocial/internal/state"  	"github.com/superseriousbusiness/gotosocial/internal/storage"  	"github.com/superseriousbusiness/gotosocial/testrig"  ) @@ -42,6 +41,7 @@ type ReportsStandardTestSuite struct {  	processor    *processing.Processor  	emailSender  email.Sender  	sentEmails   map[string]string +	state        state.State  	// standard suite models  	testTokens       map[string]*gtsmodel.Token @@ -67,19 +67,22 @@ func (suite *ReportsStandardTestSuite) SetupSuite() {  }  func (suite *ReportsStandardTestSuite) SetupTest() { +	suite.state.Caches.Init() +	testrig.StartWorkers(&suite.state) +  	testrig.InitTestConfig()  	testrig.InitTestLog() -	fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) -	clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) - -	suite.db = testrig.NewTestDB() +	suite.db = testrig.NewTestDB(&suite.state) +	suite.state.DB = suite.db  	suite.storage = testrig.NewInMemoryStorage() -	suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage) -	suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker) +	suite.state.Storage = suite.storage + +	suite.mediaManager = testrig.NewTestMediaManager(&suite.state) +	suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager)  	suite.sentEmails = make(map[string]string)  	suite.emailSender = testrig.NewEmailSender("../../../../web/template/", suite.sentEmails) -	suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker) +	suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager)  	suite.reportsModule = reports.New(suite.processor)  	testrig.StandardDBSetup(suite.db, nil)  	testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media") @@ -90,4 +93,5 @@ func (suite *ReportsStandardTestSuite) SetupTest() {  func (suite *ReportsStandardTestSuite) TearDownTest() {  	testrig.StandardDBTeardown(suite.db)  	testrig.StandardStorageTeardown(suite.storage) +	testrig.StopWorkers(&suite.state)  } diff --git a/internal/api/client/search/search_test.go b/internal/api/client/search/search_test.go index 4580f6f9d..153328cc3 100644 --- a/internal/api/client/search/search_test.go +++ b/internal/api/client/search/search_test.go @@ -26,16 +26,15 @@ import (  	"github.com/gin-gonic/gin"  	"github.com/stretchr/testify/suite"  	"github.com/superseriousbusiness/gotosocial/internal/api/client/search" -	"github.com/superseriousbusiness/gotosocial/internal/concurrency"  	"github.com/superseriousbusiness/gotosocial/internal/config"  	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/email"  	"github.com/superseriousbusiness/gotosocial/internal/federation"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"  	"github.com/superseriousbusiness/gotosocial/internal/media" -	"github.com/superseriousbusiness/gotosocial/internal/messages"  	"github.com/superseriousbusiness/gotosocial/internal/oauth"  	"github.com/superseriousbusiness/gotosocial/internal/processing" +	"github.com/superseriousbusiness/gotosocial/internal/state"  	"github.com/superseriousbusiness/gotosocial/internal/storage"  	"github.com/superseriousbusiness/gotosocial/testrig"  ) @@ -50,6 +49,7 @@ type SearchStandardTestSuite struct {  	processor    *processing.Processor  	emailSender  email.Sender  	sentEmails   map[string]string +	state        state.State  	// standard suite models  	testTokens       map[string]*gtsmodel.Token @@ -71,19 +71,22 @@ func (suite *SearchStandardTestSuite) SetupSuite() {  }  func (suite *SearchStandardTestSuite) SetupTest() { +	suite.state.Caches.Init() +	testrig.StartWorkers(&suite.state) +  	testrig.InitTestConfig()  	testrig.InitTestLog() -	fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) -	clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) - -	suite.db = testrig.NewTestDB() +	suite.db = testrig.NewTestDB(&suite.state) +	suite.state.DB = suite.db  	suite.storage = testrig.NewInMemoryStorage() -	suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage) -	suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker) +	suite.state.Storage = suite.storage + +	suite.mediaManager = testrig.NewTestMediaManager(&suite.state) +	suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager)  	suite.sentEmails = make(map[string]string)  	suite.emailSender = testrig.NewEmailSender("../../../../web/template/", suite.sentEmails) -	suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker) +	suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager)  	suite.searchModule = search.New(suite.processor)  	testrig.StandardDBSetup(suite.db, nil)  	testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media") @@ -94,6 +97,7 @@ func (suite *SearchStandardTestSuite) SetupTest() {  func (suite *SearchStandardTestSuite) TearDownTest() {  	testrig.StandardDBTeardown(suite.db)  	testrig.StandardStorageTeardown(suite.storage) +	testrig.StopWorkers(&suite.state)  }  func (suite *SearchStandardTestSuite) newContext(recorder *httptest.ResponseRecorder, requestPath string) *gin.Context { diff --git a/internal/api/client/statuses/status_test.go b/internal/api/client/statuses/status_test.go index a87fd36f7..93745ffd8 100644 --- a/internal/api/client/statuses/status_test.go +++ b/internal/api/client/statuses/status_test.go @@ -21,14 +21,13 @@ package statuses_test  import (  	"github.com/stretchr/testify/suite"  	"github.com/superseriousbusiness/gotosocial/internal/api/client/statuses" -	"github.com/superseriousbusiness/gotosocial/internal/concurrency"  	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/email"  	"github.com/superseriousbusiness/gotosocial/internal/federation"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"  	"github.com/superseriousbusiness/gotosocial/internal/media" -	"github.com/superseriousbusiness/gotosocial/internal/messages"  	"github.com/superseriousbusiness/gotosocial/internal/processing" +	"github.com/superseriousbusiness/gotosocial/internal/state"  	"github.com/superseriousbusiness/gotosocial/internal/storage"  	"github.com/superseriousbusiness/gotosocial/internal/typeutils"  	"github.com/superseriousbusiness/gotosocial/testrig" @@ -44,6 +43,7 @@ type StatusStandardTestSuite struct {  	emailSender  email.Sender  	processor    *processing.Processor  	storage      *storage.Driver +	state        state.State  	// standard suite models  	testTokens       map[string]*gtsmodel.Token @@ -71,22 +71,26 @@ func (suite *StatusStandardTestSuite) SetupSuite() {  }  func (suite *StatusStandardTestSuite) SetupTest() { +	suite.state.Caches.Init() +	testrig.StartWorkers(&suite.state) +  	testrig.InitTestConfig()  	testrig.InitTestLog() -	suite.db = testrig.NewTestDB() -	suite.tc = testrig.NewTestTypeConverter(suite.db) +	suite.db = testrig.NewTestDB(&suite.state) +	suite.state.DB = suite.db  	suite.storage = testrig.NewInMemoryStorage() +	suite.state.Storage = suite.storage + +	suite.tc = testrig.NewTestTypeConverter(suite.db) +  	testrig.StandardDBSetup(suite.db, nil)  	testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media") -	fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) -	clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) - -	suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage) -	suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker) +	suite.mediaManager = testrig.NewTestMediaManager(&suite.state) +	suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager)  	suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil) -	suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker) +	suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager)  	suite.statusModule = statuses.New(suite.processor)  	suite.NoError(suite.processor.Start()) @@ -95,4 +99,5 @@ func (suite *StatusStandardTestSuite) SetupTest() {  func (suite *StatusStandardTestSuite) TearDownTest() {  	testrig.StandardDBTeardown(suite.db)  	testrig.StandardStorageTeardown(suite.storage) +	testrig.StopWorkers(&suite.state)  } diff --git a/internal/api/client/streaming/streaming_test.go b/internal/api/client/streaming/streaming_test.go index 5fb470af8..ac27aad8a 100644 --- a/internal/api/client/streaming/streaming_test.go +++ b/internal/api/client/streaming/streaming_test.go @@ -32,15 +32,14 @@ import (  	"github.com/gin-gonic/gin"  	"github.com/stretchr/testify/suite"  	"github.com/superseriousbusiness/gotosocial/internal/api/client/streaming" -	"github.com/superseriousbusiness/gotosocial/internal/concurrency"  	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/email"  	"github.com/superseriousbusiness/gotosocial/internal/federation"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"  	"github.com/superseriousbusiness/gotosocial/internal/media" -	"github.com/superseriousbusiness/gotosocial/internal/messages"  	"github.com/superseriousbusiness/gotosocial/internal/oauth"  	"github.com/superseriousbusiness/gotosocial/internal/processing" +	"github.com/superseriousbusiness/gotosocial/internal/state"  	"github.com/superseriousbusiness/gotosocial/internal/storage"  	"github.com/superseriousbusiness/gotosocial/internal/typeutils"  	"github.com/superseriousbusiness/gotosocial/testrig" @@ -56,6 +55,7 @@ type StreamingTestSuite struct {  	emailSender  email.Sender  	processor    *processing.Processor  	storage      *storage.Driver +	state        state.State  	// standard suite models  	testTokens       map[string]*gtsmodel.Token @@ -83,22 +83,25 @@ func (suite *StreamingTestSuite) SetupSuite() {  }  func (suite *StreamingTestSuite) SetupTest() { +	suite.state.Caches.Init() +	testrig.StartWorkers(&suite.state) +  	testrig.InitTestConfig()  	testrig.InitTestLog() -	suite.db = testrig.NewTestDB() -	suite.tc = testrig.NewTestTypeConverter(suite.db) +	suite.db = testrig.NewTestDB(&suite.state) +	suite.state.DB = suite.db  	suite.storage = testrig.NewInMemoryStorage() +	suite.state.Storage = suite.storage + +	suite.tc = testrig.NewTestTypeConverter(suite.db)  	testrig.StandardDBSetup(suite.db, nil)  	testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media") -	fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) -	clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) - -	suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage) -	suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker) +	suite.mediaManager = testrig.NewTestMediaManager(&suite.state) +	suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager)  	suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil) -	suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker) +	suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager)  	suite.streamingModule = streaming.New(suite.processor, 1, 4096)  	suite.NoError(suite.processor.Start())  } @@ -106,6 +109,7 @@ func (suite *StreamingTestSuite) SetupTest() {  func (suite *StreamingTestSuite) TearDownTest() {  	testrig.StandardDBTeardown(suite.db)  	testrig.StandardStorageTeardown(suite.storage) +	testrig.StopWorkers(&suite.state)  }  // Addr is a fake network interface which implements the net.Addr interface diff --git a/internal/api/client/user/user_test.go b/internal/api/client/user/user_test.go index c990abb56..ce117059e 100644 --- a/internal/api/client/user/user_test.go +++ b/internal/api/client/user/user_test.go @@ -21,14 +21,13 @@ package user_test  import (  	"github.com/stretchr/testify/suite"  	"github.com/superseriousbusiness/gotosocial/internal/api/client/user" -	"github.com/superseriousbusiness/gotosocial/internal/concurrency"  	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/email"  	"github.com/superseriousbusiness/gotosocial/internal/federation"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"  	"github.com/superseriousbusiness/gotosocial/internal/media" -	"github.com/superseriousbusiness/gotosocial/internal/messages"  	"github.com/superseriousbusiness/gotosocial/internal/processing" +	"github.com/superseriousbusiness/gotosocial/internal/state"  	"github.com/superseriousbusiness/gotosocial/internal/storage"  	"github.com/superseriousbusiness/gotosocial/internal/typeutils"  	"github.com/superseriousbusiness/gotosocial/testrig" @@ -43,6 +42,7 @@ type UserStandardTestSuite struct {  	emailSender  email.Sender  	processor    *processing.Processor  	storage      *storage.Driver +	state        state.State  	testTokens       map[string]*gtsmodel.Token  	testClients      map[string]*gtsmodel.Client @@ -56,23 +56,29 @@ type UserStandardTestSuite struct {  }  func (suite *UserStandardTestSuite) SetupTest() { +	suite.state.Caches.Init() +	testrig.StartWorkers(&suite.state) +  	testrig.InitTestConfig()  	testrig.InitTestLog() -	fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) -	clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) +  	suite.testTokens = testrig.NewTestTokens()  	suite.testClients = testrig.NewTestClients()  	suite.testApplications = testrig.NewTestApplications()  	suite.testUsers = testrig.NewTestUsers()  	suite.testAccounts = testrig.NewTestAccounts() -	suite.db = testrig.NewTestDB() + +	suite.db = testrig.NewTestDB(&suite.state) +	suite.state.DB = suite.db  	suite.storage = testrig.NewInMemoryStorage() +	suite.state.Storage = suite.storage +  	suite.tc = testrig.NewTestTypeConverter(suite.db) -	suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage) -	suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker) +	suite.mediaManager = testrig.NewTestMediaManager(&suite.state) +	suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager)  	suite.sentEmails = make(map[string]string)  	suite.emailSender = testrig.NewEmailSender("../../../../web/template/", suite.sentEmails) -	suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker) +	suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager)  	suite.userModule = user.New(suite.processor)  	testrig.StandardDBSetup(suite.db, suite.testAccounts)  	testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media") @@ -83,4 +89,5 @@ func (suite *UserStandardTestSuite) SetupTest() {  func (suite *UserStandardTestSuite) TearDownTest() {  	testrig.StandardDBTeardown(suite.db)  	testrig.StandardStorageTeardown(suite.storage) +	testrig.StopWorkers(&suite.state)  } diff --git a/internal/api/fileserver/fileserver_test.go b/internal/api/fileserver/fileserver_test.go index 0a6879e70..0e0dd9434 100644 --- a/internal/api/fileserver/fileserver_test.go +++ b/internal/api/fileserver/fileserver_test.go @@ -23,16 +23,15 @@ import (  	"github.com/stretchr/testify/suite"  	"github.com/superseriousbusiness/gotosocial/internal/api/fileserver" -	"github.com/superseriousbusiness/gotosocial/internal/concurrency"  	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/email"  	"github.com/superseriousbusiness/gotosocial/internal/federation"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"  	"github.com/superseriousbusiness/gotosocial/internal/log"  	"github.com/superseriousbusiness/gotosocial/internal/media" -	"github.com/superseriousbusiness/gotosocial/internal/messages"  	"github.com/superseriousbusiness/gotosocial/internal/oauth"  	"github.com/superseriousbusiness/gotosocial/internal/processing" +	"github.com/superseriousbusiness/gotosocial/internal/state"  	"github.com/superseriousbusiness/gotosocial/internal/storage"  	"github.com/superseriousbusiness/gotosocial/internal/typeutils"  	"github.com/superseriousbusiness/gotosocial/testrig" @@ -43,6 +42,7 @@ type FileserverTestSuite struct {  	suite.Suite  	db           db.DB  	storage      *storage.Driver +	state        state.State  	federator    federation.Federator  	tc           typeutils.TypeConverter  	processor    *processing.Processor @@ -67,26 +67,32 @@ type FileserverTestSuite struct {  */  func (suite *FileserverTestSuite) SetupSuite() { +	testrig.StartWorkers(&suite.state) +  	testrig.InitTestConfig()  	testrig.InitTestLog() -	fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) -	clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) - -	suite.db = testrig.NewTestDB() +	suite.db = testrig.NewTestDB(&suite.state) +	suite.state.DB = suite.db  	suite.storage = testrig.NewInMemoryStorage() -	suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker) -	suite.emailSender = testrig.NewEmailSender("../../../web/template/", nil) +	suite.state.Storage = suite.storage + +	suite.mediaManager = testrig.NewTestMediaManager(&suite.state) +	suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../testrig/media")), suite.mediaManager) +	suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager) -	suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, testrig.NewTestMediaManager(suite.db, suite.storage), clientWorker, fedWorker)  	suite.tc = testrig.NewTestTypeConverter(suite.db) -	suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage) +	suite.mediaManager = testrig.NewTestMediaManager(&suite.state)  	suite.oauthServer = testrig.NewTestOauthServer(suite.db) +	suite.emailSender = testrig.NewEmailSender("../../../web/template/", nil)  	suite.fileServer = fileserver.New(suite.processor)  }  func (suite *FileserverTestSuite) SetupTest() { +	suite.state.Caches.Init() +	testrig.StartWorkers(&suite.state) +  	testrig.StandardDBSetup(suite.db, nil)  	testrig.StandardStorageSetup(suite.storage, "../../../testrig/media")  	suite.testTokens = testrig.NewTestTokens() @@ -101,9 +107,11 @@ func (suite *FileserverTestSuite) TearDownSuite() {  	if err := suite.db.Stop(context.Background()); err != nil {  		log.Panicf(nil, "error closing db connection: %s", err)  	} +	testrig.StopWorkers(&suite.state)  }  func (suite *FileserverTestSuite) TearDownTest() {  	testrig.StandardDBTeardown(suite.db)  	testrig.StandardStorageTeardown(suite.storage) +	testrig.StopWorkers(&suite.state)  } diff --git a/internal/api/wellknown/webfinger/webfinger_test.go b/internal/api/wellknown/webfinger/webfinger_test.go index 38228e928..3148279c5 100644 --- a/internal/api/wellknown/webfinger/webfinger_test.go +++ b/internal/api/wellknown/webfinger/webfinger_test.go @@ -26,15 +26,14 @@ import (  	"github.com/stretchr/testify/suite"  	"github.com/superseriousbusiness/gotosocial/internal/ap"  	"github.com/superseriousbusiness/gotosocial/internal/api/wellknown/webfinger" -	"github.com/superseriousbusiness/gotosocial/internal/concurrency"  	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/email"  	"github.com/superseriousbusiness/gotosocial/internal/federation"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"  	"github.com/superseriousbusiness/gotosocial/internal/media" -	"github.com/superseriousbusiness/gotosocial/internal/messages"  	"github.com/superseriousbusiness/gotosocial/internal/oauth"  	"github.com/superseriousbusiness/gotosocial/internal/processing" +	"github.com/superseriousbusiness/gotosocial/internal/state"  	"github.com/superseriousbusiness/gotosocial/internal/storage"  	"github.com/superseriousbusiness/gotosocial/internal/typeutils"  	"github.com/superseriousbusiness/gotosocial/testrig" @@ -44,6 +43,7 @@ type WebfingerStandardTestSuite struct {  	// standard suite interfaces  	suite.Suite  	db           db.DB +	state        state.State  	tc           typeutils.TypeConverter  	mediaManager media.Manager  	federator    federation.Federator @@ -76,19 +76,21 @@ func (suite *WebfingerStandardTestSuite) SetupSuite() {  }  func (suite *WebfingerStandardTestSuite) SetupTest() { +	suite.state.Caches.Init() +	testrig.StartWorkers(&suite.state) +  	testrig.InitTestLog()  	testrig.InitTestConfig() -	clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) -	fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) - -	suite.db = testrig.NewTestDB() +	suite.db = testrig.NewTestDB(&suite.state) +	suite.state.DB = suite.db  	suite.tc = testrig.NewTestTypeConverter(suite.db)  	suite.storage = testrig.NewInMemoryStorage() -	suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage) -	suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker) +	suite.state.Storage = suite.storage +	suite.mediaManager = testrig.NewTestMediaManager(&suite.state) +	suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager)  	suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil) -	suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker) +	suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager)  	suite.webfingerModule = webfinger.New(suite.processor)  	suite.oauthServer = testrig.NewTestOauthServer(suite.db)  	testrig.StandardDBSetup(suite.db, suite.testAccounts) @@ -100,6 +102,7 @@ func (suite *WebfingerStandardTestSuite) SetupTest() {  func (suite *WebfingerStandardTestSuite) TearDownTest() {  	testrig.StandardDBTeardown(suite.db)  	testrig.StandardStorageTeardown(suite.storage) +	testrig.StopWorkers(&suite.state)  }  func accountDomainAccount() *gtsmodel.Account { diff --git a/internal/api/wellknown/webfinger/webfingerget_test.go b/internal/api/wellknown/webfinger/webfingerget_test.go index 7587dfee1..a345d0602 100644 --- a/internal/api/wellknown/webfinger/webfingerget_test.go +++ b/internal/api/wellknown/webfinger/webfingerget_test.go @@ -30,9 +30,7 @@ import (  	"github.com/stretchr/testify/suite"  	"github.com/superseriousbusiness/gotosocial/internal/api/wellknown/webfinger" -	"github.com/superseriousbusiness/gotosocial/internal/concurrency"  	"github.com/superseriousbusiness/gotosocial/internal/config" -	"github.com/superseriousbusiness/gotosocial/internal/messages"  	"github.com/superseriousbusiness/gotosocial/internal/processing"  	"github.com/superseriousbusiness/gotosocial/testrig"  ) @@ -91,9 +89,7 @@ func (suite *WebfingerGetTestSuite) TestFingerUserWithDifferentAccountDomainByHo  	config.SetHost("gts.example.org")  	config.SetAccountDomain("example.org") -	clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) -	fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) -	suite.processor = processing.NewProcessor(suite.tc, suite.federator, testrig.NewTestOauthServer(suite.db), testrig.NewTestMediaManager(suite.db, suite.storage), suite.storage, suite.db, suite.emailSender, clientWorker, fedWorker) +	suite.processor = processing.NewProcessor(suite.tc, suite.federator, testrig.NewTestOauthServer(suite.db), testrig.NewTestMediaManager(&suite.state), &suite.state, suite.emailSender)  	suite.webfingerModule = webfinger.New(suite.processor)  	targetAccount := accountDomainAccount() @@ -148,9 +144,7 @@ func (suite *WebfingerGetTestSuite) TestFingerUserWithDifferentAccountDomainByAc  	config.SetHost("gts.example.org")  	config.SetAccountDomain("example.org") -	clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) -	fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) -	suite.processor = processing.NewProcessor(suite.tc, suite.federator, testrig.NewTestOauthServer(suite.db), testrig.NewTestMediaManager(suite.db, suite.storage), suite.storage, suite.db, suite.emailSender, clientWorker, fedWorker) +	suite.processor = processing.NewProcessor(suite.tc, suite.federator, testrig.NewTestOauthServer(suite.db), testrig.NewTestMediaManager(&suite.state), &suite.state, suite.emailSender)  	suite.webfingerModule = webfinger.New(suite.processor)  	targetAccount := accountDomainAccount() diff --git a/internal/concurrency/workers.go b/internal/concurrency/workers.go deleted file mode 100644 index ed99509cf..000000000 --- a/internal/concurrency/workers.go +++ /dev/null @@ -1,141 +0,0 @@ -/* -   GoToSocial -   Copyright (C) 2021-2023 GoToSocial Authors admin@gotosocial.org - -   This program is free software: you can redistribute it and/or modify -   it under the terms of the GNU Affero General Public License as published by -   the Free Software Foundation, either version 3 of the License, or -   (at your option) any later version. - -   This program is distributed in the hope that it will be useful, -   but WITHOUT ANY WARRANTY; without even the implied warranty of -   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the -   GNU Affero General Public License for more details. - -   You should have received a copy of the GNU Affero General Public License -   along with this program.  If not, see <http://www.gnu.org/licenses/>. -*/ - -package concurrency - -import ( -	"context" -	"errors" -	"fmt" -	"path" -	"reflect" -	"runtime" - -	"codeberg.org/gruf/go-kv" -	"codeberg.org/gruf/go-runners" -	"github.com/superseriousbusiness/gotosocial/internal/log" -) - -// WorkerPool represents a proccessor for MsgType objects, using a worker pool to allocate resources. -type WorkerPool[MsgType any] struct { -	workers runners.WorkerPool -	process func(context.Context, MsgType) error -	nw, nq  int -	wtype   string // contains worker type for logging -} - -// New returns a new WorkerPool[MsgType] with given number of workers and queue ratio, -// where the queue ratio is multiplied by no. workers to get queue size. If args < 1 -// then suitable defaults are determined from the runtime's GOMAXPROCS variable. -func NewWorkerPool[MsgType any](workers int, queueRatio int) *WorkerPool[MsgType] { -	var zero MsgType - -	if workers < 1 { -		// ensure sensible workers -		workers = runtime.GOMAXPROCS(0) * 4 -	} -	if queueRatio < 1 { -		// ensure sensible ratio -		queueRatio = 100 -	} - -	// Calculate the short type string for the msg type -	msgType := reflect.TypeOf(zero).String() -	_, msgType = path.Split(msgType) - -	w := &WorkerPool[MsgType]{ -		process: nil, -		nw:      workers, -		nq:      workers * queueRatio, -		wtype:   fmt.Sprintf("worker.Worker[%s]", msgType), -	} - -	// Log new worker creation with worker type prefix -	log.Infof(nil, "%s created with workers=%d queue=%d", -		w.wtype, -		workers, -		workers*queueRatio, -	) - -	return w -} - -// Start will attempt to start the underlying worker pool, or return error. -func (w *WorkerPool[MsgType]) Start() error { -	log.Infof(nil, "%s starting", w.wtype) - -	// Check processor was set -	if w.process == nil { -		return errors.New("nil Worker.process function") -	} - -	// Attempt to start pool -	if !w.workers.Start(w.nw, w.nq) { -		return errors.New("failed to start Worker pool") -	} - -	return nil -} - -// Stop will attempt to stop the underlying worker pool, or return error. -func (w *WorkerPool[MsgType]) Stop() error { -	log.Infof(nil, "%s stopping", w.wtype) - -	// Attempt to stop pool -	if !w.workers.Stop() { -		return errors.New("failed to stop Worker pool") -	} - -	return nil -} - -// SetProcessor will set the Worker's processor function, which is called for each queued message. -func (w *WorkerPool[MsgType]) SetProcessor(fn func(context.Context, MsgType) error) { -	if w.process != nil { -		log.Panicf(nil, "%s Worker.process is already set", w.wtype) -	} -	w.process = fn -} - -// Queue will queue provided message to be processed with there's a free worker. -func (w *WorkerPool[MsgType]) Queue(msg MsgType) { -	log.Tracef(nil, "%s queueing message: %+v", w.wtype, msg) - -	// Create new process function for msg -	process := func(ctx context.Context) { -		if err := w.process(ctx, msg); err != nil { -			log.WithContext(ctx). -				WithFields(kv.Fields{ -					kv.Field{K: "type", V: w.wtype}, -					kv.Field{K: "error", V: err}, -				}...).Error("message processing error") -		} -	} - -	// Attempt a fast-enqueue of process -	if !w.workers.EnqueueNow(process) { -		// No spot acquired, log warning -		log.WithFields(kv.Fields{ -			kv.Field{K: "type", V: w.wtype}, -			kv.Field{K: "queue", V: w.workers.Queue()}, -		}...).Warn("full worker queue") - -		// Block on enqueuing process func -		w.workers.Enqueue(process) -	} -} diff --git a/internal/db/bundb/admin_test.go b/internal/db/bundb/admin_test.go index b0da97ef1..ce255d036 100644 --- a/internal/db/bundb/admin_test.go +++ b/internal/db/bundb/admin_test.go @@ -70,8 +70,8 @@ func (suite *AdminTestSuite) TestIsEmailAvailableDomainBlocked() {  }  func (suite *AdminTestSuite) TestCreateInstanceAccount() { -	// reinitialize test DB to clear caches -	suite.db = testrig.NewTestDB() +	// reinitialize db caches to clear +	suite.state.Caches.Init()  	// we need to take an empty db for this...  	testrig.StandardDBTeardown(suite.db)  	// ...with tables created but no data diff --git a/internal/db/bundb/bundb_test.go b/internal/db/bundb/bundb_test.go index e050c2b5d..bad8bfc72 100644 --- a/internal/db/bundb/bundb_test.go +++ b/internal/db/bundb/bundb_test.go @@ -22,13 +22,15 @@ import (  	"github.com/stretchr/testify/suite"  	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +	"github.com/superseriousbusiness/gotosocial/internal/state"  	"github.com/superseriousbusiness/gotosocial/testrig"  )  type BunDBStandardTestSuite struct {  	// standard suite interfaces  	suite.Suite -	db db.DB +	db    db.DB +	state state.State  	// standard suite models  	testTokens       map[string]*gtsmodel.Token @@ -61,9 +63,10 @@ func (suite *BunDBStandardTestSuite) SetupSuite() {  }  func (suite *BunDBStandardTestSuite) SetupTest() { +	suite.state.Caches.Init()  	testrig.InitTestConfig()  	testrig.InitTestLog() -	suite.db = testrig.NewTestDB() +	suite.db = testrig.NewTestDB(&suite.state)  	testrig.StandardDBSetup(suite.db, suite.testAccounts)  } diff --git a/internal/federation/dereferencing/dereferencer_test.go b/internal/federation/dereferencing/dereferencer_test.go index daca8b7de..f5b59b0ed 100644 --- a/internal/federation/dereferencing/dereferencer_test.go +++ b/internal/federation/dereferencing/dereferencer_test.go @@ -21,11 +21,10 @@ package dereferencing_test  import (  	"github.com/stretchr/testify/suite"  	"github.com/superseriousbusiness/activity/streams/vocab" -	"github.com/superseriousbusiness/gotosocial/internal/concurrency"  	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/federation/dereferencing"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" -	"github.com/superseriousbusiness/gotosocial/internal/messages" +	"github.com/superseriousbusiness/gotosocial/internal/state"  	"github.com/superseriousbusiness/gotosocial/internal/storage"  	"github.com/superseriousbusiness/gotosocial/testrig"  ) @@ -34,6 +33,7 @@ type DereferencerStandardTestSuite struct {  	suite.Suite  	db      db.DB  	storage *storage.Driver +	state   state.State  	testRemoteStatuses    map[string]vocab.ActivityStreamsNote  	testRemotePeople      map[string]vocab.ActivityStreamsPerson @@ -58,12 +58,19 @@ func (suite *DereferencerStandardTestSuite) SetupTest() {  	suite.testRemoteAttachments = testrig.NewTestFediAttachments("../../../testrig/media")  	suite.testEmojis = testrig.NewTestEmojis() -	suite.db = testrig.NewTestDB() +	suite.state.Caches.Init() +	testrig.StartWorkers(&suite.state) + +	suite.db = testrig.NewTestDB(&suite.state)  	suite.storage = testrig.NewInMemoryStorage() -	suite.dereferencer = dereferencing.NewDereferencer(suite.db, testrig.NewTestTypeConverter(suite.db), testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../testrig/media"), suite.db, concurrency.NewWorkerPool[messages.FromFederator](-1, -1)), testrig.NewTestMediaManager(suite.db, suite.storage)) +	suite.state.DB = suite.db +	suite.state.Storage = suite.storage +	media := testrig.NewTestMediaManager(&suite.state) +	suite.dereferencer = dereferencing.NewDereferencer(suite.db, testrig.NewTestTypeConverter(suite.db), testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../testrig/media")), media)  	testrig.StandardDBSetup(suite.db, nil)  }  func (suite *DereferencerStandardTestSuite) TearDownTest() {  	testrig.StandardDBTeardown(suite.db) +	testrig.StopWorkers(&suite.state)  } diff --git a/internal/federation/federatingactor_test.go b/internal/federation/federatingactor_test.go index 0d1d8e37f..f63ecd827 100644 --- a/internal/federation/federatingactor_test.go +++ b/internal/federation/federatingactor_test.go @@ -27,10 +27,8 @@ import (  	"time"  	"github.com/stretchr/testify/suite" -	"github.com/superseriousbusiness/gotosocial/internal/concurrency"  	"github.com/superseriousbusiness/gotosocial/internal/federation"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" -	"github.com/superseriousbusiness/gotosocial/internal/messages"  	"github.com/superseriousbusiness/gotosocial/testrig"  ) @@ -56,14 +54,12 @@ func (suite *FederatingActorTestSuite) TestSendNoRemoteFollowers() {  	)  	testActivity := testrig.WrapAPNoteInCreate(testrig.URLMustParse("http://localhost:8080/whatever_some_create"), testrig.URLMustParse(testAccount.URI), time.Now(), testNote) -	fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) -  	// setup transport controller with a no-op client so we don't make external calls  	httpClient := testrig.NewMockHTTPClient(nil, "../../testrig/media") -	tc := testrig.NewTestTransportController(httpClient, suite.db, fedWorker) +	tc := testrig.NewTestTransportController(&suite.state, httpClient)  	// setup module being tested -	federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(suite.db, fedWorker), tc, suite.tc, testrig.NewTestMediaManager(suite.db, suite.storage)) +	federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(&suite.state), tc, suite.tc, testrig.NewTestMediaManager(&suite.state))  	activity, err := federator.FederatingActor().Send(ctx, testrig.URLMustParse(testAccount.OutboxURI), testActivity)  	suite.NoError(err) @@ -105,12 +101,10 @@ func (suite *FederatingActorTestSuite) TestSendRemoteFollower() {  	)  	testActivity := testrig.WrapAPNoteInCreate(testrig.URLMustParse("http://localhost:8080/whatever_some_create"), testrig.URLMustParse(testAccount.URI), testrig.TimeMustParse("2022-06-02T12:22:21+02:00"), testNote) -	fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) -  	httpClient := testrig.NewMockHTTPClient(nil, "../../testrig/media") -	tc := testrig.NewTestTransportController(httpClient, suite.db, fedWorker) +	tc := testrig.NewTestTransportController(&suite.state, httpClient)  	// setup module being tested -	federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(suite.db, fedWorker), tc, suite.tc, testrig.NewTestMediaManager(suite.db, suite.storage)) +	federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(&suite.state), tc, suite.tc, testrig.NewTestMediaManager(&suite.state))  	activity, err := federator.FederatingActor().Send(ctx, testrig.URLMustParse(testAccount.OutboxURI), testActivity)  	suite.NoError(err) diff --git a/internal/federation/federatingdb/accept.go b/internal/federation/federatingdb/accept.go index d3e227a10..184d2b09d 100644 --- a/internal/federation/federatingdb/accept.go +++ b/internal/federation/federatingdb/accept.go @@ -65,7 +65,7 @@ func (f *federatingDB) Accept(ctx context.Context, accept vocab.ActivityStreamsA  			if uris.IsFollowPath(acceptedObjectIRI) {  				// ACCEPT FOLLOW  				gtsFollowRequest := >smodel.FollowRequest{} -				if err := f.db.GetWhere(ctx, []db.Where{{Key: "uri", Value: acceptedObjectIRI.String()}}, gtsFollowRequest); err != nil { +				if err := f.state.DB.GetWhere(ctx, []db.Where{{Key: "uri", Value: acceptedObjectIRI.String()}}, gtsFollowRequest); err != nil {  					return fmt.Errorf("ACCEPT: couldn't get follow request with id %s from the database: %s", acceptedObjectIRI.String(), err)  				} @@ -73,12 +73,12 @@ func (f *federatingDB) Accept(ctx context.Context, accept vocab.ActivityStreamsA  				if gtsFollowRequest.AccountID != receivingAccount.ID {  					return errors.New("ACCEPT: follow object account and inbox account were not the same")  				} -				follow, err := f.db.AcceptFollowRequest(ctx, gtsFollowRequest.AccountID, gtsFollowRequest.TargetAccountID) +				follow, err := f.state.DB.AcceptFollowRequest(ctx, gtsFollowRequest.AccountID, gtsFollowRequest.TargetAccountID)  				if err != nil {  					return err  				} -				f.fedWorker.Queue(messages.FromFederator{ +				f.state.Workers.EnqueueFederator(ctx, messages.FromFederator{  					APObjectType:     ap.ActivityFollow,  					APActivityType:   ap.ActivityAccept,  					GTSModel:         follow, @@ -108,12 +108,12 @@ func (f *federatingDB) Accept(ctx context.Context, accept vocab.ActivityStreamsA  			if gtsFollow.AccountID != receivingAccount.ID {  				return errors.New("ACCEPT: follow object account and inbox account were not the same")  			} -			follow, err := f.db.AcceptFollowRequest(ctx, gtsFollow.AccountID, gtsFollow.TargetAccountID) +			follow, err := f.state.DB.AcceptFollowRequest(ctx, gtsFollow.AccountID, gtsFollow.TargetAccountID)  			if err != nil {  				return err  			} -			f.fedWorker.Queue(messages.FromFederator{ +			f.state.Workers.EnqueueFederator(ctx, messages.FromFederator{  				APObjectType:     ap.ActivityFollow,  				APActivityType:   ap.ActivityAccept,  				GTSModel:         follow, diff --git a/internal/federation/federatingdb/announce.go b/internal/federation/federatingdb/announce.go index f4d145148..552a95ba9 100644 --- a/internal/federation/federatingdb/announce.go +++ b/internal/federation/federatingdb/announce.go @@ -59,7 +59,7 @@ func (f *federatingDB) Announce(ctx context.Context, announce vocab.ActivityStre  	}  	// it's a new announce so pass it back to the processor async for dereferencing etc -	f.fedWorker.Queue(messages.FromFederator{ +	f.state.Workers.EnqueueFederator(ctx, messages.FromFederator{  		APObjectType:     ap.ActivityAnnounce,  		APActivityType:   ap.ActivityCreate,  		GTSModel:         boost, diff --git a/internal/federation/federatingdb/announce_test.go b/internal/federation/federatingdb/announce_test.go index 6c0d969f4..d9158f383 100644 --- a/internal/federation/federatingdb/announce_test.go +++ b/internal/federation/federatingdb/announce_test.go @@ -25,6 +25,7 @@ import (  	"github.com/superseriousbusiness/activity/streams/vocab"  	"github.com/superseriousbusiness/gotosocial/internal/ap"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +	"github.com/superseriousbusiness/gotosocial/internal/id"  )  type AnnounceTestSuite struct { @@ -74,6 +75,13 @@ func (suite *AnnounceTestSuite) TestAnnounceTwice() {  	suite.True(ok)  	suite.Equal(announcingAccount.ID, boost.AccountID) +	// Insert the boost-of status into the +	// DB cache to emulate processor handling +	boost.ID, _ = id.NewULIDFromTime(boost.CreatedAt) +	suite.state.Caches.GTS.Status().Store(boost, func() error { +		return nil +	}) +  	// only the URI will be set on the boosted status because it still needs to be dereferenced  	suite.NotEmpty(boost.BoostOf.URI) diff --git a/internal/federation/federatingdb/create.go b/internal/federation/federatingdb/create.go index bf3e7f75d..ca87131fe 100644 --- a/internal/federation/federatingdb/create.go +++ b/internal/federation/federatingdb/create.go @@ -103,11 +103,11 @@ func (f *federatingDB) activityBlock(ctx context.Context, asType vocab.Type, rec  	block.ID = id.NewULID() -	if err := f.db.PutBlock(ctx, block); err != nil { +	if err := f.state.DB.PutBlock(ctx, block); err != nil {  		return fmt.Errorf("activityBlock: database error inserting block: %s", err)  	} -	f.fedWorker.Queue(messages.FromFederator{ +	f.state.Workers.EnqueueFederator(ctx, messages.FromFederator{  		APObjectType:     ap.ActivityBlock,  		APActivityType:   ap.ActivityCreate,  		GTSModel:         block, @@ -202,7 +202,7 @@ func (f *federatingDB) createNote(ctx context.Context, note vocab.ActivityStream  			return nil  		}  		// pass the note iri into the processor and have it do the dereferencing instead of doing it here -		f.fedWorker.Queue(messages.FromFederator{ +		f.state.Workers.EnqueueFederator(ctx, messages.FromFederator{  			APObjectType:     ap.ObjectNote,  			APActivityType:   ap.ActivityCreate,  			APIri:            id.GetIRI(), @@ -226,7 +226,7 @@ func (f *federatingDB) createNote(ctx context.Context, note vocab.ActivityStream  	}  	status.ID = statusID -	if err := f.db.PutStatus(ctx, status); err != nil { +	if err := f.state.DB.PutStatus(ctx, status); err != nil {  		if errors.Is(err, db.ErrAlreadyExists) {  			// the status already exists in the database, which means we've already handled everything else,  			// so we can just return nil here and be done with it. @@ -236,7 +236,7 @@ func (f *federatingDB) createNote(ctx context.Context, note vocab.ActivityStream  		return fmt.Errorf("createNote: database error inserting status: %s", err)  	} -	f.fedWorker.Queue(messages.FromFederator{ +	f.state.Workers.EnqueueFederator(ctx, messages.FromFederator{  		APObjectType:     ap.ObjectNote,  		APActivityType:   ap.ActivityCreate,  		GTSModel:         status, @@ -263,11 +263,11 @@ func (f *federatingDB) activityFollow(ctx context.Context, asType vocab.Type, re  	followRequest.ID = id.NewULID() -	if err := f.db.Put(ctx, followRequest); err != nil { +	if err := f.state.DB.Put(ctx, followRequest); err != nil {  		return fmt.Errorf("activityFollow: database error inserting follow request: %s", err)  	} -	f.fedWorker.Queue(messages.FromFederator{ +	f.state.Workers.EnqueueFederator(ctx, messages.FromFederator{  		APObjectType:     ap.ActivityFollow,  		APActivityType:   ap.ActivityCreate,  		GTSModel:         followRequest, @@ -294,11 +294,11 @@ func (f *federatingDB) activityLike(ctx context.Context, asType vocab.Type, rece  	fave.ID = id.NewULID() -	if err := f.db.Put(ctx, fave); err != nil { +	if err := f.state.DB.Put(ctx, fave); err != nil {  		return fmt.Errorf("activityLike: database error inserting fave: %s", err)  	} -	f.fedWorker.Queue(messages.FromFederator{ +	f.state.Workers.EnqueueFederator(ctx, messages.FromFederator{  		APObjectType:     ap.ActivityLike,  		APActivityType:   ap.ActivityCreate,  		GTSModel:         fave, @@ -325,11 +325,11 @@ func (f *federatingDB) activityFlag(ctx context.Context, asType vocab.Type, rece  	report.ID = id.NewULID() -	if err := f.db.PutReport(ctx, report); err != nil { +	if err := f.state.DB.PutReport(ctx, report); err != nil {  		return fmt.Errorf("activityFlag: database error inserting report: %w", err)  	} -	f.fedWorker.Queue(messages.FromFederator{ +	f.state.Workers.EnqueueFederator(ctx, messages.FromFederator{  		APObjectType:     ap.ActivityFlag,  		APActivityType:   ap.ActivityCreate,  		GTSModel:         report, diff --git a/internal/federation/federatingdb/db.go b/internal/federation/federatingdb/db.go index 24455a553..af4aceeeb 100644 --- a/internal/federation/federatingdb/db.go +++ b/internal/federation/federatingdb/db.go @@ -24,9 +24,7 @@ import (  	"codeberg.org/gruf/go-mutexes"  	"github.com/superseriousbusiness/activity/pub"  	"github.com/superseriousbusiness/activity/streams/vocab" -	"github.com/superseriousbusiness/gotosocial/internal/concurrency" -	"github.com/superseriousbusiness/gotosocial/internal/db" -	"github.com/superseriousbusiness/gotosocial/internal/messages" +	"github.com/superseriousbusiness/gotosocial/internal/state"  	"github.com/superseriousbusiness/gotosocial/internal/typeutils"  ) @@ -43,17 +41,15 @@ type DB interface {  // It doesn't care what the underlying implementation of the DB interface is, as long as it works.  type federatingDB struct {  	locks         mutexes.MutexMap -	db            db.DB -	fedWorker     *concurrency.WorkerPool[messages.FromFederator] +	state         *state.State  	typeConverter typeutils.TypeConverter  }  // New returns a DB interface using the given database and config -func New(db db.DB, fedWorker *concurrency.WorkerPool[messages.FromFederator], tc typeutils.TypeConverter) DB { +func New(state *state.State, tc typeutils.TypeConverter) DB {  	fdb := federatingDB{  		locks:         mutexes.NewMap(-1, -1), // use defaults -		db:            db, -		fedWorker:     fedWorker, +		state:         state,  		typeConverter: tc,  	}  	return &fdb diff --git a/internal/federation/federatingdb/delete.go b/internal/federation/federatingdb/delete.go index a1890b56b..695f199b4 100644 --- a/internal/federation/federatingdb/delete.go +++ b/internal/federation/federatingdb/delete.go @@ -51,9 +51,9 @@ func (f *federatingDB) Delete(ctx context.Context, id *url.URL) error {  	// in a delete we only get the URI, we can't know if we have a status or a profile or something else,  	// so we have to try a few different things... -	if s, err := f.db.GetStatusByURI(ctx, id.String()); err == nil && requestingAccount.ID == s.AccountID { +	if s, err := f.state.DB.GetStatusByURI(ctx, id.String()); err == nil && requestingAccount.ID == s.AccountID {  		l.Debugf("uri is for STATUS with id: %s", s.ID) -		f.fedWorker.Queue(messages.FromFederator{ +		f.state.Workers.EnqueueFederator(ctx, messages.FromFederator{  			APObjectType:     ap.ObjectNote,  			APActivityType:   ap.ActivityDelete,  			GTSModel:         s, @@ -61,9 +61,9 @@ func (f *federatingDB) Delete(ctx context.Context, id *url.URL) error {  		})  	} -	if a, err := f.db.GetAccountByURI(ctx, id.String()); err == nil && requestingAccount.ID == a.ID { +	if a, err := f.state.DB.GetAccountByURI(ctx, id.String()); err == nil && requestingAccount.ID == a.ID {  		l.Debugf("uri is for ACCOUNT with id %s", a.ID) -		f.fedWorker.Queue(messages.FromFederator{ +		f.state.Workers.EnqueueFederator(ctx, messages.FromFederator{  			APObjectType:     ap.ObjectProfile,  			APActivityType:   ap.ActivityDelete,  			GTSModel:         a, diff --git a/internal/federation/federatingdb/federatingdb_test.go b/internal/federation/federatingdb/federatingdb_test.go index dd5a5f5f9..b0893f246 100644 --- a/internal/federation/federatingdb/federatingdb_test.go +++ b/internal/federation/federatingdb/federatingdb_test.go @@ -23,11 +23,11 @@ import (  	"github.com/stretchr/testify/suite"  	"github.com/superseriousbusiness/gotosocial/internal/ap" -	"github.com/superseriousbusiness/gotosocial/internal/concurrency"  	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/federation/federatingdb"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"  	"github.com/superseriousbusiness/gotosocial/internal/messages" +	"github.com/superseriousbusiness/gotosocial/internal/state"  	"github.com/superseriousbusiness/gotosocial/internal/typeutils"  	"github.com/superseriousbusiness/gotosocial/testrig"  ) @@ -36,9 +36,9 @@ type FederatingDBTestSuite struct {  	suite.Suite  	db            db.DB  	tc            typeutils.TypeConverter -	fedWorker     *concurrency.WorkerPool[messages.FromFederator]  	fromFederator chan messages.FromFederator  	federatingDB  federatingdb.DB +	state         state.State  	testTokens       map[string]*gtsmodel.Token  	testClients      map[string]*gtsmodel.Client @@ -66,22 +66,33 @@ func (suite *FederatingDBTestSuite) SetupTest() {  	testrig.InitTestConfig()  	testrig.InitTestLog() -	suite.fedWorker = concurrency.NewWorkerPool[messages.FromFederator](-1, -1) +	suite.state.Caches.Init() +	testrig.StartWorkers(&suite.state) +  	suite.fromFederator = make(chan messages.FromFederator, 10) -	suite.fedWorker.SetProcessor(func(ctx context.Context, msg messages.FromFederator) error { +	suite.state.Workers.EnqueueFederator = func(ctx context.Context, msg messages.FromFederator) {  		suite.fromFederator <- msg -		return nil -	}) -	_ = suite.fedWorker.Start() -	suite.db = testrig.NewTestDB() +	} + +	suite.db = testrig.NewTestDB(&suite.state)  	suite.testActivities = testrig.NewTestActivities(suite.testAccounts)  	suite.tc = testrig.NewTestTypeConverter(suite.db) -	suite.federatingDB = testrig.NewTestFederatingDB(suite.db, suite.fedWorker) +	suite.federatingDB = testrig.NewTestFederatingDB(&suite.state)  	testrig.StandardDBSetup(suite.db, suite.testAccounts) + +	suite.state.DB = suite.db  }  func (suite *FederatingDBTestSuite) TearDownTest() {  	testrig.StandardDBTeardown(suite.db) +	testrig.StopWorkers(&suite.state) +	for suite.fromFederator != nil { +		select { +		case <-suite.fromFederator: +		default: +			return +		} +	}  }  func createTestContext(receivingAccount *gtsmodel.Account, requestingAccount *gtsmodel.Account) context.Context { diff --git a/internal/federation/federatingdb/followers.go b/internal/federation/federatingdb/followers.go index c47a2b625..69746c99b 100644 --- a/internal/federation/federatingdb/followers.go +++ b/internal/federation/federatingdb/followers.go @@ -29,7 +29,7 @@ func (f *federatingDB) Followers(ctx context.Context, actorIRI *url.URL) (follow  		return nil, err  	} -	acctFollowers, err := f.db.GetAccountFollowedBy(ctx, acct.ID, false) +	acctFollowers, err := f.state.DB.GetAccountFollowedBy(ctx, acct.ID, false)  	if err != nil {  		return nil, fmt.Errorf("Followers: db error getting followers for account id %s: %s", acct.ID, err)  	} @@ -37,7 +37,7 @@ func (f *federatingDB) Followers(ctx context.Context, actorIRI *url.URL) (follow  	iris := []*url.URL{}  	for _, follow := range acctFollowers {  		if follow.Account == nil { -			a, err := f.db.GetAccountByID(ctx, follow.AccountID) +			a, err := f.state.DB.GetAccountByID(ctx, follow.AccountID)  			if err != nil {  				errWrapped := fmt.Errorf("Followers: db error getting account id %s: %s", follow.AccountID, err)  				if err == db.ErrNoEntries { diff --git a/internal/federation/federatingdb/following.go b/internal/federation/federatingdb/following.go index f4f07bb25..9c22c0574 100644 --- a/internal/federation/federatingdb/following.go +++ b/internal/federation/federatingdb/following.go @@ -47,7 +47,7 @@ func (f *federatingDB) Following(ctx context.Context, actorIRI *url.URL) (follow  		return nil, err  	} -	acctFollowing, err := f.db.GetAccountFollows(ctx, acct.ID) +	acctFollowing, err := f.state.DB.GetAccountFollows(ctx, acct.ID)  	if err != nil {  		return nil, fmt.Errorf("Following: db error getting following for account id %s: %s", acct.ID, err)  	} @@ -55,7 +55,7 @@ func (f *federatingDB) Following(ctx context.Context, actorIRI *url.URL) (follow  	iris := []*url.URL{}  	for _, follow := range acctFollowing {  		if follow.TargetAccount == nil { -			a, err := f.db.GetAccountByID(ctx, follow.TargetAccountID) +			a, err := f.state.DB.GetAccountByID(ctx, follow.TargetAccountID)  			if err != nil {  				errWrapped := fmt.Errorf("Following: db error getting account id %s: %s", follow.TargetAccountID, err)  				if err == db.ErrNoEntries { diff --git a/internal/federation/federatingdb/get.go b/internal/federation/federatingdb/get.go index 92a79d70f..1d687f110 100644 --- a/internal/federation/federatingdb/get.go +++ b/internal/federation/federatingdb/get.go @@ -39,13 +39,13 @@ func (f *federatingDB) Get(ctx context.Context, id *url.URL) (value vocab.Type,  	switch {  	case uris.IsUserPath(id): -		acct, err := f.db.GetAccountByURI(ctx, id.String()) +		acct, err := f.state.DB.GetAccountByURI(ctx, id.String())  		if err != nil {  			return nil, err  		}  		return f.typeConverter.AccountToAS(ctx, acct)  	case uris.IsStatusesPath(id): -		status, err := f.db.GetStatusByURI(ctx, id.String()) +		status, err := f.state.DB.GetStatusByURI(ctx, id.String())  		if err != nil {  			return nil, err  		} diff --git a/internal/federation/federatingdb/inbox.go b/internal/federation/federatingdb/inbox.go index 5ec735bd4..1a6da4ef0 100644 --- a/internal/federation/federatingdb/inbox.go +++ b/internal/federation/federatingdb/inbox.go @@ -85,12 +85,12 @@ func (f *federatingDB) InboxesForIRI(c context.Context, iri *url.URL) (inboxIRIs  			return nil, fmt.Errorf("couldn't extract local account username from uri %s: %s", iri, err)  		} -		account, err := f.db.GetAccountByUsernameDomain(c, localAccountUsername, "") +		account, err := f.state.DB.GetAccountByUsernameDomain(c, localAccountUsername, "")  		if err != nil {  			return nil, fmt.Errorf("couldn't find local account with username %s: %s", localAccountUsername, err)  		} -		follows, err := f.db.GetAccountFollowedBy(c, account.ID, false) +		follows, err := f.state.DB.GetAccountFollowedBy(c, account.ID, false)  		if err != nil {  			return nil, fmt.Errorf("couldn't get followers of local account %s: %s", localAccountUsername, err)  		} @@ -98,7 +98,7 @@ func (f *federatingDB) InboxesForIRI(c context.Context, iri *url.URL) (inboxIRIs  		for _, follow := range follows {  			// make sure we retrieved the following account from the db  			if follow.Account == nil { -				followingAccount, err := f.db.GetAccountByID(c, follow.AccountID) +				followingAccount, err := f.state.DB.GetAccountByID(c, follow.AccountID)  				if err != nil {  					if err == db.ErrNoEntries {  						continue @@ -126,7 +126,7 @@ func (f *federatingDB) InboxesForIRI(c context.Context, iri *url.URL) (inboxIRIs  	}  	// check if this is just an account IRI... -	if account, err := f.db.GetAccountByURI(c, iri.String()); err == nil { +	if account, err := f.state.DB.GetAccountByURI(c, iri.String()); err == nil {  		// deliver to a shared inbox if we have that option  		var inbox string  		if config.GetInstanceDeliverToSharedInboxes() && account.SharedInboxURI != nil && *account.SharedInboxURI != "" { diff --git a/internal/federation/federatingdb/owns.go b/internal/federation/federatingdb/owns.go index def0fa518..2c11e8148 100644 --- a/internal/federation/federatingdb/owns.go +++ b/internal/federation/federatingdb/owns.go @@ -54,7 +54,7 @@ func (f *federatingDB) Owns(ctx context.Context, id *url.URL) (bool, error) {  		if err != nil {  			return false, fmt.Errorf("error parsing statuses path for url %s: %s", id.String(), err)  		} -		status, err := f.db.GetStatusByURI(ctx, uid) +		status, err := f.state.DB.GetStatusByURI(ctx, uid)  		if err != nil {  			if err == db.ErrNoEntries {  				// there are no entries for this status @@ -71,7 +71,7 @@ func (f *federatingDB) Owns(ctx context.Context, id *url.URL) (bool, error) {  		if err != nil {  			return false, fmt.Errorf("error parsing statuses path for url %s: %s", id.String(), err)  		} -		if _, err := f.db.GetAccountByUsernameDomain(ctx, username, ""); err != nil { +		if _, err := f.state.DB.GetAccountByUsernameDomain(ctx, username, ""); err != nil {  			if err == db.ErrNoEntries {  				// there are no entries for this username  				return false, nil @@ -88,7 +88,7 @@ func (f *federatingDB) Owns(ctx context.Context, id *url.URL) (bool, error) {  		if err != nil {  			return false, fmt.Errorf("error parsing statuses path for url %s: %s", id.String(), err)  		} -		if _, err := f.db.GetAccountByUsernameDomain(ctx, username, ""); err != nil { +		if _, err := f.state.DB.GetAccountByUsernameDomain(ctx, username, ""); err != nil {  			if err == db.ErrNoEntries {  				// there are no entries for this username  				return false, nil @@ -105,7 +105,7 @@ func (f *federatingDB) Owns(ctx context.Context, id *url.URL) (bool, error) {  		if err != nil {  			return false, fmt.Errorf("error parsing statuses path for url %s: %s", id.String(), err)  		} -		if _, err := f.db.GetAccountByUsernameDomain(ctx, username, ""); err != nil { +		if _, err := f.state.DB.GetAccountByUsernameDomain(ctx, username, ""); err != nil {  			if err == db.ErrNoEntries {  				// there are no entries for this username  				return false, nil @@ -122,7 +122,7 @@ func (f *federatingDB) Owns(ctx context.Context, id *url.URL) (bool, error) {  		if err != nil {  			return false, fmt.Errorf("error parsing like path for url %s: %s", id.String(), err)  		} -		if _, err := f.db.GetAccountByUsernameDomain(ctx, username, ""); err != nil { +		if _, err := f.state.DB.GetAccountByUsernameDomain(ctx, username, ""); err != nil {  			if err == db.ErrNoEntries {  				// there are no entries for this username  				return false, nil @@ -130,7 +130,7 @@ func (f *federatingDB) Owns(ctx context.Context, id *url.URL) (bool, error) {  			// an actual error happened  			return false, fmt.Errorf("database error fetching account with username %s: %s", username, err)  		} -		if err := f.db.GetByID(ctx, likeID, >smodel.StatusFave{}); err != nil { +		if err := f.state.DB.GetByID(ctx, likeID, >smodel.StatusFave{}); err != nil {  			if err == db.ErrNoEntries {  				// there are no entries  				return false, nil @@ -147,7 +147,7 @@ func (f *federatingDB) Owns(ctx context.Context, id *url.URL) (bool, error) {  		if err != nil {  			return false, fmt.Errorf("error parsing block path for url %s: %s", id.String(), err)  		} -		if _, err := f.db.GetAccountByUsernameDomain(ctx, username, ""); err != nil { +		if _, err := f.state.DB.GetAccountByUsernameDomain(ctx, username, ""); err != nil {  			if err == db.ErrNoEntries {  				// there are no entries for this username  				return false, nil @@ -155,7 +155,7 @@ func (f *federatingDB) Owns(ctx context.Context, id *url.URL) (bool, error) {  			// an actual error happened  			return false, fmt.Errorf("database error fetching account with username %s: %s", username, err)  		} -		if err := f.db.GetByID(ctx, blockID, >smodel.Block{}); err != nil { +		if err := f.state.DB.GetByID(ctx, blockID, >smodel.Block{}); err != nil {  			if err == db.ErrNoEntries {  				// there are no entries  				return false, nil diff --git a/internal/federation/federatingdb/reject.go b/internal/federation/federatingdb/reject.go index 3c3cd7c75..d443cd6cb 100644 --- a/internal/federation/federatingdb/reject.go +++ b/internal/federation/federatingdb/reject.go @@ -64,7 +64,7 @@ func (f *federatingDB) Reject(ctx context.Context, reject vocab.ActivityStreamsR  			if uris.IsFollowPath(rejectedObjectIRI) {  				// REJECT FOLLOW  				gtsFollowRequest := >smodel.FollowRequest{} -				if err := f.db.GetWhere(ctx, []db.Where{{Key: "uri", Value: rejectedObjectIRI.String()}}, gtsFollowRequest); err != nil { +				if err := f.state.DB.GetWhere(ctx, []db.Where{{Key: "uri", Value: rejectedObjectIRI.String()}}, gtsFollowRequest); err != nil {  					return fmt.Errorf("Reject: couldn't get follow request with id %s from the database: %s", rejectedObjectIRI.String(), err)  				} @@ -73,7 +73,7 @@ func (f *federatingDB) Reject(ctx context.Context, reject vocab.ActivityStreamsR  					return errors.New("Reject: follow object account and inbox account were not the same")  				} -				if _, err := f.db.RejectFollowRequest(ctx, gtsFollowRequest.AccountID, gtsFollowRequest.TargetAccountID); err != nil { +				if _, err := f.state.DB.RejectFollowRequest(ctx, gtsFollowRequest.AccountID, gtsFollowRequest.TargetAccountID); err != nil {  					return err  				} @@ -102,7 +102,7 @@ func (f *federatingDB) Reject(ctx context.Context, reject vocab.ActivityStreamsR  			if gtsFollow.AccountID != receivingAccount.ID {  				return errors.New("Reject: follow object account and inbox account were not the same")  			} -			if _, err := f.db.RejectFollowRequest(ctx, gtsFollow.AccountID, gtsFollow.TargetAccountID); err != nil { +			if _, err := f.state.DB.RejectFollowRequest(ctx, gtsFollow.AccountID, gtsFollow.TargetAccountID); err != nil {  				return err  			} diff --git a/internal/federation/federatingdb/undo.go b/internal/federation/federatingdb/undo.go index b239aabb4..e33b365fa 100644 --- a/internal/federation/federatingdb/undo.go +++ b/internal/federation/federatingdb/undo.go @@ -81,11 +81,11 @@ func (f *federatingDB) Undo(ctx context.Context, undo vocab.ActivityStreamsUndo)  				return errors.New("UNDO: follow object account and inbox account were not the same")  			}  			// delete any existing FOLLOW -			if err := f.db.DeleteWhere(ctx, []db.Where{{Key: "uri", Value: gtsFollow.URI}}, >smodel.Follow{}); err != nil { +			if err := f.state.DB.DeleteWhere(ctx, []db.Where{{Key: "uri", Value: gtsFollow.URI}}, >smodel.Follow{}); err != nil {  				return fmt.Errorf("UNDO: db error removing follow: %s", err)  			}  			// delete any existing FOLLOW REQUEST -			if err := f.db.DeleteWhere(ctx, []db.Where{{Key: "uri", Value: gtsFollow.URI}}, >smodel.FollowRequest{}); err != nil { +			if err := f.state.DB.DeleteWhere(ctx, []db.Where{{Key: "uri", Value: gtsFollow.URI}}, >smodel.FollowRequest{}); err != nil {  				return fmt.Errorf("UNDO: db error removing follow request: %s", err)  			}  			l.Debug("follow undone") @@ -114,7 +114,7 @@ func (f *federatingDB) Undo(ctx context.Context, undo vocab.ActivityStreamsUndo)  				return errors.New("UNDO: block object account and inbox account were not the same")  			}  			// delete any existing BLOCK -			if err := f.db.DeleteBlockByURI(ctx, gtsBlock.URI); err != nil { +			if err := f.state.DB.DeleteBlockByURI(ctx, gtsBlock.URI); err != nil {  				return fmt.Errorf("UNDO: db error removing block: %s", err)  			}  			l.Debug("block undone") diff --git a/internal/federation/federatingdb/update.go b/internal/federation/federatingdb/update.go index 570729a31..bed5de4db 100644 --- a/internal/federation/federatingdb/update.go +++ b/internal/federation/federatingdb/update.go @@ -138,7 +138,7 @@ func (f *federatingDB) Update(ctx context.Context, asType vocab.Type) error {  		// pass to the processor for further updating of eg., avatar/header, emojis  		// the actual db insert/update will take place a bit later -		f.fedWorker.Queue(messages.FromFederator{ +		f.state.Workers.EnqueueFederator(ctx, messages.FromFederator{  			APObjectType:     ap.ObjectProfile,  			APActivityType:   ap.ActivityUpdate,  			GTSModel:         updatedAcct, diff --git a/internal/federation/federatingdb/util.go b/internal/federation/federatingdb/util.go index 64f32d39c..f63eb6dc9 100644 --- a/internal/federation/federatingdb/util.go +++ b/internal/federation/federatingdb/util.go @@ -95,7 +95,7 @@ func (f *federatingDB) NewID(ctx context.Context, t vocab.Type) (idURL *url.URL,  				// take the IRI of the first actor we can find (there should only be one)  				if iter.IsIRI() {  					// if there's an error here, just use the fallback behavior -- we don't need to return an error here -					if actorAccount, err := f.db.GetAccountByURI(ctx, iter.GetIRI().String()); err == nil { +					if actorAccount, err := f.state.DB.GetAccountByURI(ctx, iter.GetIRI().String()); err == nil {  						newID, err := id.NewRandomULID()  						if err != nil {  							return nil, err @@ -238,7 +238,7 @@ func (f *federatingDB) getAccountForIRI(ctx context.Context, iri *url.URL) (*gts  	switch {  	case uris.IsUserPath(iri): -		if acct, err = f.db.GetAccountByURI(ctx, iri.String()); err != nil { +		if acct, err = f.state.DB.GetAccountByURI(ctx, iri.String()); err != nil {  			if err == db.ErrNoEntries {  				return nil, fmt.Errorf("no actor found that corresponds to uri %s", iri.String())  			} @@ -246,7 +246,7 @@ func (f *federatingDB) getAccountForIRI(ctx context.Context, iri *url.URL) (*gts  		}  		return acct, nil  	case uris.IsInboxPath(iri): -		if err = f.db.GetWhere(ctx, []db.Where{{Key: "inbox_uri", Value: iri.String()}}, acct); err != nil { +		if err = f.state.DB.GetWhere(ctx, []db.Where{{Key: "inbox_uri", Value: iri.String()}}, acct); err != nil {  			if err == db.ErrNoEntries {  				return nil, fmt.Errorf("no actor found that corresponds to inbox %s", iri.String())  			} @@ -254,7 +254,7 @@ func (f *federatingDB) getAccountForIRI(ctx context.Context, iri *url.URL) (*gts  		}  		return acct, nil  	case uris.IsOutboxPath(iri): -		if err = f.db.GetWhere(ctx, []db.Where{{Key: "outbox_uri", Value: iri.String()}}, acct); err != nil { +		if err = f.state.DB.GetWhere(ctx, []db.Where{{Key: "outbox_uri", Value: iri.String()}}, acct); err != nil {  			if err == db.ErrNoEntries {  				return nil, fmt.Errorf("no actor found that corresponds to outbox %s", iri.String())  			} @@ -262,7 +262,7 @@ func (f *federatingDB) getAccountForIRI(ctx context.Context, iri *url.URL) (*gts  		}  		return acct, nil  	case uris.IsFollowersPath(iri): -		if err = f.db.GetWhere(ctx, []db.Where{{Key: "followers_uri", Value: iri.String()}}, acct); err != nil { +		if err = f.state.DB.GetWhere(ctx, []db.Where{{Key: "followers_uri", Value: iri.String()}}, acct); err != nil {  			if err == db.ErrNoEntries {  				return nil, fmt.Errorf("no actor found that corresponds to followers_uri %s", iri.String())  			} @@ -270,7 +270,7 @@ func (f *federatingDB) getAccountForIRI(ctx context.Context, iri *url.URL) (*gts  		}  		return acct, nil  	case uris.IsFollowingPath(iri): -		if err = f.db.GetWhere(ctx, []db.Where{{Key: "following_uri", Value: iri.String()}}, acct); err != nil { +		if err = f.state.DB.GetWhere(ctx, []db.Where{{Key: "following_uri", Value: iri.String()}}, acct); err != nil {  			if err == db.ErrNoEntries {  				return nil, fmt.Errorf("no actor found that corresponds to following_uri %s", iri.String())  			} diff --git a/internal/federation/federatingprotocol_test.go b/internal/federation/federatingprotocol_test.go index faa168a71..e66cd78cb 100644 --- a/internal/federation/federatingprotocol_test.go +++ b/internal/federation/federatingprotocol_test.go @@ -28,10 +28,8 @@ import (  	"github.com/go-fed/httpsig"  	"github.com/stretchr/testify/suite"  	"github.com/superseriousbusiness/gotosocial/internal/ap" -	"github.com/superseriousbusiness/gotosocial/internal/concurrency"  	"github.com/superseriousbusiness/gotosocial/internal/federation"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" -	"github.com/superseriousbusiness/gotosocial/internal/messages"  	"github.com/superseriousbusiness/gotosocial/testrig"  ) @@ -43,12 +41,10 @@ func (suite *FederatingProtocolTestSuite) TestPostInboxRequestBodyHook1() {  	// the activity we're gonna use  	activity := suite.testActivities["dm_for_zork"] -	fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) -  	httpClient := testrig.NewMockHTTPClient(nil, "../../testrig/media") -	tc := testrig.NewTestTransportController(httpClient, suite.db, fedWorker) +	tc := testrig.NewTestTransportController(&suite.state, httpClient)  	// setup module being tested -	federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(suite.db, fedWorker), tc, suite.tc, testrig.NewTestMediaManager(suite.db, suite.storage)) +	federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(&suite.state), tc, suite.tc, testrig.NewTestMediaManager(&suite.state))  	// setup request  	ctx := context.Background() @@ -74,13 +70,11 @@ func (suite *FederatingProtocolTestSuite) TestPostInboxRequestBodyHook2() {  	// the activity we're gonna use  	activity := suite.testActivities["reply_to_turtle_for_zork"] -	fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) -  	httpClient := testrig.NewMockHTTPClient(nil, "../../testrig/media") -	tc := testrig.NewTestTransportController(httpClient, suite.db, fedWorker) +	tc := testrig.NewTestTransportController(&suite.state, httpClient)  	// setup module being tested -	federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(suite.db, fedWorker), tc, suite.tc, testrig.NewTestMediaManager(suite.db, suite.storage)) +	federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(&suite.state), tc, suite.tc, testrig.NewTestMediaManager(&suite.state))  	// setup request  	ctx := context.Background() @@ -107,13 +101,11 @@ func (suite *FederatingProtocolTestSuite) TestPostInboxRequestBodyHook3() {  	// the activity we're gonna use  	activity := suite.testActivities["reply_to_turtle_for_turtle"] -	fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) -  	httpClient := testrig.NewMockHTTPClient(nil, "../../testrig/media") -	tc := testrig.NewTestTransportController(httpClient, suite.db, fedWorker) +	tc := testrig.NewTestTransportController(&suite.state, httpClient)  	// setup module being tested -	federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(suite.db, fedWorker), tc, suite.tc, testrig.NewTestMediaManager(suite.db, suite.storage)) +	federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(&suite.state), tc, suite.tc, testrig.NewTestMediaManager(&suite.state))  	// setup request  	ctx := context.Background() @@ -142,13 +134,11 @@ func (suite *FederatingProtocolTestSuite) TestAuthenticatePostInbox() {  	sendingAccount := suite.testAccounts["remote_account_1"]  	inboxAccount := suite.testAccounts["local_account_1"] -	fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) -  	httpClient := testrig.NewMockHTTPClient(nil, "../../testrig/media") -	tc := testrig.NewTestTransportController(httpClient, suite.db, fedWorker) +	tc := testrig.NewTestTransportController(&suite.state, httpClient)  	// now setup module being tested, with the mock transport controller -	federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(suite.db, fedWorker), tc, suite.tc, testrig.NewTestMediaManager(suite.db, suite.storage)) +	federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(&suite.state), tc, suite.tc, testrig.NewTestMediaManager(&suite.state))  	request := httptest.NewRequest(http.MethodPost, "http://localhost:8080/users/the_mighty_zork/inbox", nil)  	// we need these headers for the request to be validated @@ -187,13 +177,11 @@ func (suite *FederatingProtocolTestSuite) TestAuthenticatePostGone() {  	activity := suite.testActivities["delete_https://somewhere.mysterious/users/rest_in_piss#main-key"]  	inboxAccount := suite.testAccounts["local_account_1"] -	fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) -  	httpClient := testrig.NewMockHTTPClient(nil, "../../testrig/media") -	tc := testrig.NewTestTransportController(httpClient, suite.db, fedWorker) +	tc := testrig.NewTestTransportController(&suite.state, httpClient)  	// now setup module being tested, with the mock transport controller -	federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(suite.db, fedWorker), tc, suite.tc, testrig.NewTestMediaManager(suite.db, suite.storage)) +	federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(&suite.state), tc, suite.tc, testrig.NewTestMediaManager(&suite.state))  	request := httptest.NewRequest(http.MethodPost, "http://localhost:8080/users/the_mighty_zork/inbox", nil)  	// we need these headers for the request to be validated @@ -231,13 +219,11 @@ func (suite *FederatingProtocolTestSuite) TestAuthenticatePostGoneNoTombstoneYet  	activity := suite.testActivities["delete_https://somewhere.mysterious/users/rest_in_piss#main-key"]  	inboxAccount := suite.testAccounts["local_account_1"] -	fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) -  	httpClient := testrig.NewMockHTTPClient(nil, "../../testrig/media") -	tc := testrig.NewTestTransportController(httpClient, suite.db, fedWorker) +	tc := testrig.NewTestTransportController(&suite.state, httpClient)  	// now setup module being tested, with the mock transport controller -	federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(suite.db, fedWorker), tc, suite.tc, testrig.NewTestMediaManager(suite.db, suite.storage)) +	federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(&suite.state), tc, suite.tc, testrig.NewTestMediaManager(&suite.state))  	request := httptest.NewRequest(http.MethodPost, "http://localhost:8080/users/the_mighty_zork/inbox", nil)  	// we need these headers for the request to be validated @@ -271,10 +257,9 @@ func (suite *FederatingProtocolTestSuite) TestAuthenticatePostGoneNoTombstoneYet  }  func (suite *FederatingProtocolTestSuite) TestBlocked1() { -	fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)  	httpClient := testrig.NewMockHTTPClient(nil, "../../testrig/media") -	tc := testrig.NewTestTransportController(httpClient, suite.db, fedWorker) -	federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(suite.db, fedWorker), tc, suite.tc, testrig.NewTestMediaManager(suite.db, suite.storage)) +	tc := testrig.NewTestTransportController(&suite.state, httpClient) +	federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(&suite.state), tc, suite.tc, testrig.NewTestMediaManager(&suite.state))  	sendingAccount := suite.testAccounts["remote_account_1"]  	inboxAccount := suite.testAccounts["local_account_1"] @@ -294,10 +279,9 @@ func (suite *FederatingProtocolTestSuite) TestBlocked1() {  }  func (suite *FederatingProtocolTestSuite) TestBlocked2() { -	fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)  	httpClient := testrig.NewMockHTTPClient(nil, "../../testrig/media") -	tc := testrig.NewTestTransportController(httpClient, suite.db, fedWorker) -	federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(suite.db, fedWorker), tc, suite.tc, testrig.NewTestMediaManager(suite.db, suite.storage)) +	tc := testrig.NewTestTransportController(&suite.state, httpClient) +	federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(&suite.state), tc, suite.tc, testrig.NewTestMediaManager(&suite.state))  	sendingAccount := suite.testAccounts["remote_account_1"]  	inboxAccount := suite.testAccounts["local_account_1"] @@ -328,10 +312,9 @@ func (suite *FederatingProtocolTestSuite) TestBlocked2() {  }  func (suite *FederatingProtocolTestSuite) TestBlocked3() { -	fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)  	httpClient := testrig.NewMockHTTPClient(nil, "../../testrig/media") -	tc := testrig.NewTestTransportController(httpClient, suite.db, fedWorker) -	federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(suite.db, fedWorker), tc, suite.tc, testrig.NewTestMediaManager(suite.db, suite.storage)) +	tc := testrig.NewTestTransportController(&suite.state, httpClient) +	federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(&suite.state), tc, suite.tc, testrig.NewTestMediaManager(&suite.state))  	sendingAccount := suite.testAccounts["remote_account_1"]  	inboxAccount := suite.testAccounts["local_account_1"] @@ -365,10 +348,9 @@ func (suite *FederatingProtocolTestSuite) TestBlocked3() {  }  func (suite *FederatingProtocolTestSuite) TestBlocked4() { -	fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)  	httpClient := testrig.NewMockHTTPClient(nil, "../../testrig/media") -	tc := testrig.NewTestTransportController(httpClient, suite.db, fedWorker) -	federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(suite.db, fedWorker), tc, suite.tc, testrig.NewTestMediaManager(suite.db, suite.storage)) +	tc := testrig.NewTestTransportController(&suite.state, httpClient) +	federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(&suite.state), tc, suite.tc, testrig.NewTestMediaManager(&suite.state))  	sendingAccount := suite.testAccounts["remote_account_1"]  	inboxAccount := suite.testAccounts["local_account_1"] diff --git a/internal/federation/federator_test.go b/internal/federation/federator_test.go index da6038ace..8a045aa1f 100644 --- a/internal/federation/federator_test.go +++ b/internal/federation/federator_test.go @@ -23,6 +23,7 @@ import (  	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +	"github.com/superseriousbusiness/gotosocial/internal/state"  	"github.com/superseriousbusiness/gotosocial/internal/storage"  	"github.com/superseriousbusiness/gotosocial/internal/typeutils"  	"github.com/superseriousbusiness/gotosocial/testrig" @@ -32,6 +33,7 @@ type FederatorStandardTestSuite struct {  	suite.Suite  	db             db.DB  	storage        *storage.Driver +	state          state.State  	tc             typeutils.TypeConverter  	testAccounts   map[string]*gtsmodel.Account  	testStatuses   map[string]*gtsmodel.Status @@ -42,8 +44,9 @@ type FederatorStandardTestSuite struct {  // SetupSuite sets some variables on the suite that we can use as consts (more or less) throughout  func (suite *FederatorStandardTestSuite) SetupSuite() {  	// setup standard items +	testrig.StartWorkers(&suite.state)  	suite.storage = testrig.NewInMemoryStorage() -	suite.tc = testrig.NewTestTypeConverter(suite.db) +	suite.state.Storage = suite.storage  	suite.testAccounts = testrig.NewTestAccounts()  	suite.testStatuses = testrig.NewTestStatuses()  	suite.testTombstones = testrig.NewTestTombstones() @@ -52,7 +55,10 @@ func (suite *FederatorStandardTestSuite) SetupSuite() {  func (suite *FederatorStandardTestSuite) SetupTest() {  	testrig.InitTestConfig()  	testrig.InitTestLog() -	suite.db = testrig.NewTestDB() +	suite.state.Caches.Init() +	suite.db = testrig.NewTestDB(&suite.state) +	suite.tc = testrig.NewTestTypeConverter(suite.db) +	suite.state.DB = suite.db  	suite.testActivities = testrig.NewTestActivities(suite.testAccounts)  	testrig.StandardDBSetup(suite.db, suite.testAccounts)  } diff --git a/internal/media/media_test.go b/internal/media/media_test.go index d9f01c1ff..393126ac7 100644 --- a/internal/media/media_test.go +++ b/internal/media/media_test.go @@ -20,11 +20,10 @@ package media_test  import (  	"github.com/stretchr/testify/suite" -	"github.com/superseriousbusiness/gotosocial/internal/concurrency"  	"github.com/superseriousbusiness/gotosocial/internal/db"  	gtsmodel "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"  	"github.com/superseriousbusiness/gotosocial/internal/media" -	"github.com/superseriousbusiness/gotosocial/internal/messages" +	"github.com/superseriousbusiness/gotosocial/internal/state"  	"github.com/superseriousbusiness/gotosocial/internal/storage"  	"github.com/superseriousbusiness/gotosocial/internal/transport"  	"github.com/superseriousbusiness/gotosocial/testrig" @@ -35,6 +34,7 @@ type MediaStandardTestSuite struct {  	db                  db.DB  	storage             *storage.Driver +	state               state.State  	manager             media.Manager  	transportController transport.Controller  	testAttachments     map[string]*gtsmodel.MediaAttachment @@ -46,21 +46,27 @@ func (suite *MediaStandardTestSuite) SetupSuite() {  	testrig.InitTestConfig()  	testrig.InitTestLog() -	suite.db = testrig.NewTestDB() +	suite.db = testrig.NewTestDB(&suite.state)  	suite.storage = testrig.NewInMemoryStorage() +	suite.state.DB = suite.db +	suite.state.Storage = suite.storage  }  func (suite *MediaStandardTestSuite) SetupTest() { +	suite.state.Caches.Init() +	testrig.StartWorkers(&suite.state) +  	testrig.StandardStorageSetup(suite.storage, "../../testrig/media")  	testrig.StandardDBSetup(suite.db, nil)  	suite.testAttachments = testrig.NewTestAttachments()  	suite.testAccounts = testrig.NewTestAccounts()  	suite.testEmojis = testrig.NewTestEmojis() -	suite.manager = testrig.NewTestMediaManager(suite.db, suite.storage) -	suite.transportController = testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../testrig/media"), suite.db, concurrency.NewWorkerPool[messages.FromFederator](0, 0)) +	suite.manager = testrig.NewTestMediaManager(&suite.state) +	suite.transportController = testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../testrig/media"))  }  func (suite *MediaStandardTestSuite) TearDownTest() {  	testrig.StandardDBTeardown(suite.db)  	testrig.StandardStorageTeardown(suite.storage) +	testrig.StopWorkers(&suite.state)  } diff --git a/internal/oauth/clientstore_test.go b/internal/oauth/clientstore_test.go index 92c117bb3..a243383da 100644 --- a/internal/oauth/clientstore_test.go +++ b/internal/oauth/clientstore_test.go @@ -25,6 +25,7 @@ import (  	"github.com/stretchr/testify/suite"  	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/oauth" +	"github.com/superseriousbusiness/gotosocial/internal/state"  	"github.com/superseriousbusiness/gotosocial/testrig"  	"github.com/superseriousbusiness/oauth2/v4/models"  ) @@ -32,6 +33,7 @@ import (  type PgClientStoreTestSuite struct {  	suite.Suite  	db               db.DB +	state            state.State  	testClientID     string  	testClientSecret string  	testClientDomain string @@ -48,9 +50,11 @@ func (suite *PgClientStoreTestSuite) SetupSuite() {  // SetupTest creates a postgres connection and creates the oauth_clients table before each test  func (suite *PgClientStoreTestSuite) SetupTest() { +	suite.state.Caches.Init()  	testrig.InitTestLog()  	testrig.InitTestConfig() -	suite.db = testrig.NewTestDB() +	suite.db = testrig.NewTestDB(&suite.state) +	suite.state.DB = suite.db  	testrig.StandardDBSetup(suite.db, nil)  } diff --git a/internal/processing/account/account.go b/internal/processing/account/account.go index 41315d483..62330c0dc 100644 --- a/internal/processing/account/account.go +++ b/internal/processing/account/account.go @@ -19,13 +19,11 @@  package account  import ( -	"github.com/superseriousbusiness/gotosocial/internal/concurrency" -	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/federation"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"  	"github.com/superseriousbusiness/gotosocial/internal/media" -	"github.com/superseriousbusiness/gotosocial/internal/messages"  	"github.com/superseriousbusiness/gotosocial/internal/oauth" +	"github.com/superseriousbusiness/gotosocial/internal/state"  	"github.com/superseriousbusiness/gotosocial/internal/text"  	"github.com/superseriousbusiness/gotosocial/internal/typeutils"  	"github.com/superseriousbusiness/gotosocial/internal/visibility" @@ -35,35 +33,32 @@ import (  //  // It also contains logic for actions towards accounts such as following, blocking, seeing follows, etc.  type Processor struct { +	state        *state.State  	tc           typeutils.TypeConverter  	mediaManager media.Manager -	clientWorker *concurrency.WorkerPool[messages.FromClientAPI]  	oauthServer  oauth.Server  	filter       visibility.Filter  	formatter    text.Formatter -	db           db.DB  	federator    federation.Federator  	parseMention gtsmodel.ParseMentionFunc  }  // New returns a new account processor.  func New( -	db db.DB, +	state *state.State,  	tc typeutils.TypeConverter,  	mediaManager media.Manager,  	oauthServer oauth.Server, -	clientWorker *concurrency.WorkerPool[messages.FromClientAPI],  	federator federation.Federator,  	parseMention gtsmodel.ParseMentionFunc,  ) Processor {  	return Processor{ +		state:        state,  		tc:           tc,  		mediaManager: mediaManager, -		clientWorker: clientWorker,  		oauthServer:  oauthServer, -		filter:       visibility.NewFilter(db), -		formatter:    text.NewFormatter(db), -		db:           db, +		filter:       visibility.NewFilter(state.DB), +		formatter:    text.NewFormatter(state.DB),  		federator:    federator,  		parseMention: parseMention,  	} diff --git a/internal/processing/account/account_test.go b/internal/processing/account/account_test.go index 2e7cdb994..7a2e5aa8d 100644 --- a/internal/processing/account/account_test.go +++ b/internal/processing/account/account_test.go @@ -22,7 +22,6 @@ import (  	"context"  	"github.com/stretchr/testify/suite" -	"github.com/superseriousbusiness/gotosocial/internal/concurrency"  	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/email"  	"github.com/superseriousbusiness/gotosocial/internal/federation" @@ -32,6 +31,7 @@ import (  	"github.com/superseriousbusiness/gotosocial/internal/oauth"  	"github.com/superseriousbusiness/gotosocial/internal/processing"  	"github.com/superseriousbusiness/gotosocial/internal/processing/account" +	"github.com/superseriousbusiness/gotosocial/internal/state"  	"github.com/superseriousbusiness/gotosocial/internal/storage"  	"github.com/superseriousbusiness/gotosocial/internal/transport"  	"github.com/superseriousbusiness/gotosocial/internal/typeutils" @@ -44,6 +44,7 @@ type AccountStandardTestSuite struct {  	db                  db.DB  	tc                  typeutils.TypeConverter  	storage             *storage.Driver +	state               state.State  	mediaManager        media.Manager  	oauthServer         oauth.Server  	fromClientAPIChan   chan messages.FromClientAPI @@ -76,30 +77,30 @@ func (suite *AccountStandardTestSuite) SetupSuite() {  }  func (suite *AccountStandardTestSuite) SetupTest() { +	suite.state.Caches.Init() +	testrig.StartWorkers(&suite.state) +  	testrig.InitTestLog()  	testrig.InitTestConfig() -	fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) -	clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) -	clientWorker.SetProcessor(func(_ context.Context, msg messages.FromClientAPI) error { -		suite.fromClientAPIChan <- msg -		return nil -	}) - -	_ = fedWorker.Start() -	_ = clientWorker.Start() - -	suite.db = testrig.NewTestDB() +	suite.db = testrig.NewTestDB(&suite.state) +	suite.state.DB = suite.db  	suite.tc = testrig.NewTestTypeConverter(suite.db)  	suite.storage = testrig.NewInMemoryStorage() -	suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage) +	suite.state.Storage = suite.storage +	suite.mediaManager = testrig.NewTestMediaManager(&suite.state)  	suite.oauthServer = testrig.NewTestOauthServer(suite.db) +  	suite.fromClientAPIChan = make(chan messages.FromClientAPI, 100) -	suite.transportController = testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../testrig/media"), suite.db, fedWorker) -	suite.federator = testrig.NewTestFederator(suite.db, suite.transportController, suite.storage, suite.mediaManager, fedWorker) +	suite.state.Workers.EnqueueClientAPI = func(ctx context.Context, msg messages.FromClientAPI) { +		suite.fromClientAPIChan <- msg +	} + +	suite.transportController = testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../testrig/media")) +	suite.federator = testrig.NewTestFederator(&suite.state, suite.transportController, suite.mediaManager)  	suite.sentEmails = make(map[string]string)  	suite.emailSender = testrig.NewEmailSender("../../../web/template/", suite.sentEmails) -	suite.accountProcessor = account.New(suite.db, suite.tc, suite.mediaManager, suite.oauthServer, clientWorker, suite.federator, processing.GetParseMentionFunc(suite.db, suite.federator)) +	suite.accountProcessor = account.New(&suite.state, suite.tc, suite.mediaManager, suite.oauthServer, suite.federator, processing.GetParseMentionFunc(suite.db, suite.federator))  	testrig.StandardDBSetup(suite.db, nil)  	testrig.StandardStorageSetup(suite.storage, "../../../testrig/media")  } @@ -107,4 +108,5 @@ func (suite *AccountStandardTestSuite) SetupTest() {  func (suite *AccountStandardTestSuite) TearDownTest() {  	testrig.StandardDBTeardown(suite.db)  	testrig.StandardStorageTeardown(suite.storage) +	testrig.StopWorkers(&suite.state)  } diff --git a/internal/processing/account/block.go b/internal/processing/account/block.go index 99effd3a3..edec106b1 100644 --- a/internal/processing/account/block.go +++ b/internal/processing/account/block.go @@ -36,13 +36,13 @@ import (  // BlockCreate handles the creation of a block from requestingAccount to targetAccountID, either remote or local.  func (p *Processor) BlockCreate(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) {  	// make sure the target account actually exists in our db -	targetAccount, err := p.db.GetAccountByID(ctx, targetAccountID) +	targetAccount, err := p.state.DB.GetAccountByID(ctx, targetAccountID)  	if err != nil {  		return nil, gtserror.NewErrorNotFound(fmt.Errorf("BlockCreate: error getting account %s from the db: %s", targetAccountID, err))  	}  	// if requestingAccount already blocks target account, we don't need to do anything -	if blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, targetAccountID, false); err != nil { +	if blocked, err := p.state.DB.IsBlocked(ctx, requestingAccount.ID, targetAccountID, false); err != nil {  		return nil, gtserror.NewErrorInternalError(fmt.Errorf("BlockCreate: error checking existence of block: %s", err))  	} else if blocked {  		return p.RelationshipGet(ctx, requestingAccount, targetAccountID) @@ -64,18 +64,18 @@ func (p *Processor) BlockCreate(ctx context.Context, requestingAccount *gtsmodel  	block.URI = uris.GenerateURIForBlock(requestingAccount.Username, newBlockID)  	// whack it in the database -	if err := p.db.PutBlock(ctx, block); err != nil { +	if err := p.state.DB.PutBlock(ctx, block); err != nil {  		return nil, gtserror.NewErrorInternalError(fmt.Errorf("BlockCreate: error creating block in db: %s", err))  	}  	// clear any follows or follow requests from the blocked account to the target account -- this is a simple delete -	if err := p.db.DeleteWhere(ctx, []db.Where{ +	if err := p.state.DB.DeleteWhere(ctx, []db.Where{  		{Key: "account_id", Value: targetAccountID},  		{Key: "target_account_id", Value: requestingAccount.ID},  	}, >smodel.Follow{}); err != nil {  		return nil, gtserror.NewErrorInternalError(fmt.Errorf("BlockCreate: error removing follow in db: %s", err))  	} -	if err := p.db.DeleteWhere(ctx, []db.Where{ +	if err := p.state.DB.DeleteWhere(ctx, []db.Where{  		{Key: "account_id", Value: targetAccountID},  		{Key: "target_account_id", Value: requestingAccount.ID},  	}, >smodel.FollowRequest{}); err != nil { @@ -89,12 +89,12 @@ func (p *Processor) BlockCreate(ctx context.Context, requestingAccount *gtsmodel  	var frChanged bool  	var frURI string  	fr := >smodel.FollowRequest{} -	if err := p.db.GetWhere(ctx, []db.Where{ +	if err := p.state.DB.GetWhere(ctx, []db.Where{  		{Key: "account_id", Value: requestingAccount.ID},  		{Key: "target_account_id", Value: targetAccountID},  	}, fr); err == nil {  		frURI = fr.URI -		if err := p.db.DeleteByID(ctx, fr.ID, fr); err != nil { +		if err := p.state.DB.DeleteByID(ctx, fr.ID, fr); err != nil {  			return nil, gtserror.NewErrorInternalError(fmt.Errorf("BlockCreate: error removing follow request from db: %s", err))  		}  		frChanged = true @@ -104,12 +104,12 @@ func (p *Processor) BlockCreate(ctx context.Context, requestingAccount *gtsmodel  	var fChanged bool  	var fURI string  	f := >smodel.Follow{} -	if err := p.db.GetWhere(ctx, []db.Where{ +	if err := p.state.DB.GetWhere(ctx, []db.Where{  		{Key: "account_id", Value: requestingAccount.ID},  		{Key: "target_account_id", Value: targetAccountID},  	}, f); err == nil {  		fURI = f.URI -		if err := p.db.DeleteByID(ctx, f.ID, f); err != nil { +		if err := p.state.DB.DeleteByID(ctx, f.ID, f); err != nil {  			return nil, gtserror.NewErrorInternalError(fmt.Errorf("BlockCreate: error removing follow from db: %s", err))  		}  		fChanged = true @@ -117,7 +117,7 @@ func (p *Processor) BlockCreate(ctx context.Context, requestingAccount *gtsmodel  	// follow request status changed so send the UNDO activity to the channel for async processing  	if frChanged { -		p.clientWorker.Queue(messages.FromClientAPI{ +		p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{  			APObjectType:   ap.ActivityFollow,  			APActivityType: ap.ActivityUndo,  			GTSModel: >smodel.Follow{ @@ -132,7 +132,7 @@ func (p *Processor) BlockCreate(ctx context.Context, requestingAccount *gtsmodel  	// follow status changed so send the UNDO activity to the channel for async processing  	if fChanged { -		p.clientWorker.Queue(messages.FromClientAPI{ +		p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{  			APObjectType:   ap.ActivityFollow,  			APActivityType: ap.ActivityUndo,  			GTSModel: >smodel.Follow{ @@ -146,7 +146,7 @@ func (p *Processor) BlockCreate(ctx context.Context, requestingAccount *gtsmodel  	}  	// handle the rest of the block process asynchronously -	p.clientWorker.Queue(messages.FromClientAPI{ +	p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{  		APObjectType:   ap.ActivityBlock,  		APActivityType: ap.ActivityCreate,  		GTSModel:       block, @@ -160,23 +160,23 @@ func (p *Processor) BlockCreate(ctx context.Context, requestingAccount *gtsmodel  // BlockRemove handles the removal of a block from requestingAccount to targetAccountID, either remote or local.  func (p *Processor) BlockRemove(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) {  	// make sure the target account actually exists in our db -	targetAccount, err := p.db.GetAccountByID(ctx, targetAccountID) +	targetAccount, err := p.state.DB.GetAccountByID(ctx, targetAccountID)  	if err != nil {  		return nil, gtserror.NewErrorNotFound(fmt.Errorf("BlockCreate: error getting account %s from the db: %s", targetAccountID, err))  	}  	// check if a block exists, and remove it if it does -	block, err := p.db.GetBlock(ctx, requestingAccount.ID, targetAccountID) +	block, err := p.state.DB.GetBlock(ctx, requestingAccount.ID, targetAccountID)  	if err == nil {  		// we got a block, remove it  		block.Account = requestingAccount  		block.TargetAccount = targetAccount -		if err := p.db.DeleteBlockByID(ctx, block.ID); err != nil { +		if err := p.state.DB.DeleteBlockByID(ctx, block.ID); err != nil {  			return nil, gtserror.NewErrorInternalError(fmt.Errorf("BlockRemove: error removing block from db: %s", err))  		}  		// send the UNDO activity to the client worker for async processing -		p.clientWorker.Queue(messages.FromClientAPI{ +		p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{  			APObjectType:   ap.ActivityBlock,  			APActivityType: ap.ActivityUndo,  			GTSModel:       block, diff --git a/internal/processing/account/bookmarks.go b/internal/processing/account/bookmarks.go index 28688c20d..cf53e63bb 100644 --- a/internal/processing/account/bookmarks.go +++ b/internal/processing/account/bookmarks.go @@ -34,7 +34,7 @@ import (  // BookmarksGet returns a pageable response of statuses that are bookmarked by requestingAccount.  // Paging for this response is done based on bookmark ID rather than status ID.  func (p *Processor) BookmarksGet(ctx context.Context, requestingAccount *gtsmodel.Account, limit int, maxID string, minID string) (*apimodel.PageableResponse, gtserror.WithCode) { -	bookmarks, err := p.db.GetBookmarks(ctx, requestingAccount.ID, limit, maxID, minID) +	bookmarks, err := p.state.DB.GetBookmarks(ctx, requestingAccount.ID, limit, maxID, minID)  	if err != nil {  		return nil, gtserror.NewErrorInternalError(err)  	} @@ -47,7 +47,7 @@ func (p *Processor) BookmarksGet(ctx context.Context, requestingAccount *gtsmode  	)  	for _, bookmark := range bookmarks { -		status, err := p.db.GetStatusByID(ctx, bookmark.StatusID) +		status, err := p.state.DB.GetStatusByID(ctx, bookmark.StatusID)  		if err != nil {  			if errors.Is(err, db.ErrNoEntries) {  				// We just don't have the status for some reason. diff --git a/internal/processing/account/create.go b/internal/processing/account/create.go index 8b82bc681..9c9cfb57f 100644 --- a/internal/processing/account/create.go +++ b/internal/processing/account/create.go @@ -35,7 +35,7 @@ import (  // Create processes the given form for creating a new account, returning an oauth token for that account if successful.  func (p *Processor) Create(ctx context.Context, applicationToken oauth2.TokenInfo, application *gtsmodel.Application, form *apimodel.AccountCreateRequest) (*apimodel.Token, gtserror.WithCode) { -	emailAvailable, err := p.db.IsEmailAvailable(ctx, form.Email) +	emailAvailable, err := p.state.DB.IsEmailAvailable(ctx, form.Email)  	if err != nil {  		return nil, gtserror.NewErrorBadRequest(err)  	} @@ -43,7 +43,7 @@ func (p *Processor) Create(ctx context.Context, applicationToken oauth2.TokenInf  		return nil, gtserror.NewErrorConflict(fmt.Errorf("email address %s is not available", form.Email))  	} -	usernameAvailable, err := p.db.IsUsernameAvailable(ctx, form.Username) +	usernameAvailable, err := p.state.DB.IsUsernameAvailable(ctx, form.Username)  	if err != nil {  		return nil, gtserror.NewErrorBadRequest(err)  	} @@ -61,7 +61,7 @@ func (p *Processor) Create(ctx context.Context, applicationToken oauth2.TokenInf  	}  	log.Trace(ctx, "creating new username and account") -	user, err := p.db.NewSignup(ctx, form.Username, text.SanitizePlaintext(reason), approvalRequired, form.Email, form.Password, form.IP, form.Locale, application.ID, false, "", false) +	user, err := p.state.DB.NewSignup(ctx, form.Username, text.SanitizePlaintext(reason), approvalRequired, form.Email, form.Password, form.IP, form.Locale, application.ID, false, "", false)  	if err != nil {  		return nil, gtserror.NewErrorInternalError(fmt.Errorf("error creating new signup in the database: %s", err))  	} @@ -73,7 +73,7 @@ func (p *Processor) Create(ctx context.Context, applicationToken oauth2.TokenInf  	}  	if user.Account == nil { -		a, err := p.db.GetAccountByID(ctx, user.AccountID) +		a, err := p.state.DB.GetAccountByID(ctx, user.AccountID)  		if err != nil {  			return nil, gtserror.NewErrorInternalError(fmt.Errorf("error getting new account from the database: %s", err))  		} @@ -82,7 +82,7 @@ func (p *Processor) Create(ctx context.Context, applicationToken oauth2.TokenInf  	// there are side effects for creating a new account (sending confirmation emails etc)  	// so pass a message to the processor so that it can do it asynchronously -	p.clientWorker.Queue(messages.FromClientAPI{ +	p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{  		APObjectType:   ap.ObjectProfile,  		APActivityType: ap.ActivityCreate,  		GTSModel:       user.Account, diff --git a/internal/processing/account/delete.go b/internal/processing/account/delete.go index 58a967337..eea4a621e 100644 --- a/internal/processing/account/delete.go +++ b/internal/processing/account/delete.go @@ -54,22 +54,22 @@ func (p *Processor) Delete(ctx context.Context, account *gtsmodel.Account, origi  	if account.Domain == "" {  		// see if we can get a user for this account  		var err error -		if user, err = p.db.GetUserByAccountID(ctx, account.ID); err == nil { +		if user, err = p.state.DB.GetUserByAccountID(ctx, account.ID); err == nil {  			// we got one! select all tokens with the user's ID  			tokens := []*gtsmodel.Token{} -			if err := p.db.GetWhere(ctx, []db.Where{{Key: "user_id", Value: user.ID}}, &tokens); err == nil { +			if err := p.state.DB.GetWhere(ctx, []db.Where{{Key: "user_id", Value: user.ID}}, &tokens); err == nil {  				// we have some tokens to delete  				for _, t := range tokens {  					// delete client(s) associated with this token -					if err := p.db.DeleteByID(ctx, t.ClientID, >smodel.Client{}); err != nil { +					if err := p.state.DB.DeleteByID(ctx, t.ClientID, >smodel.Client{}); err != nil {  						l.Errorf("error deleting oauth client: %s", err)  					}  					// delete application(s) associated with this token -					if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "client_id", Value: t.ClientID}}, >smodel.Application{}); err != nil { +					if err := p.state.DB.DeleteWhere(ctx, []db.Where{{Key: "client_id", Value: t.ClientID}}, >smodel.Application{}); err != nil {  						l.Errorf("error deleting application: %s", err)  					}  					// delete the token itself -					if err := p.db.DeleteByID(ctx, t.ID, t); err != nil { +					if err := p.state.DB.DeleteByID(ctx, t.ID, t); err != nil {  						l.Errorf("error deleting oauth token: %s", err)  					}  				} @@ -80,12 +80,12 @@ func (p *Processor) Delete(ctx context.Context, account *gtsmodel.Account, origi  	// 2. Delete account's blocks  	l.Trace("deleting account blocks")  	// first delete any blocks that this account created -	if err := p.db.DeleteBlocksByOriginAccountID(ctx, account.ID); err != nil { +	if err := p.state.DB.DeleteBlocksByOriginAccountID(ctx, account.ID); err != nil {  		l.Errorf("error deleting blocks created by account: %s", err)  	}  	// now delete any blocks that target this account -	if err := p.db.DeleteBlocksByTargetAccountID(ctx, account.ID); err != nil { +	if err := p.state.DB.DeleteBlocksByTargetAccountID(ctx, account.ID); err != nil {  		l.Errorf("error deleting blocks targeting account: %s", err)  	} @@ -96,12 +96,12 @@ func (p *Processor) Delete(ctx context.Context, account *gtsmodel.Account, origi  	// TODO: federate these if necessary  	l.Trace("deleting account follow requests")  	// first delete any follow requests that this account created -	if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.FollowRequest{}); err != nil { +	if err := p.state.DB.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.FollowRequest{}); err != nil {  		l.Errorf("error deleting follow requests created by account: %s", err)  	}  	// now delete any follow requests that target this account -	if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "target_account_id", Value: account.ID}}, &[]*gtsmodel.FollowRequest{}); err != nil { +	if err := p.state.DB.DeleteWhere(ctx, []db.Where{{Key: "target_account_id", Value: account.ID}}, &[]*gtsmodel.FollowRequest{}); err != nil {  		l.Errorf("error deleting follow requests targeting account: %s", err)  	} @@ -109,12 +109,12 @@ func (p *Processor) Delete(ctx context.Context, account *gtsmodel.Account, origi  	// TODO: federate these if necessary  	l.Trace("deleting account follows")  	// first delete any follows that this account created -	if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.Follow{}); err != nil { +	if err := p.state.DB.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.Follow{}); err != nil {  		l.Errorf("error deleting follows created by account: %s", err)  	}  	// now delete any follows that target this account -	if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "target_account_id", Value: account.ID}}, &[]*gtsmodel.Follow{}); err != nil { +	if err := p.state.DB.DeleteWhere(ctx, []db.Where{{Key: "target_account_id", Value: account.ID}}, &[]*gtsmodel.Follow{}); err != nil {  		l.Errorf("error deleting follows targeting account: %s", err)  	} @@ -129,7 +129,7 @@ func (p *Processor) Delete(ctx context.Context, account *gtsmodel.Account, origi  	for {  		// Fetch next block of account statuses from database -		statuses, err := p.db.GetAccountStatuses(ctx, account.ID, 20, false, false, maxID, "", false, false) +		statuses, err := p.state.DB.GetAccountStatuses(ctx, account.ID, 20, false, false, maxID, "", false, false)  		if err != nil {  			if !errors.Is(err, db.ErrNoEntries) {  				// an actual error has occurred @@ -149,7 +149,7 @@ func (p *Processor) Delete(ctx context.Context, account *gtsmodel.Account, origi  			l.Tracef("queue client API status delete: %s", status.ID)  			// pass the status delete through the client api channel for processing -			p.clientWorker.Queue(messages.FromClientAPI{ +			p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{  				APObjectType:   ap.ObjectNote,  				APActivityType: ap.ActivityDelete,  				GTSModel:       status, @@ -158,7 +158,7 @@ func (p *Processor) Delete(ctx context.Context, account *gtsmodel.Account, origi  			})  			// Look for any boosts of this status in DB -			boosts, err := p.db.GetStatusReblogs(ctx, status) +			boosts, err := p.state.DB.GetStatusReblogs(ctx, status)  			if err != nil && !errors.Is(err, db.ErrNoEntries) {  				l.Errorf("error fetching status reblogs for %q: %v", status.ID, err)  				continue @@ -167,7 +167,7 @@ func (p *Processor) Delete(ctx context.Context, account *gtsmodel.Account, origi  			for _, boost := range boosts {  				if boost.Account == nil {  					// Fetch the relevant account for this status boost -					boostAcc, err := p.db.GetAccountByID(ctx, boost.AccountID) +					boostAcc, err := p.state.DB.GetAccountByID(ctx, boost.AccountID)  					if err != nil {  						l.Errorf("error fetching boosted status account for %q: %v", boost.AccountID, err)  						continue @@ -180,7 +180,7 @@ func (p *Processor) Delete(ctx context.Context, account *gtsmodel.Account, origi  				l.Tracef("queue client API boost delete: %s", status.ID)  				// pass the boost delete through the client api channel for processing -				p.clientWorker.Queue(messages.FromClientAPI{ +				p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{  					APObjectType:   ap.ActivityAnnounce,  					APActivityType: ap.ActivityUndo,  					GTSModel:       status, @@ -197,31 +197,31 @@ func (p *Processor) Delete(ctx context.Context, account *gtsmodel.Account, origi  	// 10. Delete account's notifications  	l.Trace("deleting account notifications")  	// first notifications created by account -	if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "origin_account_id", Value: account.ID}}, &[]*gtsmodel.Notification{}); err != nil { +	if err := p.state.DB.DeleteWhere(ctx, []db.Where{{Key: "origin_account_id", Value: account.ID}}, &[]*gtsmodel.Notification{}); err != nil {  		l.Errorf("error deleting notifications created by account: %s", err)  	}  	// now notifications targeting account -	if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "target_account_id", Value: account.ID}}, &[]*gtsmodel.Notification{}); err != nil { +	if err := p.state.DB.DeleteWhere(ctx, []db.Where{{Key: "target_account_id", Value: account.ID}}, &[]*gtsmodel.Notification{}); err != nil {  		l.Errorf("error deleting notifications targeting account: %s", err)  	}  	// 11. Delete account's bookmarks  	l.Trace("deleting account bookmarks") -	if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.StatusBookmark{}); err != nil { +	if err := p.state.DB.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.StatusBookmark{}); err != nil {  		l.Errorf("error deleting bookmarks created by account: %s", err)  	}  	// 12. Delete account's faves  	// TODO: federate these if necessary  	l.Trace("deleting account faves") -	if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.StatusFave{}); err != nil { +	if err := p.state.DB.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.StatusFave{}); err != nil {  		l.Errorf("error deleting faves created by account: %s", err)  	}  	// 13. Delete account's mutes  	l.Trace("deleting account mutes") -	if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.StatusMute{}); err != nil { +	if err := p.state.DB.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.StatusMute{}); err != nil {  		l.Errorf("error deleting status mutes created by account: %s", err)  	} @@ -234,7 +234,7 @@ func (p *Processor) Delete(ctx context.Context, account *gtsmodel.Account, origi  	// 16. Delete account's user  	if user != nil {  		l.Trace("deleting account user") -		if err := p.db.DeleteUserByID(ctx, user.ID); err != nil { +		if err := p.state.DB.DeleteUserByID(ctx, user.ID); err != nil {  			return gtserror.NewErrorInternalError(err)  		}  	} @@ -261,7 +261,7 @@ func (p *Processor) Delete(ctx context.Context, account *gtsmodel.Account, origi  	account.Discoverable = &discoverable  	account.SuspendedAt = time.Now()  	account.SuspensionOrigin = origin -	err := p.db.UpdateAccount(ctx, account) +	err := p.state.DB.UpdateAccount(ctx, account)  	if err != nil {  		return gtserror.NewErrorInternalError(err)  	} @@ -281,7 +281,7 @@ func (p *Processor) DeleteLocal(ctx context.Context, account *gtsmodel.Account,  	if form.DeleteOriginID == account.ID {  		// the account owner themself has requested deletion via the API, get their user from the db -		user, err := p.db.GetUserByAccountID(ctx, account.ID) +		user, err := p.state.DB.GetUserByAccountID(ctx, account.ID)  		if err != nil {  			return gtserror.NewErrorInternalError(err)  		} @@ -301,7 +301,7 @@ func (p *Processor) DeleteLocal(ctx context.Context, account *gtsmodel.Account,  	} else {  		// the delete has been requested by some other account, grab it;  		// if we've reached this point we know it has permission already -		requestingAccount, err := p.db.GetAccountByID(ctx, form.DeleteOriginID) +		requestingAccount, err := p.state.DB.GetAccountByID(ctx, form.DeleteOriginID)  		if err != nil {  			return gtserror.NewErrorInternalError(err)  		} @@ -310,7 +310,7 @@ func (p *Processor) DeleteLocal(ctx context.Context, account *gtsmodel.Account,  	}  	// put the delete in the processor queue to handle the rest of it asynchronously -	p.clientWorker.Queue(fromClientAPIMessage) +	p.state.Workers.EnqueueClientAPI(ctx, fromClientAPIMessage)  	return nil  } diff --git a/internal/processing/account/follow.go b/internal/processing/account/follow.go index d4d479be7..ac65c39f2 100644 --- a/internal/processing/account/follow.go +++ b/internal/processing/account/follow.go @@ -35,14 +35,14 @@ import (  // FollowCreate handles a follow request to an account, either remote or local.  func (p *Processor) FollowCreate(ctx context.Context, requestingAccount *gtsmodel.Account, form *apimodel.AccountFollowRequest) (*apimodel.Relationship, gtserror.WithCode) {  	// if there's a block between the accounts we shouldn't create the request ofc -	if blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, form.ID, true); err != nil { +	if blocked, err := p.state.DB.IsBlocked(ctx, requestingAccount.ID, form.ID, true); err != nil {  		return nil, gtserror.NewErrorInternalError(err)  	} else if blocked {  		return nil, gtserror.NewErrorNotFound(fmt.Errorf("block exists between accounts"))  	}  	// make sure the target account actually exists in our db -	targetAcct, err := p.db.GetAccountByID(ctx, form.ID) +	targetAcct, err := p.state.DB.GetAccountByID(ctx, form.ID)  	if err != nil {  		if err == db.ErrNoEntries {  			return nil, gtserror.NewErrorNotFound(fmt.Errorf("accountfollowcreate: account %s not found in the db: %s", form.ID, err)) @@ -51,7 +51,7 @@ func (p *Processor) FollowCreate(ctx context.Context, requestingAccount *gtsmode  	}  	// check if a follow exists already -	if follows, err := p.db.IsFollowing(ctx, requestingAccount, targetAcct); err != nil { +	if follows, err := p.state.DB.IsFollowing(ctx, requestingAccount, targetAcct); err != nil {  		return nil, gtserror.NewErrorInternalError(fmt.Errorf("accountfollowcreate: error checking follow in db: %s", err))  	} else if follows {  		// already follows so just return the relationship @@ -59,7 +59,7 @@ func (p *Processor) FollowCreate(ctx context.Context, requestingAccount *gtsmode  	}  	// check if a follow request exists already -	if followRequested, err := p.db.IsFollowRequested(ctx, requestingAccount, targetAcct); err != nil { +	if followRequested, err := p.state.DB.IsFollowRequested(ctx, requestingAccount, targetAcct); err != nil {  		return nil, gtserror.NewErrorInternalError(fmt.Errorf("accountfollowcreate: error checking follow request in db: %s", err))  	} else if followRequested {  		// already follow requested so just return the relationship @@ -95,13 +95,13 @@ func (p *Processor) FollowCreate(ctx context.Context, requestingAccount *gtsmode  	}  	// whack it in the database -	if err := p.db.Put(ctx, fr); err != nil { +	if err := p.state.DB.Put(ctx, fr); err != nil {  		return nil, gtserror.NewErrorInternalError(fmt.Errorf("accountfollowcreate: error creating follow request in db: %s", err))  	}  	// if it's a local account that's not locked we can just straight up accept the follow request  	if !*targetAcct.Locked && targetAcct.Domain == "" { -		if _, err := p.db.AcceptFollowRequest(ctx, requestingAccount.ID, form.ID); err != nil { +		if _, err := p.state.DB.AcceptFollowRequest(ctx, requestingAccount.ID, form.ID); err != nil {  			return nil, gtserror.NewErrorInternalError(fmt.Errorf("accountfollowcreate: error accepting folow request for local unlocked account: %s", err))  		}  		// return the new relationship @@ -109,7 +109,7 @@ func (p *Processor) FollowCreate(ctx context.Context, requestingAccount *gtsmode  	}  	// otherwise we leave the follow request as it is and we handle the rest of the process asynchronously -	p.clientWorker.Queue(messages.FromClientAPI{ +	p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{  		APObjectType:   ap.ActivityFollow,  		APActivityType: ap.ActivityCreate,  		GTSModel:       fr, @@ -124,7 +124,7 @@ func (p *Processor) FollowCreate(ctx context.Context, requestingAccount *gtsmode  // FollowRemove handles the removal of a follow/follow request to an account, either remote or local.  func (p *Processor) FollowRemove(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) {  	// if there's a block between the accounts we shouldn't do anything -	blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, targetAccountID, true) +	blocked, err := p.state.DB.IsBlocked(ctx, requestingAccount.ID, targetAccountID, true)  	if err != nil {  		return nil, gtserror.NewErrorInternalError(err)  	} @@ -133,7 +133,7 @@ func (p *Processor) FollowRemove(ctx context.Context, requestingAccount *gtsmode  	}  	// make sure the target account actually exists in our db -	targetAcct, err := p.db.GetAccountByID(ctx, targetAccountID) +	targetAcct, err := p.state.DB.GetAccountByID(ctx, targetAccountID)  	if err != nil {  		if err == db.ErrNoEntries {  			return nil, gtserror.NewErrorNotFound(fmt.Errorf("AccountFollowRemove: account %s not found in the db: %s", targetAccountID, err)) @@ -144,12 +144,12 @@ func (p *Processor) FollowRemove(ctx context.Context, requestingAccount *gtsmode  	var frChanged bool  	var frURI string  	fr := >smodel.FollowRequest{} -	if err := p.db.GetWhere(ctx, []db.Where{ +	if err := p.state.DB.GetWhere(ctx, []db.Where{  		{Key: "account_id", Value: requestingAccount.ID},  		{Key: "target_account_id", Value: targetAccountID},  	}, fr); err == nil {  		frURI = fr.URI -		if err := p.db.DeleteByID(ctx, fr.ID, fr); err != nil { +		if err := p.state.DB.DeleteByID(ctx, fr.ID, fr); err != nil {  			return nil, gtserror.NewErrorInternalError(fmt.Errorf("AccountFollowRemove: error removing follow request from db: %s", err))  		}  		frChanged = true @@ -159,12 +159,12 @@ func (p *Processor) FollowRemove(ctx context.Context, requestingAccount *gtsmode  	var fChanged bool  	var fURI string  	f := >smodel.Follow{} -	if err := p.db.GetWhere(ctx, []db.Where{ +	if err := p.state.DB.GetWhere(ctx, []db.Where{  		{Key: "account_id", Value: requestingAccount.ID},  		{Key: "target_account_id", Value: targetAccountID},  	}, f); err == nil {  		fURI = f.URI -		if err := p.db.DeleteByID(ctx, f.ID, f); err != nil { +		if err := p.state.DB.DeleteByID(ctx, f.ID, f); err != nil {  			return nil, gtserror.NewErrorInternalError(fmt.Errorf("AccountFollowRemove: error removing follow from db: %s", err))  		}  		fChanged = true @@ -172,7 +172,7 @@ func (p *Processor) FollowRemove(ctx context.Context, requestingAccount *gtsmode  	// follow request status changed so send the UNDO activity to the channel for async processing  	if frChanged { -		p.clientWorker.Queue(messages.FromClientAPI{ +		p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{  			APObjectType:   ap.ActivityFollow,  			APActivityType: ap.ActivityUndo,  			GTSModel: >smodel.Follow{ @@ -187,7 +187,7 @@ func (p *Processor) FollowRemove(ctx context.Context, requestingAccount *gtsmode  	// follow status changed so send the UNDO activity to the channel for async processing  	if fChanged { -		p.clientWorker.Queue(messages.FromClientAPI{ +		p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{  			APObjectType:   ap.ActivityFollow,  			APActivityType: ap.ActivityUndo,  			GTSModel: >smodel.Follow{ diff --git a/internal/processing/account/get.go b/internal/processing/account/get.go index 11de1ddac..2c650254f 100644 --- a/internal/processing/account/get.go +++ b/internal/processing/account/get.go @@ -33,7 +33,7 @@ import (  // Get processes the given request for account information.  func (p *Processor) Get(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Account, gtserror.WithCode) { -	targetAccount, err := p.db.GetAccountByID(ctx, targetAccountID) +	targetAccount, err := p.state.DB.GetAccountByID(ctx, targetAccountID)  	if err != nil {  		if err == db.ErrNoEntries {  			return nil, gtserror.NewErrorNotFound(errors.New("account not found")) @@ -46,7 +46,7 @@ func (p *Processor) Get(ctx context.Context, requestingAccount *gtsmodel.Account  // GetLocalByUsername processes the given request for account information targeting a local account by username.  func (p *Processor) GetLocalByUsername(ctx context.Context, requestingAccount *gtsmodel.Account, username string) (*apimodel.Account, gtserror.WithCode) { -	targetAccount, err := p.db.GetAccountByUsernameDomain(ctx, username, "") +	targetAccount, err := p.state.DB.GetAccountByUsernameDomain(ctx, username, "")  	if err != nil {  		if err == db.ErrNoEntries {  			return nil, gtserror.NewErrorNotFound(errors.New("account not found")) @@ -59,7 +59,7 @@ func (p *Processor) GetLocalByUsername(ctx context.Context, requestingAccount *g  // GetCustomCSSForUsername returns custom css for the given local username.  func (p *Processor) GetCustomCSSForUsername(ctx context.Context, username string) (string, gtserror.WithCode) { -	customCSS, err := p.db.GetAccountCustomCSSByUsername(ctx, username) +	customCSS, err := p.state.DB.GetAccountCustomCSSByUsername(ctx, username)  	if err != nil {  		if err == db.ErrNoEntries {  			return "", gtserror.NewErrorNotFound(errors.New("account not found")) @@ -74,7 +74,7 @@ func (p *Processor) getFor(ctx context.Context, requestingAccount *gtsmodel.Acco  	var blocked bool  	var err error  	if requestingAccount != nil { -		blocked, err = p.db.IsBlocked(ctx, requestingAccount.ID, targetAccount.ID, true) +		blocked, err = p.state.DB.IsBlocked(ctx, requestingAccount.ID, targetAccount.ID, true)  		if err != nil {  			return nil, gtserror.NewErrorInternalError(fmt.Errorf("error checking account block: %s", err))  		} diff --git a/internal/processing/account/relationships.go b/internal/processing/account/relationships.go index cb2789829..f60216f95 100644 --- a/internal/processing/account/relationships.go +++ b/internal/processing/account/relationships.go @@ -31,14 +31,14 @@ import (  // FollowersGet fetches a list of the target account's followers.  func (p *Processor) FollowersGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) { -	if blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, targetAccountID, true); err != nil { +	if blocked, err := p.state.DB.IsBlocked(ctx, requestingAccount.ID, targetAccountID, true); err != nil {  		return nil, gtserror.NewErrorInternalError(err)  	} else if blocked {  		return nil, gtserror.NewErrorNotFound(fmt.Errorf("block exists between accounts"))  	}  	accounts := []apimodel.Account{} -	follows, err := p.db.GetAccountFollowedBy(ctx, targetAccountID, false) +	follows, err := p.state.DB.GetAccountFollowedBy(ctx, targetAccountID, false)  	if err != nil {  		if err == db.ErrNoEntries {  			return accounts, nil @@ -47,7 +47,7 @@ func (p *Processor) FollowersGet(ctx context.Context, requestingAccount *gtsmode  	}  	for _, f := range follows { -		blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, f.AccountID, true) +		blocked, err := p.state.DB.IsBlocked(ctx, requestingAccount.ID, f.AccountID, true)  		if err != nil {  			return nil, gtserror.NewErrorInternalError(err)  		} @@ -56,7 +56,7 @@ func (p *Processor) FollowersGet(ctx context.Context, requestingAccount *gtsmode  		}  		if f.Account == nil { -			a, err := p.db.GetAccountByID(ctx, f.AccountID) +			a, err := p.state.DB.GetAccountByID(ctx, f.AccountID)  			if err != nil {  				if err == db.ErrNoEntries {  					continue @@ -77,14 +77,14 @@ func (p *Processor) FollowersGet(ctx context.Context, requestingAccount *gtsmode  // FollowingGet fetches a list of the accounts that target account is following.  func (p *Processor) FollowingGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) { -	if blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, targetAccountID, true); err != nil { +	if blocked, err := p.state.DB.IsBlocked(ctx, requestingAccount.ID, targetAccountID, true); err != nil {  		return nil, gtserror.NewErrorInternalError(err)  	} else if blocked {  		return nil, gtserror.NewErrorNotFound(fmt.Errorf("block exists between accounts"))  	}  	accounts := []apimodel.Account{} -	follows, err := p.db.GetAccountFollows(ctx, targetAccountID) +	follows, err := p.state.DB.GetAccountFollows(ctx, targetAccountID)  	if err != nil {  		if err == db.ErrNoEntries {  			return accounts, nil @@ -93,7 +93,7 @@ func (p *Processor) FollowingGet(ctx context.Context, requestingAccount *gtsmode  	}  	for _, f := range follows { -		blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, f.AccountID, true) +		blocked, err := p.state.DB.IsBlocked(ctx, requestingAccount.ID, f.AccountID, true)  		if err != nil {  			return nil, gtserror.NewErrorInternalError(err)  		} @@ -102,7 +102,7 @@ func (p *Processor) FollowingGet(ctx context.Context, requestingAccount *gtsmode  		}  		if f.TargetAccount == nil { -			a, err := p.db.GetAccountByID(ctx, f.TargetAccountID) +			a, err := p.state.DB.GetAccountByID(ctx, f.TargetAccountID)  			if err != nil {  				if err == db.ErrNoEntries {  					continue @@ -127,7 +127,7 @@ func (p *Processor) RelationshipGet(ctx context.Context, requestingAccount *gtsm  		return nil, gtserror.NewErrorForbidden(errors.New("not authed"))  	} -	gtsR, err := p.db.GetRelationship(ctx, requestingAccount.ID, targetAccountID) +	gtsR, err := p.state.DB.GetRelationship(ctx, requestingAccount.ID, targetAccountID)  	if err != nil {  		return nil, gtserror.NewErrorInternalError(fmt.Errorf("error getting relationship: %s", err))  	} diff --git a/internal/processing/account/rss.go b/internal/processing/account/rss.go index 22065cf8e..61fcc1c51 100644 --- a/internal/processing/account/rss.go +++ b/internal/processing/account/rss.go @@ -34,7 +34,7 @@ const rssFeedLength = 20  // GetRSSFeedForUsername returns RSS feed for the given local username.  func (p *Processor) GetRSSFeedForUsername(ctx context.Context, username string) (func() (string, gtserror.WithCode), time.Time, gtserror.WithCode) { -	account, err := p.db.GetAccountByUsernameDomain(ctx, username, "") +	account, err := p.state.DB.GetAccountByUsernameDomain(ctx, username, "")  	if err != nil {  		if err == db.ErrNoEntries {  			return nil, time.Time{}, gtserror.NewErrorNotFound(errors.New("GetRSSFeedForUsername: account not found")) @@ -46,13 +46,13 @@ func (p *Processor) GetRSSFeedForUsername(ctx context.Context, username string)  		return nil, time.Time{}, gtserror.NewErrorNotFound(errors.New("GetRSSFeedForUsername: account RSS feed not enabled"))  	} -	lastModified, err := p.db.GetAccountLastPosted(ctx, account.ID, true) +	lastModified, err := p.state.DB.GetAccountLastPosted(ctx, account.ID, true)  	if err != nil {  		return nil, time.Time{}, gtserror.NewErrorInternalError(fmt.Errorf("GetRSSFeedForUsername: db error: %s", err))  	}  	return func() (string, gtserror.WithCode) { -		statuses, err := p.db.GetAccountWebStatuses(ctx, account.ID, rssFeedLength, "") +		statuses, err := p.state.DB.GetAccountWebStatuses(ctx, account.ID, rssFeedLength, "")  		if err != nil && err != db.ErrNoEntries {  			return "", gtserror.NewErrorInternalError(fmt.Errorf("GetRSSFeedForUsername: db error: %s", err))  		} @@ -65,7 +65,7 @@ func (p *Processor) GetRSSFeedForUsername(ctx context.Context, username string)  		var image *feeds.Image  		if account.AvatarMediaAttachmentID != "" {  			if account.AvatarMediaAttachment == nil { -				avatar, err := p.db.GetAttachmentByID(ctx, account.AvatarMediaAttachmentID) +				avatar, err := p.state.DB.GetAttachmentByID(ctx, account.AvatarMediaAttachmentID)  				if err != nil {  					return "", gtserror.NewErrorInternalError(fmt.Errorf("GetRSSFeedForUsername: db error fetching avatar attachment: %s", err))  				} diff --git a/internal/processing/account/statuses.go b/internal/processing/account/statuses.go index 7ff6de2ff..9961dbdbe 100644 --- a/internal/processing/account/statuses.go +++ b/internal/processing/account/statuses.go @@ -33,7 +33,7 @@ import (  // the account given in authed.  func (p *Processor) StatusesGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string, limit int, excludeReplies bool, excludeReblogs bool, maxID string, minID string, pinned bool, mediaOnly bool, publicOnly bool) (*apimodel.PageableResponse, gtserror.WithCode) {  	if requestingAccount != nil { -		if blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, targetAccountID, true); err != nil { +		if blocked, err := p.state.DB.IsBlocked(ctx, requestingAccount.ID, targetAccountID, true); err != nil {  			return nil, gtserror.NewErrorInternalError(err)  		} else if blocked {  			return nil, gtserror.NewErrorNotFound(fmt.Errorf("block exists between accounts")) @@ -46,10 +46,10 @@ func (p *Processor) StatusesGet(ctx context.Context, requestingAccount *gtsmodel  	)  	if pinned {  		// Get *ONLY* pinned statuses. -		statuses, err = p.db.GetAccountPinnedStatuses(ctx, targetAccountID) +		statuses, err = p.state.DB.GetAccountPinnedStatuses(ctx, targetAccountID)  	} else {  		// Get account statuses which *may* include pinned ones. -		statuses, err = p.db.GetAccountStatuses(ctx, targetAccountID, limit, excludeReplies, excludeReblogs, maxID, minID, mediaOnly, publicOnly) +		statuses, err = p.state.DB.GetAccountStatuses(ctx, targetAccountID, limit, excludeReplies, excludeReblogs, maxID, minID, mediaOnly, publicOnly)  	}  	if err != nil {  		if err == db.ErrNoEntries { @@ -120,7 +120,7 @@ func (p *Processor) StatusesGet(ctx context.Context, requestingAccount *gtsmodel  // WebStatusesGet fetches a number of statuses (in descending order) from the given account. It selects only  // statuses which are suitable for showing on the public web profile of an account.  func (p *Processor) WebStatusesGet(ctx context.Context, targetAccountID string, maxID string) (*apimodel.PageableResponse, gtserror.WithCode) { -	acct, err := p.db.GetAccountByID(ctx, targetAccountID) +	acct, err := p.state.DB.GetAccountByID(ctx, targetAccountID)  	if err != nil {  		if err == db.ErrNoEntries {  			err := fmt.Errorf("account %s not found in the db, not getting web statuses for it", targetAccountID) @@ -134,7 +134,7 @@ func (p *Processor) WebStatusesGet(ctx context.Context, targetAccountID string,  		return nil, gtserror.NewErrorNotFound(err)  	} -	statuses, err := p.db.GetAccountWebStatuses(ctx, targetAccountID, 10, maxID) +	statuses, err := p.state.DB.GetAccountWebStatuses(ctx, targetAccountID, 10, maxID)  	if err != nil {  		if err == db.ErrNoEntries {  			return util.EmptyPageableResponse(), nil diff --git a/internal/processing/account/update.go b/internal/processing/account/update.go index cffbbb0c5..537857cee 100644 --- a/internal/processing/account/update.go +++ b/internal/processing/account/update.go @@ -165,12 +165,12 @@ func (p *Processor) Update(ctx context.Context, account *gtsmodel.Account, form  		account.EnableRSS = form.EnableRSS  	} -	err := p.db.UpdateAccount(ctx, account) +	err := p.state.DB.UpdateAccount(ctx, account)  	if err != nil {  		return nil, gtserror.NewErrorInternalError(fmt.Errorf("could not update account %s: %s", account.ID, err))  	} -	p.clientWorker.Queue(messages.FromClientAPI{ +	p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{  		APObjectType:   ap.ObjectProfile,  		APActivityType: ap.ActivityUpdate,  		GTSModel:       account, diff --git a/internal/processing/admin/account.go b/internal/processing/admin/account.go index d23d1fbfe..ba4c5d4eb 100644 --- a/internal/processing/admin/account.go +++ b/internal/processing/admin/account.go @@ -31,7 +31,7 @@ import (  )  func (p *Processor) AccountAction(ctx context.Context, account *gtsmodel.Account, form *apimodel.AdminAccountActionRequest) gtserror.WithCode { -	targetAccount, err := p.db.GetAccountByID(ctx, form.TargetAccountID) +	targetAccount, err := p.state.DB.GetAccountByID(ctx, form.TargetAccountID)  	if err != nil {  		return gtserror.NewErrorInternalError(err)  	} @@ -47,7 +47,7 @@ func (p *Processor) AccountAction(ctx context.Context, account *gtsmodel.Account  	case string(gtsmodel.AdminActionSuspend):  		adminAction.Type = gtsmodel.AdminActionSuspend  		// pass the account delete through the client api channel for processing -		p.clientWorker.Queue(messages.FromClientAPI{ +		p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{  			APObjectType:   ap.ActorPerson,  			APActivityType: ap.ActivityDelete,  			OriginAccount:  account, @@ -57,7 +57,7 @@ func (p *Processor) AccountAction(ctx context.Context, account *gtsmodel.Account  		return gtserror.NewErrorBadRequest(fmt.Errorf("admin action type %s is not supported for this endpoint", form.Type))  	} -	if err := p.db.Put(ctx, adminAction); err != nil { +	if err := p.state.DB.Put(ctx, adminAction); err != nil {  		return gtserror.NewErrorInternalError(err)  	} diff --git a/internal/processing/admin/admin.go b/internal/processing/admin/admin.go index 54827b8fd..ba09969dc 100644 --- a/internal/processing/admin/admin.go +++ b/internal/processing/admin/admin.go @@ -19,32 +19,25 @@  package admin  import ( -	"github.com/superseriousbusiness/gotosocial/internal/concurrency" -	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/media" -	"github.com/superseriousbusiness/gotosocial/internal/messages" -	"github.com/superseriousbusiness/gotosocial/internal/storage" +	"github.com/superseriousbusiness/gotosocial/internal/state"  	"github.com/superseriousbusiness/gotosocial/internal/transport"  	"github.com/superseriousbusiness/gotosocial/internal/typeutils"  )  type Processor struct { +	state               *state.State  	tc                  typeutils.TypeConverter  	mediaManager        media.Manager  	transportController transport.Controller -	storage             *storage.Driver -	clientWorker        *concurrency.WorkerPool[messages.FromClientAPI] -	db                  db.DB  }  // New returns a new admin processor. -func New(db db.DB, tc typeutils.TypeConverter, mediaManager media.Manager, transportController transport.Controller, storage *storage.Driver, clientWorker *concurrency.WorkerPool[messages.FromClientAPI]) Processor { +func New(state *state.State, tc typeutils.TypeConverter, mediaManager media.Manager, transportController transport.Controller) Processor {  	return Processor{ +		state:               state,  		tc:                  tc,  		mediaManager:        mediaManager,  		transportController: transportController, -		storage:             storage, -		clientWorker:        clientWorker, -		db:                  db,  	}  } diff --git a/internal/processing/admin/domainblock.go b/internal/processing/admin/domainblock.go index 415ac610f..dd22f72e6 100644 --- a/internal/processing/admin/domainblock.go +++ b/internal/processing/admin/domainblock.go @@ -28,7 +28,7 @@ func (p *Processor) DomainBlockCreate(ctx context.Context, account *gtsmodel.Acc  	domain = strings.ToLower(domain)  	// first check if we already have a block -- if err == nil we already had a block so we can skip a whole lot of work -	block, err := p.db.GetDomainBlock(ctx, domain) +	block, err := p.state.DB.GetDomainBlock(ctx, domain)  	if err != nil {  		if !errors.Is(err, db.ErrNoEntries) {  			// something went wrong in the DB @@ -47,7 +47,7 @@ func (p *Processor) DomainBlockCreate(ctx context.Context, account *gtsmodel.Acc  		}  		// Insert the new block into the database -		if err := p.db.CreateDomainBlock(ctx, newBlock); err != nil { +		if err := p.state.DB.CreateDomainBlock(ctx, newBlock); err != nil {  			return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error putting new domain block %s: %s", domain, err))  		} @@ -80,7 +80,7 @@ func (p *Processor) initiateDomainBlockSideEffects(ctx context.Context, account  	// if we have an instance entry for this domain, update it with the new block ID and clear all fields  	instance := >smodel.Instance{} -	if err := p.db.GetWhere(ctx, []db.Where{{Key: "domain", Value: block.Domain}}, instance); err == nil { +	if err := p.state.DB.GetWhere(ctx, []db.Where{{Key: "domain", Value: block.Domain}}, instance); err == nil {  		updatingColumns := []string{  			"title",  			"updated_at", @@ -105,15 +105,15 @@ func (p *Processor) initiateDomainBlockSideEffects(ctx context.Context, account  		instance.ContactAccountUsername = ""  		instance.ContactAccountID = ""  		instance.Version = "" -		if err := p.db.UpdateByID(ctx, instance, instance.ID, updatingColumns...); err != nil { +		if err := p.state.DB.UpdateByID(ctx, instance, instance.ID, updatingColumns...); err != nil {  			l.Errorf("domainBlockProcessSideEffects: db error updating instance: %s", err)  		}  		l.Debug("domainBlockProcessSideEffects: instance entry updated")  	}  	// if we have an instance account for this instance, delete it -	if instanceAccount, err := p.db.GetAccountByUsernameDomain(ctx, block.Domain, block.Domain); err == nil { -		if err := p.db.DeleteAccount(ctx, instanceAccount.ID); err != nil { +	if instanceAccount, err := p.state.DB.GetAccountByUsernameDomain(ctx, block.Domain, block.Domain); err == nil { +		if err := p.state.DB.DeleteAccount(ctx, instanceAccount.ID); err != nil {  			l.Errorf("domainBlockProcessSideEffects: db error deleting instance account: %s", err)  		}  	} @@ -125,7 +125,7 @@ func (p *Processor) initiateDomainBlockSideEffects(ctx context.Context, account  selectAccountsLoop:  	for { -		accounts, err := p.db.GetInstanceAccounts(ctx, block.Domain, maxID, limit) +		accounts, err := p.state.DB.GetInstanceAccounts(ctx, block.Domain, maxID, limit)  		if err != nil {  			if err == db.ErrNoEntries {  				// no accounts left for this instance so we're done @@ -141,7 +141,7 @@ selectAccountsLoop:  			l.Debugf("putting delete for account %s in the clientAPI channel", a.Username)  			// pass the account delete through the client api channel for processing -			p.clientWorker.Queue(messages.FromClientAPI{ +			p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{  				APObjectType:   ap.ActorPerson,  				APActivityType: ap.ActivityDelete,  				GTSModel:       block, @@ -195,7 +195,7 @@ func (p *Processor) DomainBlocksImport(ctx context.Context, account *gtsmodel.Ac  func (p *Processor) DomainBlocksGet(ctx context.Context, account *gtsmodel.Account, export bool) ([]*apimodel.DomainBlock, gtserror.WithCode) {  	domainBlocks := []*gtsmodel.DomainBlock{} -	if err := p.db.GetAll(ctx, &domainBlocks); err != nil { +	if err := p.state.DB.GetAll(ctx, &domainBlocks); err != nil {  		if !errors.Is(err, db.ErrNoEntries) {  			// something has gone really wrong  			return nil, gtserror.NewErrorInternalError(err) @@ -219,7 +219,7 @@ func (p *Processor) DomainBlocksGet(ctx context.Context, account *gtsmodel.Accou  func (p *Processor) DomainBlockGet(ctx context.Context, account *gtsmodel.Account, id string, export bool) (*apimodel.DomainBlock, gtserror.WithCode) {  	domainBlock := >smodel.DomainBlock{} -	if err := p.db.GetByID(ctx, id, domainBlock); err != nil { +	if err := p.state.DB.GetByID(ctx, id, domainBlock); err != nil {  		if !errors.Is(err, db.ErrNoEntries) {  			// something has gone really wrong  			return nil, gtserror.NewErrorInternalError(err) @@ -240,7 +240,7 @@ func (p *Processor) DomainBlockGet(ctx context.Context, account *gtsmodel.Accoun  func (p *Processor) DomainBlockDelete(ctx context.Context, account *gtsmodel.Account, id string) (*apimodel.DomainBlock, gtserror.WithCode) {  	domainBlock := >smodel.DomainBlock{} -	if err := p.db.GetByID(ctx, id, domainBlock); err != nil { +	if err := p.state.DB.GetByID(ctx, id, domainBlock); err != nil {  		if !errors.Is(err, db.ErrNoEntries) {  			// something has gone really wrong  			return nil, gtserror.NewErrorInternalError(err) @@ -256,13 +256,13 @@ func (p *Processor) DomainBlockDelete(ctx context.Context, account *gtsmodel.Acc  	}  	// Delete the domain block -	if err := p.db.DeleteDomainBlock(ctx, domainBlock.Domain); err != nil { +	if err := p.state.DB.DeleteDomainBlock(ctx, domainBlock.Domain); err != nil {  		return nil, gtserror.NewErrorInternalError(err)  	}  	// remove the domain block reference from the instance, if we have an entry for it  	i := >smodel.Instance{} -	if err := p.db.GetWhere(ctx, []db.Where{ +	if err := p.state.DB.GetWhere(ctx, []db.Where{  		{Key: "domain", Value: domainBlock.Domain},  		{Key: "domain_block_id", Value: id},  	}, i); err == nil { @@ -270,21 +270,21 @@ func (p *Processor) DomainBlockDelete(ctx context.Context, account *gtsmodel.Acc  		i.SuspendedAt = time.Time{}  		i.DomainBlockID = ""  		i.UpdatedAt = time.Now() -		if err := p.db.UpdateByID(ctx, i, i.ID, updatingColumns...); err != nil { +		if err := p.state.DB.UpdateByID(ctx, i, i.ID, updatingColumns...); err != nil {  			return nil, gtserror.NewErrorInternalError(fmt.Errorf("couldn't update database entry for instance %s: %s", domainBlock.Domain, err))  		}  	}  	// unsuspend all accounts whose suspension origin was this domain block  	// 1. remove the 'suspended_at' entry from their accounts -	if err := p.db.UpdateWhere(ctx, []db.Where{ +	if err := p.state.DB.UpdateWhere(ctx, []db.Where{  		{Key: "suspension_origin", Value: domainBlock.ID},  	}, "suspended_at", nil, &[]*gtsmodel.Account{}); err != nil {  		return nil, gtserror.NewErrorInternalError(fmt.Errorf("database error removing suspended_at from accounts: %s", err))  	}  	// 2. remove the 'suspension_origin' entry from their accounts -	if err := p.db.UpdateWhere(ctx, []db.Where{ +	if err := p.state.DB.UpdateWhere(ctx, []db.Where{  		{Key: "suspension_origin", Value: domainBlock.ID},  	}, "suspension_origin", nil, &[]*gtsmodel.Account{}); err != nil {  		return nil, gtserror.NewErrorInternalError(fmt.Errorf("database error removing suspension_origin from accounts: %s", err)) diff --git a/internal/processing/admin/emoji.go b/internal/processing/admin/emoji.go index 391d18525..3eacbf888 100644 --- a/internal/processing/admin/emoji.go +++ b/internal/processing/admin/emoji.go @@ -42,7 +42,7 @@ func (p *Processor) EmojiCreate(ctx context.Context, account *gtsmodel.Account,  		return nil, gtserror.NewErrorUnauthorized(fmt.Errorf("user %s not an admin", user.ID), "user is not an admin")  	} -	maybeExisting, err := p.db.GetEmojiByShortcodeDomain(ctx, form.Shortcode, "") +	maybeExisting, err := p.state.DB.GetEmojiByShortcodeDomain(ctx, form.Shortcode, "")  	if maybeExisting != nil {  		return nil, gtserror.NewErrorConflict(fmt.Errorf("emoji with shortcode %s already exists", form.Shortcode), fmt.Sprintf("emoji with shortcode %s already exists", form.Shortcode))  	} @@ -110,7 +110,7 @@ func (p *Processor) EmojisGet(  		return nil, gtserror.NewErrorUnauthorized(fmt.Errorf("user %s not an admin", user.ID), "user is not an admin")  	} -	emojis, err := p.db.GetEmojis(ctx, domain, includeDisabled, includeEnabled, shortcode, maxShortcodeDomain, minShortcodeDomain, limit) +	emojis, err := p.state.DB.GetEmojis(ctx, domain, includeDisabled, includeEnabled, shortcode, maxShortcodeDomain, minShortcodeDomain, limit)  	if err != nil && !errors.Is(err, db.ErrNoEntries) {  		err := fmt.Errorf("EmojisGet: db error: %s", err)  		return nil, gtserror.NewErrorInternalError(err) @@ -176,7 +176,7 @@ func (p *Processor) EmojiGet(ctx context.Context, account *gtsmodel.Account, use  		return nil, gtserror.NewErrorUnauthorized(fmt.Errorf("user %s not an admin", user.ID), "user is not an admin")  	} -	emoji, err := p.db.GetEmojiByID(ctx, id) +	emoji, err := p.state.DB.GetEmojiByID(ctx, id)  	if err != nil {  		if errors.Is(err, db.ErrNoEntries) {  			err = fmt.Errorf("EmojiGet: no emoji with id %s found in the db", id) @@ -197,7 +197,7 @@ func (p *Processor) EmojiGet(ctx context.Context, account *gtsmodel.Account, use  // EmojiDelete deletes one emoji from the database, with the given id.  func (p *Processor) EmojiDelete(ctx context.Context, id string) (*apimodel.AdminEmoji, gtserror.WithCode) { -	emoji, err := p.db.GetEmojiByID(ctx, id) +	emoji, err := p.state.DB.GetEmojiByID(ctx, id)  	if err != nil {  		if errors.Is(err, db.ErrNoEntries) {  			err = fmt.Errorf("EmojiDelete: no emoji with id %s found in the db", id) @@ -218,7 +218,7 @@ func (p *Processor) EmojiDelete(ctx context.Context, id string) (*apimodel.Admin  		return nil, gtserror.NewErrorInternalError(err)  	} -	if err := p.db.DeleteEmojiByID(ctx, id); err != nil { +	if err := p.state.DB.DeleteEmojiByID(ctx, id); err != nil {  		err := fmt.Errorf("EmojiDelete: db error: %s", err)  		return nil, gtserror.NewErrorInternalError(err)  	} @@ -228,7 +228,7 @@ func (p *Processor) EmojiDelete(ctx context.Context, id string) (*apimodel.Admin  // EmojiUpdate updates one emoji with the given id, using the provided form parameters.  func (p *Processor) EmojiUpdate(ctx context.Context, id string, form *apimodel.EmojiUpdateRequest) (*apimodel.AdminEmoji, gtserror.WithCode) { -	emoji, err := p.db.GetEmojiByID(ctx, id) +	emoji, err := p.state.DB.GetEmojiByID(ctx, id)  	if err != nil {  		if errors.Is(err, db.ErrNoEntries) {  			err = fmt.Errorf("EmojiUpdate: no emoji with id %s found in the db", id) @@ -253,7 +253,7 @@ func (p *Processor) EmojiUpdate(ctx context.Context, id string, form *apimodel.E  // EmojiCategoriesGet returns all custom emoji categories that exist on this instance.  func (p *Processor) EmojiCategoriesGet(ctx context.Context) ([]*apimodel.EmojiCategory, gtserror.WithCode) { -	categories, err := p.db.GetEmojiCategories(ctx) +	categories, err := p.state.DB.GetEmojiCategories(ctx)  	if err != nil {  		err := fmt.Errorf("EmojiCategoriesGet: db error: %s", err)  		return nil, gtserror.NewErrorInternalError(err) @@ -277,7 +277,7 @@ func (p *Processor) EmojiCategoriesGet(ctx context.Context) ([]*apimodel.EmojiCa  */  func (p *Processor) getOrCreateEmojiCategory(ctx context.Context, name string) (*gtsmodel.EmojiCategory, error) { -	category, err := p.db.GetEmojiCategoryByName(ctx, name) +	category, err := p.state.DB.GetEmojiCategoryByName(ctx, name)  	if err == nil {  		return category, nil  	} @@ -299,7 +299,7 @@ func (p *Processor) getOrCreateEmojiCategory(ctx context.Context, name string) (  		Name: name,  	} -	if err := p.db.PutEmojiCategory(ctx, category); err != nil { +	if err := p.state.DB.PutEmojiCategory(ctx, category); err != nil {  		err = fmt.Errorf("GetOrCreateEmojiCategory: error putting new emoji category in the database: %s", err)  		return nil, err  	} @@ -319,7 +319,7 @@ func (p *Processor) emojiUpdateCopy(ctx context.Context, emoji *gtsmodel.Emoji,  		return nil, gtserror.NewErrorBadRequest(err, err.Error())  	} -	maybeExisting, err := p.db.GetEmojiByShortcodeDomain(ctx, *shortcode, "") +	maybeExisting, err := p.state.DB.GetEmojiByShortcodeDomain(ctx, *shortcode, "")  	if maybeExisting != nil {  		err := fmt.Errorf("emojiUpdateCopy: emoji %s could not be copied, emoji with shortcode %s already exists on this instance", emoji.ID, *shortcode)  		return nil, gtserror.NewErrorConflict(err, err.Error()) @@ -339,7 +339,7 @@ func (p *Processor) emojiUpdateCopy(ctx context.Context, emoji *gtsmodel.Emoji,  	newEmojiURI := uris.GenerateURIForEmoji(newEmojiID)  	data := func(ctx context.Context) (reader io.ReadCloser, fileSize int64, err error) { -		rc, err := p.storage.GetStream(ctx, emoji.ImagePath) +		rc, err := p.state.Storage.GetStream(ctx, emoji.ImagePath)  		return rc, int64(emoji.ImageFileSize), err  	} @@ -386,7 +386,7 @@ func (p *Processor) emojiUpdateDisable(ctx context.Context, emoji *gtsmodel.Emoj  	emojiDisabled := true  	emoji.Disabled = &emojiDisabled -	updatedEmoji, err := p.db.UpdateEmoji(ctx, emoji, "updated_at", "disabled") +	updatedEmoji, err := p.state.DB.UpdateEmoji(ctx, emoji, "updated_at", "disabled")  	if err != nil {  		err = fmt.Errorf("emojiUpdateDisable: error updating emoji %s: %s", emoji.ID, err)  		return nil, gtserror.NewErrorInternalError(err) @@ -443,7 +443,7 @@ func (p *Processor) emojiUpdateModify(ctx context.Context, emoji *gtsmodel.Emoji  		}  		var err error -		updatedEmoji, err = p.db.UpdateEmoji(ctx, emoji, columns...) +		updatedEmoji, err = p.state.DB.UpdateEmoji(ctx, emoji, columns...)  		if err != nil {  			err = fmt.Errorf("emojiUpdateModify: error updating emoji %s: %s", emoji.ID, err)  			return nil, gtserror.NewErrorInternalError(err) diff --git a/internal/processing/admin/report.go b/internal/processing/admin/report.go index 3a6028bca..bed97e204 100644 --- a/internal/processing/admin/report.go +++ b/internal/processing/admin/report.go @@ -43,7 +43,7 @@ func (p *Processor) ReportsGet(  	minID string,  	limit int,  ) (*apimodel.PageableResponse, gtserror.WithCode) { -	reports, err := p.db.GetReports(ctx, resolved, accountID, targetAccountID, maxID, sinceID, minID, limit) +	reports, err := p.state.DB.GetReports(ctx, resolved, accountID, targetAccountID, maxID, sinceID, minID, limit)  	if err != nil {  		if err == db.ErrNoEntries {  			return util.EmptyPageableResponse(), nil @@ -95,7 +95,7 @@ func (p *Processor) ReportsGet(  // ReportGet returns one report, with the given ID.  func (p *Processor) ReportGet(ctx context.Context, account *gtsmodel.Account, id string) (*apimodel.AdminReport, gtserror.WithCode) { -	report, err := p.db.GetReportByID(ctx, id) +	report, err := p.state.DB.GetReportByID(ctx, id)  	if err != nil {  		if err == db.ErrNoEntries {  			return nil, gtserror.NewErrorNotFound(err) @@ -113,7 +113,7 @@ func (p *Processor) ReportGet(ctx context.Context, account *gtsmodel.Account, id  // ReportResolve marks a report with the given id as resolved, and stores the provided actionTakenComment (if not null).  func (p *Processor) ReportResolve(ctx context.Context, account *gtsmodel.Account, id string, actionTakenComment *string) (*apimodel.AdminReport, gtserror.WithCode) { -	report, err := p.db.GetReportByID(ctx, id) +	report, err := p.state.DB.GetReportByID(ctx, id)  	if err != nil {  		if err == db.ErrNoEntries {  			return nil, gtserror.NewErrorNotFound(err) @@ -134,7 +134,7 @@ func (p *Processor) ReportResolve(ctx context.Context, account *gtsmodel.Account  		columns = append(columns, "action_taken")  	} -	updatedReport, err := p.db.UpdateReport(ctx, report, columns...) +	updatedReport, err := p.state.DB.UpdateReport(ctx, report, columns...)  	if err != nil {  		return nil, gtserror.NewErrorInternalError(err)  	} diff --git a/internal/processing/app.go b/internal/processing/app.go index f2a938b22..e4cda5a43 100644 --- a/internal/processing/app.go +++ b/internal/processing/app.go @@ -62,7 +62,7 @@ func (p *Processor) AppCreate(ctx context.Context, authed *oauth.Auth, form *api  	}  	// chuck it in the db -	if err := p.db.Put(ctx, app); err != nil { +	if err := p.state.DB.Put(ctx, app); err != nil {  		return nil, gtserror.NewErrorInternalError(err)  	} @@ -76,7 +76,7 @@ func (p *Processor) AppCreate(ctx context.Context, authed *oauth.Auth, form *api  	}  	// chuck it in the db -	if err := p.db.Put(ctx, oc); err != nil { +	if err := p.state.DB.Put(ctx, oc); err != nil {  		return nil, gtserror.NewErrorInternalError(err)  	} diff --git a/internal/processing/blocks.go b/internal/processing/blocks.go index 6dd9c3de9..754954f02 100644 --- a/internal/processing/blocks.go +++ b/internal/processing/blocks.go @@ -31,7 +31,7 @@ import (  )  func (p *Processor) BlocksGet(ctx context.Context, authed *oauth.Auth, maxID string, sinceID string, limit int) (*apimodel.BlocksResponse, gtserror.WithCode) { -	accounts, nextMaxID, prevMinID, err := p.db.GetAccountBlocks(ctx, authed.Account.ID, maxID, sinceID, limit) +	accounts, nextMaxID, prevMinID, err := p.state.DB.GetAccountBlocks(ctx, authed.Account.ID, maxID, sinceID, limit)  	if err != nil {  		if err == db.ErrNoEntries {  			// there are just no entries diff --git a/internal/processing/fedi/collections.go b/internal/processing/fedi/collections.go index 78a65bebe..627511c3b 100644 --- a/internal/processing/fedi/collections.go +++ b/internal/processing/fedi/collections.go @@ -84,8 +84,8 @@ func (p *Processor) OutboxGet(ctx context.Context, requestedUsername string, pag  	// scenario 2 -- get the requested page  	// limit pages to 30 entries per page -	publicStatuses, err := p.db.GetAccountStatuses(ctx, requestedAccount.ID, 30, true, true, maxID, minID, false, true) -	if err != nil && err != db.ErrNoEntries { +	publicStatuses, err := p.state.DB.GetAccountStatuses(ctx, requestedAccount.ID, 30, true, true, maxID, minID, false, true) +	if err != nil && !errors.Is(err, db.ErrNoEntries) {  		return nil, gtserror.NewErrorInternalError(err)  	} @@ -161,7 +161,7 @@ func (p *Processor) FeaturedCollectionGet(ctx context.Context, requestedUsername  		return nil, errWithCode  	} -	statuses, err := p.db.GetAccountPinnedStatuses(ctx, requestedAccount.ID) +	statuses, err := p.state.DB.GetAccountPinnedStatuses(ctx, requestedAccount.ID)  	if err != nil {  		if !errors.Is(err, db.ErrNoEntries) {  			return nil, gtserror.NewErrorInternalError(err) diff --git a/internal/processing/fedi/common.go b/internal/processing/fedi/common.go index 37c604ded..a2c7f9b37 100644 --- a/internal/processing/fedi/common.go +++ b/internal/processing/fedi/common.go @@ -29,7 +29,7 @@ import (  )  func (p *Processor) authenticate(ctx context.Context, requestedUsername string) (requestedAccount, requestingAccount *gtsmodel.Account, errWithCode gtserror.WithCode) { -	requestedAccount, err := p.db.GetAccountByUsernameDomain(ctx, requestedUsername, "") +	requestedAccount, err := p.state.DB.GetAccountByUsernameDomain(ctx, requestedUsername, "")  	if err != nil {  		errWithCode = gtserror.NewErrorNotFound(fmt.Errorf("database error getting account with username %s: %s", requestedUsername, err))  		return @@ -46,7 +46,7 @@ func (p *Processor) authenticate(ctx context.Context, requestedUsername string)  		return  	} -	blocked, err := p.db.IsBlocked(ctx, requestedAccount.ID, requestingAccount.ID, true) +	blocked, err := p.state.DB.IsBlocked(ctx, requestedAccount.ID, requestingAccount.ID, true)  	if err != nil {  		errWithCode = gtserror.NewErrorInternalError(err)  		return diff --git a/internal/processing/fedi/emoji.go b/internal/processing/fedi/emoji.go index 0b1dd3440..b2618ca13 100644 --- a/internal/processing/fedi/emoji.go +++ b/internal/processing/fedi/emoji.go @@ -32,7 +32,7 @@ func (p *Processor) EmojiGet(ctx context.Context, requestedEmojiID string) (inte  		return nil, errWithCode  	} -	requestedEmoji, err := p.db.GetEmojiByID(ctx, requestedEmojiID) +	requestedEmoji, err := p.state.DB.GetEmojiByID(ctx, requestedEmojiID)  	if err != nil {  		return nil, gtserror.NewErrorNotFound(fmt.Errorf("database error getting emoji with id %s: %s", requestedEmojiID, err))  	} diff --git a/internal/processing/fedi/fedi.go b/internal/processing/fedi/fedi.go index e72d037f5..c8f78c5a6 100644 --- a/internal/processing/fedi/fedi.go +++ b/internal/processing/fedi/fedi.go @@ -19,25 +19,25 @@  package fedi  import ( -	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/federation" +	"github.com/superseriousbusiness/gotosocial/internal/state"  	"github.com/superseriousbusiness/gotosocial/internal/typeutils"  	"github.com/superseriousbusiness/gotosocial/internal/visibility"  )  type Processor struct { -	db        db.DB +	state     *state.State  	federator federation.Federator  	tc        typeutils.TypeConverter  	filter    visibility.Filter  }  // New returns a new fedi processor. -func New(db db.DB, tc typeutils.TypeConverter, federator federation.Federator) Processor { +func New(state *state.State, tc typeutils.TypeConverter, federator federation.Federator) Processor {  	return Processor{ -		db:        db, +		state:     state,  		federator: federator,  		tc:        tc, -		filter:    visibility.NewFilter(db), +		filter:    visibility.NewFilter(state.DB),  	}  } diff --git a/internal/processing/fedi/status.go b/internal/processing/fedi/status.go index fbadcb290..60ebb3c84 100644 --- a/internal/processing/fedi/status.go +++ b/internal/processing/fedi/status.go @@ -36,7 +36,7 @@ func (p *Processor) StatusGet(ctx context.Context, requestedUsername string, req  		return nil, errWithCode  	} -	status, err := p.db.GetStatusByID(ctx, requestedStatusID) +	status, err := p.state.DB.GetStatusByID(ctx, requestedStatusID)  	if err != nil {  		return nil, gtserror.NewErrorNotFound(err)  	} @@ -74,7 +74,7 @@ func (p *Processor) StatusRepliesGet(ctx context.Context, requestedUsername stri  		return nil, errWithCode  	} -	status, err := p.db.GetStatusByID(ctx, requestedStatusID) +	status, err := p.state.DB.GetStatusByID(ctx, requestedStatusID)  	if err != nil {  		return nil, gtserror.NewErrorNotFound(err)  	} @@ -125,7 +125,7 @@ func (p *Processor) StatusRepliesGet(ctx context.Context, requestedUsername stri  	default:  		// scenario 3  		// get immediate children -		replies, err := p.db.GetStatusChildren(ctx, status, true, minID) +		replies, err := p.state.DB.GetStatusChildren(ctx, status, true, minID)  		if err != nil {  			return nil, gtserror.NewErrorInternalError(err)  		} diff --git a/internal/processing/fedi/user.go b/internal/processing/fedi/user.go index 899d063d1..35e756e57 100644 --- a/internal/processing/fedi/user.go +++ b/internal/processing/fedi/user.go @@ -34,7 +34,7 @@ import (  // before returning a JSON serializable interface to the caller.  func (p *Processor) UserGet(ctx context.Context, requestedUsername string, requestURL *url.URL) (interface{}, gtserror.WithCode) {  	// Get the instance-local account the request is referring to. -	requestedAccount, err := p.db.GetAccountByUsernameDomain(ctx, requestedUsername, "") +	requestedAccount, err := p.state.DB.GetAccountByUsernameDomain(ctx, requestedUsername, "")  	if err != nil {  		return nil, gtserror.NewErrorNotFound(fmt.Errorf("database error getting account with username %s: %s", requestedUsername, err))  	} @@ -63,7 +63,7 @@ func (p *Processor) UserGet(ctx context.Context, requestedUsername string, reque  				return nil, gtserror.NewErrorUnauthorized(err)  			} -			blocked, err := p.db.IsBlocked(ctx, requestedAccount.ID, requestingAccount.ID, true) +			blocked, err := p.state.DB.IsBlocked(ctx, requestedAccount.ID, requestingAccount.ID, true)  			if err != nil {  				return nil, gtserror.NewErrorInternalError(err)  			} diff --git a/internal/processing/fedi/wellknown.go b/internal/processing/fedi/wellknown.go index 75ed34ec2..6f113ac5d 100644 --- a/internal/processing/fedi/wellknown.go +++ b/internal/processing/fedi/wellknown.go @@ -64,12 +64,12 @@ func (p *Processor) NodeInfoRelGet(ctx context.Context) (*apimodel.WellKnownResp  func (p *Processor) NodeInfoGet(ctx context.Context) (*apimodel.Nodeinfo, gtserror.WithCode) {  	host := config.GetHost() -	userCount, err := p.db.CountInstanceUsers(ctx, host) +	userCount, err := p.state.DB.CountInstanceUsers(ctx, host)  	if err != nil {  		return nil, gtserror.NewErrorInternalError(err)  	} -	postCount, err := p.db.CountInstanceStatuses(ctx, host) +	postCount, err := p.state.DB.CountInstanceStatuses(ctx, host)  	if err != nil {  		return nil, gtserror.NewErrorInternalError(err)  	} @@ -99,7 +99,7 @@ func (p *Processor) NodeInfoGet(ctx context.Context) (*apimodel.Nodeinfo, gtserr  // WebfingerGet handles the GET for a webfinger resource. Most commonly, it will be used for returning account lookups.  func (p *Processor) WebfingerGet(ctx context.Context, requestedUsername string) (*apimodel.WellKnownResponse, gtserror.WithCode) {  	// Get the local account the request is referring to. -	requestedAccount, err := p.db.GetAccountByUsernameDomain(ctx, requestedUsername, "") +	requestedAccount, err := p.state.DB.GetAccountByUsernameDomain(ctx, requestedUsername, "")  	if err != nil {  		return nil, gtserror.NewErrorNotFound(fmt.Errorf("database error getting account with username %s: %s", requestedUsername, err))  	} diff --git a/internal/processing/followrequest.go b/internal/processing/followrequest.go index 1f1b7f3c2..9bd13cc0b 100644 --- a/internal/processing/followrequest.go +++ b/internal/processing/followrequest.go @@ -30,7 +30,7 @@ import (  )  func (p *Processor) FollowRequestsGet(ctx context.Context, auth *oauth.Auth) ([]apimodel.Account, gtserror.WithCode) { -	frs, err := p.db.GetAccountFollowRequests(ctx, auth.Account.ID) +	frs, err := p.state.DB.GetAccountFollowRequests(ctx, auth.Account.ID)  	if err != nil {  		if err != db.ErrNoEntries {  			return nil, gtserror.NewErrorInternalError(err) @@ -40,7 +40,7 @@ func (p *Processor) FollowRequestsGet(ctx context.Context, auth *oauth.Auth) ([]  	accts := []apimodel.Account{}  	for _, fr := range frs {  		if fr.Account == nil { -			frAcct, err := p.db.GetAccountByID(ctx, fr.AccountID) +			frAcct, err := p.state.DB.GetAccountByID(ctx, fr.AccountID)  			if err != nil {  				return nil, gtserror.NewErrorInternalError(err)  			} @@ -57,13 +57,13 @@ func (p *Processor) FollowRequestsGet(ctx context.Context, auth *oauth.Auth) ([]  }  func (p *Processor) FollowRequestAccept(ctx context.Context, auth *oauth.Auth, accountID string) (*apimodel.Relationship, gtserror.WithCode) { -	follow, err := p.db.AcceptFollowRequest(ctx, accountID, auth.Account.ID) +	follow, err := p.state.DB.AcceptFollowRequest(ctx, accountID, auth.Account.ID)  	if err != nil {  		return nil, gtserror.NewErrorNotFound(err)  	}  	if follow.Account == nil { -		followAccount, err := p.db.GetAccountByID(ctx, follow.AccountID) +		followAccount, err := p.state.DB.GetAccountByID(ctx, follow.AccountID)  		if err != nil {  			return nil, gtserror.NewErrorInternalError(err)  		} @@ -71,14 +71,14 @@ func (p *Processor) FollowRequestAccept(ctx context.Context, auth *oauth.Auth, a  	}  	if follow.TargetAccount == nil { -		followTargetAccount, err := p.db.GetAccountByID(ctx, follow.TargetAccountID) +		followTargetAccount, err := p.state.DB.GetAccountByID(ctx, follow.TargetAccountID)  		if err != nil {  			return nil, gtserror.NewErrorInternalError(err)  		}  		follow.TargetAccount = followTargetAccount  	} -	p.clientWorker.Queue(messages.FromClientAPI{ +	p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{  		APObjectType:   ap.ActivityFollow,  		APActivityType: ap.ActivityAccept,  		GTSModel:       follow, @@ -86,7 +86,7 @@ func (p *Processor) FollowRequestAccept(ctx context.Context, auth *oauth.Auth, a  		TargetAccount:  follow.TargetAccount,  	}) -	gtsR, err := p.db.GetRelationship(ctx, auth.Account.ID, accountID) +	gtsR, err := p.state.DB.GetRelationship(ctx, auth.Account.ID, accountID)  	if err != nil {  		return nil, gtserror.NewErrorInternalError(err)  	} @@ -100,13 +100,13 @@ func (p *Processor) FollowRequestAccept(ctx context.Context, auth *oauth.Auth, a  }  func (p *Processor) FollowRequestReject(ctx context.Context, auth *oauth.Auth, accountID string) (*apimodel.Relationship, gtserror.WithCode) { -	followRequest, err := p.db.RejectFollowRequest(ctx, accountID, auth.Account.ID) +	followRequest, err := p.state.DB.RejectFollowRequest(ctx, accountID, auth.Account.ID)  	if err != nil {  		return nil, gtserror.NewErrorNotFound(err)  	}  	if followRequest.Account == nil { -		a, err := p.db.GetAccountByID(ctx, followRequest.AccountID) +		a, err := p.state.DB.GetAccountByID(ctx, followRequest.AccountID)  		if err != nil {  			return nil, gtserror.NewErrorInternalError(err)  		} @@ -114,14 +114,14 @@ func (p *Processor) FollowRequestReject(ctx context.Context, auth *oauth.Auth, a  	}  	if followRequest.TargetAccount == nil { -		a, err := p.db.GetAccountByID(ctx, followRequest.TargetAccountID) +		a, err := p.state.DB.GetAccountByID(ctx, followRequest.TargetAccountID)  		if err != nil {  			return nil, gtserror.NewErrorInternalError(err)  		}  		followRequest.TargetAccount = a  	} -	p.clientWorker.Queue(messages.FromClientAPI{ +	p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{  		APObjectType:   ap.ActivityFollow,  		APActivityType: ap.ActivityReject,  		GTSModel:       followRequest, @@ -129,7 +129,7 @@ func (p *Processor) FollowRequestReject(ctx context.Context, auth *oauth.Auth, a  		TargetAccount:  followRequest.TargetAccount,  	}) -	gtsR, err := p.db.GetRelationship(ctx, auth.Account.ID, accountID) +	gtsR, err := p.state.DB.GetRelationship(ctx, auth.Account.ID, accountID)  	if err != nil {  		return nil, gtserror.NewErrorInternalError(err)  	} diff --git a/internal/processing/fromclientapi.go b/internal/processing/fromclientapi.go index 701f425f6..209a27105 100644 --- a/internal/processing/fromclientapi.go +++ b/internal/processing/fromclientapi.go @@ -143,7 +143,7 @@ func (p *Processor) processCreateAccountFromClientAPI(ctx context.Context, clien  	}  	// get the user this account belongs to -	user, err := p.db.GetUserByAccountID(ctx, account.ID) +	user, err := p.state.DB.GetUserByAccountID(ctx, account.ID)  	if err != nil {  		return err  	} @@ -293,7 +293,7 @@ func (p *Processor) processUndoAnnounceFromClientAPI(ctx context.Context, client  		return errors.New("undo was not parseable as *gtsmodel.Status")  	} -	if err := p.db.DeleteStatusByID(ctx, boost.ID); err != nil { +	if err := p.state.DB.DeleteStatusByID(ctx, boost.ID); err != nil {  		return err  	} @@ -422,7 +422,7 @@ func (p *Processor) federateStatus(ctx context.Context, status *gtsmodel.Status)  	}  	if status.Account == nil { -		statusAccount, err := p.db.GetAccountByID(ctx, status.AccountID) +		statusAccount, err := p.state.DB.GetAccountByID(ctx, status.AccountID)  		if err != nil {  			return fmt.Errorf("federateStatus: error fetching status author account: %s", err)  		} @@ -455,7 +455,7 @@ func (p *Processor) federateStatus(ctx context.Context, status *gtsmodel.Status)  func (p *Processor) federateStatusDelete(ctx context.Context, status *gtsmodel.Status) error {  	if status.Account == nil { -		statusAccount, err := p.db.GetAccountByID(ctx, status.AccountID) +		statusAccount, err := p.state.DB.GetAccountByID(ctx, status.AccountID)  		if err != nil {  			return fmt.Errorf("federateStatusDelete: error fetching status author account: %s", err)  		} @@ -642,7 +642,7 @@ func (p *Processor) federateUnannounce(ctx context.Context, boost *gtsmodel.Stat  func (p *Processor) federateAcceptFollowRequest(ctx context.Context, follow *gtsmodel.Follow) error {  	if follow.Account == nil { -		a, err := p.db.GetAccountByID(ctx, follow.AccountID) +		a, err := p.state.DB.GetAccountByID(ctx, follow.AccountID)  		if err != nil {  			return err  		} @@ -651,7 +651,7 @@ func (p *Processor) federateAcceptFollowRequest(ctx context.Context, follow *gts  	originAccount := follow.Account  	if follow.TargetAccount == nil { -		a, err := p.db.GetAccountByID(ctx, follow.TargetAccountID) +		a, err := p.state.DB.GetAccountByID(ctx, follow.TargetAccountID)  		if err != nil {  			return err  		} @@ -715,7 +715,7 @@ func (p *Processor) federateAcceptFollowRequest(ctx context.Context, follow *gts  func (p *Processor) federateRejectFollowRequest(ctx context.Context, followRequest *gtsmodel.FollowRequest) error {  	if followRequest.Account == nil { -		a, err := p.db.GetAccountByID(ctx, followRequest.AccountID) +		a, err := p.state.DB.GetAccountByID(ctx, followRequest.AccountID)  		if err != nil {  			return err  		} @@ -724,7 +724,7 @@ func (p *Processor) federateRejectFollowRequest(ctx context.Context, followReque  	originAccount := followRequest.Account  	if followRequest.TargetAccount == nil { -		a, err := p.db.GetAccountByID(ctx, followRequest.TargetAccountID) +		a, err := p.state.DB.GetAccountByID(ctx, followRequest.TargetAccountID)  		if err != nil {  			return err  		} @@ -844,7 +844,7 @@ func (p *Processor) federateAccountUpdate(ctx context.Context, updatedAccount *g  func (p *Processor) federateBlock(ctx context.Context, block *gtsmodel.Block) error {  	if block.Account == nil { -		blockAccount, err := p.db.GetAccountByID(ctx, block.AccountID) +		blockAccount, err := p.state.DB.GetAccountByID(ctx, block.AccountID)  		if err != nil {  			return fmt.Errorf("federateBlock: error getting block account from database: %s", err)  		} @@ -852,7 +852,7 @@ func (p *Processor) federateBlock(ctx context.Context, block *gtsmodel.Block) er  	}  	if block.TargetAccount == nil { -		blockTargetAccount, err := p.db.GetAccountByID(ctx, block.TargetAccountID) +		blockTargetAccount, err := p.state.DB.GetAccountByID(ctx, block.TargetAccountID)  		if err != nil {  			return fmt.Errorf("federateBlock: error getting block target account from database: %s", err)  		} @@ -880,7 +880,7 @@ func (p *Processor) federateBlock(ctx context.Context, block *gtsmodel.Block) er  func (p *Processor) federateUnblock(ctx context.Context, block *gtsmodel.Block) error {  	if block.Account == nil { -		blockAccount, err := p.db.GetAccountByID(ctx, block.AccountID) +		blockAccount, err := p.state.DB.GetAccountByID(ctx, block.AccountID)  		if err != nil {  			return fmt.Errorf("federateUnblock: error getting block account from database: %s", err)  		} @@ -888,7 +888,7 @@ func (p *Processor) federateUnblock(ctx context.Context, block *gtsmodel.Block)  	}  	if block.TargetAccount == nil { -		blockTargetAccount, err := p.db.GetAccountByID(ctx, block.TargetAccountID) +		blockTargetAccount, err := p.state.DB.GetAccountByID(ctx, block.TargetAccountID)  		if err != nil {  			return fmt.Errorf("federateUnblock: error getting block target account from database: %s", err)  		} @@ -934,7 +934,7 @@ func (p *Processor) federateUnblock(ctx context.Context, block *gtsmodel.Block)  func (p *Processor) federateReport(ctx context.Context, report *gtsmodel.Report) error {  	if report.TargetAccount == nil { -		reportTargetAccount, err := p.db.GetAccountByID(ctx, report.TargetAccountID) +		reportTargetAccount, err := p.state.DB.GetAccountByID(ctx, report.TargetAccountID)  		if err != nil {  			return fmt.Errorf("federateReport: error getting report target account from database: %w", err)  		} @@ -942,7 +942,7 @@ func (p *Processor) federateReport(ctx context.Context, report *gtsmodel.Report)  	}  	if len(report.StatusIDs) > 0 && len(report.Statuses) == 0 { -		statuses, err := p.db.GetStatuses(ctx, report.StatusIDs) +		statuses, err := p.state.DB.GetStatuses(ctx, report.StatusIDs)  		if err != nil {  			return fmt.Errorf("federateReport: error getting report statuses from database: %w", err)  		} @@ -966,7 +966,7 @@ func (p *Processor) federateReport(ctx context.Context, report *gtsmodel.Report)  	// deliver the flag using the outbox of the  	// instance account to anonymize the report -	instanceAccount, err := p.db.GetInstanceAccount(ctx, "") +	instanceAccount, err := p.state.DB.GetInstanceAccount(ctx, "")  	if err != nil {  		return fmt.Errorf("federateReport: error getting instance account: %w", err)  	} diff --git a/internal/processing/fromcommon.go b/internal/processing/fromcommon.go index 3e4c62c6c..f9e732732 100644 --- a/internal/processing/fromcommon.go +++ b/internal/processing/fromcommon.go @@ -38,7 +38,7 @@ func (p *Processor) notifyStatus(ctx context.Context, status *gtsmodel.Status) e  	if status.Mentions == nil {  		// there are mentions but they're not fully populated on the status yet so do this -		menchies, err := p.db.GetMentions(ctx, status.MentionIDs) +		menchies, err := p.state.DB.GetMentions(ctx, status.MentionIDs)  		if err != nil {  			return fmt.Errorf("notifyStatus: error getting mentions for status %s from the db: %s", status.ID, err)  		} @@ -49,7 +49,7 @@ func (p *Processor) notifyStatus(ctx context.Context, status *gtsmodel.Status) e  	for _, m := range status.Mentions {  		// make sure this is a local account, otherwise we don't need to create a notification for it  		if m.TargetAccount == nil { -			a, err := p.db.GetAccountByID(ctx, m.TargetAccountID) +			a, err := p.state.DB.GetAccountByID(ctx, m.TargetAccountID)  			if err != nil {  				// we don't have the account or there's been an error  				return fmt.Errorf("notifyStatus: error getting account with id %s from the db: %s", m.TargetAccountID, err) @@ -62,7 +62,7 @@ func (p *Processor) notifyStatus(ctx context.Context, status *gtsmodel.Status) e  		}  		// make sure a notif doesn't already exist for this mention -		if err := p.db.GetWhere(ctx, []db.Where{ +		if err := p.state.DB.GetWhere(ctx, []db.Where{  			{Key: "notification_type", Value: gtsmodel.NotificationMention},  			{Key: "target_account_id", Value: m.TargetAccountID},  			{Key: "origin_account_id", Value: m.OriginAccountID}, @@ -87,7 +87,7 @@ func (p *Processor) notifyStatus(ctx context.Context, status *gtsmodel.Status) e  			Status:           status,  		} -		if err := p.db.Put(ctx, notif); err != nil { +		if err := p.state.DB.Put(ctx, notif); err != nil {  			return fmt.Errorf("notifyStatus: error putting notification in database: %s", err)  		} @@ -108,7 +108,7 @@ func (p *Processor) notifyStatus(ctx context.Context, status *gtsmodel.Status) e  func (p *Processor) notifyFollowRequest(ctx context.Context, followRequest *gtsmodel.FollowRequest) error {  	// make sure we have the target account pinned on the follow request  	if followRequest.TargetAccount == nil { -		a, err := p.db.GetAccountByID(ctx, followRequest.TargetAccountID) +		a, err := p.state.DB.GetAccountByID(ctx, followRequest.TargetAccountID)  		if err != nil {  			return err  		} @@ -129,7 +129,7 @@ func (p *Processor) notifyFollowRequest(ctx context.Context, followRequest *gtsm  		OriginAccountID:  followRequest.AccountID,  	} -	if err := p.db.Put(ctx, notif); err != nil { +	if err := p.state.DB.Put(ctx, notif); err != nil {  		return fmt.Errorf("notifyFollowRequest: error putting notification in database: %s", err)  	} @@ -153,7 +153,7 @@ func (p *Processor) notifyFollow(ctx context.Context, follow *gtsmodel.Follow, t  	}  	// first remove the follow request notification -	if err := p.db.DeleteWhere(ctx, []db.Where{ +	if err := p.state.DB.DeleteWhere(ctx, []db.Where{  		{Key: "notification_type", Value: gtsmodel.NotificationFollowRequest},  		{Key: "target_account_id", Value: follow.TargetAccountID},  		{Key: "origin_account_id", Value: follow.AccountID}, @@ -170,7 +170,7 @@ func (p *Processor) notifyFollow(ctx context.Context, follow *gtsmodel.Follow, t  		OriginAccountID:  follow.AccountID,  		OriginAccount:    follow.Account,  	} -	if err := p.db.Put(ctx, notif); err != nil { +	if err := p.state.DB.Put(ctx, notif); err != nil {  		return fmt.Errorf("notifyFollow: error putting notification in database: %s", err)  	} @@ -194,7 +194,7 @@ func (p *Processor) notifyFave(ctx context.Context, fave *gtsmodel.StatusFave) e  	}  	if fave.TargetAccount == nil { -		a, err := p.db.GetAccountByID(ctx, fave.TargetAccountID) +		a, err := p.state.DB.GetAccountByID(ctx, fave.TargetAccountID)  		if err != nil {  			return err  		} @@ -218,7 +218,7 @@ func (p *Processor) notifyFave(ctx context.Context, fave *gtsmodel.StatusFave) e  		Status:           fave.Status,  	} -	if err := p.db.Put(ctx, notif); err != nil { +	if err := p.state.DB.Put(ctx, notif); err != nil {  		return fmt.Errorf("notifyFave: error putting notification in database: %s", err)  	} @@ -242,7 +242,7 @@ func (p *Processor) notifyAnnounce(ctx context.Context, status *gtsmodel.Status)  	}  	if status.BoostOf == nil { -		boostedStatus, err := p.db.GetStatusByID(ctx, status.BoostOfID) +		boostedStatus, err := p.state.DB.GetStatusByID(ctx, status.BoostOfID)  		if err != nil {  			return fmt.Errorf("notifyAnnounce: error getting status with id %s: %s", status.BoostOfID, err)  		} @@ -250,7 +250,7 @@ func (p *Processor) notifyAnnounce(ctx context.Context, status *gtsmodel.Status)  	}  	if status.BoostOfAccount == nil { -		boostedAcct, err := p.db.GetAccountByID(ctx, status.BoostOfAccountID) +		boostedAcct, err := p.state.DB.GetAccountByID(ctx, status.BoostOfAccountID)  		if err != nil {  			return fmt.Errorf("notifyAnnounce: error getting account with id %s: %s", status.BoostOfAccountID, err)  		} @@ -269,7 +269,7 @@ func (p *Processor) notifyAnnounce(ctx context.Context, status *gtsmodel.Status)  	}  	// make sure a notif doesn't already exist for this announce -	err := p.db.GetWhere(ctx, []db.Where{ +	err := p.state.DB.GetWhere(ctx, []db.Where{  		{Key: "notification_type", Value: gtsmodel.NotificationReblog},  		{Key: "target_account_id", Value: status.BoostOfAccountID},  		{Key: "origin_account_id", Value: status.AccountID}, @@ -292,7 +292,7 @@ func (p *Processor) notifyAnnounce(ctx context.Context, status *gtsmodel.Status)  		Status:           status,  	} -	if err := p.db.Put(ctx, notif); err != nil { +	if err := p.state.DB.Put(ctx, notif); err != nil {  		return fmt.Errorf("notifyAnnounce: error putting notification in database: %s", err)  	} @@ -314,7 +314,7 @@ func (p *Processor) notifyAnnounce(ctx context.Context, status *gtsmodel.Status)  func (p *Processor) timelineStatus(ctx context.Context, status *gtsmodel.Status) error {  	// make sure the author account is pinned onto the status  	if status.Account == nil { -		a, err := p.db.GetAccountByID(ctx, status.AccountID) +		a, err := p.state.DB.GetAccountByID(ctx, status.AccountID)  		if err != nil {  			return fmt.Errorf("timelineStatus: error getting author account with id %s: %s", status.AccountID, err)  		} @@ -322,7 +322,7 @@ func (p *Processor) timelineStatus(ctx context.Context, status *gtsmodel.Status)  	}  	// get local followers of the account that posted the status -	follows, err := p.db.GetAccountFollowedBy(ctx, status.AccountID, true) +	follows, err := p.state.DB.GetAccountFollowedBy(ctx, status.AccountID, true)  	if err != nil {  		return fmt.Errorf("timelineStatus: error getting followers for account id %s: %s", status.AccountID, err)  	} @@ -374,7 +374,7 @@ func (p *Processor) timelineStatusForAccount(ctx context.Context, status *gtsmod  	defer wg.Done()  	// get the timeline owner account -	timelineAccount, err := p.db.GetAccountByID(ctx, accountID) +	timelineAccount, err := p.state.DB.GetAccountByID(ctx, accountID)  	if err != nil {  		errors <- fmt.Errorf("timelineStatusForAccount: error getting account for timeline with id %s: %s", accountID, err)  		return @@ -446,28 +446,28 @@ func (p *Processor) wipeStatus(ctx context.Context, statusToDelete *gtsmodel.Sta  	// delete all mention entries generated by this status  	for _, m := range statusToDelete.MentionIDs { -		if err := p.db.DeleteByID(ctx, m, >smodel.Mention{}); err != nil { +		if err := p.state.DB.DeleteByID(ctx, m, >smodel.Mention{}); err != nil {  			return err  		}  	}  	// delete all notification entries generated by this status -	if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "status_id", Value: statusToDelete.ID}}, &[]*gtsmodel.Notification{}); err != nil { +	if err := p.state.DB.DeleteWhere(ctx, []db.Where{{Key: "status_id", Value: statusToDelete.ID}}, &[]*gtsmodel.Notification{}); err != nil {  		return err  	}  	// delete all bookmarks that point to this status -	if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "status_id", Value: statusToDelete.ID}}, &[]*gtsmodel.StatusBookmark{}); err != nil { +	if err := p.state.DB.DeleteWhere(ctx, []db.Where{{Key: "status_id", Value: statusToDelete.ID}}, &[]*gtsmodel.StatusBookmark{}); err != nil {  		return err  	}  	// delete all boosts for this status + remove them from timelines -	if boosts, err := p.db.GetStatusReblogs(ctx, statusToDelete); err == nil { +	if boosts, err := p.state.DB.GetStatusReblogs(ctx, statusToDelete); err == nil {  		for _, b := range boosts {  			if err := p.deleteStatusFromTimelines(ctx, b); err != nil {  				return err  			} -			if err := p.db.DeleteStatusByID(ctx, b.ID); err != nil { +			if err := p.state.DB.DeleteStatusByID(ctx, b.ID); err != nil {  				return err  			}  		} @@ -479,7 +479,7 @@ func (p *Processor) wipeStatus(ctx context.Context, statusToDelete *gtsmodel.Sta  	}  	// delete the status itself -	if err := p.db.DeleteStatusByID(ctx, statusToDelete.ID); err != nil { +	if err := p.state.DB.DeleteStatusByID(ctx, statusToDelete.ID); err != nil {  		return err  	} diff --git a/internal/processing/fromfederator.go b/internal/processing/fromfederator.go index eea3c529d..afddedf93 100644 --- a/internal/processing/fromfederator.go +++ b/internal/processing/fromfederator.go @@ -139,7 +139,7 @@ func (p *Processor) processCreateStatusFromFederator(ctx context.Context, federa  	// make sure the account is pinned  	if status.Account == nil { -		a, err := p.db.GetAccountByID(ctx, status.AccountID) +		a, err := p.state.DB.GetAccountByID(ctx, status.AccountID)  		if err != nil {  			return err  		} @@ -185,7 +185,7 @@ func (p *Processor) processCreateFaveFromFederator(ctx context.Context, federato  	// make sure the account is pinned  	if incomingFave.Account == nil { -		a, err := p.db.GetAccountByID(ctx, incomingFave.AccountID) +		a, err := p.state.DB.GetAccountByID(ctx, incomingFave.AccountID)  		if err != nil {  			return err  		} @@ -227,7 +227,7 @@ func (p *Processor) processCreateFollowRequestFromFederator(ctx context.Context,  	// make sure the account is pinned  	if followRequest.Account == nil { -		a, err := p.db.GetAccountByID(ctx, followRequest.AccountID) +		a, err := p.state.DB.GetAccountByID(ctx, followRequest.AccountID)  		if err != nil {  			return err  		} @@ -254,7 +254,7 @@ func (p *Processor) processCreateFollowRequestFromFederator(ctx context.Context,  	}  	if followRequest.TargetAccount == nil { -		a, err := p.db.GetAccountByID(ctx, followRequest.TargetAccountID) +		a, err := p.state.DB.GetAccountByID(ctx, followRequest.TargetAccountID)  		if err != nil {  			return err  		} @@ -267,7 +267,7 @@ func (p *Processor) processCreateFollowRequestFromFederator(ctx context.Context,  	}  	// if the target account isn't locked, we should already accept the follow and notify about the new follower instead -	follow, err := p.db.AcceptFollowRequest(ctx, followRequest.AccountID, followRequest.TargetAccountID) +	follow, err := p.state.DB.AcceptFollowRequest(ctx, followRequest.AccountID, followRequest.TargetAccountID)  	if err != nil {  		return err  	} @@ -288,7 +288,7 @@ func (p *Processor) processCreateAnnounceFromFederator(ctx context.Context, fede  	// make sure the account is pinned  	if incomingAnnounce.Account == nil { -		a, err := p.db.GetAccountByID(ctx, incomingAnnounce.AccountID) +		a, err := p.state.DB.GetAccountByID(ctx, incomingAnnounce.AccountID)  		if err != nil {  			return err  		} @@ -324,7 +324,7 @@ func (p *Processor) processCreateAnnounceFromFederator(ctx context.Context, fede  	}  	incomingAnnounce.ID = incomingAnnounceID -	if err := p.db.PutStatus(ctx, incomingAnnounce); err != nil { +	if err := p.state.DB.PutStatus(ctx, incomingAnnounce); err != nil {  		return fmt.Errorf("error adding dereferenced announce to the db: %s", err)  	} diff --git a/internal/processing/fromfederator_test.go b/internal/processing/fromfederator_test.go index 6913b22af..d8f8ad6e1 100644 --- a/internal/processing/fromfederator_test.go +++ b/internal/processing/fromfederator_test.go @@ -344,7 +344,6 @@ func (suite *FromFederatorTestSuite) TestProcessAccountDelete() {  	suite.NoError(err)  	// now they are mufos! -  	err = suite.processor.ProcessFromFederator(ctx, messages.FromFederator{  		APObjectType:     ap.ObjectProfile,  		APActivityType:   ap.ActivityDelete, diff --git a/internal/processing/instance.go b/internal/processing/instance.go index c3dc4dcea..3ca807af3 100644 --- a/internal/processing/instance.go +++ b/internal/processing/instance.go @@ -35,7 +35,7 @@ import (  func (p *Processor) getThisInstance(ctx context.Context) (*gtsmodel.Instance, error) {  	i := >smodel.Instance{} -	if err := p.db.GetWhere(ctx, []db.Where{{Key: "domain", Value: config.GetHost()}}, i); err != nil { +	if err := p.state.DB.GetWhere(ctx, []db.Where{{Key: "domain", Value: config.GetHost()}}, i); err != nil {  		return nil, err  	}  	return i, nil @@ -73,7 +73,7 @@ func (p *Processor) InstancePeersGet(ctx context.Context, includeSuspended bool,  	domains := []*apimodel.Domain{}  	if includeOpen { -		instances, err := p.db.GetInstancePeers(ctx, false) +		instances, err := p.state.DB.GetInstancePeers(ctx, false)  		if err != nil && err != db.ErrNoEntries {  			err = fmt.Errorf("error selecting instance peers: %s", err)  			return nil, gtserror.NewErrorInternalError(err) @@ -87,7 +87,7 @@ func (p *Processor) InstancePeersGet(ctx context.Context, includeSuspended bool,  	if includeSuspended {  		domainBlocks := []*gtsmodel.DomainBlock{} -		if err := p.db.GetAll(ctx, &domainBlocks); err != nil && err != db.ErrNoEntries { +		if err := p.state.DB.GetAll(ctx, &domainBlocks); err != nil && err != db.ErrNoEntries {  			return nil, gtserror.NewErrorInternalError(err)  		} @@ -124,12 +124,12 @@ func (p *Processor) InstancePatch(ctx context.Context, form *apimodel.InstanceSe  	// fetch the instance entry from the db for processing  	i := >smodel.Instance{}  	host := config.GetHost() -	if err := p.db.GetWhere(ctx, []db.Where{{Key: "domain", Value: host}}, i); err != nil { +	if err := p.state.DB.GetWhere(ctx, []db.Where{{Key: "domain", Value: host}}, i); err != nil {  		return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error fetching instance %s: %s", host, err))  	}  	// fetch the instance account from the db for processing -	ia, err := p.db.GetInstanceAccount(ctx, "") +	ia, err := p.state.DB.GetInstanceAccount(ctx, "")  	if err != nil {  		return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error fetching instance account %s: %s", host, err))  	} @@ -148,12 +148,12 @@ func (p *Processor) InstancePatch(ctx context.Context, form *apimodel.InstanceSe  	// validate & update site contact account if it's set on the form  	if form.ContactUsername != nil {  		// make sure the account with the given username exists in the db -		contactAccount, err := p.db.GetAccountByUsernameDomain(ctx, *form.ContactUsername, "") +		contactAccount, err := p.state.DB.GetAccountByUsernameDomain(ctx, *form.ContactUsername, "")  		if err != nil {  			return nil, gtserror.NewErrorBadRequest(err, fmt.Sprintf("account with username %s not retrievable", *form.ContactUsername))  		}  		// make sure it has a user associated with it -		contactUser, err := p.db.GetUserByAccountID(ctx, contactAccount.ID) +		contactUser, err := p.state.DB.GetUserByAccountID(ctx, contactAccount.ID)  		if err != nil {  			return nil, gtserror.NewErrorBadRequest(err, fmt.Sprintf("user for account with username %s not retrievable", *form.ContactUsername))  		} @@ -233,7 +233,7 @@ func (p *Processor) InstancePatch(ctx context.Context, form *apimodel.InstanceSe  	} else if form.AvatarDescription != nil && ia.AvatarMediaAttachment != nil {  		// process just the description for the existing avatar  		ia.AvatarMediaAttachment.Description = *form.AvatarDescription -		if err := p.db.UpdateByID(ctx, ia.AvatarMediaAttachment, ia.AvatarMediaAttachmentID, "description"); err != nil { +		if err := p.state.DB.UpdateByID(ctx, ia.AvatarMediaAttachment, ia.AvatarMediaAttachmentID, "description"); err != nil {  			return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error updating instance avatar description: %s", err))  		}  	} @@ -252,13 +252,13 @@ func (p *Processor) InstancePatch(ctx context.Context, form *apimodel.InstanceSe  	if updateInstanceAccount {  		// if either avatar or header is updated, we need  		// to update the instance account that stores them -		if err := p.db.UpdateAccount(ctx, ia); err != nil { +		if err := p.state.DB.UpdateAccount(ctx, ia); err != nil {  			return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error updating instance account: %s", err))  		}  	}  	if len(updatingColumns) != 0 { -		if err := p.db.UpdateByID(ctx, i, i.ID, updatingColumns...); err != nil { +		if err := p.state.DB.UpdateByID(ctx, i, i.ID, updatingColumns...); err != nil {  			return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error updating instance %s: %s", host, err))  		}  	} diff --git a/internal/processing/media/delete.go b/internal/processing/media/delete.go index 6507fcae4..02bd6cd0d 100644 --- a/internal/processing/media/delete.go +++ b/internal/processing/media/delete.go @@ -13,7 +13,7 @@ import (  // Delete deletes the media attachment with the given ID, including all files pertaining to that attachment.  func (p *Processor) Delete(ctx context.Context, mediaAttachmentID string) gtserror.WithCode { -	attachment, err := p.db.GetAttachmentByID(ctx, mediaAttachmentID) +	attachment, err := p.state.DB.GetAttachmentByID(ctx, mediaAttachmentID)  	if err != nil {  		if err == db.ErrNoEntries {  			// attachment already gone @@ -27,20 +27,20 @@ func (p *Processor) Delete(ctx context.Context, mediaAttachmentID string) gtserr  	// delete the thumbnail from storage  	if attachment.Thumbnail.Path != "" { -		if err := p.storage.Delete(ctx, attachment.Thumbnail.Path); err != nil && !errors.Is(err, storage.ErrNotFound) { +		if err := p.state.Storage.Delete(ctx, attachment.Thumbnail.Path); err != nil && !errors.Is(err, storage.ErrNotFound) {  			errs = append(errs, fmt.Sprintf("remove thumbnail at path %s: %s", attachment.Thumbnail.Path, err))  		}  	}  	// delete the file from storage  	if attachment.File.Path != "" { -		if err := p.storage.Delete(ctx, attachment.File.Path); err != nil && !errors.Is(err, storage.ErrNotFound) { +		if err := p.state.Storage.Delete(ctx, attachment.File.Path); err != nil && !errors.Is(err, storage.ErrNotFound) {  			errs = append(errs, fmt.Sprintf("remove file at path %s: %s", attachment.File.Path, err))  		}  	}  	// delete the attachment -	if err := p.db.DeleteByID(ctx, mediaAttachmentID, attachment); err != nil && !errors.Is(err, db.ErrNoEntries) { +	if err := p.state.DB.DeleteByID(ctx, mediaAttachmentID, attachment); err != nil && !errors.Is(err, db.ErrNoEntries) {  		errs = append(errs, fmt.Sprintf("remove attachment: %s", err))  	} diff --git a/internal/processing/media/getemoji.go b/internal/processing/media/getemoji.go index 4c0ce9930..fba059f60 100644 --- a/internal/processing/media/getemoji.go +++ b/internal/processing/media/getemoji.go @@ -31,7 +31,7 @@ import (  // GetCustomEmojis returns a list of all useable local custom emojis stored on this instance.  // 'useable' in this context means visible and picker, and not disabled.  func (p *Processor) GetCustomEmojis(ctx context.Context) ([]*apimodel.Emoji, gtserror.WithCode) { -	emojis, err := p.db.GetUseableEmojis(ctx) +	emojis, err := p.state.DB.GetUseableEmojis(ctx)  	if err != nil {  		if err != db.ErrNoEntries {  			return nil, gtserror.NewErrorNotFound(fmt.Errorf("db error retrieving custom emojis: %s", err)) diff --git a/internal/processing/media/getfile.go b/internal/processing/media/getfile.go index 2a4ef2097..f9c6c23c2 100644 --- a/internal/processing/media/getfile.go +++ b/internal/processing/media/getfile.go @@ -54,7 +54,7 @@ func (p *Processor) GetFile(ctx context.Context, requestingAccount *gtsmodel.Acc  	owningAccountID := form.AccountID  	// get the account that owns the media and make sure it's not suspended -	owningAccount, err := p.db.GetAccountByID(ctx, owningAccountID) +	owningAccount, err := p.state.DB.GetAccountByID(ctx, owningAccountID)  	if err != nil {  		return nil, gtserror.NewErrorNotFound(fmt.Errorf("account with id %s could not be selected from the db: %s", owningAccountID, err))  	} @@ -64,7 +64,7 @@ func (p *Processor) GetFile(ctx context.Context, requestingAccount *gtsmodel.Acc  	// make sure the requesting account and the media account don't block each other  	if requestingAccount != nil { -		blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, owningAccountID, true) +		blocked, err := p.state.DB.IsBlocked(ctx, requestingAccount.ID, owningAccountID, true)  		if err != nil {  			return nil, gtserror.NewErrorNotFound(fmt.Errorf("block status could not be established between accounts %s and %s: %s", owningAccountID, requestingAccount.ID, err))  		} @@ -117,7 +117,7 @@ func parseSize(s string) (media.Size, error) {  func (p *Processor) getAttachmentContent(ctx context.Context, requestingAccount *gtsmodel.Account, wantedMediaID string, owningAccountID string, mediaSize media.Size) (*apimodel.Content, gtserror.WithCode) {  	// retrieve attachment from the database and do basic checks on it -	a, err := p.db.GetAttachmentByID(ctx, wantedMediaID) +	a, err := p.state.DB.GetAttachmentByID(ctx, wantedMediaID)  	if err != nil {  		return nil, gtserror.NewErrorNotFound(fmt.Errorf("attachment %s could not be taken from the db: %s", wantedMediaID, err))  	} @@ -209,7 +209,7 @@ func (p *Processor) getEmojiContent(ctx context.Context, fileName string, owning  	// so this is more reliable than using full size url  	imageStaticURL := uris.GenerateURIForAttachment(owningAccountID, string(media.TypeEmoji), string(media.SizeStatic), fileName, "png") -	e, err := p.db.GetEmojiByStaticURL(ctx, imageStaticURL) +	e, err := p.state.DB.GetEmojiByStaticURL(ctx, imageStaticURL)  	if err != nil {  		return nil, gtserror.NewErrorNotFound(fmt.Errorf("emoji %s could not be taken from the db: %s", fileName, err))  	} @@ -237,12 +237,12 @@ func (p *Processor) getEmojiContent(ctx context.Context, fileName string, owning  func (p *Processor) retrieveFromStorage(ctx context.Context, storagePath string, content *apimodel.Content) (*apimodel.Content, gtserror.WithCode) {  	// If running on S3 storage with proxying disabled then  	// just fetch a pre-signed URL instead of serving the content. -	if url := p.storage.URL(ctx, storagePath); url != nil { +	if url := p.state.Storage.URL(ctx, storagePath); url != nil {  		content.URL = url  		return content, nil  	} -	reader, err := p.storage.GetStream(ctx, storagePath) +	reader, err := p.state.Storage.GetStream(ctx, storagePath)  	if err != nil {  		return nil, gtserror.NewErrorNotFound(fmt.Errorf("error retrieving from storage: %s", err))  	} diff --git a/internal/processing/media/getmedia.go b/internal/processing/media/getmedia.go index 03d5ba770..dad6ac538 100644 --- a/internal/processing/media/getmedia.go +++ b/internal/processing/media/getmedia.go @@ -30,7 +30,7 @@ import (  )  func (p *Processor) Get(ctx context.Context, account *gtsmodel.Account, mediaAttachmentID string) (*apimodel.Attachment, gtserror.WithCode) { -	attachment, err := p.db.GetAttachmentByID(ctx, mediaAttachmentID) +	attachment, err := p.state.DB.GetAttachmentByID(ctx, mediaAttachmentID)  	if err != nil {  		if err == db.ErrNoEntries {  			// attachment doesn't exist diff --git a/internal/processing/media/media.go b/internal/processing/media/media.go index ca95e276f..51585102a 100644 --- a/internal/processing/media/media.go +++ b/internal/processing/media/media.go @@ -19,28 +19,25 @@  package media  import ( -	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/media" -	"github.com/superseriousbusiness/gotosocial/internal/storage" +	"github.com/superseriousbusiness/gotosocial/internal/state"  	"github.com/superseriousbusiness/gotosocial/internal/transport"  	"github.com/superseriousbusiness/gotosocial/internal/typeutils"  )  type Processor struct { +	state               *state.State  	tc                  typeutils.TypeConverter  	mediaManager        media.Manager  	transportController transport.Controller -	storage             *storage.Driver -	db                  db.DB  }  // New returns a new media processor. -func New(db db.DB, tc typeutils.TypeConverter, mediaManager media.Manager, transportController transport.Controller, storage *storage.Driver) Processor { +func New(state *state.State, tc typeutils.TypeConverter, mediaManager media.Manager, transportController transport.Controller) Processor {  	return Processor{ +		state:               state,  		tc:                  tc,  		mediaManager:        mediaManager,  		transportController: transportController, -		storage:             storage, -		db:                  db,  	}  } diff --git a/internal/processing/media/media_test.go b/internal/processing/media/media_test.go index 1d223a66c..e706dbd7a 100644 --- a/internal/processing/media/media_test.go +++ b/internal/processing/media/media_test.go @@ -20,12 +20,11 @@ package media_test  import (  	"github.com/stretchr/testify/suite" -	"github.com/superseriousbusiness/gotosocial/internal/concurrency"  	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"  	"github.com/superseriousbusiness/gotosocial/internal/media" -	"github.com/superseriousbusiness/gotosocial/internal/messages"  	mediaprocessing "github.com/superseriousbusiness/gotosocial/internal/processing/media" +	"github.com/superseriousbusiness/gotosocial/internal/state"  	"github.com/superseriousbusiness/gotosocial/internal/storage"  	"github.com/superseriousbusiness/gotosocial/internal/transport"  	"github.com/superseriousbusiness/gotosocial/internal/typeutils" @@ -38,6 +37,7 @@ type MediaStandardTestSuite struct {  	db                  db.DB  	tc                  typeutils.TypeConverter  	storage             *storage.Driver +	state               state.State  	mediaManager        media.Manager  	transportController transport.Controller @@ -67,15 +67,19 @@ func (suite *MediaStandardTestSuite) SetupSuite() {  }  func (suite *MediaStandardTestSuite) SetupTest() { +	suite.state.Caches.Init() +  	testrig.InitTestConfig()  	testrig.InitTestLog() -	suite.db = testrig.NewTestDB() +	suite.db = testrig.NewTestDB(&suite.state) +	suite.state.DB = suite.db  	suite.tc = testrig.NewTestTypeConverter(suite.db)  	suite.storage = testrig.NewInMemoryStorage() -	suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage) -	suite.transportController = testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../testrig/media"), suite.db, concurrency.NewWorkerPool[messages.FromFederator](-1, -1)) -	suite.mediaProcessor = mediaprocessing.New(suite.db, suite.tc, suite.mediaManager, suite.transportController, suite.storage) +	suite.state.Storage = suite.storage +	suite.mediaManager = testrig.NewTestMediaManager(&suite.state) +	suite.transportController = testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../testrig/media")) +	suite.mediaProcessor = mediaprocessing.New(&suite.state, suite.tc, suite.mediaManager, suite.transportController)  	testrig.StandardDBSetup(suite.db, nil)  	testrig.StandardStorageSetup(suite.storage, "../../../testrig/media")  } diff --git a/internal/processing/media/unattach.go b/internal/processing/media/unattach.go index 816b5134e..7c6f7dbac 100644 --- a/internal/processing/media/unattach.go +++ b/internal/processing/media/unattach.go @@ -33,7 +33,7 @@ import (  // Unattach unattaches the media attachment with the given ID from any statuses it was attached to, making it available  // for reattachment again.  func (p *Processor) Unattach(ctx context.Context, account *gtsmodel.Account, mediaAttachmentID string) (*apimodel.Attachment, gtserror.WithCode) { -	attachment, err := p.db.GetAttachmentByID(ctx, mediaAttachmentID) +	attachment, err := p.state.DB.GetAttachmentByID(ctx, mediaAttachmentID)  	if err != nil {  		if err == db.ErrNoEntries {  			return nil, gtserror.NewErrorNotFound(errors.New("attachment doesn't exist in the db")) @@ -49,7 +49,7 @@ func (p *Processor) Unattach(ctx context.Context, account *gtsmodel.Account, med  	attachment.UpdatedAt = time.Now()  	attachment.StatusID = "" -	if err := p.db.UpdateByID(ctx, attachment, attachment.ID, updatingColumns...); err != nil { +	if err := p.state.DB.UpdateByID(ctx, attachment, attachment.ID, updatingColumns...); err != nil {  		return nil, gtserror.NewErrorNotFound(fmt.Errorf("db error updating attachment: %s", err))  	} diff --git a/internal/processing/media/update.go b/internal/processing/media/update.go index c03df705b..cf49168f0 100644 --- a/internal/processing/media/update.go +++ b/internal/processing/media/update.go @@ -32,7 +32,7 @@ import (  // Update updates a media attachment with the given id, using the provided form parameters.  func (p *Processor) Update(ctx context.Context, account *gtsmodel.Account, mediaAttachmentID string, form *apimodel.AttachmentUpdateRequest) (*apimodel.Attachment, gtserror.WithCode) { -	attachment, err := p.db.GetAttachmentByID(ctx, mediaAttachmentID) +	attachment, err := p.state.DB.GetAttachmentByID(ctx, mediaAttachmentID)  	if err != nil {  		if err == db.ErrNoEntries {  			// attachment doesn't exist @@ -62,7 +62,7 @@ func (p *Processor) Update(ctx context.Context, account *gtsmodel.Account, media  		updatingColumns = append(updatingColumns, "focus_x", "focus_y")  	} -	if err := p.db.UpdateByID(ctx, attachment, attachment.ID, updatingColumns...); err != nil { +	if err := p.state.DB.UpdateByID(ctx, attachment, attachment.ID, updatingColumns...); err != nil {  		return nil, gtserror.NewErrorInternalError(fmt.Errorf("database error updating media: %s", err))  	} diff --git a/internal/processing/notification.go b/internal/processing/notification.go index 05d0e82ee..57100e743 100644 --- a/internal/processing/notification.go +++ b/internal/processing/notification.go @@ -29,7 +29,7 @@ import (  )  func (p *Processor) NotificationsGet(ctx context.Context, authed *oauth.Auth, excludeTypes []string, limit int, maxID string, sinceID string) (*apimodel.PageableResponse, gtserror.WithCode) { -	notifs, err := p.db.GetNotifications(ctx, authed.Account.ID, excludeTypes, limit, maxID, sinceID) +	notifs, err := p.state.DB.GetNotifications(ctx, authed.Account.ID, excludeTypes, limit, maxID, sinceID)  	if err != nil {  		return nil, gtserror.NewErrorInternalError(err)  	} @@ -72,7 +72,7 @@ func (p *Processor) NotificationsGet(ctx context.Context, authed *oauth.Auth, ex  }  func (p *Processor) NotificationsClear(ctx context.Context, authed *oauth.Auth) gtserror.WithCode { -	err := p.db.ClearNotifications(ctx, authed.Account.ID) +	err := p.state.DB.ClearNotifications(ctx, authed.Account.ID)  	if err != nil {  		return gtserror.NewErrorInternalError(err)  	} diff --git a/internal/processing/processor.go b/internal/processing/processor.go index 07fcdb8b3..bb75aab76 100644 --- a/internal/processing/processor.go +++ b/internal/processing/processor.go @@ -19,10 +19,11 @@  package processing  import ( -	"github.com/superseriousbusiness/gotosocial/internal/concurrency" -	"github.com/superseriousbusiness/gotosocial/internal/db" +	"context" +  	"github.com/superseriousbusiness/gotosocial/internal/email"  	"github.com/superseriousbusiness/gotosocial/internal/federation" +	"github.com/superseriousbusiness/gotosocial/internal/log"  	mm "github.com/superseriousbusiness/gotosocial/internal/media"  	"github.com/superseriousbusiness/gotosocial/internal/messages"  	"github.com/superseriousbusiness/gotosocial/internal/oauth" @@ -34,23 +35,19 @@ import (  	"github.com/superseriousbusiness/gotosocial/internal/processing/status"  	"github.com/superseriousbusiness/gotosocial/internal/processing/stream"  	"github.com/superseriousbusiness/gotosocial/internal/processing/user" -	"github.com/superseriousbusiness/gotosocial/internal/storage" +	"github.com/superseriousbusiness/gotosocial/internal/state"  	"github.com/superseriousbusiness/gotosocial/internal/timeline"  	"github.com/superseriousbusiness/gotosocial/internal/typeutils"  	"github.com/superseriousbusiness/gotosocial/internal/visibility"  )  type Processor struct { -	clientWorker *concurrency.WorkerPool[messages.FromClientAPI] -	fedWorker    *concurrency.WorkerPool[messages.FromFederator] -  	federator       federation.Federator  	tc              typeutils.TypeConverter  	oauthServer     oauth.Server  	mediaManager    mm.Manager -	storage         *storage.Driver  	statusTimelines timeline.Manager -	db              db.DB +	state           *state.State  	filter          visibility.Filter  	/* @@ -105,76 +102,65 @@ func NewProcessor(  	federator federation.Federator,  	oauthServer oauth.Server,  	mediaManager mm.Manager, -	storage *storage.Driver, -	db db.DB, +	state *state.State,  	emailSender email.Sender, -	clientWorker *concurrency.WorkerPool[messages.FromClientAPI], -	fedWorker *concurrency.WorkerPool[messages.FromFederator],  ) *Processor { -	parseMentionFunc := GetParseMentionFunc(db, federator) - -	filter := visibility.NewFilter(db) - -	return &Processor{ -		clientWorker: clientWorker, -		fedWorker:    fedWorker, - -		federator:       federator, -		tc:              tc, -		oauthServer:     oauthServer, -		mediaManager:    mediaManager, -		storage:         storage, -		statusTimelines: timeline.NewManager(StatusGrabFunction(db), StatusFilterFunction(db, filter), StatusPrepareFunction(db, tc), StatusSkipInsertFunction()), -		db:              db, -		filter:          filter, - -		// sub processors -		account: account.New(db, tc, mediaManager, oauthServer, clientWorker, federator, parseMentionFunc), -		admin:   admin.New(db, tc, mediaManager, federator.TransportController(), storage, clientWorker), -		fedi:    fedi.New(db, tc, federator), -		media:   media.New(db, tc, mediaManager, federator.TransportController(), storage), -		report:  report.New(db, tc, clientWorker), -		status:  status.New(db, tc, clientWorker, parseMentionFunc), -		stream:  stream.New(db, oauthServer), -		user:    user.New(db, emailSender), +	parseMentionFunc := GetParseMentionFunc(state.DB, federator) + +	filter := visibility.NewFilter(state.DB) + +	processor := &Processor{ +		federator:    federator, +		tc:           tc, +		oauthServer:  oauthServer, +		mediaManager: mediaManager, +		statusTimelines: timeline.NewManager( +			StatusGrabFunction(state.DB), +			StatusFilterFunction(state.DB, filter), +			StatusPrepareFunction(state.DB, tc), +			StatusSkipInsertFunction(), +		), +		state:  state, +		filter: filter,  	} -} -// Start starts the Processor, reading from its channels and passing messages back and forth. -func (p *Processor) Start() error { -	// Setup and start the client API worker pool -	p.clientWorker.SetProcessor(p.ProcessFromClientAPI) -	if err := p.clientWorker.Start(); err != nil { -		return err -	} +	// sub processors +	processor.account = account.New(state, tc, mediaManager, oauthServer, federator, parseMentionFunc) +	processor.admin = admin.New(state, tc, mediaManager, federator.TransportController()) +	processor.fedi = fedi.New(state, tc, federator) +	processor.media = media.New(state, tc, mediaManager, federator.TransportController()) +	processor.report = report.New(state, tc) +	processor.status = status.New(state, tc, parseMentionFunc) +	processor.stream = stream.New(state, oauthServer) +	processor.user = user.New(state, emailSender) + +	return processor +} -	// Setup and start the federator worker pool -	p.fedWorker.SetProcessor(p.ProcessFromFederator) -	if err := p.fedWorker.Start(); err != nil { -		return err -	} +func (p *Processor) EnqueueClientAPI(ctx context.Context, msg messages.FromClientAPI) { +	log.WithContext(ctx).WithField("msg", msg).Trace("enqueuing client API") +	_ = p.state.Workers.ClientAPI.MustEnqueueCtx(ctx, func(ctx context.Context) { +		if err := p.ProcessFromClientAPI(ctx, msg); err != nil { +			log.Errorf(ctx, "error processing client API message: %v", err) +		} +	}) +} -	// Start status timelines -	if err := p.statusTimelines.Start(); err != nil { -		return err -	} +func (p *Processor) EnqueueFederator(ctx context.Context, msg messages.FromFederator) { +	log.WithContext(ctx).WithField("msg", msg).Trace("enqueuing federator") +	_ = p.state.Workers.Federator.MustEnqueueCtx(ctx, func(ctx context.Context) { +		if err := p.ProcessFromFederator(ctx, msg); err != nil { +			log.Errorf(ctx, "error processing federator message: %v", err) +		} +	}) +} -	return nil +// Start starts the Processor. +func (p *Processor) Start() error { +	return p.statusTimelines.Start()  } -// Stop stops the processor cleanly, finishing handling any remaining messages before closing down. +// Stop stops the processor cleanly.  func (p *Processor) Stop() error { -	if err := p.clientWorker.Stop(); err != nil { -		return err -	} - -	if err := p.fedWorker.Stop(); err != nil { -		return err -	} - -	if err := p.statusTimelines.Stop(); err != nil { -		return err -	} - -	return nil +	return p.statusTimelines.Stop()  } diff --git a/internal/processing/processor_test.go b/internal/processing/processor_test.go index 44857cb47..d8da87bcc 100644 --- a/internal/processing/processor_test.go +++ b/internal/processing/processor_test.go @@ -20,15 +20,14 @@ package processing_test  import (  	"github.com/stretchr/testify/suite" -	"github.com/superseriousbusiness/gotosocial/internal/concurrency"  	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/email"  	"github.com/superseriousbusiness/gotosocial/internal/federation"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"  	"github.com/superseriousbusiness/gotosocial/internal/media" -	"github.com/superseriousbusiness/gotosocial/internal/messages"  	"github.com/superseriousbusiness/gotosocial/internal/oauth"  	"github.com/superseriousbusiness/gotosocial/internal/processing" +	"github.com/superseriousbusiness/gotosocial/internal/state"  	"github.com/superseriousbusiness/gotosocial/internal/storage"  	"github.com/superseriousbusiness/gotosocial/internal/transport"  	"github.com/superseriousbusiness/gotosocial/internal/typeutils" @@ -40,6 +39,7 @@ type ProcessingStandardTestSuite struct {  	suite.Suite  	db                  db.DB  	storage             *storage.Driver +	state               state.State  	mediaManager        media.Manager  	typeconverter       typeutils.TypeConverter  	httpClient          *testrig.MockHTTPClient @@ -86,25 +86,29 @@ func (suite *ProcessingStandardTestSuite) SetupSuite() {  }  func (suite *ProcessingStandardTestSuite) SetupTest() { +	suite.state.Caches.Init() +	testrig.StartWorkers(&suite.state) +  	testrig.InitTestConfig()  	testrig.InitTestLog() -	suite.db = testrig.NewTestDB() +	suite.db = testrig.NewTestDB(&suite.state) +	suite.state.DB = suite.db  	suite.testActivities = testrig.NewTestActivities(suite.testAccounts)  	suite.storage = testrig.NewInMemoryStorage() +	suite.state.Storage = suite.storage  	suite.typeconverter = testrig.NewTestTypeConverter(suite.db)  	suite.httpClient = testrig.NewMockHTTPClient(nil, "../../testrig/media") -	clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) -	fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) - -	suite.transportController = testrig.NewTestTransportController(suite.httpClient, suite.db, fedWorker) -	suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage) -	suite.federator = testrig.NewTestFederator(suite.db, suite.transportController, suite.storage, suite.mediaManager, fedWorker) +	suite.transportController = testrig.NewTestTransportController(&suite.state, suite.httpClient) +	suite.mediaManager = testrig.NewTestMediaManager(&suite.state) +	suite.federator = testrig.NewTestFederator(&suite.state, suite.transportController, suite.mediaManager)  	suite.oauthServer = testrig.NewTestOauthServer(suite.db)  	suite.emailSender = testrig.NewEmailSender("../../web/template/", nil) -	suite.processor = processing.NewProcessor(suite.typeconverter, suite.federator, suite.oauthServer, suite.mediaManager, suite.storage, suite.db, suite.emailSender, clientWorker, fedWorker) +	suite.processor = processing.NewProcessor(suite.typeconverter, suite.federator, suite.oauthServer, suite.mediaManager, &suite.state, suite.emailSender) +	suite.state.Workers.EnqueueClientAPI = suite.processor.EnqueueClientAPI +	suite.state.Workers.EnqueueFederator = suite.processor.EnqueueFederator  	testrig.StandardDBSetup(suite.db, suite.testAccounts)  	testrig.StandardStorageSetup(suite.storage, "../../testrig/media") @@ -119,4 +123,5 @@ func (suite *ProcessingStandardTestSuite) TearDownTest() {  	if err := suite.processor.Stop(); err != nil {  		panic(err)  	} +	testrig.StopWorkers(&suite.state)  } diff --git a/internal/processing/report/create.go b/internal/processing/report/create.go index 726d11666..e0918554e 100644 --- a/internal/processing/report/create.go +++ b/internal/processing/report/create.go @@ -41,7 +41,7 @@ func (p *Processor) Create(ctx context.Context, account *gtsmodel.Account, form  	}  	// validate + fetch target account -	targetAccount, err := p.db.GetAccountByID(ctx, form.AccountID) +	targetAccount, err := p.state.DB.GetAccountByID(ctx, form.AccountID)  	if err != nil {  		if errors.Is(err, db.ErrNoEntries) {  			err = fmt.Errorf("account with ID %s does not exist", form.AccountID) @@ -52,7 +52,7 @@ func (p *Processor) Create(ctx context.Context, account *gtsmodel.Account, form  	}  	// fetch statuses by IDs given in the report form (noop if no statuses given) -	statuses, err := p.db.GetStatuses(ctx, form.StatusIDs) +	statuses, err := p.state.DB.GetStatuses(ctx, form.StatusIDs)  	if err != nil {  		err = fmt.Errorf("db error fetching report target statuses: %w", err)  		return nil, gtserror.NewErrorInternalError(err) @@ -79,11 +79,11 @@ func (p *Processor) Create(ctx context.Context, account *gtsmodel.Account, form  		Forwarded:       &form.Forward,  	} -	if err := p.db.PutReport(ctx, report); err != nil { +	if err := p.state.DB.PutReport(ctx, report); err != nil {  		return nil, gtserror.NewErrorInternalError(err)  	} -	p.clientWorker.Queue(messages.FromClientAPI{ +	p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{  		APObjectType:   ap.ObjectProfile,  		APActivityType: ap.ActivityFlag,  		GTSModel:       report, diff --git a/internal/processing/report/get.go b/internal/processing/report/get.go index af2079b8a..0348c397c 100644 --- a/internal/processing/report/get.go +++ b/internal/processing/report/get.go @@ -32,7 +32,7 @@ import (  // Get returns the user view of a moderation report, with the given id.  func (p *Processor) Get(ctx context.Context, account *gtsmodel.Account, id string) (*apimodel.Report, gtserror.WithCode) { -	report, err := p.db.GetReportByID(ctx, id) +	report, err := p.state.DB.GetReportByID(ctx, id)  	if err != nil {  		if err == db.ErrNoEntries {  			return nil, gtserror.NewErrorNotFound(err) @@ -64,7 +64,7 @@ func (p *Processor) GetMultiple(  	minID string,  	limit int,  ) (*apimodel.PageableResponse, gtserror.WithCode) { -	reports, err := p.db.GetReports(ctx, resolved, account.ID, targetAccountID, maxID, sinceID, minID, limit) +	reports, err := p.state.DB.GetReports(ctx, resolved, account.ID, targetAccountID, maxID, sinceID, minID, limit)  	if err != nil {  		if err == db.ErrNoEntries {  			return util.EmptyPageableResponse(), nil diff --git a/internal/processing/report/report.go b/internal/processing/report/report.go index b5f4b301e..bc634af2e 100644 --- a/internal/processing/report/report.go +++ b/internal/processing/report/report.go @@ -19,22 +19,18 @@  package report  import ( -	"github.com/superseriousbusiness/gotosocial/internal/concurrency" -	"github.com/superseriousbusiness/gotosocial/internal/db" -	"github.com/superseriousbusiness/gotosocial/internal/messages" +	"github.com/superseriousbusiness/gotosocial/internal/state"  	"github.com/superseriousbusiness/gotosocial/internal/typeutils"  )  type Processor struct { -	db           db.DB -	tc           typeutils.TypeConverter -	clientWorker *concurrency.WorkerPool[messages.FromClientAPI] +	state *state.State +	tc    typeutils.TypeConverter  } -func New(db db.DB, tc typeutils.TypeConverter, clientWorker *concurrency.WorkerPool[messages.FromClientAPI]) Processor { +func New(state *state.State, tc typeutils.TypeConverter) Processor {  	return Processor{ -		tc:           tc, -		db:           db, -		clientWorker: clientWorker, +		state: state, +		tc:    tc,  	}  } diff --git a/internal/processing/search.go b/internal/processing/search.go index 05a1fe353..c5592fffd 100644 --- a/internal/processing/search.go +++ b/internal/processing/search.go @@ -88,7 +88,7 @@ func (p *Processor) SearchGet(ctx context.Context, authed *oauth.Auth, search *a  	if username, domain, err := util.ExtractNamestringParts(maybeNamestring); err == nil {  		l.Trace("search term is a mention, looking it up...") -		blocked, err := p.db.IsDomainBlocked(ctx, domain) +		blocked, err := p.state.DB.IsDomainBlocked(ctx, domain)  		if err != nil {  			return nil, gtserror.NewErrorInternalError(fmt.Errorf("error checking domain block: %w", err))  		} @@ -120,7 +120,7 @@ func (p *Processor) SearchGet(ctx context.Context, authed *oauth.Auth, search *a  		if uri, err := url.Parse(query); err == nil {  			if uri.Scheme == "https" || uri.Scheme == "http" {  				l.Trace("search term is a uri, looking it up...") -				blocked, err := p.db.IsURIBlocked(ctx, uri) +				blocked, err := p.state.DB.IsURIBlocked(ctx, uri)  				if err != nil {  					return nil, gtserror.NewErrorInternalError(fmt.Errorf("error checking domain block: %w", err))  				} @@ -178,7 +178,7 @@ func (p *Processor) SearchGet(ctx context.Context, authed *oauth.Auth, search *a  	*/  	for _, foundAccount := range foundAccounts {  		// make sure there's no block in either direction between the account and the requester -		blocked, err := p.db.IsBlocked(ctx, authed.Account.ID, foundAccount.ID, true) +		blocked, err := p.state.DB.IsBlocked(ctx, authed.Account.ID, foundAccount.ID, true)  		if err != nil {  			err = fmt.Errorf("SearchGet: error checking block between %s and %s: %s", authed.Account.ID, foundAccount.ID, err)  			return nil, gtserror.NewErrorInternalError(err) @@ -246,14 +246,14 @@ func (p *Processor) searchAccountByURI(ctx context.Context, authed *oauth.Auth,  		)  		// Search the database for existing account with ID URI. -		account, err = p.db.GetAccountByURI(ctx, uriStr) +		account, err = p.state.DB.GetAccountByURI(ctx, uriStr)  		if err != nil && !errors.Is(err, db.ErrNoEntries) {  			return nil, fmt.Errorf("searchAccountByURI: error checking database for account %s: %w", uriStr, err)  		}  		if account == nil {  			// Else, search the database for existing by ID URL. -			account, err = p.db.GetAccountByURL(ctx, uriStr) +			account, err = p.state.DB.GetAccountByURL(ctx, uriStr)  			if err != nil {  				if !errors.Is(err, db.ErrNoEntries) {  					return nil, fmt.Errorf("searchAccountByURI: error checking database for account %s: %w", uriStr, err) @@ -281,7 +281,7 @@ func (p *Processor) searchAccountByUsernameDomain(ctx context.Context, authed *o  		}  		// Search the database for existing account with USERNAME@DOMAIN -		account, err := p.db.GetAccountByUsernameDomain(ctx, username, domain) +		account, err := p.state.DB.GetAccountByUsernameDomain(ctx, username, domain)  		if err != nil {  			if !errors.Is(err, db.ErrNoEntries) {  				return nil, fmt.Errorf("searchAccountByUsernameDomain: error checking database for account %s@%s: %w", username, domain, err) diff --git a/internal/processing/status/bookmark.go b/internal/processing/status/bookmark.go index dde31ea7d..cf3787da2 100644 --- a/internal/processing/status/bookmark.go +++ b/internal/processing/status/bookmark.go @@ -32,7 +32,7 @@ import (  // BookmarkCreate adds a bookmark for the requestingAccount, targeting the given status (no-op if bookmark already exists).  func (p *Processor) BookmarkCreate(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) (*apimodel.Status, gtserror.WithCode) { -	targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID) +	targetStatus, err := p.state.DB.GetStatusByID(ctx, targetStatusID)  	if err != nil {  		return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err))  	} @@ -50,7 +50,7 @@ func (p *Processor) BookmarkCreate(ctx context.Context, requestingAccount *gtsmo  	// first check if the status is already bookmarked, if so we don't need to do anything  	newBookmark := true  	gtsBookmark := >smodel.StatusBookmark{} -	if err := p.db.GetWhere(ctx, []db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsBookmark); err == nil { +	if err := p.state.DB.GetWhere(ctx, []db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsBookmark); err == nil {  		// we already have a bookmark for this status  		newBookmark = false  	} @@ -67,7 +67,7 @@ func (p *Processor) BookmarkCreate(ctx context.Context, requestingAccount *gtsmo  			Status:          targetStatus,  		} -		if err := p.db.Put(ctx, gtsBookmark); err != nil { +		if err := p.state.DB.Put(ctx, gtsBookmark); err != nil {  			return nil, gtserror.NewErrorInternalError(fmt.Errorf("error putting bookmark in database: %s", err))  		}  	} @@ -83,7 +83,7 @@ func (p *Processor) BookmarkCreate(ctx context.Context, requestingAccount *gtsmo  // BookmarkRemove removes a bookmark for the requesting account, targeting the given status (no-op if bookmark doesn't exist).  func (p *Processor) BookmarkRemove(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) (*apimodel.Status, gtserror.WithCode) { -	targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID) +	targetStatus, err := p.state.DB.GetStatusByID(ctx, targetStatusID)  	if err != nil {  		return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err))  	} @@ -101,13 +101,13 @@ func (p *Processor) BookmarkRemove(ctx context.Context, requestingAccount *gtsmo  	// first check if the status is actually bookmarked  	toUnbookmark := false  	gtsBookmark := >smodel.StatusBookmark{} -	if err := p.db.GetWhere(ctx, []db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsBookmark); err == nil { +	if err := p.state.DB.GetWhere(ctx, []db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsBookmark); err == nil {  		// we have a bookmark for this status  		toUnbookmark = true  	}  	if toUnbookmark { -		if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsBookmark); err != nil { +		if err := p.state.DB.DeleteWhere(ctx, []db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsBookmark); err != nil {  			return nil, gtserror.NewErrorInternalError(fmt.Errorf("error unfaveing status: %s", err))  		}  	} diff --git a/internal/processing/status/boost.go b/internal/processing/status/boost.go index 4dfe17019..6756d816c 100644 --- a/internal/processing/status/boost.go +++ b/internal/processing/status/boost.go @@ -33,7 +33,7 @@ import (  // BoostCreate processes the boost/reblog of a given status, returning the newly-created boost if all is well.  func (p *Processor) BoostCreate(ctx context.Context, requestingAccount *gtsmodel.Account, application *gtsmodel.Application, targetStatusID string) (*apimodel.Status, gtserror.WithCode) { -	targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID) +	targetStatus, err := p.state.DB.GetStatusByID(ctx, targetStatusID)  	if err != nil {  		return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err))  	} @@ -47,7 +47,7 @@ func (p *Processor) BoostCreate(ctx context.Context, requestingAccount *gtsmodel  	// boost boosts, and it looks absolutely bizarre in the UI  	if targetStatus.BoostOfID != "" {  		if targetStatus.BoostOf == nil { -			b, err := p.db.GetStatusByID(ctx, targetStatus.BoostOfID) +			b, err := p.state.DB.GetStatusByID(ctx, targetStatus.BoostOfID)  			if err != nil {  				return nil, gtserror.NewErrorNotFound(fmt.Errorf("couldn't fetch boosted status %s", targetStatus.BoostOfID))  			} @@ -74,12 +74,12 @@ func (p *Processor) BoostCreate(ctx context.Context, requestingAccount *gtsmodel  	boostWrapperStatus.BoostOfAccount = targetStatus.Account  	// put the boost in the database -	if err := p.db.PutStatus(ctx, boostWrapperStatus); err != nil { +	if err := p.state.DB.PutStatus(ctx, boostWrapperStatus); err != nil {  		return nil, gtserror.NewErrorInternalError(err)  	}  	// send it back to the processor for async processing -	p.clientWorker.Queue(messages.FromClientAPI{ +	p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{  		APObjectType:   ap.ActivityAnnounce,  		APActivityType: ap.ActivityCreate,  		GTSModel:       boostWrapperStatus, @@ -98,7 +98,7 @@ func (p *Processor) BoostCreate(ctx context.Context, requestingAccount *gtsmodel  // BoostRemove processes the unboost/unreblog of a given status, returning the status if all is well.  func (p *Processor) BoostRemove(ctx context.Context, requestingAccount *gtsmodel.Account, application *gtsmodel.Application, targetStatusID string) (*apimodel.Status, gtserror.WithCode) { -	targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID) +	targetStatus, err := p.state.DB.GetStatusByID(ctx, targetStatusID)  	if err != nil {  		return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err))  	} @@ -128,7 +128,7 @@ func (p *Processor) BoostRemove(ctx context.Context, requestingAccount *gtsmodel  			Value: requestingAccount.ID,  		},  	} -	err = p.db.GetWhere(ctx, where, gtsBoost) +	err = p.state.DB.GetWhere(ctx, where, gtsBoost)  	if err == nil {  		// we have a boost  		toUnboost = true @@ -151,7 +151,7 @@ func (p *Processor) BoostRemove(ctx context.Context, requestingAccount *gtsmodel  		gtsBoost.BoostOf.Account = targetStatus.Account  		// send it back to the processor for async processing -		p.clientWorker.Queue(messages.FromClientAPI{ +		p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{  			APObjectType:   ap.ActivityAnnounce,  			APActivityType: ap.ActivityUndo,  			GTSModel:       gtsBoost, @@ -170,7 +170,7 @@ func (p *Processor) BoostRemove(ctx context.Context, requestingAccount *gtsmodel  // StatusBoostedBy returns a slice of accounts that have boosted the given status, filtered according to privacy settings.  func (p *Processor) StatusBoostedBy(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) ([]*apimodel.Account, gtserror.WithCode) { -	targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID) +	targetStatus, err := p.state.DB.GetStatusByID(ctx, targetStatusID)  	if err != nil {  		wrapped := fmt.Errorf("BoostedBy: error fetching status %s: %s", targetStatusID, err)  		if !errors.Is(err, db.ErrNoEntries) { @@ -181,7 +181,7 @@ func (p *Processor) StatusBoostedBy(ctx context.Context, requestingAccount *gtsm  	if boostOfID := targetStatus.BoostOfID; boostOfID != "" {  		// the target status is a boost wrapper, redirect this request to the status it boosts -		boostedStatus, err := p.db.GetStatusByID(ctx, boostOfID) +		boostedStatus, err := p.state.DB.GetStatusByID(ctx, boostOfID)  		if err != nil {  			wrapped := fmt.Errorf("BoostedBy: error fetching status %s: %s", boostOfID, err)  			if !errors.Is(err, db.ErrNoEntries) { @@ -202,7 +202,7 @@ func (p *Processor) StatusBoostedBy(ctx context.Context, requestingAccount *gtsm  		return nil, gtserror.NewErrorNotFound(err)  	} -	statusReblogs, err := p.db.GetStatusReblogs(ctx, targetStatus) +	statusReblogs, err := p.state.DB.GetStatusReblogs(ctx, targetStatus)  	if err != nil {  		err = fmt.Errorf("BoostedBy: error seeing who boosted status: %s", err)  		return nil, gtserror.NewErrorNotFound(err) @@ -211,7 +211,7 @@ func (p *Processor) StatusBoostedBy(ctx context.Context, requestingAccount *gtsm  	// filter account IDs so the user doesn't see accounts they blocked or which blocked them  	accountIDs := make([]string, 0, len(statusReblogs))  	for _, s := range statusReblogs { -		blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, s.AccountID, true) +		blocked, err := p.state.DB.IsBlocked(ctx, requestingAccount.ID, s.AccountID, true)  		if err != nil {  			err = fmt.Errorf("BoostedBy: error checking blocks: %s", err)  			return nil, gtserror.NewErrorNotFound(err) @@ -226,7 +226,7 @@ func (p *Processor) StatusBoostedBy(ctx context.Context, requestingAccount *gtsm  	// fetch accounts + create their API representations  	apiAccounts := make([]*apimodel.Account, 0, len(accountIDs))  	for _, accountID := range accountIDs { -		account, err := p.db.GetAccountByID(ctx, accountID) +		account, err := p.state.DB.GetAccountByID(ctx, accountID)  		if err != nil {  			wrapped := fmt.Errorf("BoostedBy: error fetching account %s: %s", accountID, err)  			if !errors.Is(err, db.ErrNoEntries) { diff --git a/internal/processing/status/create.go b/internal/processing/status/create.go index f47c850dd..4e5399469 100644 --- a/internal/processing/status/create.go +++ b/internal/processing/status/create.go @@ -61,11 +61,11 @@ func (p *Processor) Create(ctx context.Context, account *gtsmodel.Account, appli  		Text:                     form.Status,  	} -	if errWithCode := processReplyToID(ctx, p.db, form, account.ID, newStatus); errWithCode != nil { +	if errWithCode := processReplyToID(ctx, p.state.DB, form, account.ID, newStatus); errWithCode != nil {  		return nil, errWithCode  	} -	if errWithCode := processMediaIDs(ctx, p.db, form, account.ID, newStatus); errWithCode != nil { +	if errWithCode := processMediaIDs(ctx, p.state.DB, form, account.ID, newStatus); errWithCode != nil {  		return nil, errWithCode  	} @@ -77,17 +77,17 @@ func (p *Processor) Create(ctx context.Context, account *gtsmodel.Account, appli  		return nil, gtserror.NewErrorInternalError(err)  	} -	if err := processContent(ctx, p.db, p.formatter, p.parseMention, form, account.ID, newStatus); err != nil { +	if err := processContent(ctx, p.state.DB, p.formatter, p.parseMention, form, account.ID, newStatus); err != nil {  		return nil, gtserror.NewErrorInternalError(err)  	}  	// put the new status in the database -	if err := p.db.PutStatus(ctx, newStatus); err != nil { +	if err := p.state.DB.PutStatus(ctx, newStatus); err != nil {  		return nil, gtserror.NewErrorInternalError(err)  	}  	// send it back to the processor for async processing -	p.clientWorker.Queue(messages.FromClientAPI{ +	p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{  		APObjectType:   ap.ObjectNote,  		APActivityType: ap.ActivityCreate,  		GTSModel:       newStatus, diff --git a/internal/processing/status/delete.go b/internal/processing/status/delete.go index d3a03aad6..0e9510e08 100644 --- a/internal/processing/status/delete.go +++ b/internal/processing/status/delete.go @@ -32,7 +32,7 @@ import (  // Delete processes the delete of a given status, returning the deleted status if the delete goes through.  func (p *Processor) Delete(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) (*apimodel.Status, gtserror.WithCode) { -	targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID) +	targetStatus, err := p.state.DB.GetStatusByID(ctx, targetStatusID)  	if err != nil {  		return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err))  	} @@ -50,7 +50,7 @@ func (p *Processor) Delete(ctx context.Context, requestingAccount *gtsmodel.Acco  	}  	// send the status back to the processor for async processing -	p.clientWorker.Queue(messages.FromClientAPI{ +	p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{  		APObjectType:   ap.ObjectNote,  		APActivityType: ap.ActivityDelete,  		GTSModel:       targetStatus, diff --git a/internal/processing/status/fave.go b/internal/processing/status/fave.go index 3bcb1835f..3025c720d 100644 --- a/internal/processing/status/fave.go +++ b/internal/processing/status/fave.go @@ -35,7 +35,7 @@ import (  // FaveCreate processes the faving of a given status, returning the updated status if the fave goes through.  func (p *Processor) FaveCreate(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) (*apimodel.Status, gtserror.WithCode) { -	targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID) +	targetStatus, err := p.state.DB.GetStatusByID(ctx, targetStatusID)  	if err != nil {  		return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err))  	} @@ -57,7 +57,7 @@ func (p *Processor) FaveCreate(ctx context.Context, requestingAccount *gtsmodel.  	// first check if the status is already faved, if so we don't need to do anything  	newFave := true  	gtsFave := >smodel.StatusFave{} -	if err := p.db.GetWhere(ctx, []db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsFave); err == nil { +	if err := p.state.DB.GetWhere(ctx, []db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsFave); err == nil {  		// we already have a fave for this status  		newFave = false  	} @@ -77,12 +77,12 @@ func (p *Processor) FaveCreate(ctx context.Context, requestingAccount *gtsmodel.  			URI:             uris.GenerateURIForLike(requestingAccount.Username, thisFaveID),  		} -		if err := p.db.Put(ctx, gtsFave); err != nil { +		if err := p.state.DB.Put(ctx, gtsFave); err != nil {  			return nil, gtserror.NewErrorInternalError(fmt.Errorf("error putting fave in database: %s", err))  		}  		// send it back to the processor for async processing -		p.clientWorker.Queue(messages.FromClientAPI{ +		p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{  			APObjectType:   ap.ActivityLike,  			APActivityType: ap.ActivityCreate,  			GTSModel:       gtsFave, @@ -102,7 +102,7 @@ func (p *Processor) FaveCreate(ctx context.Context, requestingAccount *gtsmodel.  // FaveRemove processes the unfaving of a given status, returning the updated status if the fave goes through.  func (p *Processor) FaveRemove(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) (*apimodel.Status, gtserror.WithCode) { -	targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID) +	targetStatus, err := p.state.DB.GetStatusByID(ctx, targetStatusID)  	if err != nil {  		return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err))  	} @@ -122,7 +122,7 @@ func (p *Processor) FaveRemove(ctx context.Context, requestingAccount *gtsmodel.  	var toUnfave bool  	gtsFave := >smodel.StatusFave{} -	err = p.db.GetWhere(ctx, []db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsFave) +	err = p.state.DB.GetWhere(ctx, []db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsFave)  	if err == nil {  		// we have a fave  		toUnfave = true @@ -138,12 +138,12 @@ func (p *Processor) FaveRemove(ctx context.Context, requestingAccount *gtsmodel.  	if toUnfave {  		// we had a fave, so take some action to get rid of it -		if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsFave); err != nil { +		if err := p.state.DB.DeleteWhere(ctx, []db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsFave); err != nil {  			return nil, gtserror.NewErrorInternalError(fmt.Errorf("error unfaveing status: %s", err))  		}  		// send it back to the processor for async processing -		p.clientWorker.Queue(messages.FromClientAPI{ +		p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{  			APObjectType:   ap.ActivityLike,  			APActivityType: ap.ActivityUndo,  			GTSModel:       gtsFave, @@ -162,7 +162,7 @@ func (p *Processor) FaveRemove(ctx context.Context, requestingAccount *gtsmodel.  // FavedBy returns a slice of accounts that have liked the given status, filtered according to privacy settings.  func (p *Processor) FavedBy(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) ([]*apimodel.Account, gtserror.WithCode) { -	targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID) +	targetStatus, err := p.state.DB.GetStatusByID(ctx, targetStatusID)  	if err != nil {  		return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err))  	} @@ -178,7 +178,7 @@ func (p *Processor) FavedBy(ctx context.Context, requestingAccount *gtsmodel.Acc  		return nil, gtserror.NewErrorNotFound(errors.New("status is not visible"))  	} -	statusFaves, err := p.db.GetStatusFaves(ctx, targetStatus) +	statusFaves, err := p.state.DB.GetStatusFaves(ctx, targetStatus)  	if err != nil {  		return nil, gtserror.NewErrorNotFound(fmt.Errorf("error seeing who faved status: %s", err))  	} @@ -186,7 +186,7 @@ func (p *Processor) FavedBy(ctx context.Context, requestingAccount *gtsmodel.Acc  	// filter the list so the user doesn't see accounts they blocked or which blocked them  	filteredAccounts := []*gtsmodel.Account{}  	for _, fave := range statusFaves { -		blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, fave.AccountID, true) +		blocked, err := p.state.DB.IsBlocked(ctx, requestingAccount.ID, fave.AccountID, true)  		if err != nil {  			return nil, gtserror.NewErrorInternalError(fmt.Errorf("error checking blocks: %s", err))  		} diff --git a/internal/processing/status/get.go b/internal/processing/status/get.go index edefeb440..51c384c44 100644 --- a/internal/processing/status/get.go +++ b/internal/processing/status/get.go @@ -31,7 +31,7 @@ import (  // Get gets the given status, taking account of privacy settings and blocks etc.  func (p *Processor) Get(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) (*apimodel.Status, gtserror.WithCode) { -	targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID) +	targetStatus, err := p.state.DB.GetStatusByID(ctx, targetStatusID)  	if err != nil {  		return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err))  	} @@ -57,7 +57,7 @@ func (p *Processor) Get(ctx context.Context, requestingAccount *gtsmodel.Account  // ContextGet returns the context (previous and following posts) from the given status ID.  func (p *Processor) ContextGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) (*apimodel.Context, gtserror.WithCode) { -	targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID) +	targetStatus, err := p.state.DB.GetStatusByID(ctx, targetStatusID)  	if err != nil {  		return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err))  	} @@ -78,7 +78,7 @@ func (p *Processor) ContextGet(ctx context.Context, requestingAccount *gtsmodel.  		Descendants: []apimodel.Status{},  	} -	parents, err := p.db.GetStatusParents(ctx, targetStatus, false) +	parents, err := p.state.DB.GetStatusParents(ctx, targetStatus, false)  	if err != nil {  		return nil, gtserror.NewErrorInternalError(err)  	} @@ -96,7 +96,7 @@ func (p *Processor) ContextGet(ctx context.Context, requestingAccount *gtsmodel.  		return context.Ancestors[i].ID < context.Ancestors[j].ID  	}) -	children, err := p.db.GetStatusChildren(ctx, targetStatus, false, "") +	children, err := p.state.DB.GetStatusChildren(ctx, targetStatus, false, "")  	if err != nil {  		return nil, gtserror.NewErrorInternalError(err)  	} diff --git a/internal/processing/status/pin.go b/internal/processing/status/pin.go index 3e50b0c73..6001a147f 100644 --- a/internal/processing/status/pin.go +++ b/internal/processing/status/pin.go @@ -39,7 +39,7 @@ const allowedPinnedCount = 10  //   - Status is public, unlisted, or followers-only.  //   - Status is not a boost.  func (p *Processor) getPinnableStatus(ctx context.Context, targetStatusID string, requestingAccountID string) (*gtsmodel.Status, gtserror.WithCode) { -	targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID) +	targetStatus, err := p.state.DB.GetStatusByID(ctx, targetStatusID)  	if err != nil {  		err = fmt.Errorf("error fetching status %s: %w", targetStatusID, err)  		return nil, gtserror.NewErrorNotFound(err) @@ -84,7 +84,7 @@ func (p *Processor) PinCreate(ctx context.Context, requestingAccount *gtsmodel.A  		return nil, gtserror.NewErrorUnprocessableEntity(err, err.Error())  	} -	pinnedCount, err := p.db.CountAccountPinned(ctx, requestingAccount.ID) +	pinnedCount, err := p.state.DB.CountAccountPinned(ctx, requestingAccount.ID)  	if err != nil {  		return nil, gtserror.NewErrorInternalError(fmt.Errorf("error checking number of pinned statuses: %w", err))  	} @@ -95,7 +95,7 @@ func (p *Processor) PinCreate(ctx context.Context, requestingAccount *gtsmodel.A  	}  	targetStatus.PinnedAt = time.Now() -	if err := p.db.UpdateStatus(ctx, targetStatus, "pinned_at"); err != nil { +	if err := p.state.DB.UpdateStatus(ctx, targetStatus, "pinned_at"); err != nil {  		return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error pinning status: %w", err))  	} @@ -126,7 +126,7 @@ func (p *Processor) PinRemove(ctx context.Context, requestingAccount *gtsmodel.A  	if targetStatus.PinnedAt.IsZero() {  		targetStatus.PinnedAt = time.Time{} -		if err := p.db.UpdateStatus(ctx, targetStatus, "pinned_at"); err != nil { +		if err := p.state.DB.UpdateStatus(ctx, targetStatus, "pinned_at"); err != nil {  			return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error unpinning status: %w", err))  		}  	} diff --git a/internal/processing/status/status.go b/internal/processing/status/status.go index c91fd85d1..909b06481 100644 --- a/internal/processing/status/status.go +++ b/internal/processing/status/status.go @@ -19,32 +19,28 @@  package status  import ( -	"github.com/superseriousbusiness/gotosocial/internal/concurrency" -	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" -	"github.com/superseriousbusiness/gotosocial/internal/messages" +	"github.com/superseriousbusiness/gotosocial/internal/state"  	"github.com/superseriousbusiness/gotosocial/internal/text"  	"github.com/superseriousbusiness/gotosocial/internal/typeutils"  	"github.com/superseriousbusiness/gotosocial/internal/visibility"  )  type Processor struct { +	state        *state.State  	tc           typeutils.TypeConverter -	db           db.DB  	filter       visibility.Filter  	formatter    text.Formatter -	clientWorker *concurrency.WorkerPool[messages.FromClientAPI]  	parseMention gtsmodel.ParseMentionFunc  }  // New returns a new status processor. -func New(db db.DB, tc typeutils.TypeConverter, clientWorker *concurrency.WorkerPool[messages.FromClientAPI], parseMention gtsmodel.ParseMentionFunc) Processor { +func New(state *state.State, tc typeutils.TypeConverter, parseMention gtsmodel.ParseMentionFunc) Processor {  	return Processor{ +		state:        state,  		tc:           tc, -		db:           db, -		filter:       visibility.NewFilter(db), -		formatter:    text.NewFormatter(db), -		clientWorker: clientWorker, +		filter:       visibility.NewFilter(state.DB), +		formatter:    text.NewFormatter(state.DB),  		parseMention: parseMention,  	}  } diff --git a/internal/processing/status/status_test.go b/internal/processing/status/status_test.go index 272d2c8ea..1b35b69db 100644 --- a/internal/processing/status/status_test.go +++ b/internal/processing/status/status_test.go @@ -19,17 +19,14 @@  package status_test  import ( -	"context" -  	"github.com/stretchr/testify/suite" -	"github.com/superseriousbusiness/gotosocial/internal/concurrency"  	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/federation"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"  	"github.com/superseriousbusiness/gotosocial/internal/media" -	"github.com/superseriousbusiness/gotosocial/internal/messages"  	"github.com/superseriousbusiness/gotosocial/internal/processing"  	"github.com/superseriousbusiness/gotosocial/internal/processing/status" +	"github.com/superseriousbusiness/gotosocial/internal/state"  	"github.com/superseriousbusiness/gotosocial/internal/storage"  	"github.com/superseriousbusiness/gotosocial/internal/transport"  	"github.com/superseriousbusiness/gotosocial/internal/typeutils" @@ -42,9 +39,9 @@ type StatusStandardTestSuite struct {  	typeConverter typeutils.TypeConverter  	tc            transport.Controller  	storage       *storage.Driver +	state         state.State  	mediaManager  media.Manager  	federator     federation.Federator -	clientWorker  *concurrency.WorkerPool[messages.FromClientAPI]  	// standard suite models  	testTokens       map[string]*gtsmodel.Token @@ -74,21 +71,22 @@ func (suite *StatusStandardTestSuite) SetupSuite() {  }  func (suite *StatusStandardTestSuite) SetupTest() { +	suite.state.Caches.Init() +	testrig.StartWorkers(&suite.state) +  	testrig.InitTestConfig()  	testrig.InitTestLog() -	fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) - -	suite.db = testrig.NewTestDB() +	suite.db = testrig.NewTestDB(&suite.state)  	suite.typeConverter = testrig.NewTestTypeConverter(suite.db) -	suite.clientWorker = concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) -	suite.tc = testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../testrig/media"), suite.db, fedWorker) +	suite.state.DB = suite.db + +	suite.tc = testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../testrig/media"))  	suite.storage = testrig.NewInMemoryStorage() -	suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage) -	suite.federator = testrig.NewTestFederator(suite.db, suite.tc, suite.storage, suite.mediaManager, fedWorker) -	suite.status = status.New(suite.db, suite.typeConverter, suite.clientWorker, processing.GetParseMentionFunc(suite.db, suite.federator)) -	suite.clientWorker.SetProcessor(func(ctx context.Context, msg messages.FromClientAPI) error { return nil }) -	suite.NoError(suite.clientWorker.Start()) +	suite.state.Storage = suite.storage +	suite.mediaManager = testrig.NewTestMediaManager(&suite.state) +	suite.federator = testrig.NewTestFederator(&suite.state, suite.tc, suite.mediaManager) +	suite.status = status.New(&suite.state, suite.typeConverter, processing.GetParseMentionFunc(suite.db, suite.federator))  	testrig.StandardDBSetup(suite.db, suite.testAccounts)  	testrig.StandardStorageSetup(suite.storage, "../../../testrig/media") @@ -97,4 +95,5 @@ func (suite *StatusStandardTestSuite) SetupTest() {  func (suite *StatusStandardTestSuite) TearDownTest() {  	testrig.StandardDBTeardown(suite.db)  	testrig.StandardStorageTeardown(suite.storage) +	testrig.StopWorkers(&suite.state)  } diff --git a/internal/processing/statustimeline.go b/internal/processing/statustimeline.go index 7c9f36f16..8c8e20316 100644 --- a/internal/processing/statustimeline.go +++ b/internal/processing/statustimeline.go @@ -173,7 +173,7 @@ func (p *Processor) HomeTimelineGet(ctx context.Context, authed *oauth.Auth, max  }  func (p *Processor) PublicTimelineGet(ctx context.Context, authed *oauth.Auth, maxID string, sinceID string, minID string, limit int, local bool) (*apimodel.PageableResponse, gtserror.WithCode) { -	statuses, err := p.db.GetPublicTimeline(ctx, maxID, sinceID, minID, limit, local) +	statuses, err := p.state.DB.GetPublicTimeline(ctx, maxID, sinceID, minID, limit, local)  	if err != nil {  		if err == db.ErrNoEntries {  			// there are just no entries left @@ -218,7 +218,7 @@ func (p *Processor) PublicTimelineGet(ctx context.Context, authed *oauth.Auth, m  }  func (p *Processor) FavedTimelineGet(ctx context.Context, authed *oauth.Auth, maxID string, minID string, limit int) (*apimodel.PageableResponse, gtserror.WithCode) { -	statuses, nextMaxID, prevMinID, err := p.db.GetFavedTimeline(ctx, authed.Account.ID, maxID, minID, limit) +	statuses, nextMaxID, prevMinID, err := p.state.DB.GetFavedTimeline(ctx, authed.Account.ID, maxID, minID, limit)  	if err != nil {  		if err == db.ErrNoEntries {  			// there are just no entries left @@ -255,7 +255,7 @@ func (p *Processor) filterPublicStatuses(ctx context.Context, authed *oauth.Auth  	apiStatuses := []*apimodel.Status{}  	for _, s := range statuses {  		targetAccount := >smodel.Account{} -		if err := p.db.GetByID(ctx, s.AccountID, targetAccount); err != nil { +		if err := p.state.DB.GetByID(ctx, s.AccountID, targetAccount); err != nil {  			if err == db.ErrNoEntries {  				log.Debugf(ctx, "skipping status %s because account %s can't be found in the db", s.ID, s.AccountID)  				continue @@ -288,7 +288,7 @@ func (p *Processor) filterFavedStatuses(ctx context.Context, authed *oauth.Auth,  	apiStatuses := []*apimodel.Status{}  	for _, s := range statuses {  		targetAccount := >smodel.Account{} -		if err := p.db.GetByID(ctx, s.AccountID, targetAccount); err != nil { +		if err := p.state.DB.GetByID(ctx, s.AccountID, targetAccount); err != nil {  			if err == db.ErrNoEntries {  				log.Debugf(ctx, "skipping status %s because account %s can't be found in the db", s.ID, s.AccountID)  				continue diff --git a/internal/processing/stream/authorize.go b/internal/processing/stream/authorize.go index 5f6811db9..a30e6fb33 100644 --- a/internal/processing/stream/authorize.go +++ b/internal/processing/stream/authorize.go @@ -41,7 +41,7 @@ func (p *Processor) Authorize(ctx context.Context, accessToken string) (*gtsmode  		return nil, gtserror.NewErrorUnauthorized(err)  	} -	user, err := p.db.GetUserByID(ctx, uid) +	user, err := p.state.DB.GetUserByID(ctx, uid)  	if err != nil {  		if err == db.ErrNoEntries {  			err := fmt.Errorf("no user found for validated uid %s", uid) @@ -50,7 +50,7 @@ func (p *Processor) Authorize(ctx context.Context, accessToken string) (*gtsmode  		return nil, gtserror.NewErrorInternalError(err)  	} -	acct, err := p.db.GetAccountByID(ctx, user.AccountID) +	acct, err := p.state.DB.GetAccountByID(ctx, user.AccountID)  	if err != nil {  		if err == db.ErrNoEntries {  			err := fmt.Errorf("no account found for validated uid %s", uid) diff --git a/internal/processing/stream/stream.go b/internal/processing/stream/stream.go index 3c38e720a..a10ab2474 100644 --- a/internal/processing/stream/stream.go +++ b/internal/processing/stream/stream.go @@ -22,22 +22,21 @@ import (  	"errors"  	"sync" -	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/oauth" +	"github.com/superseriousbusiness/gotosocial/internal/state"  	"github.com/superseriousbusiness/gotosocial/internal/stream"  )  type Processor struct { -	db          db.DB +	state       *state.State  	oauthServer oauth.Server -	streamMap   *sync.Map +	streamMap   sync.Map  } -func New(db db.DB, oauthServer oauth.Server) Processor { +func New(state *state.State, oauthServer oauth.Server) Processor {  	return Processor{ -		db:          db, +		state:       state,  		oauthServer: oauthServer, -		streamMap:   &sync.Map{},  	}  } diff --git a/internal/processing/stream/stream_test.go b/internal/processing/stream/stream_test.go index 907c7e1d0..9e1eb57f2 100644 --- a/internal/processing/stream/stream_test.go +++ b/internal/processing/stream/stream_test.go @@ -24,6 +24,7 @@ import (  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"  	"github.com/superseriousbusiness/gotosocial/internal/oauth"  	"github.com/superseriousbusiness/gotosocial/internal/processing/stream" +	"github.com/superseriousbusiness/gotosocial/internal/state"  	"github.com/superseriousbusiness/gotosocial/testrig"  ) @@ -33,19 +34,23 @@ type StreamTestSuite struct {  	testTokens   map[string]*gtsmodel.Token  	db           db.DB  	oauthServer  oauth.Server +	state        state.State  	streamProcessor stream.Processor  }  func (suite *StreamTestSuite) SetupTest() { +	suite.state.Caches.Init() +  	testrig.InitTestLog()  	testrig.InitTestConfig()  	suite.testAccounts = testrig.NewTestAccounts()  	suite.testTokens = testrig.NewTestTokens() -	suite.db = testrig.NewTestDB() +	suite.db = testrig.NewTestDB(&suite.state) +	suite.state.DB = suite.db  	suite.oauthServer = testrig.NewTestOauthServer(suite.db) -	suite.streamProcessor = stream.New(suite.db, suite.oauthServer) +	suite.streamProcessor = stream.New(&suite.state, suite.oauthServer)  	testrig.StandardDBSetup(suite.db, suite.testAccounts)  } diff --git a/internal/processing/user/email.go b/internal/processing/user/email.go index 349e27f47..c55488954 100644 --- a/internal/processing/user/email.go +++ b/internal/processing/user/email.go @@ -56,7 +56,7 @@ func (p *Processor) EmailSendConfirmation(ctx context.Context, user *gtsmodel.Us  	// pull our instance entry from the database so we can greet the user nicely in the email  	instance := >smodel.Instance{}  	host := config.GetHost() -	if err := p.db.GetWhere(ctx, []db.Where{{Key: "domain", Value: host}}, instance); err != nil { +	if err := p.state.DB.GetWhere(ctx, []db.Where{{Key: "domain", Value: host}}, instance); err != nil {  		return fmt.Errorf("SendConfirmEmail: error getting instance: %s", err)  	} @@ -78,7 +78,7 @@ func (p *Processor) EmailSendConfirmation(ctx context.Context, user *gtsmodel.Us  	user.LastEmailedAt = time.Now()  	user.UpdatedAt = time.Now() -	if err := p.db.UpdateByID(ctx, user, user.ID, updatingColumns...); err != nil { +	if err := p.state.DB.UpdateByID(ctx, user, user.ID, updatingColumns...); err != nil {  		return fmt.Errorf("SendConfirmEmail: error updating user entry after email sent: %s", err)  	} @@ -92,7 +92,7 @@ func (p *Processor) EmailConfirm(ctx context.Context, token string) (*gtsmodel.U  		return nil, gtserror.NewErrorNotFound(errors.New("no token provided"))  	} -	user, err := p.db.GetUserByConfirmationToken(ctx, token) +	user, err := p.state.DB.GetUserByConfirmationToken(ctx, token)  	if err != nil {  		if err == db.ErrNoEntries {  			return nil, gtserror.NewErrorNotFound(err) @@ -101,7 +101,7 @@ func (p *Processor) EmailConfirm(ctx context.Context, token string) (*gtsmodel.U  	}  	if user.Account == nil { -		a, err := p.db.GetAccountByID(ctx, user.AccountID) +		a, err := p.state.DB.GetAccountByID(ctx, user.AccountID)  		if err != nil {  			return nil, gtserror.NewErrorNotFound(err)  		} @@ -129,7 +129,7 @@ func (p *Processor) EmailConfirm(ctx context.Context, token string) (*gtsmodel.U  	user.ConfirmationToken = ""  	user.UpdatedAt = time.Now() -	if err := p.db.UpdateByID(ctx, user, user.ID, updatingColumns...); err != nil { +	if err := p.state.DB.UpdateByID(ctx, user, user.ID, updatingColumns...); err != nil {  		return nil, gtserror.NewErrorInternalError(err)  	} diff --git a/internal/processing/user/password.go b/internal/processing/user/password.go index 3475e005e..72ef5ffa7 100644 --- a/internal/processing/user/password.go +++ b/internal/processing/user/password.go @@ -44,7 +44,7 @@ func (p *Processor) PasswordChange(ctx context.Context, user *gtsmodel.User, old  	user.EncryptedPassword = string(newPasswordHash) -	if err := p.db.UpdateUser(ctx, user, "encrypted_password"); err != nil { +	if err := p.state.DB.UpdateUser(ctx, user, "encrypted_password"); err != nil {  		return gtserror.NewErrorInternalError(err)  	} diff --git a/internal/processing/user/user.go b/internal/processing/user/user.go index fce628d0c..4fda4c1f6 100644 --- a/internal/processing/user/user.go +++ b/internal/processing/user/user.go @@ -19,19 +19,19 @@  package user  import ( -	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/email" +	"github.com/superseriousbusiness/gotosocial/internal/state"  )  type Processor struct { +	state       *state.State  	emailSender email.Sender -	db          db.DB  }  // New returns a new user processor -func New(db db.DB, emailSender email.Sender) Processor { +func New(state *state.State, emailSender email.Sender) Processor {  	return Processor{ +		state:       state,  		emailSender: emailSender, -		db:          db,  	}  } diff --git a/internal/processing/user/user_test.go b/internal/processing/user/user_test.go index 83ab5892e..7379b568e 100644 --- a/internal/processing/user/user_test.go +++ b/internal/processing/user/user_test.go @@ -24,6 +24,7 @@ import (  	"github.com/superseriousbusiness/gotosocial/internal/email"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"  	"github.com/superseriousbusiness/gotosocial/internal/processing/user" +	"github.com/superseriousbusiness/gotosocial/internal/state"  	"github.com/superseriousbusiness/gotosocial/testrig"  ) @@ -31,6 +32,7 @@ type UserStandardTestSuite struct {  	suite.Suite  	emailSender email.Sender  	db          db.DB +	state       state.State  	testUsers map[string]*gtsmodel.User @@ -40,15 +42,19 @@ type UserStandardTestSuite struct {  }  func (suite *UserStandardTestSuite) SetupTest() { +	suite.state.Caches.Init() +  	testrig.InitTestConfig()  	testrig.InitTestLog() -	suite.db = testrig.NewTestDB() +	suite.db = testrig.NewTestDB(&suite.state) +	suite.state.DB = suite.db +  	suite.sentEmails = make(map[string]string)  	suite.emailSender = testrig.NewEmailSender("../../../web/template/", suite.sentEmails)  	suite.testUsers = testrig.NewTestUsers() -	suite.user = user.New(suite.db, suite.emailSender) +	suite.user = user.New(&suite.state, suite.emailSender)  	testrig.StandardDBSetup(suite.db, nil)  } diff --git a/internal/text/formatter_test.go b/internal/text/formatter_test.go index 32ae74488..304a538fc 100644 --- a/internal/text/formatter_test.go +++ b/internal/text/formatter_test.go @@ -20,12 +20,12 @@ package text_test  import (  	"context" +  	"github.com/stretchr/testify/suite" -	"github.com/superseriousbusiness/gotosocial/internal/concurrency"  	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" -	"github.com/superseriousbusiness/gotosocial/internal/messages"  	"github.com/superseriousbusiness/gotosocial/internal/processing" +	"github.com/superseriousbusiness/gotosocial/internal/state"  	"github.com/superseriousbusiness/gotosocial/internal/text"  	"github.com/superseriousbusiness/gotosocial/testrig"  ) @@ -66,13 +66,15 @@ func (suite *TextStandardTestSuite) SetupSuite() {  }  func (suite *TextStandardTestSuite) SetupTest() { +	var state state.State +	state.Caches.Init() +  	testrig.InitTestLog()  	testrig.InitTestConfig() -	suite.db = testrig.NewTestDB() +	suite.db = testrig.NewTestDB(&state) -	fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) -	federator := testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../testrig/media"), suite.db, fedWorker), nil, nil, fedWorker) +	federator := testrig.NewTestFederator(&state, testrig.NewTestTransportController(&state, testrig.NewMockHTTPClient(nil, "../../testrig/media")), nil)  	suite.parseMention = processing.GetParseMentionFunc(suite.db, federator)  	suite.formatter = text.NewFormatter(suite.db) diff --git a/internal/timeline/get_test.go b/internal/timeline/get_test.go index 9be1fdb90..0c866c7a8 100644 --- a/internal/timeline/get_test.go +++ b/internal/timeline/get_test.go @@ -27,6 +27,7 @@ import (  	"github.com/stretchr/testify/suite"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"  	"github.com/superseriousbusiness/gotosocial/internal/processing" +	"github.com/superseriousbusiness/gotosocial/internal/state"  	"github.com/superseriousbusiness/gotosocial/internal/timeline"  	"github.com/superseriousbusiness/gotosocial/internal/visibility"  	"github.com/superseriousbusiness/gotosocial/testrig" @@ -42,10 +43,13 @@ func (suite *GetTestSuite) SetupSuite() {  }  func (suite *GetTestSuite) SetupTest() { +	var state state.State +	state.Caches.Init() +  	testrig.InitTestLog()  	testrig.InitTestConfig() -	suite.db = testrig.NewTestDB() +	suite.db = testrig.NewTestDB(&state)  	suite.tc = testrig.NewTestTypeConverter(suite.db)  	suite.filter = visibility.NewFilter(suite.db) diff --git a/internal/timeline/index_test.go b/internal/timeline/index_test.go index 692688aba..9d79f12c2 100644 --- a/internal/timeline/index_test.go +++ b/internal/timeline/index_test.go @@ -26,6 +26,7 @@ import (  	"github.com/stretchr/testify/suite"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"  	"github.com/superseriousbusiness/gotosocial/internal/processing" +	"github.com/superseriousbusiness/gotosocial/internal/state"  	"github.com/superseriousbusiness/gotosocial/internal/timeline"  	"github.com/superseriousbusiness/gotosocial/internal/visibility"  	"github.com/superseriousbusiness/gotosocial/testrig" @@ -41,10 +42,13 @@ func (suite *IndexTestSuite) SetupSuite() {  }  func (suite *IndexTestSuite) SetupTest() { +	var state state.State +	state.Caches.Init() +  	testrig.InitTestLog()  	testrig.InitTestConfig() -	suite.db = testrig.NewTestDB() +	suite.db = testrig.NewTestDB(&state)  	suite.tc = testrig.NewTestTypeConverter(suite.db)  	suite.filter = visibility.NewFilter(suite.db) diff --git a/internal/timeline/manager_test.go b/internal/timeline/manager_test.go index 03804bf78..e033ffda4 100644 --- a/internal/timeline/manager_test.go +++ b/internal/timeline/manager_test.go @@ -24,6 +24,7 @@ import (  	"github.com/stretchr/testify/suite"  	"github.com/superseriousbusiness/gotosocial/internal/processing" +	"github.com/superseriousbusiness/gotosocial/internal/state"  	"github.com/superseriousbusiness/gotosocial/internal/timeline"  	"github.com/superseriousbusiness/gotosocial/internal/visibility"  	"github.com/superseriousbusiness/gotosocial/testrig" @@ -39,10 +40,13 @@ func (suite *ManagerTestSuite) SetupSuite() {  }  func (suite *ManagerTestSuite) SetupTest() { +	var state state.State +	state.Caches.Init() +  	testrig.InitTestLog()  	testrig.InitTestConfig() -	suite.db = testrig.NewTestDB() +	suite.db = testrig.NewTestDB(&state)  	suite.tc = testrig.NewTestTypeConverter(suite.db)  	suite.filter = visibility.NewFilter(suite.db) diff --git a/internal/timeline/prune_test.go b/internal/timeline/prune_test.go index 9d539e0e0..48bba41dc 100644 --- a/internal/timeline/prune_test.go +++ b/internal/timeline/prune_test.go @@ -26,6 +26,7 @@ import (  	"github.com/stretchr/testify/suite"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"  	"github.com/superseriousbusiness/gotosocial/internal/processing" +	"github.com/superseriousbusiness/gotosocial/internal/state"  	"github.com/superseriousbusiness/gotosocial/internal/timeline"  	"github.com/superseriousbusiness/gotosocial/internal/visibility"  	"github.com/superseriousbusiness/gotosocial/testrig" @@ -41,10 +42,13 @@ func (suite *PruneTestSuite) SetupSuite() {  }  func (suite *PruneTestSuite) SetupTest() { +	var state state.State +	state.Caches.Init() +  	testrig.InitTestLog()  	testrig.InitTestConfig() -	suite.db = testrig.NewTestDB() +	suite.db = testrig.NewTestDB(&state)  	suite.tc = testrig.NewTestTypeConverter(suite.db)  	suite.filter = visibility.NewFilter(suite.db) diff --git a/internal/trans/import_test.go b/internal/trans/import_test.go index 128ac58a3..a53305c79 100644 --- a/internal/trans/import_test.go +++ b/internal/trans/import_test.go @@ -27,6 +27,7 @@ import (  	"github.com/google/uuid"  	"github.com/stretchr/testify/suite"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +	"github.com/superseriousbusiness/gotosocial/internal/state"  	"github.com/superseriousbusiness/gotosocial/internal/trans"  	"github.com/superseriousbusiness/gotosocial/testrig"  ) @@ -57,8 +58,11 @@ func (suite *ImportMinimalTestSuite) TestImportMinimalOK() {  	suite.NotEmpty(b)  	fmt.Println(string(b)) +	var state state.State +	state.Caches.Init() +  	// create a new database with just the tables created, no entries -	newDB := testrig.NewTestDB() +	newDB := testrig.NewTestDB(&state)  	importer := trans.NewImporter(newDB)  	err = importer.Import(ctx, tempFilePath) diff --git a/internal/trans/trans_test.go b/internal/trans/trans_test.go index 9364891a0..2b6bbb57b 100644 --- a/internal/trans/trans_test.go +++ b/internal/trans/trans_test.go @@ -22,6 +22,7 @@ import (  	"github.com/stretchr/testify/suite"  	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +	"github.com/superseriousbusiness/gotosocial/internal/state"  	"github.com/superseriousbusiness/gotosocial/testrig"  ) @@ -32,12 +33,15 @@ type TransTestSuite struct {  }  func (suite *TransTestSuite) SetupTest() { +	var state state.State +	state.Caches.Init() +  	testrig.InitTestConfig()  	testrig.InitTestLog()  	suite.testAccounts = testrig.NewTestAccounts() -	suite.db = testrig.NewTestDB() +	suite.db = testrig.NewTestDB(&state)  	testrig.StandardDBSetup(suite.db, nil)  } diff --git a/internal/typeutils/converter_test.go b/internal/typeutils/converter_test.go index c6f3c2579..bc81a7c6d 100644 --- a/internal/typeutils/converter_test.go +++ b/internal/typeutils/converter_test.go @@ -23,6 +23,7 @@ import (  	"github.com/superseriousbusiness/activity/streams/vocab"  	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +	"github.com/superseriousbusiness/gotosocial/internal/state"  	"github.com/superseriousbusiness/gotosocial/internal/typeutils"  	"github.com/superseriousbusiness/gotosocial/testrig"  ) @@ -481,10 +482,13 @@ type TypeUtilsTestSuite struct {  }  func (suite *TypeUtilsTestSuite) SetupSuite() { +	var state state.State +	state.Caches.Init() +  	testrig.InitTestConfig()  	testrig.InitTestLog() -	suite.db = testrig.NewTestDB() +	suite.db = testrig.NewTestDB(&state)  	suite.testAccounts = testrig.NewTestAccounts()  	suite.testStatuses = testrig.NewTestStatuses()  	suite.testAttachments = testrig.NewTestAttachments() diff --git a/internal/visibility/filter_test.go b/internal/visibility/filter_test.go index bd7a8671e..9697dd72c 100644 --- a/internal/visibility/filter_test.go +++ b/internal/visibility/filter_test.go @@ -22,6 +22,7 @@ import (  	"github.com/stretchr/testify/suite"  	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +	"github.com/superseriousbusiness/gotosocial/internal/state"  	"github.com/superseriousbusiness/gotosocial/internal/visibility"  	"github.com/superseriousbusiness/gotosocial/testrig"  ) @@ -60,10 +61,13 @@ func (suite *FilterStandardTestSuite) SetupSuite() {  }  func (suite *FilterStandardTestSuite) SetupTest() { +	var state state.State +	state.Caches.Init() +  	testrig.InitTestConfig()  	testrig.InitTestLog() -	suite.db = testrig.NewTestDB() +	suite.db = testrig.NewTestDB(&state)  	suite.filter = visibility.NewFilter(suite.db)  	testrig.StandardDBSetup(suite.db, nil) diff --git a/internal/workers/workers.go b/internal/workers/workers.go index 77b3065ce..b29d115aa 100644 --- a/internal/workers/workers.go +++ b/internal/workers/workers.go @@ -19,20 +19,28 @@ along with this program.  If not, see <http://www.gnu.org/licenses/>.  package workers  import ( +	"context"  	"log"  	"runtime"  	"codeberg.org/gruf/go-runners"  	"codeberg.org/gruf/go-sched" +	"github.com/superseriousbusiness/gotosocial/internal/messages"  )  type Workers struct {  	// Main task scheduler instance.  	Scheduler sched.Scheduler -	// Processor / federator worker pools. -	// ClientAPI runners.WorkerPool -	// Federator runners.WorkerPool +	// ClientAPI / federator worker pools. +	ClientAPI runners.WorkerPool +	Federator runners.WorkerPool + +	// Enqueue functions for clientAPI / federator worker pools, +	// these are pointers to Processor{}.Enqueue___() msg functions. +	// This prevents dependency cycling as Processor depends on Workers. +	EnqueueClientAPI func(context.Context, messages.FromClientAPI) +	EnqueueFederator func(context.Context, messages.FromFederator)  	// Media manager worker pools.  	Media runners.WorkerPool @@ -50,13 +58,13 @@ func (w *Workers) Start() {  		return w.Scheduler.Start(nil)  	}) -	// tryUntil("starting client API workerpool", 5, func() bool { -	// 	return w.ClientAPI.Start(4*maxprocs, 400*maxprocs) -	// }) +	tryUntil("starting client API workerpool", 5, func() bool { +		return w.ClientAPI.Start(4*maxprocs, 400*maxprocs) +	}) -	// tryUntil("starting federator workerpool", 5, func() bool { -	// 	return w.Federator.Start(4*maxprocs, 400*maxprocs) -	// }) +	tryUntil("starting federator workerpool", 5, func() bool { +		return w.Federator.Start(4*maxprocs, 400*maxprocs) +	})  	tryUntil("starting media workerpool", 5, func() bool {  		return w.Media.Start(8*maxprocs, 80*maxprocs) @@ -66,8 +74,8 @@ func (w *Workers) Start() {  // Stop will stop all of the contained worker pools (and global scheduler).  func (w *Workers) Stop() {  	tryUntil("stopping scheduler", 5, w.Scheduler.Stop) -	// tryUntil("stopping client API workerpool", 5, w.ClientAPI.Stop) -	// tryUntil("stopping federator workerpool", 5, w.Federator.Stop) +	tryUntil("stopping client API workerpool", 5, w.ClientAPI.Stop) +	tryUntil("stopping federator workerpool", 5, w.Federator.Stop)  	tryUntil("stopping media workerpool", 5, w.Media.Stop)  } diff --git a/testrig/db.go b/testrig/db.go index 8479347eb..1a29aa8b9 100644 --- a/testrig/db.go +++ b/testrig/db.go @@ -71,7 +71,7 @@ var testModels = []interface{}{  //  // If the environment variable GTS_DB_PORT is set, it will take that  // value as the port instead. -func NewTestDB() db.DB { +func NewTestDB(state *state.State) db.DB {  	if alternateAddress := os.Getenv("GTS_DB_ADDRESS"); alternateAddress != "" {  		config.SetDbAddress(alternateAddress)  	} @@ -88,10 +88,9 @@ func NewTestDB() db.DB {  		config.SetDbPort(int(port))  	} -	var state state.State  	state.Caches.Init() -	testDB, err := bundb.NewBunDBService(context.Background(), &state) +	testDB, err := bundb.NewBunDBService(context.Background(), state)  	if err != nil {  		log.Panic(nil, err)  	} diff --git a/testrig/federatingdb.go b/testrig/federatingdb.go index 9b1f1961e..27adc4c51 100644 --- a/testrig/federatingdb.go +++ b/testrig/federatingdb.go @@ -19,13 +19,11 @@  package testrig  import ( -	"github.com/superseriousbusiness/gotosocial/internal/concurrency" -	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/federation/federatingdb" -	"github.com/superseriousbusiness/gotosocial/internal/messages" +	"github.com/superseriousbusiness/gotosocial/internal/state"  )  // NewTestFederatingDB returns a federating DB with the underlying db -func NewTestFederatingDB(db db.DB, fedWorker *concurrency.WorkerPool[messages.FromFederator]) federatingdb.DB { -	return federatingdb.New(db, fedWorker, NewTestTypeConverter(db)) +func NewTestFederatingDB(state *state.State) federatingdb.DB { +	return federatingdb.New(state, NewTestTypeConverter(state.DB))  } diff --git a/testrig/federator.go b/testrig/federator.go index 605a2c8f3..bc150633e 100644 --- a/testrig/federator.go +++ b/testrig/federator.go @@ -19,16 +19,13 @@  package testrig  import ( -	"github.com/superseriousbusiness/gotosocial/internal/concurrency" -	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/federation"  	"github.com/superseriousbusiness/gotosocial/internal/media" -	"github.com/superseriousbusiness/gotosocial/internal/messages" -	"github.com/superseriousbusiness/gotosocial/internal/storage" +	"github.com/superseriousbusiness/gotosocial/internal/state"  	"github.com/superseriousbusiness/gotosocial/internal/transport"  )  // NewTestFederator returns a federator with the given database and (mock!!) transport controller. -func NewTestFederator(db db.DB, tc transport.Controller, storage *storage.Driver, mediaManager media.Manager, fedWorker *concurrency.WorkerPool[messages.FromFederator]) federation.Federator { -	return federation.NewFederator(db, NewTestFederatingDB(db, fedWorker), tc, NewTestTypeConverter(db), mediaManager) +func NewTestFederator(state *state.State, tc transport.Controller, mediaManager media.Manager) federation.Federator { +	return federation.NewFederator(state.DB, NewTestFederatingDB(state), tc, NewTestTypeConverter(state.DB), mediaManager)  } diff --git a/testrig/mediahandler.go b/testrig/mediahandler.go index a1863218c..b4b992b0b 100644 --- a/testrig/mediahandler.go +++ b/testrig/mediahandler.go @@ -19,17 +19,12 @@  package testrig  import ( -	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/media"  	"github.com/superseriousbusiness/gotosocial/internal/state" -	"github.com/superseriousbusiness/gotosocial/internal/storage"  )  // NewTestMediaManager returns a media handler with the default test config, and the given db and storage. -func NewTestMediaManager(db db.DB, storage *storage.Driver) media.Manager { -	var state state.State -	state.DB = db -	state.Storage = storage -	state.Workers.Start() -	return media.NewManager(&state) +func NewTestMediaManager(state *state.State) media.Manager { +	StartWorkers(state) // ensure started +	return media.NewManager(state)  } diff --git a/testrig/processor.go b/testrig/processor.go index f451d4ad0..856ee523d 100644 --- a/testrig/processor.go +++ b/testrig/processor.go @@ -19,17 +19,17 @@  package testrig  import ( -	"github.com/superseriousbusiness/gotosocial/internal/concurrency" -	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/email"  	"github.com/superseriousbusiness/gotosocial/internal/federation"  	"github.com/superseriousbusiness/gotosocial/internal/media" -	"github.com/superseriousbusiness/gotosocial/internal/messages"  	"github.com/superseriousbusiness/gotosocial/internal/processing" -	"github.com/superseriousbusiness/gotosocial/internal/storage" +	"github.com/superseriousbusiness/gotosocial/internal/state"  )  // NewTestProcessor returns a Processor suitable for testing purposes -func NewTestProcessor(db db.DB, storage *storage.Driver, federator federation.Federator, emailSender email.Sender, mediaManager media.Manager, clientWorker *concurrency.WorkerPool[messages.FromClientAPI], fedWorker *concurrency.WorkerPool[messages.FromFederator]) *processing.Processor { -	return processing.NewProcessor(NewTestTypeConverter(db), federator, NewTestOauthServer(db), mediaManager, storage, db, emailSender, clientWorker, fedWorker) +func NewTestProcessor(state *state.State, federator federation.Federator, emailSender email.Sender, mediaManager media.Manager) *processing.Processor { +	p := processing.NewProcessor(NewTestTypeConverter(state.DB), federator, NewTestOauthServer(state.DB), mediaManager, state, emailSender) +	state.Workers.EnqueueClientAPI = p.EnqueueClientAPI +	state.Workers.EnqueueFederator = p.EnqueueFederator +	return p  } diff --git a/testrig/transportcontroller.go b/testrig/transportcontroller.go index 7565a741c..9657205f6 100644 --- a/testrig/transportcontroller.go +++ b/testrig/transportcontroller.go @@ -30,12 +30,10 @@ import (  	"github.com/superseriousbusiness/activity/streams"  	"github.com/superseriousbusiness/activity/streams/vocab"  	apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" -	"github.com/superseriousbusiness/gotosocial/internal/concurrency" -	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/federation"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"  	"github.com/superseriousbusiness/gotosocial/internal/log" -	"github.com/superseriousbusiness/gotosocial/internal/messages" +	"github.com/superseriousbusiness/gotosocial/internal/state"  	"github.com/superseriousbusiness/gotosocial/internal/transport"  ) @@ -53,8 +51,8 @@ const (  // Unlike the other test interfaces provided in this package, you'll probably want to call this function  // PER TEST rather than per suite, so that the do function can be set on a test by test (or even more granular)  // basis. -func NewTestTransportController(client pub.HttpClient, db db.DB, fedWorker *concurrency.WorkerPool[messages.FromFederator]) transport.Controller { -	return transport.NewController(db, NewTestFederatingDB(db, fedWorker), &federation.Clock{}, client) +func NewTestTransportController(state *state.State, client pub.HttpClient) transport.Controller { +	return transport.NewController(state.DB, NewTestFederatingDB(state), &federation.Clock{}, client)  }  type MockHTTPClient struct { diff --git a/testrig/util.go b/testrig/util.go index cc392b315..0cda93024 100644 --- a/testrig/util.go +++ b/testrig/util.go @@ -20,13 +20,34 @@ package testrig  import (  	"bytes" +	"context"  	"io"  	"mime/multipart"  	"net/url"  	"os"  	"time" + +	"github.com/superseriousbusiness/gotosocial/internal/messages" +	"github.com/superseriousbusiness/gotosocial/internal/state"  ) +func StartWorkers(state *state.State) { +	state.Workers.EnqueueClientAPI = func(context.Context, messages.FromClientAPI) {} +	state.Workers.EnqueueFederator = func(context.Context, messages.FromFederator) {} + +	_ = state.Workers.Scheduler.Start(nil) +	_ = state.Workers.ClientAPI.Start(1, 10) +	_ = state.Workers.Federator.Start(1, 10) +	_ = state.Workers.Media.Start(1, 10) +} + +func StopWorkers(state *state.State) { +	_ = state.Workers.Scheduler.Stop() +	_ = state.Workers.ClientAPI.Stop() +	_ = state.Workers.Federator.Stop() +	_ = state.Workers.Media.Stop() +} +  // CreateMultipartFormData is a handy function for taking a fieldname and a filename, and creating a multipart form bytes buffer  // with the file contents set in the given fieldname. The extraFields param can be used to add extra FormFields to the request, as necessary.  // The returned bytes.Buffer b can be used like so: | 
