diff options
Diffstat (limited to 'internal/processing/status')
-rw-r--r-- | internal/processing/status/bookmark.go | 12 | ||||
-rw-r--r-- | internal/processing/status/boost.go | 24 | ||||
-rw-r--r-- | internal/processing/status/create.go | 10 | ||||
-rw-r--r-- | internal/processing/status/delete.go | 4 | ||||
-rw-r--r-- | internal/processing/status/fave.go | 22 | ||||
-rw-r--r-- | internal/processing/status/get.go | 8 | ||||
-rw-r--r-- | internal/processing/status/pin.go | 8 | ||||
-rw-r--r-- | internal/processing/status/status.go | 16 | ||||
-rw-r--r-- | internal/processing/status/status_test.go | 29 |
9 files changed, 64 insertions, 69 deletions
diff --git a/internal/processing/status/bookmark.go b/internal/processing/status/bookmark.go index dde31ea7d..cf3787da2 100644 --- a/internal/processing/status/bookmark.go +++ b/internal/processing/status/bookmark.go @@ -32,7 +32,7 @@ import ( // BookmarkCreate adds a bookmark for the requestingAccount, targeting the given status (no-op if bookmark already exists). func (p *Processor) BookmarkCreate(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) (*apimodel.Status, gtserror.WithCode) { - targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID) + targetStatus, err := p.state.DB.GetStatusByID(ctx, targetStatusID) if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err)) } @@ -50,7 +50,7 @@ func (p *Processor) BookmarkCreate(ctx context.Context, requestingAccount *gtsmo // first check if the status is already bookmarked, if so we don't need to do anything newBookmark := true gtsBookmark := >smodel.StatusBookmark{} - if err := p.db.GetWhere(ctx, []db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsBookmark); err == nil { + if err := p.state.DB.GetWhere(ctx, []db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsBookmark); err == nil { // we already have a bookmark for this status newBookmark = false } @@ -67,7 +67,7 @@ func (p *Processor) BookmarkCreate(ctx context.Context, requestingAccount *gtsmo Status: targetStatus, } - if err := p.db.Put(ctx, gtsBookmark); err != nil { + if err := p.state.DB.Put(ctx, gtsBookmark); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("error putting bookmark in database: %s", err)) } } @@ -83,7 +83,7 @@ func (p *Processor) BookmarkCreate(ctx context.Context, requestingAccount *gtsmo // BookmarkRemove removes a bookmark for the requesting account, targeting the given status (no-op if bookmark doesn't exist). func (p *Processor) BookmarkRemove(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) (*apimodel.Status, gtserror.WithCode) { - targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID) + targetStatus, err := p.state.DB.GetStatusByID(ctx, targetStatusID) if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err)) } @@ -101,13 +101,13 @@ func (p *Processor) BookmarkRemove(ctx context.Context, requestingAccount *gtsmo // first check if the status is actually bookmarked toUnbookmark := false gtsBookmark := >smodel.StatusBookmark{} - if err := p.db.GetWhere(ctx, []db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsBookmark); err == nil { + if err := p.state.DB.GetWhere(ctx, []db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsBookmark); err == nil { // we have a bookmark for this status toUnbookmark = true } if toUnbookmark { - if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsBookmark); err != nil { + if err := p.state.DB.DeleteWhere(ctx, []db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsBookmark); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("error unfaveing status: %s", err)) } } diff --git a/internal/processing/status/boost.go b/internal/processing/status/boost.go index 4dfe17019..6756d816c 100644 --- a/internal/processing/status/boost.go +++ b/internal/processing/status/boost.go @@ -33,7 +33,7 @@ import ( // BoostCreate processes the boost/reblog of a given status, returning the newly-created boost if all is well. func (p *Processor) BoostCreate(ctx context.Context, requestingAccount *gtsmodel.Account, application *gtsmodel.Application, targetStatusID string) (*apimodel.Status, gtserror.WithCode) { - targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID) + targetStatus, err := p.state.DB.GetStatusByID(ctx, targetStatusID) if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err)) } @@ -47,7 +47,7 @@ func (p *Processor) BoostCreate(ctx context.Context, requestingAccount *gtsmodel // boost boosts, and it looks absolutely bizarre in the UI if targetStatus.BoostOfID != "" { if targetStatus.BoostOf == nil { - b, err := p.db.GetStatusByID(ctx, targetStatus.BoostOfID) + b, err := p.state.DB.GetStatusByID(ctx, targetStatus.BoostOfID) if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("couldn't fetch boosted status %s", targetStatus.BoostOfID)) } @@ -74,12 +74,12 @@ func (p *Processor) BoostCreate(ctx context.Context, requestingAccount *gtsmodel boostWrapperStatus.BoostOfAccount = targetStatus.Account // put the boost in the database - if err := p.db.PutStatus(ctx, boostWrapperStatus); err != nil { + if err := p.state.DB.PutStatus(ctx, boostWrapperStatus); err != nil { return nil, gtserror.NewErrorInternalError(err) } // send it back to the processor for async processing - p.clientWorker.Queue(messages.FromClientAPI{ + p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{ APObjectType: ap.ActivityAnnounce, APActivityType: ap.ActivityCreate, GTSModel: boostWrapperStatus, @@ -98,7 +98,7 @@ func (p *Processor) BoostCreate(ctx context.Context, requestingAccount *gtsmodel // BoostRemove processes the unboost/unreblog of a given status, returning the status if all is well. func (p *Processor) BoostRemove(ctx context.Context, requestingAccount *gtsmodel.Account, application *gtsmodel.Application, targetStatusID string) (*apimodel.Status, gtserror.WithCode) { - targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID) + targetStatus, err := p.state.DB.GetStatusByID(ctx, targetStatusID) if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err)) } @@ -128,7 +128,7 @@ func (p *Processor) BoostRemove(ctx context.Context, requestingAccount *gtsmodel Value: requestingAccount.ID, }, } - err = p.db.GetWhere(ctx, where, gtsBoost) + err = p.state.DB.GetWhere(ctx, where, gtsBoost) if err == nil { // we have a boost toUnboost = true @@ -151,7 +151,7 @@ func (p *Processor) BoostRemove(ctx context.Context, requestingAccount *gtsmodel gtsBoost.BoostOf.Account = targetStatus.Account // send it back to the processor for async processing - p.clientWorker.Queue(messages.FromClientAPI{ + p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{ APObjectType: ap.ActivityAnnounce, APActivityType: ap.ActivityUndo, GTSModel: gtsBoost, @@ -170,7 +170,7 @@ func (p *Processor) BoostRemove(ctx context.Context, requestingAccount *gtsmodel // StatusBoostedBy returns a slice of accounts that have boosted the given status, filtered according to privacy settings. func (p *Processor) StatusBoostedBy(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) ([]*apimodel.Account, gtserror.WithCode) { - targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID) + targetStatus, err := p.state.DB.GetStatusByID(ctx, targetStatusID) if err != nil { wrapped := fmt.Errorf("BoostedBy: error fetching status %s: %s", targetStatusID, err) if !errors.Is(err, db.ErrNoEntries) { @@ -181,7 +181,7 @@ func (p *Processor) StatusBoostedBy(ctx context.Context, requestingAccount *gtsm if boostOfID := targetStatus.BoostOfID; boostOfID != "" { // the target status is a boost wrapper, redirect this request to the status it boosts - boostedStatus, err := p.db.GetStatusByID(ctx, boostOfID) + boostedStatus, err := p.state.DB.GetStatusByID(ctx, boostOfID) if err != nil { wrapped := fmt.Errorf("BoostedBy: error fetching status %s: %s", boostOfID, err) if !errors.Is(err, db.ErrNoEntries) { @@ -202,7 +202,7 @@ func (p *Processor) StatusBoostedBy(ctx context.Context, requestingAccount *gtsm return nil, gtserror.NewErrorNotFound(err) } - statusReblogs, err := p.db.GetStatusReblogs(ctx, targetStatus) + statusReblogs, err := p.state.DB.GetStatusReblogs(ctx, targetStatus) if err != nil { err = fmt.Errorf("BoostedBy: error seeing who boosted status: %s", err) return nil, gtserror.NewErrorNotFound(err) @@ -211,7 +211,7 @@ func (p *Processor) StatusBoostedBy(ctx context.Context, requestingAccount *gtsm // filter account IDs so the user doesn't see accounts they blocked or which blocked them accountIDs := make([]string, 0, len(statusReblogs)) for _, s := range statusReblogs { - blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, s.AccountID, true) + blocked, err := p.state.DB.IsBlocked(ctx, requestingAccount.ID, s.AccountID, true) if err != nil { err = fmt.Errorf("BoostedBy: error checking blocks: %s", err) return nil, gtserror.NewErrorNotFound(err) @@ -226,7 +226,7 @@ func (p *Processor) StatusBoostedBy(ctx context.Context, requestingAccount *gtsm // fetch accounts + create their API representations apiAccounts := make([]*apimodel.Account, 0, len(accountIDs)) for _, accountID := range accountIDs { - account, err := p.db.GetAccountByID(ctx, accountID) + account, err := p.state.DB.GetAccountByID(ctx, accountID) if err != nil { wrapped := fmt.Errorf("BoostedBy: error fetching account %s: %s", accountID, err) if !errors.Is(err, db.ErrNoEntries) { diff --git a/internal/processing/status/create.go b/internal/processing/status/create.go index f47c850dd..4e5399469 100644 --- a/internal/processing/status/create.go +++ b/internal/processing/status/create.go @@ -61,11 +61,11 @@ func (p *Processor) Create(ctx context.Context, account *gtsmodel.Account, appli Text: form.Status, } - if errWithCode := processReplyToID(ctx, p.db, form, account.ID, newStatus); errWithCode != nil { + if errWithCode := processReplyToID(ctx, p.state.DB, form, account.ID, newStatus); errWithCode != nil { return nil, errWithCode } - if errWithCode := processMediaIDs(ctx, p.db, form, account.ID, newStatus); errWithCode != nil { + if errWithCode := processMediaIDs(ctx, p.state.DB, form, account.ID, newStatus); errWithCode != nil { return nil, errWithCode } @@ -77,17 +77,17 @@ func (p *Processor) Create(ctx context.Context, account *gtsmodel.Account, appli return nil, gtserror.NewErrorInternalError(err) } - if err := processContent(ctx, p.db, p.formatter, p.parseMention, form, account.ID, newStatus); err != nil { + if err := processContent(ctx, p.state.DB, p.formatter, p.parseMention, form, account.ID, newStatus); err != nil { return nil, gtserror.NewErrorInternalError(err) } // put the new status in the database - if err := p.db.PutStatus(ctx, newStatus); err != nil { + if err := p.state.DB.PutStatus(ctx, newStatus); err != nil { return nil, gtserror.NewErrorInternalError(err) } // send it back to the processor for async processing - p.clientWorker.Queue(messages.FromClientAPI{ + p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{ APObjectType: ap.ObjectNote, APActivityType: ap.ActivityCreate, GTSModel: newStatus, diff --git a/internal/processing/status/delete.go b/internal/processing/status/delete.go index d3a03aad6..0e9510e08 100644 --- a/internal/processing/status/delete.go +++ b/internal/processing/status/delete.go @@ -32,7 +32,7 @@ import ( // Delete processes the delete of a given status, returning the deleted status if the delete goes through. func (p *Processor) Delete(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) (*apimodel.Status, gtserror.WithCode) { - targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID) + targetStatus, err := p.state.DB.GetStatusByID(ctx, targetStatusID) if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err)) } @@ -50,7 +50,7 @@ func (p *Processor) Delete(ctx context.Context, requestingAccount *gtsmodel.Acco } // send the status back to the processor for async processing - p.clientWorker.Queue(messages.FromClientAPI{ + p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{ APObjectType: ap.ObjectNote, APActivityType: ap.ActivityDelete, GTSModel: targetStatus, diff --git a/internal/processing/status/fave.go b/internal/processing/status/fave.go index 3bcb1835f..3025c720d 100644 --- a/internal/processing/status/fave.go +++ b/internal/processing/status/fave.go @@ -35,7 +35,7 @@ import ( // FaveCreate processes the faving of a given status, returning the updated status if the fave goes through. func (p *Processor) FaveCreate(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) (*apimodel.Status, gtserror.WithCode) { - targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID) + targetStatus, err := p.state.DB.GetStatusByID(ctx, targetStatusID) if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err)) } @@ -57,7 +57,7 @@ func (p *Processor) FaveCreate(ctx context.Context, requestingAccount *gtsmodel. // first check if the status is already faved, if so we don't need to do anything newFave := true gtsFave := >smodel.StatusFave{} - if err := p.db.GetWhere(ctx, []db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsFave); err == nil { + if err := p.state.DB.GetWhere(ctx, []db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsFave); err == nil { // we already have a fave for this status newFave = false } @@ -77,12 +77,12 @@ func (p *Processor) FaveCreate(ctx context.Context, requestingAccount *gtsmodel. URI: uris.GenerateURIForLike(requestingAccount.Username, thisFaveID), } - if err := p.db.Put(ctx, gtsFave); err != nil { + if err := p.state.DB.Put(ctx, gtsFave); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("error putting fave in database: %s", err)) } // send it back to the processor for async processing - p.clientWorker.Queue(messages.FromClientAPI{ + p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{ APObjectType: ap.ActivityLike, APActivityType: ap.ActivityCreate, GTSModel: gtsFave, @@ -102,7 +102,7 @@ func (p *Processor) FaveCreate(ctx context.Context, requestingAccount *gtsmodel. // FaveRemove processes the unfaving of a given status, returning the updated status if the fave goes through. func (p *Processor) FaveRemove(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) (*apimodel.Status, gtserror.WithCode) { - targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID) + targetStatus, err := p.state.DB.GetStatusByID(ctx, targetStatusID) if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err)) } @@ -122,7 +122,7 @@ func (p *Processor) FaveRemove(ctx context.Context, requestingAccount *gtsmodel. var toUnfave bool gtsFave := >smodel.StatusFave{} - err = p.db.GetWhere(ctx, []db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsFave) + err = p.state.DB.GetWhere(ctx, []db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsFave) if err == nil { // we have a fave toUnfave = true @@ -138,12 +138,12 @@ func (p *Processor) FaveRemove(ctx context.Context, requestingAccount *gtsmodel. if toUnfave { // we had a fave, so take some action to get rid of it - if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsFave); err != nil { + if err := p.state.DB.DeleteWhere(ctx, []db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsFave); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("error unfaveing status: %s", err)) } // send it back to the processor for async processing - p.clientWorker.Queue(messages.FromClientAPI{ + p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{ APObjectType: ap.ActivityLike, APActivityType: ap.ActivityUndo, GTSModel: gtsFave, @@ -162,7 +162,7 @@ func (p *Processor) FaveRemove(ctx context.Context, requestingAccount *gtsmodel. // FavedBy returns a slice of accounts that have liked the given status, filtered according to privacy settings. func (p *Processor) FavedBy(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) ([]*apimodel.Account, gtserror.WithCode) { - targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID) + targetStatus, err := p.state.DB.GetStatusByID(ctx, targetStatusID) if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err)) } @@ -178,7 +178,7 @@ func (p *Processor) FavedBy(ctx context.Context, requestingAccount *gtsmodel.Acc return nil, gtserror.NewErrorNotFound(errors.New("status is not visible")) } - statusFaves, err := p.db.GetStatusFaves(ctx, targetStatus) + statusFaves, err := p.state.DB.GetStatusFaves(ctx, targetStatus) if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("error seeing who faved status: %s", err)) } @@ -186,7 +186,7 @@ func (p *Processor) FavedBy(ctx context.Context, requestingAccount *gtsmodel.Acc // filter the list so the user doesn't see accounts they blocked or which blocked them filteredAccounts := []*gtsmodel.Account{} for _, fave := range statusFaves { - blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, fave.AccountID, true) + blocked, err := p.state.DB.IsBlocked(ctx, requestingAccount.ID, fave.AccountID, true) if err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("error checking blocks: %s", err)) } diff --git a/internal/processing/status/get.go b/internal/processing/status/get.go index edefeb440..51c384c44 100644 --- a/internal/processing/status/get.go +++ b/internal/processing/status/get.go @@ -31,7 +31,7 @@ import ( // Get gets the given status, taking account of privacy settings and blocks etc. func (p *Processor) Get(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) (*apimodel.Status, gtserror.WithCode) { - targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID) + targetStatus, err := p.state.DB.GetStatusByID(ctx, targetStatusID) if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err)) } @@ -57,7 +57,7 @@ func (p *Processor) Get(ctx context.Context, requestingAccount *gtsmodel.Account // ContextGet returns the context (previous and following posts) from the given status ID. func (p *Processor) ContextGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) (*apimodel.Context, gtserror.WithCode) { - targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID) + targetStatus, err := p.state.DB.GetStatusByID(ctx, targetStatusID) if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err)) } @@ -78,7 +78,7 @@ func (p *Processor) ContextGet(ctx context.Context, requestingAccount *gtsmodel. Descendants: []apimodel.Status{}, } - parents, err := p.db.GetStatusParents(ctx, targetStatus, false) + parents, err := p.state.DB.GetStatusParents(ctx, targetStatus, false) if err != nil { return nil, gtserror.NewErrorInternalError(err) } @@ -96,7 +96,7 @@ func (p *Processor) ContextGet(ctx context.Context, requestingAccount *gtsmodel. return context.Ancestors[i].ID < context.Ancestors[j].ID }) - children, err := p.db.GetStatusChildren(ctx, targetStatus, false, "") + children, err := p.state.DB.GetStatusChildren(ctx, targetStatus, false, "") if err != nil { return nil, gtserror.NewErrorInternalError(err) } diff --git a/internal/processing/status/pin.go b/internal/processing/status/pin.go index 3e50b0c73..6001a147f 100644 --- a/internal/processing/status/pin.go +++ b/internal/processing/status/pin.go @@ -39,7 +39,7 @@ const allowedPinnedCount = 10 // - Status is public, unlisted, or followers-only. // - Status is not a boost. func (p *Processor) getPinnableStatus(ctx context.Context, targetStatusID string, requestingAccountID string) (*gtsmodel.Status, gtserror.WithCode) { - targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID) + targetStatus, err := p.state.DB.GetStatusByID(ctx, targetStatusID) if err != nil { err = fmt.Errorf("error fetching status %s: %w", targetStatusID, err) return nil, gtserror.NewErrorNotFound(err) @@ -84,7 +84,7 @@ func (p *Processor) PinCreate(ctx context.Context, requestingAccount *gtsmodel.A return nil, gtserror.NewErrorUnprocessableEntity(err, err.Error()) } - pinnedCount, err := p.db.CountAccountPinned(ctx, requestingAccount.ID) + pinnedCount, err := p.state.DB.CountAccountPinned(ctx, requestingAccount.ID) if err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("error checking number of pinned statuses: %w", err)) } @@ -95,7 +95,7 @@ func (p *Processor) PinCreate(ctx context.Context, requestingAccount *gtsmodel.A } targetStatus.PinnedAt = time.Now() - if err := p.db.UpdateStatus(ctx, targetStatus, "pinned_at"); err != nil { + if err := p.state.DB.UpdateStatus(ctx, targetStatus, "pinned_at"); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error pinning status: %w", err)) } @@ -126,7 +126,7 @@ func (p *Processor) PinRemove(ctx context.Context, requestingAccount *gtsmodel.A if targetStatus.PinnedAt.IsZero() { targetStatus.PinnedAt = time.Time{} - if err := p.db.UpdateStatus(ctx, targetStatus, "pinned_at"); err != nil { + if err := p.state.DB.UpdateStatus(ctx, targetStatus, "pinned_at"); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error unpinning status: %w", err)) } } diff --git a/internal/processing/status/status.go b/internal/processing/status/status.go index c91fd85d1..909b06481 100644 --- a/internal/processing/status/status.go +++ b/internal/processing/status/status.go @@ -19,32 +19,28 @@ package status import ( - "github.com/superseriousbusiness/gotosocial/internal/concurrency" - "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" - "github.com/superseriousbusiness/gotosocial/internal/messages" + "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/text" "github.com/superseriousbusiness/gotosocial/internal/typeutils" "github.com/superseriousbusiness/gotosocial/internal/visibility" ) type Processor struct { + state *state.State tc typeutils.TypeConverter - db db.DB filter visibility.Filter formatter text.Formatter - clientWorker *concurrency.WorkerPool[messages.FromClientAPI] parseMention gtsmodel.ParseMentionFunc } // New returns a new status processor. -func New(db db.DB, tc typeutils.TypeConverter, clientWorker *concurrency.WorkerPool[messages.FromClientAPI], parseMention gtsmodel.ParseMentionFunc) Processor { +func New(state *state.State, tc typeutils.TypeConverter, parseMention gtsmodel.ParseMentionFunc) Processor { return Processor{ + state: state, tc: tc, - db: db, - filter: visibility.NewFilter(db), - formatter: text.NewFormatter(db), - clientWorker: clientWorker, + filter: visibility.NewFilter(state.DB), + formatter: text.NewFormatter(state.DB), parseMention: parseMention, } } diff --git a/internal/processing/status/status_test.go b/internal/processing/status/status_test.go index 272d2c8ea..1b35b69db 100644 --- a/internal/processing/status/status_test.go +++ b/internal/processing/status/status_test.go @@ -19,17 +19,14 @@ package status_test import ( - "context" - "github.com/stretchr/testify/suite" - "github.com/superseriousbusiness/gotosocial/internal/concurrency" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/federation" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/media" - "github.com/superseriousbusiness/gotosocial/internal/messages" "github.com/superseriousbusiness/gotosocial/internal/processing" "github.com/superseriousbusiness/gotosocial/internal/processing/status" + "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/storage" "github.com/superseriousbusiness/gotosocial/internal/transport" "github.com/superseriousbusiness/gotosocial/internal/typeutils" @@ -42,9 +39,9 @@ type StatusStandardTestSuite struct { typeConverter typeutils.TypeConverter tc transport.Controller storage *storage.Driver + state state.State mediaManager media.Manager federator federation.Federator - clientWorker *concurrency.WorkerPool[messages.FromClientAPI] // standard suite models testTokens map[string]*gtsmodel.Token @@ -74,21 +71,22 @@ func (suite *StatusStandardTestSuite) SetupSuite() { } func (suite *StatusStandardTestSuite) SetupTest() { + suite.state.Caches.Init() + testrig.StartWorkers(&suite.state) + testrig.InitTestConfig() testrig.InitTestLog() - fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) - - suite.db = testrig.NewTestDB() + suite.db = testrig.NewTestDB(&suite.state) suite.typeConverter = testrig.NewTestTypeConverter(suite.db) - suite.clientWorker = concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) - suite.tc = testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../testrig/media"), suite.db, fedWorker) + suite.state.DB = suite.db + + suite.tc = testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../testrig/media")) suite.storage = testrig.NewInMemoryStorage() - suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage) - suite.federator = testrig.NewTestFederator(suite.db, suite.tc, suite.storage, suite.mediaManager, fedWorker) - suite.status = status.New(suite.db, suite.typeConverter, suite.clientWorker, processing.GetParseMentionFunc(suite.db, suite.federator)) - suite.clientWorker.SetProcessor(func(ctx context.Context, msg messages.FromClientAPI) error { return nil }) - suite.NoError(suite.clientWorker.Start()) + suite.state.Storage = suite.storage + suite.mediaManager = testrig.NewTestMediaManager(&suite.state) + suite.federator = testrig.NewTestFederator(&suite.state, suite.tc, suite.mediaManager) + suite.status = status.New(&suite.state, suite.typeConverter, processing.GetParseMentionFunc(suite.db, suite.federator)) testrig.StandardDBSetup(suite.db, suite.testAccounts) testrig.StandardStorageSetup(suite.storage, "../../../testrig/media") @@ -97,4 +95,5 @@ func (suite *StatusStandardTestSuite) SetupTest() { func (suite *StatusStandardTestSuite) TearDownTest() { testrig.StandardDBTeardown(suite.db) testrig.StandardStorageTeardown(suite.storage) + testrig.StopWorkers(&suite.state) } |