diff options
author | 2023-07-31 11:25:29 +0100 | |
---|---|---|
committer | 2023-07-31 11:25:29 +0100 | |
commit | ed2477ebea4c3ceec5949821f4950db9669a4a15 (patch) | |
tree | 1038d7abdfc787ddfc1febb326fd38775b189b85 /internal/db/bundb/relationship.go | |
parent | [bugfix/frontend] Decode URI component domain before showing on frontend (#2043) (diff) | |
download | gotosocial-ed2477ebea4c3ceec5949821f4950db9669a4a15.tar.xz |
[performance] cache follow, follow request and block ID lists (#2027)
Diffstat (limited to 'internal/db/bundb/relationship.go')
-rw-r--r-- | internal/db/bundb/relationship.go | 215 |
1 files changed, 162 insertions, 53 deletions
diff --git a/internal/db/bundb/relationship.go b/internal/db/bundb/relationship.go index eddd73b49..e7b563f2e 100644 --- a/internal/db/bundb/relationship.go +++ b/internal/db/bundb/relationship.go @@ -20,11 +20,12 @@ package bundb import ( "context" "errors" - "fmt" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtscontext" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/paging" "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/uptrace/bun" ) @@ -45,7 +46,7 @@ func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount targetAccount, ) if err != nil && !errors.Is(err, db.ErrNoEntries) { - return nil, fmt.Errorf("GetRelationship: error fetching follow: %w", err) + return nil, gtserror.Newf("error fetching follow: %w", err) } if follow != nil { @@ -61,7 +62,7 @@ func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount requestingAccount, ) if err != nil { - return nil, fmt.Errorf("GetRelationship: error checking followedBy: %w", err) + return nil, gtserror.Newf("error checking followedBy: %w", err) } // check if requesting has follow requested target @@ -70,19 +71,19 @@ func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount targetAccount, ) if err != nil { - return nil, fmt.Errorf("GetRelationship: error checking requested: %w", err) + return nil, gtserror.Newf("error checking requested: %w", err) } // check if the requesting account is blocking the target account rel.Blocking, err = r.IsBlocked(ctx, requestingAccount, targetAccount) if err != nil { - return nil, fmt.Errorf("GetRelationship: error checking blocking: %w", err) + return nil, gtserror.Newf("error checking blocking: %w", err) } // check if the requesting account is blocked by the target account rel.BlockedBy, err = r.IsBlocked(ctx, targetAccount, requestingAccount) if err != nil { - return nil, fmt.Errorf("GetRelationship: error checking blockedBy: %w", err) + return nil, gtserror.Newf("error checking blockedBy: %w", err) } // retrieve a note by the requesting account on the target account, if there is one @@ -92,7 +93,7 @@ func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount targetAccount, ) if err != nil && !errors.Is(err, db.ErrNoEntries) { - return nil, fmt.Errorf("GetRelationship: error fetching note: %w", err) + return nil, gtserror.Newf("error fetching note: %w", err) } if note != nil { rel.Note = note.Comment @@ -102,87 +103,186 @@ func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount } func (r *relationshipDB) GetAccountFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) { - var followIDs []string - if err := newSelectFollows(r.db, accountID). - Scan(ctx, &followIDs); err != nil { - return nil, r.db.ProcessError(err) + followIDs, err := r.getAccountFollowIDs(ctx, accountID) + if err != nil { + return nil, err } return r.GetFollowsByIDs(ctx, followIDs) } func (r *relationshipDB) GetAccountLocalFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) { - var followIDs []string - if err := newSelectLocalFollows(r.db, accountID). - Scan(ctx, &followIDs); err != nil { - return nil, r.db.ProcessError(err) + followIDs, err := r.getAccountLocalFollowIDs(ctx, accountID) + if err != nil { + return nil, err } return r.GetFollowsByIDs(ctx, followIDs) } func (r *relationshipDB) GetAccountFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) { - var followIDs []string - if err := newSelectFollowers(r.db, accountID). - Scan(ctx, &followIDs); err != nil { - return nil, r.db.ProcessError(err) + followerIDs, err := r.getAccountFollowerIDs(ctx, accountID) + if err != nil { + return nil, err } - return r.GetFollowsByIDs(ctx, followIDs) + return r.GetFollowsByIDs(ctx, followerIDs) } func (r *relationshipDB) GetAccountLocalFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) { - var followIDs []string - if err := newSelectLocalFollowers(r.db, accountID). - Scan(ctx, &followIDs); err != nil { - return nil, r.db.ProcessError(err) + followerIDs, err := r.getAccountLocalFollowerIDs(ctx, accountID) + if err != nil { + return nil, err } - return r.GetFollowsByIDs(ctx, followIDs) + return r.GetFollowsByIDs(ctx, followerIDs) +} + +func (r *relationshipDB) GetAccountFollowRequests(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, error) { + followReqIDs, err := r.getAccountFollowRequestIDs(ctx, accountID) + if err != nil { + return nil, err + } + return r.GetFollowRequestsByIDs(ctx, followReqIDs) +} + +func (r *relationshipDB) GetAccountFollowRequesting(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, error) { + followReqIDs, err := r.getAccountFollowRequestingIDs(ctx, accountID) + if err != nil { + return nil, err + } + return r.GetFollowRequestsByIDs(ctx, followReqIDs) +} + +func (r *relationshipDB) GetAccountBlocks(ctx context.Context, accountID string, page *paging.Pager) ([]*gtsmodel.Block, error) { + // Load block IDs from cache with database loader callback. + blockIDs, err := r.state.Caches.GTS.BlockIDs().LoadRange(accountID, func() ([]string, error) { + var blockIDs []string + + // Block IDs not in cache, perform DB query! + q := newSelectBlocks(r.db, accountID) + if _, err := q.Exec(ctx, &blockIDs); err != nil { + return nil, r.db.ProcessError(err) + } + + return blockIDs, nil + }, page.PageDesc) + if err != nil { + return nil, err + } + + // Convert these IDs to full block objects. + return r.GetBlocksByIDs(ctx, blockIDs) } func (r *relationshipDB) CountAccountFollows(ctx context.Context, accountID string) (int, error) { - n, err := newSelectFollows(r.db, accountID).Count(ctx) - return n, r.db.ProcessError(err) + followIDs, err := r.getAccountFollowIDs(ctx, accountID) + return len(followIDs), err } func (r *relationshipDB) CountAccountLocalFollows(ctx context.Context, accountID string) (int, error) { - n, err := newSelectLocalFollows(r.db, accountID).Count(ctx) - return n, r.db.ProcessError(err) + followIDs, err := r.getAccountLocalFollowIDs(ctx, accountID) + return len(followIDs), err } func (r *relationshipDB) CountAccountFollowers(ctx context.Context, accountID string) (int, error) { - n, err := newSelectFollowers(r.db, accountID).Count(ctx) - return n, r.db.ProcessError(err) + followerIDs, err := r.getAccountFollowerIDs(ctx, accountID) + return len(followerIDs), err } func (r *relationshipDB) CountAccountLocalFollowers(ctx context.Context, accountID string) (int, error) { - n, err := newSelectLocalFollowers(r.db, accountID).Count(ctx) - return n, r.db.ProcessError(err) + followerIDs, err := r.getAccountLocalFollowerIDs(ctx, accountID) + return len(followerIDs), err } -func (r *relationshipDB) GetAccountFollowRequests(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, error) { - var followReqIDs []string - if err := newSelectFollowRequests(r.db, accountID). - Scan(ctx, &followReqIDs); err != nil { - return nil, r.db.ProcessError(err) - } - return r.GetFollowRequestsByIDs(ctx, followReqIDs) +func (r *relationshipDB) CountAccountFollowRequests(ctx context.Context, accountID string) (int, error) { + followReqIDs, err := r.getAccountFollowRequestIDs(ctx, accountID) + return len(followReqIDs), err } -func (r *relationshipDB) GetAccountFollowRequesting(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, error) { - var followReqIDs []string - if err := newSelectFollowRequesting(r.db, accountID). - Scan(ctx, &followReqIDs); err != nil { - return nil, r.db.ProcessError(err) - } - return r.GetFollowRequestsByIDs(ctx, followReqIDs) +func (r *relationshipDB) CountAccountFollowRequesting(ctx context.Context, accountID string) (int, error) { + followReqIDs, err := r.getAccountFollowRequestingIDs(ctx, accountID) + return len(followReqIDs), err } -func (r *relationshipDB) CountAccountFollowRequests(ctx context.Context, accountID string) (int, error) { - n, err := newSelectFollowRequests(r.db, accountID).Count(ctx) - return n, r.db.ProcessError(err) +func (r *relationshipDB) getAccountFollowIDs(ctx context.Context, accountID string) ([]string, error) { + return r.state.Caches.GTS.FollowIDs().Load(">"+accountID, func() ([]string, error) { + var followIDs []string + + // Follow IDs not in cache, perform DB query! + q := newSelectFollows(r.db, accountID) + if _, err := q.Exec(ctx, &followIDs); err != nil { + return nil, r.db.ProcessError(err) + } + + return followIDs, nil + }) } -func (r *relationshipDB) CountAccountFollowRequesting(ctx context.Context, accountID string) (int, error) { - n, err := newSelectFollowRequesting(r.db, accountID).Count(ctx) - return n, r.db.ProcessError(err) +func (r *relationshipDB) getAccountLocalFollowIDs(ctx context.Context, accountID string) ([]string, error) { + return r.state.Caches.GTS.FollowIDs().Load("l>"+accountID, func() ([]string, error) { + var followIDs []string + + // Follow IDs not in cache, perform DB query! + q := newSelectLocalFollows(r.db, accountID) + if _, err := q.Exec(ctx, &followIDs); err != nil { + return nil, r.db.ProcessError(err) + } + + return followIDs, nil + }) +} + +func (r *relationshipDB) getAccountFollowerIDs(ctx context.Context, accountID string) ([]string, error) { + return r.state.Caches.GTS.FollowIDs().Load("<"+accountID, func() ([]string, error) { + var followIDs []string + + // Follow IDs not in cache, perform DB query! + q := newSelectFollowers(r.db, accountID) + if _, err := q.Exec(ctx, &followIDs); err != nil { + return nil, r.db.ProcessError(err) + } + + return followIDs, nil + }) +} + +func (r *relationshipDB) getAccountLocalFollowerIDs(ctx context.Context, accountID string) ([]string, error) { + return r.state.Caches.GTS.FollowIDs().Load("l<"+accountID, func() ([]string, error) { + var followIDs []string + + // Follow IDs not in cache, perform DB query! + q := newSelectLocalFollowers(r.db, accountID) + if _, err := q.Exec(ctx, &followIDs); err != nil { + return nil, r.db.ProcessError(err) + } + + return followIDs, nil + }) +} + +func (r *relationshipDB) getAccountFollowRequestIDs(ctx context.Context, accountID string) ([]string, error) { + return r.state.Caches.GTS.FollowRequestIDs().Load(">"+accountID, func() ([]string, error) { + var followReqIDs []string + + // Follow request IDs not in cache, perform DB query! + q := newSelectFollowRequests(r.db, accountID) + if _, err := q.Exec(ctx, &followReqIDs); err != nil { + return nil, r.db.ProcessError(err) + } + + return followReqIDs, nil + }) +} + +func (r *relationshipDB) getAccountFollowRequestingIDs(ctx context.Context, accountID string) ([]string, error) { + return r.state.Caches.GTS.FollowRequestIDs().Load("<"+accountID, func() ([]string, error) { + var followReqIDs []string + + // Follow request IDs not in cache, perform DB query! + q := newSelectFollowRequesting(r.db, accountID) + if _, err := q.Exec(ctx, &followReqIDs); err != nil { + return nil, r.db.ProcessError(err) + } + + return followReqIDs, nil + }) } // newSelectFollowRequests returns a new select query for all rows in the follow_requests table with target_account_id = accountID. @@ -256,3 +356,12 @@ func newSelectLocalFollowers(db *WrappedDB, accountID string) *bun.SelectQuery { ). OrderExpr("? DESC", bun.Ident("updated_at")) } + +// newSelectBlocks returns a new select query for all rows in the blocks table with account_id = accountID. +func newSelectBlocks(db *WrappedDB, accountID string) *bun.SelectQuery { + return db.NewSelect(). + TableExpr("?", bun.Ident("blocks")). + ColumnExpr("?", bun.Ident("?")). + Where("? = ?", bun.Ident("account_id"), accountID). + OrderExpr("? DESC", bun.Ident("updated_at")) +} |