diff options
Diffstat (limited to 'internal/db')
| -rw-r--r-- | internal/db/account.go | 2 | ||||
| -rw-r--r-- | internal/db/bundb/account.go | 40 | ||||
| -rw-r--r-- | internal/db/bundb/emoji.go | 16 | ||||
| -rw-r--r-- | internal/db/bundb/list.go | 9 | ||||
| -rw-r--r-- | internal/db/bundb/media.go | 44 | ||||
| -rw-r--r-- | internal/db/bundb/relationship.go | 215 | ||||
| -rw-r--r-- | internal/db/bundb/relationship_block.go | 37 | ||||
| -rw-r--r-- | internal/db/bundb/relationship_follow.go | 22 | ||||
| -rw-r--r-- | internal/db/bundb/relationship_follow_req.go | 25 | ||||
| -rw-r--r-- | internal/db/bundb/status.go | 5 | ||||
| -rw-r--r-- | internal/db/relationship.go | 4 | 
11 files changed, 236 insertions, 183 deletions
| diff --git a/internal/db/account.go b/internal/db/account.go index 21b8d6a1f..505ca4004 100644 --- a/internal/db/account.go +++ b/internal/db/account.go @@ -104,8 +104,6 @@ type Account interface {  	// In the case of no statuses, this function will return db.ErrNoEntries.  	GetAccountWebStatuses(ctx context.Context, accountID string, limit int, maxID string) ([]*gtsmodel.Status, error) -	GetAccountBlocks(ctx context.Context, accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, error) -  	// GetAccountLastPosted simply gets the timestamp of the most recent post by the account.  	//  	// If webOnly is true, then the time of the last non-reply, non-boost, public status of the account will be returned. diff --git a/internal/db/bundb/account.go b/internal/db/bundb/account.go index 2ef1618db..e57c01a82 100644 --- a/internal/db/bundb/account.go +++ b/internal/db/bundb/account.go @@ -694,46 +694,6 @@ func (a *accountDB) GetAccountWebStatuses(ctx context.Context, accountID string,  	return a.statusesFromIDs(ctx, statusIDs)  } -func (a *accountDB) GetAccountBlocks(ctx context.Context, accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, error) { -	blocks := []*gtsmodel.Block{} - -	fq := a.db. -		NewSelect(). -		Model(&blocks). -		Where("? = ?", bun.Ident("block.account_id"), accountID). -		Relation("TargetAccount"). -		Order("block.id DESC") - -	if maxID != "" { -		fq = fq.Where("? < ?", bun.Ident("block.id"), maxID) -	} - -	if sinceID != "" { -		fq = fq.Where("? > ?", bun.Ident("block.id"), sinceID) -	} - -	if limit > 0 { -		fq = fq.Limit(limit) -	} - -	if err := fq.Scan(ctx); err != nil { -		return nil, "", "", a.db.ProcessError(err) -	} - -	if len(blocks) == 0 { -		return nil, "", "", db.ErrNoEntries -	} - -	accounts := []*gtsmodel.Account{} -	for _, b := range blocks { -		accounts = append(accounts, b.TargetAccount) -	} - -	nextMaxID := blocks[len(blocks)-1].ID -	prevMinID := blocks[0].ID -	return accounts, nextMaxID, prevMinID, nil -} -  func (a *accountDB) statusesFromIDs(ctx context.Context, statusIDs []string) ([]*gtsmodel.Status, error) {  	// Catch case of no statuses early  	if len(statusIDs) == 0 { diff --git a/internal/db/bundb/emoji.go b/internal/db/bundb/emoji.go index 90bcd134d..04f22b6e9 100644 --- a/internal/db/bundb/emoji.go +++ b/internal/db/bundb/emoji.go @@ -126,16 +126,12 @@ func (e *emojiDB) DeleteEmojiByID(ctx context.Context, id string) error {  			return err  		} -		// Prepare SELECT accounts query. -		aq := tx.NewSelect(). -			Table("accounts"). -			Column("id") - -		// Append a WHERE LIKE clause to the query +		// Prepare a SELECT query with a WHERE LIKE  		// that checks the `emoji` column for any  		// text containing this specific emoji ID.  		//  		// (see GetStatusesUsingEmoji() for details.) +		aq := tx.NewSelect().Table("accounts").Column("id")  		aq = whereLike(aq, "emojis", id)  		// Select all accounts using this emoji into accountIDss. @@ -170,16 +166,12 @@ func (e *emojiDB) DeleteEmojiByID(ctx context.Context, id string) error {  			}  		} -		// Prepare SELECT statuses query. -		sq := tx.NewSelect(). -			Table("statuses"). -			Column("id") - -		// Append a WHERE LIKE clause to the query +		// Prepare a SELECT query with a WHERE LIKE  		// that checks the `emoji` column for any  		// text containing this specific emoji ID.  		//  		// (see GetStatusesUsingEmoji() for details.) +		sq := tx.NewSelect().Table("statuses").Column("id")  		sq = whereLike(sq, "emojis", id)  		// Select all statuses using this emoji into statusIDs. diff --git a/internal/db/bundb/list.go b/internal/db/bundb/list.go index 25bb3a65d..70faf837a 100644 --- a/internal/db/bundb/list.go +++ b/internal/db/bundb/list.go @@ -189,11 +189,10 @@ func (l *listDB) DeleteListByID(ctx context.Context, id string) error {  		gtscontext.SetBarebones(ctx),  		id,  	) -	if err != nil { -		if errors.Is(err, db.ErrNoEntries) { -			// Already gone. -			return nil -		} +	if err != nil && !errors.Is(err, db.ErrNoEntries) { +		// NOTE: even if db.ErrNoEntries is returned, we +		// still run the below transaction to ensure related +		// objects are appropriately deleted.  		return err  	} diff --git a/internal/db/bundb/media.go b/internal/db/bundb/media.go index 3b885af61..b8120b87a 100644 --- a/internal/db/bundb/media.go +++ b/internal/db/bundb/media.go @@ -106,8 +106,6 @@ func (m *mediaDB) UpdateAttachment(ctx context.Context, media *gtsmodel.MediaAtt  }  func (m *mediaDB) DeleteAttachment(ctx context.Context, id string) error { -	defer m.state.Caches.GTS.Media().Invalidate("ID", id) -  	// Load media into cache before attempting a delete,  	// as we need it cached in order to trigger the invalidate  	// callback. This in turn invalidates others. @@ -120,10 +118,8 @@ func (m *mediaDB) DeleteAttachment(ctx context.Context, id string) error {  		return err  	} -	var ( -		invalidateAccount bool -		invalidateStatus  bool -	) +	// On return, ensure that media with ID is invalidated. +	defer m.state.Caches.GTS.Media().Invalidate("ID", id)  	// Delete media attachment in new transaction.  	err = m.db.RunInTx(ctx, func(tx bun.Tx) error { @@ -161,9 +157,6 @@ func (m *mediaDB) DeleteAttachment(ctx context.Context, id string) error {  				if _, err := set(q).Exec(ctx); err != nil {  					return gtserror.Newf("error updating account: %w", err)  				} - -				// Mark as needing invalidate. -				invalidateAccount = true  			}  		} @@ -178,33 +171,18 @@ func (m *mediaDB) DeleteAttachment(ctx context.Context, id string) error {  				return gtserror.Newf("error selecting status: %w", err)  			} -			// Get length of attachments beforehand. -			before := len(status.AttachmentIDs) - -			for i := 0; i < len(status.AttachmentIDs); { -				if status.AttachmentIDs[i] == id { -					// Remove this reference to deleted attachment ID. -					copy(status.AttachmentIDs[i:], status.AttachmentIDs[i+1:]) -					status.AttachmentIDs = status.AttachmentIDs[:len(status.AttachmentIDs)-1] -					continue -				} -				i++ -			} - -			if before != len(status.AttachmentIDs) { -				// Note: this accounts for status not found. +			if updatedIDs := dropID(status.AttachmentIDs, id); // nocollapse +			len(updatedIDs) != len(status.AttachmentIDs) { +				// Note: this handles not found.  				//  				// Attachments changed, update the status.  				if _, err := tx.NewUpdate().  					Table("statuses").  					Where("? = ?", bun.Ident("id"), status.ID). -					Set("? = ?", bun.Ident("attachment_ids"), status.AttachmentIDs). +					Set("? = ?", bun.Ident("attachment_ids"), updatedIDs).  					Exec(ctx); err != nil {  					return gtserror.Newf("error updating status: %w", err)  				} - -				// Mark as needing invalidate. -				invalidateStatus = true  			}  		} @@ -219,16 +197,6 @@ func (m *mediaDB) DeleteAttachment(ctx context.Context, id string) error {  		return nil  	}) -	if invalidateAccount { -		// The account for given ID will have been updated in transaction. -		m.state.Caches.GTS.Account().Invalidate("ID", media.AccountID) -	} - -	if invalidateStatus { -		// The status for given ID will have been updated in transaction. -		m.state.Caches.GTS.Status().Invalidate("ID", media.StatusID) -	} -  	return m.db.ProcessError(err)  } diff --git a/internal/db/bundb/relationship.go b/internal/db/bundb/relationship.go index eddd73b49..e7b563f2e 100644 --- a/internal/db/bundb/relationship.go +++ b/internal/db/bundb/relationship.go @@ -20,11 +20,12 @@ package bundb  import (  	"context"  	"errors" -	"fmt"  	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/gtscontext" +	"github.com/superseriousbusiness/gotosocial/internal/gtserror"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +	"github.com/superseriousbusiness/gotosocial/internal/paging"  	"github.com/superseriousbusiness/gotosocial/internal/state"  	"github.com/uptrace/bun"  ) @@ -45,7 +46,7 @@ func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount  		targetAccount,  	)  	if err != nil && !errors.Is(err, db.ErrNoEntries) { -		return nil, fmt.Errorf("GetRelationship: error fetching follow: %w", err) +		return nil, gtserror.Newf("error fetching follow: %w", err)  	}  	if follow != nil { @@ -61,7 +62,7 @@ func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount  		requestingAccount,  	)  	if err != nil { -		return nil, fmt.Errorf("GetRelationship: error checking followedBy: %w", err) +		return nil, gtserror.Newf("error checking followedBy: %w", err)  	}  	// check if requesting has follow requested target @@ -70,19 +71,19 @@ func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount  		targetAccount,  	)  	if err != nil { -		return nil, fmt.Errorf("GetRelationship: error checking requested: %w", err) +		return nil, gtserror.Newf("error checking requested: %w", err)  	}  	// check if the requesting account is blocking the target account  	rel.Blocking, err = r.IsBlocked(ctx, requestingAccount, targetAccount)  	if err != nil { -		return nil, fmt.Errorf("GetRelationship: error checking blocking: %w", err) +		return nil, gtserror.Newf("error checking blocking: %w", err)  	}  	// check if the requesting account is blocked by the target account  	rel.BlockedBy, err = r.IsBlocked(ctx, targetAccount, requestingAccount)  	if err != nil { -		return nil, fmt.Errorf("GetRelationship: error checking blockedBy: %w", err) +		return nil, gtserror.Newf("error checking blockedBy: %w", err)  	}  	// retrieve a note by the requesting account on the target account, if there is one @@ -92,7 +93,7 @@ func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount  		targetAccount,  	)  	if err != nil && !errors.Is(err, db.ErrNoEntries) { -		return nil, fmt.Errorf("GetRelationship: error fetching note: %w", err) +		return nil, gtserror.Newf("error fetching note: %w", err)  	}  	if note != nil {  		rel.Note = note.Comment @@ -102,87 +103,186 @@ func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount  }  func (r *relationshipDB) GetAccountFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) { -	var followIDs []string -	if err := newSelectFollows(r.db, accountID). -		Scan(ctx, &followIDs); err != nil { -		return nil, r.db.ProcessError(err) +	followIDs, err := r.getAccountFollowIDs(ctx, accountID) +	if err != nil { +		return nil, err  	}  	return r.GetFollowsByIDs(ctx, followIDs)  }  func (r *relationshipDB) GetAccountLocalFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) { -	var followIDs []string -	if err := newSelectLocalFollows(r.db, accountID). -		Scan(ctx, &followIDs); err != nil { -		return nil, r.db.ProcessError(err) +	followIDs, err := r.getAccountLocalFollowIDs(ctx, accountID) +	if err != nil { +		return nil, err  	}  	return r.GetFollowsByIDs(ctx, followIDs)  }  func (r *relationshipDB) GetAccountFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) { -	var followIDs []string -	if err := newSelectFollowers(r.db, accountID). -		Scan(ctx, &followIDs); err != nil { -		return nil, r.db.ProcessError(err) +	followerIDs, err := r.getAccountFollowerIDs(ctx, accountID) +	if err != nil { +		return nil, err  	} -	return r.GetFollowsByIDs(ctx, followIDs) +	return r.GetFollowsByIDs(ctx, followerIDs)  }  func (r *relationshipDB) GetAccountLocalFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) { -	var followIDs []string -	if err := newSelectLocalFollowers(r.db, accountID). -		Scan(ctx, &followIDs); err != nil { -		return nil, r.db.ProcessError(err) +	followerIDs, err := r.getAccountLocalFollowerIDs(ctx, accountID) +	if err != nil { +		return nil, err  	} -	return r.GetFollowsByIDs(ctx, followIDs) +	return r.GetFollowsByIDs(ctx, followerIDs) +} + +func (r *relationshipDB) GetAccountFollowRequests(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, error) { +	followReqIDs, err := r.getAccountFollowRequestIDs(ctx, accountID) +	if err != nil { +		return nil, err +	} +	return r.GetFollowRequestsByIDs(ctx, followReqIDs) +} + +func (r *relationshipDB) GetAccountFollowRequesting(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, error) { +	followReqIDs, err := r.getAccountFollowRequestingIDs(ctx, accountID) +	if err != nil { +		return nil, err +	} +	return r.GetFollowRequestsByIDs(ctx, followReqIDs) +} + +func (r *relationshipDB) GetAccountBlocks(ctx context.Context, accountID string, page *paging.Pager) ([]*gtsmodel.Block, error) { +	// Load block IDs from cache with database loader callback. +	blockIDs, err := r.state.Caches.GTS.BlockIDs().LoadRange(accountID, func() ([]string, error) { +		var blockIDs []string + +		// Block IDs not in cache, perform DB query! +		q := newSelectBlocks(r.db, accountID) +		if _, err := q.Exec(ctx, &blockIDs); err != nil { +			return nil, r.db.ProcessError(err) +		} + +		return blockIDs, nil +	}, page.PageDesc) +	if err != nil { +		return nil, err +	} + +	// Convert these IDs to full block objects. +	return r.GetBlocksByIDs(ctx, blockIDs)  }  func (r *relationshipDB) CountAccountFollows(ctx context.Context, accountID string) (int, error) { -	n, err := newSelectFollows(r.db, accountID).Count(ctx) -	return n, r.db.ProcessError(err) +	followIDs, err := r.getAccountFollowIDs(ctx, accountID) +	return len(followIDs), err  }  func (r *relationshipDB) CountAccountLocalFollows(ctx context.Context, accountID string) (int, error) { -	n, err := newSelectLocalFollows(r.db, accountID).Count(ctx) -	return n, r.db.ProcessError(err) +	followIDs, err := r.getAccountLocalFollowIDs(ctx, accountID) +	return len(followIDs), err  }  func (r *relationshipDB) CountAccountFollowers(ctx context.Context, accountID string) (int, error) { -	n, err := newSelectFollowers(r.db, accountID).Count(ctx) -	return n, r.db.ProcessError(err) +	followerIDs, err := r.getAccountFollowerIDs(ctx, accountID) +	return len(followerIDs), err  }  func (r *relationshipDB) CountAccountLocalFollowers(ctx context.Context, accountID string) (int, error) { -	n, err := newSelectLocalFollowers(r.db, accountID).Count(ctx) -	return n, r.db.ProcessError(err) +	followerIDs, err := r.getAccountLocalFollowerIDs(ctx, accountID) +	return len(followerIDs), err  } -func (r *relationshipDB) GetAccountFollowRequests(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, error) { -	var followReqIDs []string -	if err := newSelectFollowRequests(r.db, accountID). -		Scan(ctx, &followReqIDs); err != nil { -		return nil, r.db.ProcessError(err) -	} -	return r.GetFollowRequestsByIDs(ctx, followReqIDs) +func (r *relationshipDB) CountAccountFollowRequests(ctx context.Context, accountID string) (int, error) { +	followReqIDs, err := r.getAccountFollowRequestIDs(ctx, accountID) +	return len(followReqIDs), err  } -func (r *relationshipDB) GetAccountFollowRequesting(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, error) { -	var followReqIDs []string -	if err := newSelectFollowRequesting(r.db, accountID). -		Scan(ctx, &followReqIDs); err != nil { -		return nil, r.db.ProcessError(err) -	} -	return r.GetFollowRequestsByIDs(ctx, followReqIDs) +func (r *relationshipDB) CountAccountFollowRequesting(ctx context.Context, accountID string) (int, error) { +	followReqIDs, err := r.getAccountFollowRequestingIDs(ctx, accountID) +	return len(followReqIDs), err  } -func (r *relationshipDB) CountAccountFollowRequests(ctx context.Context, accountID string) (int, error) { -	n, err := newSelectFollowRequests(r.db, accountID).Count(ctx) -	return n, r.db.ProcessError(err) +func (r *relationshipDB) getAccountFollowIDs(ctx context.Context, accountID string) ([]string, error) { +	return r.state.Caches.GTS.FollowIDs().Load(">"+accountID, func() ([]string, error) { +		var followIDs []string + +		// Follow IDs not in cache, perform DB query! +		q := newSelectFollows(r.db, accountID) +		if _, err := q.Exec(ctx, &followIDs); err != nil { +			return nil, r.db.ProcessError(err) +		} + +		return followIDs, nil +	})  } -func (r *relationshipDB) CountAccountFollowRequesting(ctx context.Context, accountID string) (int, error) { -	n, err := newSelectFollowRequesting(r.db, accountID).Count(ctx) -	return n, r.db.ProcessError(err) +func (r *relationshipDB) getAccountLocalFollowIDs(ctx context.Context, accountID string) ([]string, error) { +	return r.state.Caches.GTS.FollowIDs().Load("l>"+accountID, func() ([]string, error) { +		var followIDs []string + +		// Follow IDs not in cache, perform DB query! +		q := newSelectLocalFollows(r.db, accountID) +		if _, err := q.Exec(ctx, &followIDs); err != nil { +			return nil, r.db.ProcessError(err) +		} + +		return followIDs, nil +	}) +} + +func (r *relationshipDB) getAccountFollowerIDs(ctx context.Context, accountID string) ([]string, error) { +	return r.state.Caches.GTS.FollowIDs().Load("<"+accountID, func() ([]string, error) { +		var followIDs []string + +		// Follow IDs not in cache, perform DB query! +		q := newSelectFollowers(r.db, accountID) +		if _, err := q.Exec(ctx, &followIDs); err != nil { +			return nil, r.db.ProcessError(err) +		} + +		return followIDs, nil +	}) +} + +func (r *relationshipDB) getAccountLocalFollowerIDs(ctx context.Context, accountID string) ([]string, error) { +	return r.state.Caches.GTS.FollowIDs().Load("l<"+accountID, func() ([]string, error) { +		var followIDs []string + +		// Follow IDs not in cache, perform DB query! +		q := newSelectLocalFollowers(r.db, accountID) +		if _, err := q.Exec(ctx, &followIDs); err != nil { +			return nil, r.db.ProcessError(err) +		} + +		return followIDs, nil +	}) +} + +func (r *relationshipDB) getAccountFollowRequestIDs(ctx context.Context, accountID string) ([]string, error) { +	return r.state.Caches.GTS.FollowRequestIDs().Load(">"+accountID, func() ([]string, error) { +		var followReqIDs []string + +		// Follow request IDs not in cache, perform DB query! +		q := newSelectFollowRequests(r.db, accountID) +		if _, err := q.Exec(ctx, &followReqIDs); err != nil { +			return nil, r.db.ProcessError(err) +		} + +		return followReqIDs, nil +	}) +} + +func (r *relationshipDB) getAccountFollowRequestingIDs(ctx context.Context, accountID string) ([]string, error) { +	return r.state.Caches.GTS.FollowRequestIDs().Load("<"+accountID, func() ([]string, error) { +		var followReqIDs []string + +		// Follow request IDs not in cache, perform DB query! +		q := newSelectFollowRequesting(r.db, accountID) +		if _, err := q.Exec(ctx, &followReqIDs); err != nil { +			return nil, r.db.ProcessError(err) +		} + +		return followReqIDs, nil +	})  }  // newSelectFollowRequests returns a new select query for all rows in the follow_requests table with target_account_id = accountID. @@ -256,3 +356,12 @@ func newSelectLocalFollowers(db *WrappedDB, accountID string) *bun.SelectQuery {  		).  		OrderExpr("? DESC", bun.Ident("updated_at"))  } + +// newSelectBlocks returns a new select query for all rows in the blocks table with account_id = accountID. +func newSelectBlocks(db *WrappedDB, accountID string) *bun.SelectQuery { +	return db.NewSelect(). +		TableExpr("?", bun.Ident("blocks")). +		ColumnExpr("?", bun.Ident("?")). +		Where("? = ?", bun.Ident("account_id"), accountID). +		OrderExpr("? DESC", bun.Ident("updated_at")) +} diff --git a/internal/db/bundb/relationship_block.go b/internal/db/bundb/relationship_block.go index 948e82fcb..2a042bed4 100644 --- a/internal/db/bundb/relationship_block.go +++ b/internal/db/bundb/relationship_block.go @@ -25,6 +25,7 @@ import (  	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/gtscontext"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +	"github.com/superseriousbusiness/gotosocial/internal/log"  	"github.com/uptrace/bun"  ) @@ -97,6 +98,25 @@ func (r *relationshipDB) GetBlock(ctx context.Context, sourceAccountID string, t  	)  } +func (r *relationshipDB) GetBlocksByIDs(ctx context.Context, ids []string) ([]*gtsmodel.Block, error) { +	// Preallocate slice of expected length. +	blocks := make([]*gtsmodel.Block, 0, len(ids)) + +	for _, id := range ids { +		// Fetch block model for this ID. +		block, err := r.GetBlockByID(ctx, id) +		if err != nil { +			log.Errorf(ctx, "error getting block %q: %v", id, err) +			continue +		} + +		// Append to return slice. +		blocks = append(blocks, block) +	} + +	return blocks, nil +} +  func (r *relationshipDB) getBlock(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Block) error, keyParts ...any) (*gtsmodel.Block, error) {  	// Fetch block from cache with loader callback  	block, err := r.state.Caches.GTS.Block().Load(lookup, func() (*gtsmodel.Block, error) { @@ -148,8 +168,6 @@ func (r *relationshipDB) PutBlock(ctx context.Context, block *gtsmodel.Block) er  }  func (r *relationshipDB) DeleteBlockByID(ctx context.Context, id string) error { -	defer r.state.Caches.GTS.Block().Invalidate("ID", id) -  	// Load block into cache before attempting a delete,  	// as we need it cached in order to trigger the invalidate  	// callback. This in turn invalidates others. @@ -162,6 +180,9 @@ func (r *relationshipDB) DeleteBlockByID(ctx context.Context, id string) error {  		return err  	} +	// Drop this now-cached block on return after delete. +	defer r.state.Caches.GTS.Block().Invalidate("ID", id) +  	// Finally delete block from DB.  	_, err = r.db.NewDelete().  		Table("blocks"). @@ -171,8 +192,6 @@ func (r *relationshipDB) DeleteBlockByID(ctx context.Context, id string) error {  }  func (r *relationshipDB) DeleteBlockByURI(ctx context.Context, uri string) error { -	defer r.state.Caches.GTS.Block().Invalidate("URI", uri) -  	// Load block into cache before attempting a delete,  	// as we need it cached in order to trigger the invalidate  	// callback. This in turn invalidates others. @@ -185,6 +204,9 @@ func (r *relationshipDB) DeleteBlockByURI(ctx context.Context, uri string) error  		return err  	} +	// Drop this now-cached block on return after delete. +	defer r.state.Caches.GTS.Block().Invalidate("URI", uri) +  	// Finally delete block from DB.  	_, err = r.db.NewDelete().  		Table("blocks"). @@ -211,10 +233,9 @@ func (r *relationshipDB) DeleteAccountBlocks(ctx context.Context, accountID stri  	}  	defer func() { -		// Invalidate all IDs on return. -		for _, id := range blockIDs { -			r.state.Caches.GTS.Block().Invalidate("ID", id) -		} +		// Invalidate all account's incoming / outoing blocks on return. +		r.state.Caches.GTS.Block().Invalidate("AccountID", accountID) +		r.state.Caches.GTS.Block().Invalidate("TargetAccountID", accountID)  	}()  	// Load all blocks into cache, this *really* isn't great diff --git a/internal/db/bundb/relationship_follow.go b/internal/db/bundb/relationship_follow.go index 84501b0be..3b0597612 100644 --- a/internal/db/bundb/relationship_follow.go +++ b/internal/db/bundb/relationship_follow.go @@ -233,8 +233,6 @@ func (r *relationshipDB) deleteFollow(ctx context.Context, id string) error {  }  func (r *relationshipDB) DeleteFollow(ctx context.Context, sourceAccountID string, targetAccountID string) error { -	defer r.state.Caches.GTS.Follow().Invalidate("AccountID.TargetAccountID", sourceAccountID, targetAccountID) -  	// Load follow into cache before attempting a delete,  	// as we need it cached in order to trigger the invalidate  	// callback. This in turn invalidates others. @@ -251,13 +249,14 @@ func (r *relationshipDB) DeleteFollow(ctx context.Context, sourceAccountID strin  		return err  	} +	// Drop this now-cached follow on return after delete. +	defer r.state.Caches.GTS.Follow().Invalidate("AccountID.TargetAccountID", sourceAccountID, targetAccountID) +  	// Finally delete follow from DB.  	return r.deleteFollow(ctx, follow.ID)  }  func (r *relationshipDB) DeleteFollowByID(ctx context.Context, id string) error { -	defer r.state.Caches.GTS.Follow().Invalidate("ID", id) -  	// Load follow into cache before attempting a delete,  	// as we need it cached in order to trigger the invalidate  	// callback. This in turn invalidates others. @@ -270,13 +269,14 @@ func (r *relationshipDB) DeleteFollowByID(ctx context.Context, id string) error  		return err  	} +	// Drop this now-cached follow on return after delete. +	defer r.state.Caches.GTS.Follow().Invalidate("ID", id) +  	// Finally delete follow from DB.  	return r.deleteFollow(ctx, follow.ID)  }  func (r *relationshipDB) DeleteFollowByURI(ctx context.Context, uri string) error { -	defer r.state.Caches.GTS.Follow().Invalidate("URI", uri) -  	// Load follow into cache before attempting a delete,  	// as we need it cached in order to trigger the invalidate  	// callback. This in turn invalidates others. @@ -289,6 +289,9 @@ func (r *relationshipDB) DeleteFollowByURI(ctx context.Context, uri string) erro  		return err  	} +	// Drop this now-cached follow on return after delete. +	defer r.state.Caches.GTS.Follow().Invalidate("URI", uri) +  	// Finally delete follow from DB.  	return r.deleteFollow(ctx, follow.ID)  } @@ -312,10 +315,9 @@ func (r *relationshipDB) DeleteAccountFollows(ctx context.Context, accountID str  	}  	defer func() { -		// Invalidate all IDs on return. -		for _, id := range followIDs { -			r.state.Caches.GTS.Follow().Invalidate("ID", id) -		} +		// Invalidate all account's incoming / outoing follows on return. +		r.state.Caches.GTS.Follow().Invalidate("AccountID", accountID) +		r.state.Caches.GTS.Follow().Invalidate("TargetAccountID", accountID)  	}()  	// Load all follows into cache, this *really* isn't great diff --git a/internal/db/bundb/relationship_follow_req.go b/internal/db/bundb/relationship_follow_req.go index a6e913953..dc5e760e6 100644 --- a/internal/db/bundb/relationship_follow_req.go +++ b/internal/db/bundb/relationship_follow_req.go @@ -208,9 +208,6 @@ func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, sourceAccountI  		return nil, err  	} -	// Invalidate follow request from cache lookups on return. -	defer r.state.Caches.GTS.FollowRequest().Invalidate("ID", followReq.ID) -  	// Delete original follow request.  	if _, err := r.db.  		NewDelete(). @@ -243,8 +240,6 @@ func (r *relationshipDB) RejectFollowRequest(ctx context.Context, sourceAccountI  }  func (r *relationshipDB) DeleteFollowRequest(ctx context.Context, sourceAccountID string, targetAccountID string) error { -	defer r.state.Caches.GTS.FollowRequest().Invalidate("AccountID.TargetAccountID", sourceAccountID, targetAccountID) -  	// Load followreq into cache before attempting a delete,  	// as we need it cached in order to trigger the invalidate  	// callback. This in turn invalidates others. @@ -261,6 +256,9 @@ func (r *relationshipDB) DeleteFollowRequest(ctx context.Context, sourceAccountI  		return err  	} +	// Drop this now-cached follow request on return after delete. +	defer r.state.Caches.GTS.FollowRequest().Invalidate("AccountID.TargetAccountID", sourceAccountID, targetAccountID) +  	// Finally delete followreq from DB.  	_, err = r.db.NewDelete().  		Table("follow_requests"). @@ -270,8 +268,6 @@ func (r *relationshipDB) DeleteFollowRequest(ctx context.Context, sourceAccountI  }  func (r *relationshipDB) DeleteFollowRequestByID(ctx context.Context, id string) error { -	defer r.state.Caches.GTS.FollowRequest().Invalidate("ID", id) -  	// Load followreq into cache before attempting a delete,  	// as we need it cached in order to trigger the invalidate  	// callback. This in turn invalidates others. @@ -284,6 +280,9 @@ func (r *relationshipDB) DeleteFollowRequestByID(ctx context.Context, id string)  		return err  	} +	// Drop this now-cached follow request on return after delete. +	defer r.state.Caches.GTS.FollowRequest().Invalidate("ID", id) +  	// Finally delete followreq from DB.  	_, err = r.db.NewDelete().  		Table("follow_requests"). @@ -293,8 +292,6 @@ func (r *relationshipDB) DeleteFollowRequestByID(ctx context.Context, id string)  }  func (r *relationshipDB) DeleteFollowRequestByURI(ctx context.Context, uri string) error { -	defer r.state.Caches.GTS.FollowRequest().Invalidate("URI", uri) -  	// Load followreq into cache before attempting a delete,  	// as we need it cached in order to trigger the invalidate  	// callback. This in turn invalidates others. @@ -307,6 +304,9 @@ func (r *relationshipDB) DeleteFollowRequestByURI(ctx context.Context, uri strin  		return err  	} +	// Drop this now-cached follow request on return after delete. +	defer r.state.Caches.GTS.FollowRequest().Invalidate("URI", uri) +  	// Finally delete followreq from DB.  	_, err = r.db.NewDelete().  		Table("follow_requests"). @@ -334,10 +334,9 @@ func (r *relationshipDB) DeleteAccountFollowRequests(ctx context.Context, accoun  	}  	defer func() { -		// Invalidate all IDs on return. -		for _, id := range followReqIDs { -			r.state.Caches.GTS.FollowRequest().Invalidate("ID", id) -		} +		// Invalidate all account's incoming / outoing follow requests on return. +		r.state.Caches.GTS.FollowRequest().Invalidate("AccountID", accountID) +		r.state.Caches.GTS.FollowRequest().Invalidate("TargetAccountID", accountID)  	}()  	// Load all followreqs into cache, this *really* isn't diff --git a/internal/db/bundb/status.go b/internal/db/bundb/status.go index a019216d0..4dc7d8468 100644 --- a/internal/db/bundb/status.go +++ b/internal/db/bundb/status.go @@ -381,8 +381,6 @@ func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status, co  }  func (s *statusDB) DeleteStatusByID(ctx context.Context, id string) error { -	defer s.state.Caches.GTS.Status().Invalidate("ID", id) -  	// Load status into cache before attempting a delete,  	// as we need it cached in order to trigger the invalidate  	// callback. This in turn invalidates others. @@ -397,6 +395,9 @@ func (s *statusDB) DeleteStatusByID(ctx context.Context, id string) error {  		return err  	} +	// On return ensure status invalidated from cache. +	defer s.state.Caches.GTS.Status().Invalidate("ID", id) +  	return s.db.RunInTx(ctx, func(tx bun.Tx) error {  		// delete links between this status and any emojis it uses  		if _, err := tx. diff --git a/internal/db/relationship.go b/internal/db/relationship.go index e19aee646..6ba9fdf8c 100644 --- a/internal/db/relationship.go +++ b/internal/db/relationship.go @@ -21,6 +21,7 @@ import (  	"context"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +	"github.com/superseriousbusiness/gotosocial/internal/paging"  )  // Relationship contains functions for getting or modifying the relationship between two accounts. @@ -166,6 +167,9 @@ type Relationship interface {  	// CountAccountFollowerRequests returns number of follow requests originating from the given account.  	CountAccountFollowRequesting(ctx context.Context, accountID string) (int, error) +	// GetAccountBlocks returns all blocks originating from the given account, with given optional paging parameters. +	GetAccountBlocks(ctx context.Context, accountID string, paging *paging.Pager) ([]*gtsmodel.Block, error) +  	// GetNote gets a private note from a source account on a target account, if it exists.  	GetNote(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.AccountNote, error) | 
