diff options
| -rw-r--r-- | internal/cache/gts.go | 1 | ||||
| -rw-r--r-- | internal/db/bundb/migrations/20231215115920_add_status_poll_index.go | 66 | ||||
| -rw-r--r-- | internal/db/bundb/poll.go | 14 | ||||
| -rw-r--r-- | internal/db/bundb/poll_test.go | 8 | ||||
| -rw-r--r-- | internal/db/bundb/status.go | 11 | ||||
| -rw-r--r-- | internal/db/poll.go | 3 | ||||
| -rw-r--r-- | internal/db/status.go | 9 | ||||
| -rw-r--r-- | internal/federation/dereferencing/status.go | 27 | ||||
| -rw-r--r-- | internal/processing/common/status.go | 82 | ||||
| -rw-r--r-- | internal/processing/fedi/status.go | 1 | ||||
| -rw-r--r-- | internal/processing/polls/poll.go | 38 | ||||
| -rw-r--r-- | internal/processing/status/bookmark.go | 6 | ||||
| -rw-r--r-- | internal/processing/status/boost.go | 2 | ||||
| -rw-r--r-- | internal/processing/status/fave.go | 7 | ||||
| -rw-r--r-- | internal/processing/status/get.go | 18 | ||||
| -rw-r--r-- | internal/processing/status/mute.go | 6 | ||||
| -rw-r--r-- | internal/processing/status/pin.go | 6 | 
17 files changed, 207 insertions, 98 deletions
diff --git a/internal/cache/gts.go b/internal/cache/gts.go index 339605354..507947305 100644 --- a/internal/cache/gts.go +++ b/internal/cache/gts.go @@ -952,6 +952,7 @@ func (c *GTSCaches) initStatus() {  		{Name: "ID"},  		{Name: "URI"},  		{Name: "URL"}, +		{Name: "PollID"},  		{Name: "BoostOfID.AccountID"},  		{Name: "ThreadID", Multi: true},  	}, copyF, cap) diff --git a/internal/db/bundb/migrations/20231215115920_add_status_poll_index.go b/internal/db/bundb/migrations/20231215115920_add_status_poll_index.go new file mode 100644 index 000000000..54b585d60 --- /dev/null +++ b/internal/db/bundb/migrations/20231215115920_add_status_poll_index.go @@ -0,0 +1,66 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program.  If not, see <http://www.gnu.org/licenses/>. + +package migrations + +import ( +	"context" + +	"github.com/uptrace/bun" +) + +func init() { +	up := func(ctx context.Context, db *bun.DB) error { +		return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { +			type spec struct { +				index   string +				table   string +				columns []string +			} + +			for _, spec := range []spec{ +				{ +					index:   "statuses_poll_id_idx", +					table:   "statuses", +					columns: []string{"poll_id"}, +				}, +			} { +				if _, err := tx. +					NewCreateIndex(). +					Table(spec.table). +					Index(spec.index). +					Column(spec.columns...). +					IfNotExists(). +					Exec(ctx); err != nil { +					return err +				} +			} + +			return nil +		}) +	} + +	down := func(ctx context.Context, db *bun.DB) error { +		return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { +			return nil +		}) +	} + +	if err := Migrations.Register(up, down); err != nil { +		panic(err) +	} +} diff --git a/internal/db/bundb/poll.go b/internal/db/bundb/poll.go index 830fb88ec..3e77fb6c5 100644 --- a/internal/db/bundb/poll.go +++ b/internal/db/bundb/poll.go @@ -50,20 +50,6 @@ func (p *pollDB) GetPollByID(ctx context.Context, id string) (*gtsmodel.Poll, er  	)  } -func (p *pollDB) GetPollByStatusID(ctx context.Context, statusID string) (*gtsmodel.Poll, error) { -	return p.getPoll( -		ctx, -		"StatusID", -		func(poll *gtsmodel.Poll) error { -			return p.db.NewSelect(). -				Model(poll). -				Where("? = ?", bun.Ident("poll.status_id"), statusID). -				Scan(ctx) -		}, -		statusID, -	) -} -  func (p *pollDB) getPoll(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Poll) error, keyParts ...any) (*gtsmodel.Poll, error) {  	// Fetch poll from database cache with loader callback  	poll, err := p.state.Caches.GTS.Poll().Load(lookup, func() (*gtsmodel.Poll, error) { diff --git a/internal/db/bundb/poll_test.go b/internal/db/bundb/poll_test.go index 479557c55..6bdbdb983 100644 --- a/internal/db/bundb/poll_test.go +++ b/internal/db/bundb/poll_test.go @@ -67,10 +67,6 @@ func (suite *PollTestSuite) TestGetPollBy() {  			"id": func() (*gtsmodel.Poll, error) {  				return suite.db.GetPollByID(ctx, poll.ID)  			}, - -			"status_id": func() (*gtsmodel.Poll, error) { -				return suite.db.GetPollByStatusID(ctx, poll.StatusID) -			},  		} {  			// Clear database caches. @@ -287,10 +283,6 @@ func (suite *PollTestSuite) TestDeletePoll() {  		// Ensure that afterwards we cannot fetch poll.  		_, err = suite.db.GetPollByID(ctx, poll.ID)  		suite.ErrorIs(err, db.ErrNoEntries) - -		// Or again by the status it's attached to. -		_, err = suite.db.GetPollByStatusID(ctx, poll.StatusID) -		suite.ErrorIs(err, db.ErrNoEntries)  	}  } diff --git a/internal/db/bundb/status.go b/internal/db/bundb/status.go index dd161e1ec..da252c7f7 100644 --- a/internal/db/bundb/status.go +++ b/internal/db/bundb/status.go @@ -87,6 +87,17 @@ func (s *statusDB) GetStatusByURL(ctx context.Context, url string) (*gtsmodel.St  	)  } +func (s *statusDB) GetStatusByPollID(ctx context.Context, pollID string) (*gtsmodel.Status, error) { +	return s.getStatus( +		ctx, +		"PollID", +		func(status *gtsmodel.Status) error { +			return s.db.NewSelect().Model(status).Where("? = ?", bun.Ident("status.poll_id"), pollID).Scan(ctx) +		}, +		pollID, +	) +} +  func (s *statusDB) GetStatusBoost(ctx context.Context, boostOfID string, byAccountID string) (*gtsmodel.Status, error) {  	return s.getStatus(  		ctx, diff --git a/internal/db/poll.go b/internal/db/poll.go index b59d27c73..ac0229855 100644 --- a/internal/db/poll.go +++ b/internal/db/poll.go @@ -27,9 +27,6 @@ type Poll interface {  	// GetPollByID fetches the Poll with given ID from the database.  	GetPollByID(ctx context.Context, id string) (*gtsmodel.Poll, error) -	// GetPollByStatusID fetches the Poll with given status ID column value from the database. -	GetPollByStatusID(ctx context.Context, statusID string) (*gtsmodel.Poll, error) -  	// GetOpenPolls fetches all local Polls in the database with an unset `closed_at` column.  	GetOpenPolls(ctx context.Context) ([]*gtsmodel.Poll, error) diff --git a/internal/db/status.go b/internal/db/status.go index 1ebf503a8..8034d39e7 100644 --- a/internal/db/status.go +++ b/internal/db/status.go @@ -25,15 +25,18 @@ import (  // Status contains functions for getting statuses, creating statuses, and checking various other fields on statuses.  type Status interface { -	// GetStatusByID returns one status from the database, with no rel fields populated, only their linking ID / URIs +	// GetStatusByID fetches the status from the database with matching id column.  	GetStatusByID(ctx context.Context, id string) (*gtsmodel.Status, error) -	// GetStatusByURI returns one status from the database, with no rel fields populated, only their linking ID / URIs +	// GetStatusByURI fetches the status from the database with matching uri column.  	GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.Status, error) -	// GetStatusByURL returns one status from the database, with no rel fields populated, only their linking ID / URIs +	// GetStatusByURL fetches the status from the database with matching url column.  	GetStatusByURL(ctx context.Context, uri string) (*gtsmodel.Status, error) +	// GetStatusByPollID fetches the status from the database with matching poll_id column. +	GetStatusByPollID(ctx context.Context, pollID string) (*gtsmodel.Status, error) +  	// GetStatusBoost fetches the status whose boost_of_id column refers to boostOfID, authored by given account ID.  	GetStatusBoost(ctx context.Context, boostOfID string, byAccountID string) (*gtsmodel.Status, error) diff --git a/internal/federation/dereferencing/status.go b/internal/federation/dereferencing/status.go index 8a8ec60b1..2a2b99d25 100644 --- a/internal/federation/dereferencing/status.go +++ b/internal/federation/dereferencing/status.go @@ -40,14 +40,25 @@ import (  // statusUpToDate returns whether the given status model is both updateable  // (i.e. remote status) and whether it needs an update based on `fetched_at`. -func statusUpToDate(status *gtsmodel.Status) bool { +func statusUpToDate(status *gtsmodel.Status, force bool) bool {  	if *status.Local {  		// Can't update local statuses.  		return true  	} -	// If this status was updated recently (last interval), we return as-is. -	if next := status.FetchedAt.Add(2 * time.Hour); time.Now().Before(next) { +	// Default limit we allow +	// statuses to be refreshed. +	limit := 2 * time.Hour + +	if force { +		// We specifically allow the force flag +		// to force an early refresh (on a much +		// smaller cooldown period). +		limit = 5 * time.Minute +	} + +	// If this status was updated recently (within limit), return as-is. +	if next := status.FetchedAt.Add(limit); time.Now().Before(next) {  		return true  	} @@ -125,7 +136,7 @@ func (d *Dereferencer) getStatusByURI(ctx context.Context, requestUser string, u  	}  	// Check whether needs update. -	if statusUpToDate(status) { +	if statusUpToDate(status, false) {  		// This is existing up-to-date status, ensure it is populated.  		if err := d.state.DB.PopulateStatus(ctx, status); err != nil {  			log.Errorf(ctx, "error populating existing status: %v", err) @@ -159,8 +170,8 @@ func (d *Dereferencer) RefreshStatus(  	statusable ap.Statusable,  	force bool,  ) (*gtsmodel.Status, ap.Statusable, error) { -	// Check whether needs update. -	if !force && statusUpToDate(status) { +	// Check whether status needs update. +	if statusUpToDate(status, force) {  		return status, nil, nil  	} @@ -204,8 +215,8 @@ func (d *Dereferencer) RefreshStatusAsync(  	statusable ap.Statusable,  	force bool,  ) { -	// Check whether needs update. -	if !force && statusUpToDate(status) { +	// Check whether status needs update. +	if statusUpToDate(status, force) {  		return  	} diff --git a/internal/processing/common/status.go b/internal/processing/common/status.go index 0a1f495fb..ae03a5306 100644 --- a/internal/processing/common/status.go +++ b/internal/processing/common/status.go @@ -30,10 +30,12 @@ import (  // GetTargetStatusBy fetches the target status with db load function, given the authorized (or, nil) requester's  // account. This returns an approprate gtserror.WithCode accounting for not found and visibility to requester. +// The refresh argument allows specifying whether the returned copy should be force refreshed.  func (p *Processor) GetTargetStatusBy(  	ctx context.Context,  	requester *gtsmodel.Account,  	getTargetFromDB func() (*gtsmodel.Status, error), +	refresh bool,  ) (  	status *gtsmodel.Status,  	visible bool, @@ -61,47 +63,52 @@ func (p *Processor) GetTargetStatusBy(  	}  	if requester != nil && visible { -		// Ensure remote status is up-to-date. -		p.federator.RefreshStatusAsync(ctx, -			requester.Username, -			target, -			nil, -			false, -		) +		// We only bother refreshing if this status +		// is visible to requester, AND there *is* +		// a requester (i.e. request is authorized) +		// to prevent a possible DOS vector. + +		if refresh { +			// Refresh required, forcibly do synchronously. +			_, _, err := p.federator.RefreshStatus(ctx, +				requester.Username, +				target, +				nil, +				true, // force +			) +			if err != nil { +				log.Errorf(ctx, "error refreshing status: %v", err) +			} +		} else { +			// Only refresh async *if* out-of-date. +			p.federator.RefreshStatusAsync(ctx, +				requester.Username, +				target, +				nil, +				false, // force +			) +		}  	}  	return target, visible, nil  } -// GetTargetStatusByID is a call-through to GetTargetStatus() using the db GetStatusByID() function. -func (p *Processor) GetTargetStatusByID( -	ctx context.Context, -	requester *gtsmodel.Account, -	targetID string, -) ( -	status *gtsmodel.Status, -	visible bool, -	errWithCode gtserror.WithCode, -) { -	return p.GetTargetStatusBy(ctx, requester, func() (*gtsmodel.Status, error) { -		return p.state.DB.GetStatusByID(ctx, targetID) -	}) -} - -// GetVisibleTargetStatus calls GetTargetStatusByID(), +// GetVisibleTargetStatus calls GetTargetStatusBy(),  // but converts a non-visible result to not-found error. -func (p *Processor) GetVisibleTargetStatus( +func (p *Processor) GetVisibleTargetStatusBy(  	ctx context.Context,  	requester *gtsmodel.Account, -	targetID string, +	getTargetFromDB func() (*gtsmodel.Status, error), +	refresh bool,  ) (  	status *gtsmodel.Status,  	errWithCode gtserror.WithCode,  ) {  	// Fetch the target status by ID from the database. -	target, visible, errWithCode := p.GetTargetStatusByID(ctx, +	target, visible, errWithCode := p.GetTargetStatusBy(ctx,  		requester, -		targetID, +		getTargetFromDB, +		refresh,  	)  	if errWithCode != nil {  		return nil, errWithCode @@ -119,6 +126,22 @@ func (p *Processor) GetVisibleTargetStatus(  	return target, nil  } +// GetVisibleTargetStatus calls GetVisibleTargetStatusBy(), +// passing in a database function that fetches by status ID. +func (p *Processor) GetVisibleTargetStatus( +	ctx context.Context, +	requester *gtsmodel.Account, +	targetID string, +	refresh bool, +) ( +	status *gtsmodel.Status, +	errWithCode gtserror.WithCode, +) { +	return p.GetVisibleTargetStatusBy(ctx, requester, func() (*gtsmodel.Status, error) { +		return p.state.DB.GetStatusByID(ctx, targetID) +	}, refresh) +} +  // UnwrapIfBoost "unwraps" the given status if  // it's a boost wrapper, by returning the boosted  // status it targets (pending visibility checks). @@ -132,9 +155,10 @@ func (p *Processor) UnwrapIfBoost(  	if status.BoostOfID == "" {  		return status, nil  	} -  	return p.GetVisibleTargetStatus(ctx, -		requester, status.BoostOfID, +		requester, +		status.BoostOfID, +		false,  	)  } diff --git a/internal/processing/fedi/status.go b/internal/processing/fedi/status.go index 1c1af9cb4..2674ebf68 100644 --- a/internal/processing/fedi/status.go +++ b/internal/processing/fedi/status.go @@ -100,6 +100,7 @@ func (p *Processor) StatusRepliesGet(  	status, errWithCode := p.c.GetVisibleTargetStatus(ctx,  		requester,  		statusID, +		false, // refresh  	)  	if errWithCode != nil {  		return nil, errWithCode diff --git a/internal/processing/polls/poll.go b/internal/processing/polls/poll.go index 3b258b76c..19cf555e5 100644 --- a/internal/processing/polls/poll.go +++ b/internal/processing/polls/poll.go @@ -19,11 +19,8 @@ package polls  import (  	"context" -	"errors"  	apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" -	"github.com/superseriousbusiness/gotosocial/internal/db" -	"github.com/superseriousbusiness/gotosocial/internal/gtscontext"  	"github.com/superseriousbusiness/gotosocial/internal/gtserror"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"  	"github.com/superseriousbusiness/gotosocial/internal/processing/common" @@ -48,35 +45,24 @@ func New(common *common.Processor, state *state.State, converter *typeutils.Conv  }  // getTargetPoll fetches a target poll ID for requesting account, taking visibility of the poll's originating status into account. -func (p *Processor) getTargetPoll(ctx context.Context, requestingAccount *gtsmodel.Account, targetID string) (*gtsmodel.Poll, gtserror.WithCode) { -	// Load the requested poll with ID. -	// (barebones as we fetch status below) -	poll, err := p.state.DB.GetPollByID( -		gtscontext.SetBarebones(ctx), -		targetID, +func (p *Processor) getTargetPoll(ctx context.Context, requester *gtsmodel.Account, targetID string) (*gtsmodel.Poll, gtserror.WithCode) { +	// Load the status the poll is attached to by the poll ID, +	// checking for visibility and ensuring it is up-to-date. +	status, errWithCode := p.c.GetVisibleTargetStatusBy(ctx, +		requester, +		func() (*gtsmodel.Status, error) { +			return p.state.DB.GetStatusByPollID(ctx, targetID) +		}, +		true, // refresh  	) -	if err != nil && !errors.Is(err, db.ErrNoEntries) { -		return nil, gtserror.NewErrorInternalError(err) -	} - -	if poll == nil { -		// No poll could be found for given ID. -		const text = "target poll not found" -		return nil, gtserror.NewErrorNotFound( -			errors.New(text), -			text, -		) -	} - -	// Check that we can see + fetch the originating status for requesting account. -	status, errWithCode := p.c.GetVisibleTargetStatus(ctx, requestingAccount, poll.StatusID)  	if errWithCode != nil {  		return nil, errWithCode  	} -	// Update poll status. +	// Return most up-to-date +	// copy of the status poll. +	poll := status.Poll  	poll.Status = status -  	return poll, nil  } diff --git a/internal/processing/status/bookmark.go b/internal/processing/status/bookmark.go index 634529ba4..224445838 100644 --- a/internal/processing/status/bookmark.go +++ b/internal/processing/status/bookmark.go @@ -30,7 +30,11 @@ import (  )  func (p *Processor) getBookmarkableStatus(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) (*gtsmodel.Status, string, gtserror.WithCode) { -	targetStatus, errWithCode := p.c.GetVisibleTargetStatus(ctx, requestingAccount, targetStatusID) +	targetStatus, errWithCode := p.c.GetVisibleTargetStatus(ctx, +		requestingAccount, +		targetStatusID, +		false, // refresh +	)  	if errWithCode != nil {  		return nil, "", errWithCode  	} diff --git a/internal/processing/status/boost.go b/internal/processing/status/boost.go index 2062fb802..2fc96091e 100644 --- a/internal/processing/status/boost.go +++ b/internal/processing/status/boost.go @@ -43,6 +43,7 @@ func (p *Processor) BoostCreate(  		ctx,  		requester,  		targetID, +		false, // refresh  	)  	if errWithCode != nil {  		return nil, errWithCode @@ -112,6 +113,7 @@ func (p *Processor) BoostRemove(  		ctx,  		requester,  		targetID, +		false, // refresh  	)  	if errWithCode != nil {  		return nil, errWithCode diff --git a/internal/processing/status/fave.go b/internal/processing/status/fave.go index dbeba7fe9..7ac270e8c 100644 --- a/internal/processing/status/fave.go +++ b/internal/processing/status/fave.go @@ -47,6 +47,7 @@ func (p *Processor) getFaveableStatus(  		ctx,  		requester,  		targetID, +		false, // refresh  	)  	if errWithCode != nil {  		return nil, nil, errWithCode @@ -149,7 +150,11 @@ func (p *Processor) FaveRemove(ctx context.Context, requestingAccount *gtsmodel.  // FavedBy returns a slice of accounts that have liked the given status, filtered according to privacy settings.  func (p *Processor) FavedBy(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) ([]*apimodel.Account, gtserror.WithCode) { -	targetStatus, errWithCode := p.c.GetVisibleTargetStatus(ctx, requestingAccount, targetStatusID) +	targetStatus, errWithCode := p.c.GetVisibleTargetStatus(ctx, +		requestingAccount, +		targetStatusID, +		false, // refresh +	)  	if errWithCode != nil {  		return nil, errWithCode  	} diff --git a/internal/processing/status/get.go b/internal/processing/status/get.go index c182bd148..f8c037404 100644 --- a/internal/processing/status/get.go +++ b/internal/processing/status/get.go @@ -28,7 +28,11 @@ import (  // Get gets the given status, taking account of privacy settings and blocks etc.  func (p *Processor) Get(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) (*apimodel.Status, gtserror.WithCode) { -	targetStatus, errWithCode := p.c.GetVisibleTargetStatus(ctx, requestingAccount, targetStatusID) +	targetStatus, errWithCode := p.c.GetVisibleTargetStatus(ctx, +		requestingAccount, +		targetStatusID, +		false, // refresh +	)  	if errWithCode != nil {  		return nil, errWithCode  	} @@ -38,7 +42,11 @@ func (p *Processor) Get(ctx context.Context, requestingAccount *gtsmodel.Account  // WebGet gets the given status for web use, taking account of privacy settings.  func (p *Processor) WebGet(ctx context.Context, targetStatusID string) (*apimodel.Status, gtserror.WithCode) { -	targetStatus, errWithCode := p.c.GetVisibleTargetStatus(ctx, nil, targetStatusID) +	targetStatus, errWithCode := p.c.GetVisibleTargetStatus(ctx, +		nil, // requester +		targetStatusID, +		false, // refresh +	)  	if errWithCode != nil {  		return nil, errWithCode  	} @@ -57,7 +65,11 @@ func (p *Processor) contextGet(  	targetStatusID string,  	convert func(context.Context, *gtsmodel.Status, *gtsmodel.Account) (*apimodel.Status, error),  ) (*apimodel.Context, gtserror.WithCode) { -	targetStatus, errWithCode := p.c.GetVisibleTargetStatus(ctx, requestingAccount, targetStatusID) +	targetStatus, errWithCode := p.c.GetVisibleTargetStatus(ctx, +		requestingAccount, +		targetStatusID, +		false, // refresh +	)  	if errWithCode != nil {  		return nil, errWithCode  	} diff --git a/internal/processing/status/mute.go b/internal/processing/status/mute.go index 1663ee0bc..fb4f3b384 100644 --- a/internal/processing/status/mute.go +++ b/internal/processing/status/mute.go @@ -41,7 +41,11 @@ func (p *Processor) getMuteableStatus(  	requestingAccount *gtsmodel.Account,  	targetStatusID string,  ) (*gtsmodel.Status, gtserror.WithCode) { -	targetStatus, errWithCode := p.c.GetVisibleTargetStatus(ctx, requestingAccount, targetStatusID) +	targetStatus, errWithCode := p.c.GetVisibleTargetStatus(ctx, +		requestingAccount, +		targetStatusID, +		false, // refresh +	)  	if errWithCode != nil {  		return nil, errWithCode  	} diff --git a/internal/processing/status/pin.go b/internal/processing/status/pin.go index b31288a64..f08b9652c 100644 --- a/internal/processing/status/pin.go +++ b/internal/processing/status/pin.go @@ -39,7 +39,11 @@ const allowedPinnedCount = 10  //   - Status is public, unlisted, or followers-only.  //   - Status is not a boost.  func (p *Processor) getPinnableStatus(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) (*gtsmodel.Status, gtserror.WithCode) { -	targetStatus, errWithCode := p.c.GetVisibleTargetStatus(ctx, requestingAccount, targetStatusID) +	targetStatus, errWithCode := p.c.GetVisibleTargetStatus(ctx, +		requestingAccount, +		targetStatusID, +		false, // refresh +	)  	if errWithCode != nil {  		return nil, errWithCode  	}  | 
