diff options
Diffstat (limited to 'internal/processing/common/status.go')
-rw-r--r-- | internal/processing/common/status.go | 82 |
1 files changed, 53 insertions, 29 deletions
diff --git a/internal/processing/common/status.go b/internal/processing/common/status.go index 0a1f495fb..ae03a5306 100644 --- a/internal/processing/common/status.go +++ b/internal/processing/common/status.go @@ -30,10 +30,12 @@ import ( // GetTargetStatusBy fetches the target status with db load function, given the authorized (or, nil) requester's // account. This returns an approprate gtserror.WithCode accounting for not found and visibility to requester. +// The refresh argument allows specifying whether the returned copy should be force refreshed. func (p *Processor) GetTargetStatusBy( ctx context.Context, requester *gtsmodel.Account, getTargetFromDB func() (*gtsmodel.Status, error), + refresh bool, ) ( status *gtsmodel.Status, visible bool, @@ -61,47 +63,52 @@ func (p *Processor) GetTargetStatusBy( } if requester != nil && visible { - // Ensure remote status is up-to-date. - p.federator.RefreshStatusAsync(ctx, - requester.Username, - target, - nil, - false, - ) + // We only bother refreshing if this status + // is visible to requester, AND there *is* + // a requester (i.e. request is authorized) + // to prevent a possible DOS vector. + + if refresh { + // Refresh required, forcibly do synchronously. + _, _, err := p.federator.RefreshStatus(ctx, + requester.Username, + target, + nil, + true, // force + ) + if err != nil { + log.Errorf(ctx, "error refreshing status: %v", err) + } + } else { + // Only refresh async *if* out-of-date. + p.federator.RefreshStatusAsync(ctx, + requester.Username, + target, + nil, + false, // force + ) + } } return target, visible, nil } -// GetTargetStatusByID is a call-through to GetTargetStatus() using the db GetStatusByID() function. -func (p *Processor) GetTargetStatusByID( - ctx context.Context, - requester *gtsmodel.Account, - targetID string, -) ( - status *gtsmodel.Status, - visible bool, - errWithCode gtserror.WithCode, -) { - return p.GetTargetStatusBy(ctx, requester, func() (*gtsmodel.Status, error) { - return p.state.DB.GetStatusByID(ctx, targetID) - }) -} - -// GetVisibleTargetStatus calls GetTargetStatusByID(), +// GetVisibleTargetStatus calls GetTargetStatusBy(), // but converts a non-visible result to not-found error. -func (p *Processor) GetVisibleTargetStatus( +func (p *Processor) GetVisibleTargetStatusBy( ctx context.Context, requester *gtsmodel.Account, - targetID string, + getTargetFromDB func() (*gtsmodel.Status, error), + refresh bool, ) ( status *gtsmodel.Status, errWithCode gtserror.WithCode, ) { // Fetch the target status by ID from the database. - target, visible, errWithCode := p.GetTargetStatusByID(ctx, + target, visible, errWithCode := p.GetTargetStatusBy(ctx, requester, - targetID, + getTargetFromDB, + refresh, ) if errWithCode != nil { return nil, errWithCode @@ -119,6 +126,22 @@ func (p *Processor) GetVisibleTargetStatus( return target, nil } +// GetVisibleTargetStatus calls GetVisibleTargetStatusBy(), +// passing in a database function that fetches by status ID. +func (p *Processor) GetVisibleTargetStatus( + ctx context.Context, + requester *gtsmodel.Account, + targetID string, + refresh bool, +) ( + status *gtsmodel.Status, + errWithCode gtserror.WithCode, +) { + return p.GetVisibleTargetStatusBy(ctx, requester, func() (*gtsmodel.Status, error) { + return p.state.DB.GetStatusByID(ctx, targetID) + }, refresh) +} + // UnwrapIfBoost "unwraps" the given status if // it's a boost wrapper, by returning the boosted // status it targets (pending visibility checks). @@ -132,9 +155,10 @@ func (p *Processor) UnwrapIfBoost( if status.BoostOfID == "" { return status, nil } - return p.GetVisibleTargetStatus(ctx, - requester, status.BoostOfID, + requester, + status.BoostOfID, + false, ) } |