diff options
| author | 2023-09-12 14:00:35 +0100 | |
|---|---|---|
| committer | 2023-09-12 14:00:35 +0100 | |
| commit | 7293d6029b43db693fd170c0c087394339da0677 (patch) | |
| tree | 09063243faf1b178fde35973486e311f66b1ca33 /internal/db | |
| parent | [feature] Allow admins to expire remote public keys; refetch expired keys on ... (diff) | |
| download | gotosocial-7293d6029b43db693fd170c0c087394339da0677.tar.xz | |
[feature] add paging to account follows, followers and follow requests endpoints (#2186)
Diffstat (limited to 'internal/db')
| -rw-r--r-- | internal/db/bundb/relationship.go | 101 | ||||
| -rw-r--r-- | internal/db/bundb/relationship_test.go | 6 | ||||
| -rw-r--r-- | internal/db/bundb/timeline.go | 1 | ||||
| -rw-r--r-- | internal/db/bundb/timeline_test.go | 2 | ||||
| -rw-r--r-- | internal/db/bundb/util.go | 25 | ||||
| -rw-r--r-- | internal/db/relationship.go | 33 | 
6 files changed, 96 insertions, 72 deletions
| diff --git a/internal/db/bundb/relationship.go b/internal/db/bundb/relationship.go index f1bdcf52b..822e697c1 100644 --- a/internal/db/bundb/relationship.go +++ b/internal/db/bundb/relationship.go @@ -102,8 +102,8 @@ func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount  	return &rel, nil  } -func (r *relationshipDB) GetAccountFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) { -	followIDs, err := r.getAccountFollowIDs(ctx, accountID) +func (r *relationshipDB) GetAccountFollows(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.Follow, error) { +	followIDs, err := r.getAccountFollowIDs(ctx, accountID, page)  	if err != nil {  		return nil, err  	} @@ -118,8 +118,8 @@ func (r *relationshipDB) GetAccountLocalFollows(ctx context.Context, accountID s  	return r.GetFollowsByIDs(ctx, followIDs)  } -func (r *relationshipDB) GetAccountFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) { -	followerIDs, err := r.getAccountFollowerIDs(ctx, accountID) +func (r *relationshipDB) GetAccountFollowers(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.Follow, error) { +	followerIDs, err := r.getAccountFollowerIDs(ctx, accountID, page)  	if err != nil {  		return nil, err  	} @@ -134,16 +134,16 @@ func (r *relationshipDB) GetAccountLocalFollowers(ctx context.Context, accountID  	return r.GetFollowsByIDs(ctx, followerIDs)  } -func (r *relationshipDB) GetAccountFollowRequests(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, error) { -	followReqIDs, err := r.getAccountFollowRequestIDs(ctx, accountID) +func (r *relationshipDB) GetAccountFollowRequests(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.FollowRequest, error) { +	followReqIDs, err := r.getAccountFollowRequestIDs(ctx, accountID, page)  	if err != nil {  		return nil, err  	}  	return r.GetFollowRequestsByIDs(ctx, followReqIDs)  } -func (r *relationshipDB) GetAccountFollowRequesting(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, error) { -	followReqIDs, err := r.getAccountFollowRequestingIDs(ctx, accountID) +func (r *relationshipDB) GetAccountFollowRequesting(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.FollowRequest, error) { +	followReqIDs, err := r.getAccountFollowRequestingIDs(ctx, accountID, page)  	if err != nil {  		return nil, err  	} @@ -151,39 +151,15 @@ func (r *relationshipDB) GetAccountFollowRequesting(ctx context.Context, account  }  func (r *relationshipDB) GetAccountBlocks(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.Block, error) { -	// Load block IDs from cache with database loader callback. -	blockIDs, err := r.state.Caches.GTS.BlockIDs().Load(accountID, func() ([]string, error) { -		var blockIDs []string - -		// Block IDs not in cache, perform DB query! -		q := newSelectBlocks(r.db, accountID) -		if _, err := q.Exec(ctx, &blockIDs); err != nil { -			return nil, err -		} - -		return blockIDs, nil -	}) +	blockIDs, err := r.getAccountBlockIDs(ctx, accountID, page)  	if err != nil {  		return nil, err  	} - -	// Our cached / selected block IDs are -	// ALWAYS stored in descending order. -	// Depending on the paging requested -	// this may be an unexpected order. -	if !page.GetOrder().Ascending() { -		blockIDs = paging.Reverse(blockIDs) -	} - -	// Page the resulting block IDs. -	blockIDs = page.Page(blockIDs) - -	// Convert these IDs to full block objects.  	return r.GetBlocksByIDs(ctx, blockIDs)  }  func (r *relationshipDB) CountAccountFollows(ctx context.Context, accountID string) (int, error) { -	followIDs, err := r.getAccountFollowIDs(ctx, accountID) +	followIDs, err := r.getAccountFollowIDs(ctx, accountID, nil)  	return len(followIDs), err  } @@ -193,7 +169,7 @@ func (r *relationshipDB) CountAccountLocalFollows(ctx context.Context, accountID  }  func (r *relationshipDB) CountAccountFollowers(ctx context.Context, accountID string) (int, error) { -	followerIDs, err := r.getAccountFollowerIDs(ctx, accountID) +	followerIDs, err := r.getAccountFollowerIDs(ctx, accountID, nil)  	return len(followerIDs), err  } @@ -203,17 +179,22 @@ func (r *relationshipDB) CountAccountLocalFollowers(ctx context.Context, account  }  func (r *relationshipDB) CountAccountFollowRequests(ctx context.Context, accountID string) (int, error) { -	followReqIDs, err := r.getAccountFollowRequestIDs(ctx, accountID) +	followReqIDs, err := r.getAccountFollowRequestIDs(ctx, accountID, nil)  	return len(followReqIDs), err  }  func (r *relationshipDB) CountAccountFollowRequesting(ctx context.Context, accountID string) (int, error) { -	followReqIDs, err := r.getAccountFollowRequestingIDs(ctx, accountID) +	followReqIDs, err := r.getAccountFollowRequestingIDs(ctx, accountID, nil)  	return len(followReqIDs), err  } -func (r *relationshipDB) getAccountFollowIDs(ctx context.Context, accountID string) ([]string, error) { -	return r.state.Caches.GTS.FollowIDs().Load(">"+accountID, func() ([]string, error) { +func (r *relationshipDB) CountAccountBlocks(ctx context.Context, accountID string) (int, error) { +	blockIDs, err := r.getAccountBlockIDs(ctx, accountID, nil) +	return len(blockIDs), err +} + +func (r *relationshipDB) getAccountFollowIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) { +	return loadPagedIDs(r.state.Caches.GTS.FollowIDs(), ">"+accountID, page, func() ([]string, error) {  		var followIDs []string  		// Follow IDs not in cache, perform DB query! @@ -240,8 +221,8 @@ func (r *relationshipDB) getAccountLocalFollowIDs(ctx context.Context, accountID  	})  } -func (r *relationshipDB) getAccountFollowerIDs(ctx context.Context, accountID string) ([]string, error) { -	return r.state.Caches.GTS.FollowIDs().Load("<"+accountID, func() ([]string, error) { +func (r *relationshipDB) getAccountFollowerIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) { +	return loadPagedIDs(r.state.Caches.GTS.FollowIDs(), "<"+accountID, page, func() ([]string, error) {  		var followIDs []string  		// Follow IDs not in cache, perform DB query! @@ -268,8 +249,8 @@ func (r *relationshipDB) getAccountLocalFollowerIDs(ctx context.Context, account  	})  } -func (r *relationshipDB) getAccountFollowRequestIDs(ctx context.Context, accountID string) ([]string, error) { -	return r.state.Caches.GTS.FollowRequestIDs().Load(">"+accountID, func() ([]string, error) { +func (r *relationshipDB) getAccountFollowRequestIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) { +	return loadPagedIDs(r.state.Caches.GTS.FollowRequestIDs(), ">"+accountID, page, func() ([]string, error) {  		var followReqIDs []string  		// Follow request IDs not in cache, perform DB query! @@ -282,8 +263,8 @@ func (r *relationshipDB) getAccountFollowRequestIDs(ctx context.Context, account  	})  } -func (r *relationshipDB) getAccountFollowRequestingIDs(ctx context.Context, accountID string) ([]string, error) { -	return r.state.Caches.GTS.FollowRequestIDs().Load("<"+accountID, func() ([]string, error) { +func (r *relationshipDB) getAccountFollowRequestingIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) { +	return loadPagedIDs(r.state.Caches.GTS.FollowRequestIDs(), "<"+accountID, page, func() ([]string, error) {  		var followReqIDs []string  		// Follow request IDs not in cache, perform DB query! @@ -296,13 +277,27 @@ func (r *relationshipDB) getAccountFollowRequestingIDs(ctx context.Context, acco  	})  } +func (r *relationshipDB) getAccountBlockIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) { +	return loadPagedIDs(r.state.Caches.GTS.BlockIDs(), accountID, page, func() ([]string, error) { +		var blockIDs []string + +		// Block IDs not in cache, perform DB query! +		q := newSelectBlocks(r.db, accountID) +		if _, err := q.Exec(ctx, &blockIDs); err != nil { +			return nil, err +		} + +		return blockIDs, nil +	}) +} +  // newSelectFollowRequests returns a new select query for all rows in the follow_requests table with target_account_id = accountID.  func newSelectFollowRequests(db *DB, accountID string) *bun.SelectQuery {  	return db.NewSelect().  		TableExpr("?", bun.Ident("follow_requests")).  		ColumnExpr("?", bun.Ident("id")).  		Where("? = ?", bun.Ident("target_account_id"), accountID). -		OrderExpr("? DESC", bun.Ident("updated_at")) +		OrderExpr("? DESC", bun.Ident("id"))  }  // newSelectFollowRequesting returns a new select query for all rows in the follow_requests table with account_id = accountID. @@ -311,7 +306,7 @@ func newSelectFollowRequesting(db *DB, accountID string) *bun.SelectQuery {  		TableExpr("?", bun.Ident("follow_requests")).  		ColumnExpr("?", bun.Ident("id")).  		Where("? = ?", bun.Ident("target_account_id"), accountID). -		OrderExpr("? DESC", bun.Ident("updated_at")) +		OrderExpr("? DESC", bun.Ident("id"))  }  // newSelectFollows returns a new select query for all rows in the follows table with account_id = accountID. @@ -320,7 +315,7 @@ func newSelectFollows(db *DB, accountID string) *bun.SelectQuery {  		Table("follows").  		Column("id").  		Where("? = ?", bun.Ident("account_id"), accountID). -		OrderExpr("? DESC", bun.Ident("updated_at")) +		OrderExpr("? DESC", bun.Ident("id"))  }  // newSelectLocalFollows returns a new select query for all rows in the follows table with @@ -338,7 +333,7 @@ func newSelectLocalFollows(db *DB, accountID string) *bun.SelectQuery {  				Column("id").  				Where("? IS NULL", bun.Ident("domain")),  		). -		OrderExpr("? DESC", bun.Ident("updated_at")) +		OrderExpr("? DESC", bun.Ident("id"))  }  // newSelectFollowers returns a new select query for all rows in the follows table with target_account_id = accountID. @@ -347,7 +342,7 @@ func newSelectFollowers(db *DB, accountID string) *bun.SelectQuery {  		Table("follows").  		Column("id").  		Where("? = ?", bun.Ident("target_account_id"), accountID). -		OrderExpr("? DESC", bun.Ident("updated_at")) +		OrderExpr("? DESC", bun.Ident("id"))  }  // newSelectLocalFollowers returns a new select query for all rows in the follows table with @@ -365,14 +360,14 @@ func newSelectLocalFollowers(db *DB, accountID string) *bun.SelectQuery {  				Column("id").  				Where("? IS NULL", bun.Ident("domain")),  		). -		OrderExpr("? DESC", bun.Ident("updated_at")) +		OrderExpr("? DESC", bun.Ident("id"))  }  // newSelectBlocks returns a new select query for all rows in the blocks table with account_id = accountID.  func newSelectBlocks(db *DB, accountID string) *bun.SelectQuery {  	return db.NewSelect().  		TableExpr("?", bun.Ident("blocks")). -		ColumnExpr("?", bun.Ident("?")). +		ColumnExpr("?", bun.Ident("id")).  		Where("? = ?", bun.Ident("account_id"), accountID). -		OrderExpr("? DESC", bun.Ident("updated_at")) +		OrderExpr("? DESC", bun.Ident("id"))  } diff --git a/internal/db/bundb/relationship_test.go b/internal/db/bundb/relationship_test.go index d7c93ff0e..aa2353961 100644 --- a/internal/db/bundb/relationship_test.go +++ b/internal/db/bundb/relationship_test.go @@ -753,14 +753,14 @@ func (suite *RelationshipTestSuite) TestGetAccountFollowRequests() {  		suite.FailNow(err.Error())  	} -	followRequests, err := suite.db.GetAccountFollowRequests(ctx, targetAccount.ID) +	followRequests, err := suite.db.GetAccountFollowRequests(ctx, targetAccount.ID, nil)  	suite.NoError(err)  	suite.Len(followRequests, 1)  }  func (suite *RelationshipTestSuite) TestGetAccountFollows() {  	account := suite.testAccounts["local_account_1"] -	follows, err := suite.db.GetAccountFollows(context.Background(), account.ID) +	follows, err := suite.db.GetAccountFollows(context.Background(), account.ID, nil)  	suite.NoError(err)  	suite.Len(follows, 2)  } @@ -781,7 +781,7 @@ func (suite *RelationshipTestSuite) TestCountAccountFollows() {  func (suite *RelationshipTestSuite) TestGetAccountFollowers() {  	account := suite.testAccounts["local_account_1"] -	follows, err := suite.db.GetAccountFollowers(context.Background(), account.ID) +	follows, err := suite.db.GetAccountFollowers(context.Background(), account.ID, nil)  	suite.NoError(err)  	suite.Len(follows, 2)  } diff --git a/internal/db/bundb/timeline.go b/internal/db/bundb/timeline.go index f63937bc1..229245899 100644 --- a/internal/db/bundb/timeline.go +++ b/internal/db/bundb/timeline.go @@ -114,6 +114,7 @@ func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxI  	follows, err := t.state.DB.GetAccountFollows(  		gtscontext.SetBarebones(ctx),  		accountID, +		nil, // select all  	)  	if err != nil && !errors.Is(err, db.ErrNoEntries) {  		return nil, gtserror.Newf("db error getting follows for account %s: %w", accountID, err) diff --git a/internal/db/bundb/timeline_test.go b/internal/db/bundb/timeline_test.go index e5a78dfd1..ac169ec4a 100644 --- a/internal/db/bundb/timeline_test.go +++ b/internal/db/bundb/timeline_test.go @@ -167,8 +167,8 @@ func (suite *TimelineTestSuite) TestGetHomeTimelineNoFollowing() {  	follows, err := suite.state.DB.GetAccountFollows(  		gtscontext.SetBarebones(ctx),  		viewingAccount.ID, +		nil, // select all  	) -  	if err != nil {  		suite.FailNow(err.Error())  	} diff --git a/internal/db/bundb/util.go b/internal/db/bundb/util.go index 3c3249daf..1d820d081 100644 --- a/internal/db/bundb/util.go +++ b/internal/db/bundb/util.go @@ -20,7 +20,9 @@ package bundb  import (  	"strings" +	"github.com/superseriousbusiness/gotosocial/internal/cache"  	"github.com/superseriousbusiness/gotosocial/internal/db" +	"github.com/superseriousbusiness/gotosocial/internal/paging"  	"github.com/uptrace/bun"  ) @@ -83,6 +85,29 @@ func whereStartsLike(  	)  } +// loadPagedIDs loads a page of IDs from given SliceCache by `key`, resorting to `loadDESC` if required. Uses `page` to sort + page resulting IDs. +// NOTE: IDs returned from `cache` / `loadDESC` MUST be in descending order, otherwise paging will not work correctly / return things out of order. +func loadPagedIDs(cache *cache.SliceCache[string], key string, page *paging.Page, loadDESC func() ([]string, error)) ([]string, error) { +	// Check cache for IDs, else load. +	ids, err := cache.Load(key, loadDESC) +	if err != nil { +		return nil, err +	} + +	// Our cached / selected IDs are ALWAYS +	// fetched from `loadDESC` in descending +	// order. Depending on the paging requested +	// this may be an unexpected order. +	if page.GetOrder().Ascending() { +		ids = paging.Reverse(ids) +	} + +	// Page the resulting IDs. +	ids = page.Page(ids) + +	return ids, nil +} +  // updateWhere parses []db.Where and adds it to the given update query.  func updateWhere(q *bun.UpdateQuery, where []db.Where) {  	for _, w := range where { diff --git a/internal/db/relationship.go b/internal/db/relationship.go index 91c98644c..b3b45551b 100644 --- a/internal/db/relationship.go +++ b/internal/db/relationship.go @@ -138,43 +138,46 @@ type Relationship interface {  	RejectFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) error  	// GetAccountFollows returns a slice of follows owned by the given accountID. -	GetAccountFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) +	GetAccountFollows(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.Follow, error)  	// GetAccountLocalFollows returns a slice of follows owned by the given accountID, only including follows from this instance.  	GetAccountLocalFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) +	// GetAccountFollowers fetches follows that target given accountID. +	GetAccountFollowers(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.Follow, error) + +	// GetAccountLocalFollowers fetches follows that target given accountID, only including follows from this instance. +	GetAccountLocalFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) + +	// GetAccountFollowRequests returns all follow requests targeting the given account. +	GetAccountFollowRequests(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.FollowRequest, error) + +	// GetAccountFollowRequesting returns all follow requests originating from the given account. +	GetAccountFollowRequesting(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.FollowRequest, error) + +	// GetAccountBlocks returns all blocks originating from the given account, with given optional paging parameters. +	GetAccountBlocks(ctx context.Context, accountID string, paging *paging.Page) ([]*gtsmodel.Block, error) +  	// CountAccountFollows returns the amount of accounts that the given accountID is following.  	CountAccountFollows(ctx context.Context, accountID string) (int, error)  	// CountAccountLocalFollows returns the amount of accounts that the given accountID is following, only including follows from this instance.  	CountAccountLocalFollows(ctx context.Context, accountID string) (int, error) -	// GetAccountFollowers fetches follows that target given accountID. -	GetAccountFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) - -	// GetAccountLocalFollowers fetches follows that target given accountID, only including follows from this instance. -	GetAccountLocalFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) -  	// CountAccountFollowers returns the amounts that the given ID is followed by.  	CountAccountFollowers(ctx context.Context, accountID string) (int, error)  	// CountAccountLocalFollowers returns the amounts that the given ID is followed by, only including follows from this instance.  	CountAccountLocalFollowers(ctx context.Context, accountID string) (int, error) -	// GetAccountFollowRequests returns all follow requests targeting the given account. -	GetAccountFollowRequests(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, error) - -	// GetAccountFollowRequesting returns all follow requests originating from the given account. -	GetAccountFollowRequesting(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, error) -  	// CountAccountFollowRequests returns number of follow requests targeting the given account.  	CountAccountFollowRequests(ctx context.Context, accountID string) (int, error)  	// CountAccountFollowerRequests returns number of follow requests originating from the given account.  	CountAccountFollowRequesting(ctx context.Context, accountID string) (int, error) -	// GetAccountBlocks returns all blocks originating from the given account, with given optional paging parameters. -	GetAccountBlocks(ctx context.Context, accountID string, paging *paging.Page) ([]*gtsmodel.Block, error) +	// CountAccountBlocks ... +	CountAccountBlocks(ctx context.Context, accountID string) (int, error)  	// GetNote gets a private note from a source account on a target account, if it exists.  	GetNote(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.AccountNote, error) | 
