summaryrefslogtreecommitdiff
path: root/internal/db/bundb/relationship.go
diff options
context:
space:
mode:
authorLibravatar kim <89579420+NyaaaWhatsUpDoc@users.noreply.github.com>2023-03-28 14:03:14 +0100
committerLibravatar GitHub <noreply@github.com>2023-03-28 14:03:14 +0100
commitde6e3e5f2a8ea639d76e310a11cb9bc093fef3a9 (patch)
treee2b7044e22c943425a4d351a02f862fbde783657 /internal/db/bundb/relationship.go
parent[feature] Add list command to admin account (#1648) (diff)
downloadgotosocial-de6e3e5f2a8ea639d76e310a11cb9bc093fef3a9.tar.xz
[performance] refactoring + add fave / follow / request / visibility caching (#1607)
* refactor visibility checking, add caching for visibility * invalidate visibility cache items on account / status deletes * fix requester ID passed to visibility cache nil ptr * de-interface caches, fix home / public timeline caching + visibility * finish adding code comments for visibility filter * fix angry goconst linter warnings * actually finish adding filter visibility code comments for timeline functions * move home timeline status author check to after visibility * remove now-unused code * add more code comments * add TODO code comment, update printed cache start names * update printed cache names on stop * start adding separate follow(request) delete db functions, add specific visibility cache tests * add relationship type caching * fix getting local account follows / followed-bys, other small codebase improvements * simplify invalidation using cache hooks, add more GetAccountBy___() functions * fix boosting to return 404 if not boostable but no error (to not leak status ID) * remove dead code * improved placement of cache invalidation * update license headers * add example follow, follow-request config entries * add example visibility cache configuration to config file * use specific PutFollowRequest() instead of just Put() * add tests for all GetAccountBy() * add GetBlockBy() tests * update block to check primitive fields * update and finish adding Get{Account,Block,Follow,FollowRequest}By() tests * fix copy-pasted code * update envparsing test * whitespace * fix bun struct tag * add license header to gtscontext * fix old license header * improved error creation to not use fmt.Errorf() when not needed * fix various rebase conflicts, fix account test * remove commented-out code, fix-up mention caching * fix mention select bun statement * ensure mention target account populated, pass in context to customrenderer logging * remove more uncommented code, fix typeutil test * add statusfave database model caching * add status fave cache configuration * add status fave cache example config * woops, catch missed error. nice catch linter! * add back testrig panic on nil db * update example configuration to match defaults, slight tweak to cache configuration defaults * update envparsing test with new defaults * fetch followingget to use the follow target account * use accounnt.IsLocal() instead of empty domain check * use constants for the cache visibility type check * use bun.In() for notification type restriction in db query * include replies when fetching PublicTimeline() (to account for single-author threads in Visibility{}.StatusPublicTimelineable()) * use bun query building for nested select statements to ensure working with postgres * update public timeline future status checks to match visibility filter * same as previous, for home timeline * update public timeline tests to dynamically check for appropriate statuses * migrate accounts to allow unique constraint on public_key * provide minimal account with publicKey --------- Signed-off-by: kim <grufwub@gmail.com> Co-authored-by: tsmethurst <tobi.smethurst@protonmail.com>
Diffstat (limited to 'internal/db/bundb/relationship.go')
-rw-r--r--internal/db/bundb/relationship.go685
1 files changed, 147 insertions, 538 deletions
diff --git a/internal/db/bundb/relationship.go b/internal/db/bundb/relationship.go
index 21a29b5dc..82559a213 100644
--- a/internal/db/bundb/relationship.go
+++ b/internal/db/bundb/relationship.go
@@ -23,8 +23,8 @@ import (
"fmt"
"github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
- "github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/uptrace/bun"
)
@@ -34,603 +34,212 @@ type relationshipDB struct {
state *state.State
}
-func (r *relationshipDB) IsBlocked(ctx context.Context, account1 string, account2 string, eitherDirection bool) (bool, db.Error) {
- // Look for a block in direction of account1->account2
- block1, err := r.getBlock(ctx, account1, account2)
- if err != nil && !errors.Is(err, db.ErrNoEntries) {
- return false, err
- }
-
- if block1 != nil {
- // account1 blocks account2
- return true, nil
- } else if !eitherDirection {
- // Don't check for mutli-directional
- return false, nil
- }
-
- // Look for a block in direction of account2->account1
- block2, err := r.getBlock(ctx, account2, account1)
- if err != nil && !errors.Is(err, db.ErrNoEntries) {
- return false, err
- }
-
- return (block2 != nil), nil
-}
-
-func (r *relationshipDB) GetBlock(ctx context.Context, account1 string, account2 string) (*gtsmodel.Block, db.Error) {
- // Fetch block from database
- block, err := r.getBlock(ctx, account1, account2)
- if err != nil {
- return nil, err
- }
-
- // Set the block originating account
- block.Account, err = r.state.DB.GetAccountByID(ctx, block.AccountID)
- if err != nil {
- return nil, err
- }
-
- // Set the block target account
- block.TargetAccount, err = r.state.DB.GetAccountByID(ctx, block.TargetAccountID)
- if err != nil {
- return nil, err
- }
-
- return block, nil
-}
-
-func (r *relationshipDB) getBlock(ctx context.Context, account1 string, account2 string) (*gtsmodel.Block, db.Error) {
- return r.state.Caches.GTS.Block().Load("AccountID.TargetAccountID", func() (*gtsmodel.Block, error) {
- var block gtsmodel.Block
-
- q := r.conn.NewSelect().Model(&block).
- Where("? = ?", bun.Ident("block.account_id"), account1).
- Where("? = ?", bun.Ident("block.target_account_id"), account2)
- if err := q.Scan(ctx); err != nil {
- return nil, r.conn.ProcessError(err)
- }
-
- return &block, nil
- }, account1, account2)
-}
-
-func (r *relationshipDB) PutBlock(ctx context.Context, block *gtsmodel.Block) db.Error {
- return r.state.Caches.GTS.Block().Store(block, func() error {
- _, err := r.conn.NewInsert().Model(block).Exec(ctx)
- return r.conn.ProcessError(err)
- })
-}
-
-func (r *relationshipDB) DeleteBlockByID(ctx context.Context, id string) db.Error {
- if _, err := r.conn.
- NewDelete().
- TableExpr("? AS ?", bun.Ident("blocks"), bun.Ident("block")).
- Where("? = ?", bun.Ident("block.id"), id).
- Exec(ctx); err != nil {
- return r.conn.ProcessError(err)
- }
-
- // Drop any old value from cache by this ID
- r.state.Caches.GTS.Block().Invalidate("ID", id)
- return nil
-}
-
-func (r *relationshipDB) DeleteBlockByURI(ctx context.Context, uri string) db.Error {
- if _, err := r.conn.
- NewDelete().
- TableExpr("? AS ?", bun.Ident("blocks"), bun.Ident("block")).
- Where("? = ?", bun.Ident("block.uri"), uri).
- Exec(ctx); err != nil {
- return r.conn.ProcessError(err)
- }
-
- // Drop any old value from cache by this URI
- r.state.Caches.GTS.Block().Invalidate("URI", uri)
- return nil
-}
-
-func (r *relationshipDB) DeleteBlocksByOriginAccountID(ctx context.Context, originAccountID string) db.Error {
- blockIDs := []string{}
-
- q := r.conn.
- NewSelect().
- TableExpr("? AS ?", bun.Ident("blocks"), bun.Ident("block")).
- Column("block.id").
- Where("? = ?", bun.Ident("block.account_id"), originAccountID)
-
- if err := q.Scan(ctx, &blockIDs); err != nil {
- return r.conn.ProcessError(err)
- }
-
- for _, blockID := range blockIDs {
- if err := r.DeleteBlockByID(ctx, blockID); err != nil {
- return err
- }
- }
-
- return nil
-}
-
-func (r *relationshipDB) DeleteBlocksByTargetAccountID(ctx context.Context, targetAccountID string) db.Error {
- blockIDs := []string{}
-
- q := r.conn.
- NewSelect().
- TableExpr("? AS ?", bun.Ident("blocks"), bun.Ident("block")).
- Column("block.id").
- Where("? = ?", bun.Ident("block.target_account_id"), targetAccountID)
-
- if err := q.Scan(ctx, &blockIDs); err != nil {
- return r.conn.ProcessError(err)
- }
-
- for _, blockID := range blockIDs {
- if err := r.DeleteBlockByID(ctx, blockID); err != nil {
- return err
- }
- }
-
- return nil
-}
-
func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount string, targetAccount string) (*gtsmodel.Relationship, db.Error) {
- rel := &gtsmodel.Relationship{
- ID: targetAccount,
+ var rel gtsmodel.Relationship
+ rel.ID = targetAccount
+
+ // check if the requesting follows the target
+ follow, err := r.GetFollow(
+ gtscontext.SetBarebones(ctx),
+ requestingAccount,
+ targetAccount,
+ )
+ if err != nil && !errors.Is(err, db.ErrNoEntries) {
+ return nil, fmt.Errorf("GetRelationship: error fetching follow: %w", err)
}
- // check if the requesting account follows the target account
- follow := &gtsmodel.Follow{}
- if err := r.conn.
- NewSelect().
- Model(follow).
- Column("follow.show_reblogs", "follow.notify").
- Where("? = ?", bun.Ident("follow.account_id"), requestingAccount).
- Where("? = ?", bun.Ident("follow.target_account_id"), targetAccount).
- Limit(1).
- Scan(ctx); err != nil {
- if err := r.conn.ProcessError(err); err != db.ErrNoEntries {
- return nil, fmt.Errorf("GetRelationship: error fetching follow: %s", err)
- }
- // no follow exists so these are all false
- rel.Following = false
- rel.ShowingReblogs = false
- rel.Notifying = false
- } else {
+ if follow != nil {
// follow exists so we can fill these fields out...
rel.Following = true
rel.ShowingReblogs = *follow.ShowReblogs
rel.Notifying = *follow.Notify
}
- // check if the target account follows the requesting account
- followedByQ := r.conn.
- NewSelect().
- TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")).
- Column("follow.id").
- Where("? = ?", bun.Ident("follow.account_id"), targetAccount).
- Where("? = ?", bun.Ident("follow.target_account_id"), requestingAccount)
- followedBy, err := r.conn.Exists(ctx, followedByQ)
+ // check if the target follows the requesting
+ rel.FollowedBy, err = r.IsFollowing(ctx,
+ targetAccount,
+ requestingAccount,
+ )
if err != nil {
- return nil, fmt.Errorf("GetRelationship: error checking followedBy: %s", err)
+ return nil, fmt.Errorf("GetRelationship: error checking followedBy: %w", err)
}
- rel.FollowedBy = followedBy
- // check if there's a pending following request from requesting account to target account
- requestedQ := r.conn.
- NewSelect().
- TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")).
- Column("follow_request.id").
- Where("? = ?", bun.Ident("follow_request.account_id"), requestingAccount).
- Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccount)
- requested, err := r.conn.Exists(ctx, requestedQ)
+ // check if requesting has follow requested target
+ rel.Requested, err = r.IsFollowRequested(ctx,
+ requestingAccount,
+ targetAccount,
+ )
if err != nil {
- return nil, fmt.Errorf("GetRelationship: error checking requested: %s", err)
+ return nil, fmt.Errorf("GetRelationship: error checking requested: %w", err)
}
- rel.Requested = requested
// check if the requesting account is blocking the target account
- blockA2T, err := r.getBlock(ctx, requestingAccount, targetAccount)
- if err != nil && !errors.Is(err, db.ErrNoEntries) {
- return nil, fmt.Errorf("GetRelationship: error checking blocking: %s", err)
- }
- rel.Blocking = (blockA2T != nil)
-
- // check if the requesting account is blocked by the target account
- blockT2A, err := r.getBlock(ctx, targetAccount, requestingAccount)
- if err != nil && !errors.Is(err, db.ErrNoEntries) {
- return nil, fmt.Errorf("GetRelationship: error checking blockedBy: %s", err)
- }
- rel.BlockedBy = (blockT2A != nil)
-
- return rel, nil
-}
-
-func (r *relationshipDB) IsFollowing(ctx context.Context, sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.Error) {
- if sourceAccount == nil || targetAccount == nil {
- return false, nil
- }
-
- q := r.conn.
- NewSelect().
- TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")).
- Column("follow.id").
- Where("? = ?", bun.Ident("follow.account_id"), sourceAccount.ID).
- Where("? = ?", bun.Ident("follow.target_account_id"), targetAccount.ID)
-
- return r.conn.Exists(ctx, q)
-}
-
-func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.Error) {
- if sourceAccount == nil || targetAccount == nil {
- return false, nil
- }
-
- q := r.conn.
- NewSelect().
- TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")).
- Column("follow_request.id").
- Where("? = ?", bun.Ident("follow_request.account_id"), sourceAccount.ID).
- Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccount.ID)
-
- return r.conn.Exists(ctx, q)
-}
-
-func (r *relationshipDB) IsMutualFollowing(ctx context.Context, account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, db.Error) {
- if account1 == nil || account2 == nil {
- return false, nil
- }
-
- // make sure account 1 follows account 2
- f1, err := r.IsFollowing(ctx, account1, account2)
+ rel.Blocking, err = r.IsBlocked(ctx, requestingAccount, targetAccount)
if err != nil {
- return false, err
+ return nil, fmt.Errorf("GetRelationship: error checking blocking: %w", err)
}
- // make sure account 2 follows account 1
- f2, err := r.IsFollowing(ctx, account2, account1)
+ // check if the requesting account is blocked by the target account
+ rel.BlockedBy, err = r.IsBlocked(ctx, targetAccount, requestingAccount)
if err != nil {
- return false, err
+ return nil, fmt.Errorf("GetRelationship: error checking blockedBy: %w", err)
}
- return f1 && f2, nil
+ return &rel, nil
}
-func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.Follow, db.Error) {
- // Get original follow request.
- var followRequestID string
- if err := r.conn.
- NewSelect().
- TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")).
- Column("follow_request.id").
- Where("? = ?", bun.Ident("follow_request.account_id"), originAccountID).
- Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccountID).
- Scan(ctx, &followRequestID); err != nil {
+func (r *relationshipDB) GetAccountFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) {
+ var followIDs []string
+ if err := newSelectFollows(r.conn, accountID).
+ Scan(ctx, &followIDs); err != nil {
return nil, r.conn.ProcessError(err)
}
-
- followRequest, err := r.getFollowRequest(ctx, followRequestID)
- if err != nil {
- return nil, r.conn.ProcessError(err)
- }
-
- // Create a new follow to 'replace'
- // the original follow request with.
- follow := &gtsmodel.Follow{
- ID: followRequest.ID,
- AccountID: originAccountID,
- Account: followRequest.Account,
- TargetAccountID: targetAccountID,
- TargetAccount: followRequest.TargetAccount,
- URI: followRequest.URI,
- }
-
- // If the follow already exists, just
- // replace the URI with the new one.
- if _, err := r.conn.
- NewInsert().
- Model(follow).
- On("CONFLICT (?,?) DO UPDATE set ? = ?", bun.Ident("account_id"), bun.Ident("target_account_id"), bun.Ident("uri"), follow.URI).
- Exec(ctx); err != nil {
- return nil, r.conn.ProcessError(err)
- }
-
- // Delete original follow request.
- if _, err := r.conn.
- NewDelete().
- TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")).
- Where("? = ?", bun.Ident("follow_request.id"), followRequest.ID).
- Exec(ctx); err != nil {
- return nil, r.conn.ProcessError(err)
- }
-
- // Delete original follow request notification.
- if err := r.deleteFollowRequestNotif(ctx, originAccountID, targetAccountID); err != nil {
- return nil, err
- }
-
- // return the new follow
- return follow, nil
+ return r.GetFollowsByIDs(ctx, followIDs)
}
-func (r *relationshipDB) RejectFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.FollowRequest, db.Error) {
- // Get original follow request.
- var followRequestID string
- if err := r.conn.
- NewSelect().
- TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")).
- Column("follow_request.id").
- Where("? = ?", bun.Ident("follow_request.account_id"), originAccountID).
- Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccountID).
- Scan(ctx, &followRequestID); err != nil {
+func (r *relationshipDB) GetAccountLocalFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) {
+ var followIDs []string
+ if err := newSelectLocalFollows(r.conn, accountID).
+ Scan(ctx, &followIDs); err != nil {
return nil, r.conn.ProcessError(err)
}
+ return r.GetFollowsByIDs(ctx, followIDs)
+}
- followRequest, err := r.getFollowRequest(ctx, followRequestID)
- if err != nil {
+func (r *relationshipDB) GetAccountFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) {
+ var followIDs []string
+ if err := newSelectFollowers(r.conn, accountID).
+ Scan(ctx, &followIDs); err != nil {
return nil, r.conn.ProcessError(err)
}
+ return r.GetFollowsByIDs(ctx, followIDs)
+}
- // Delete original follow request.
- if _, err := r.conn.
- NewDelete().
- TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")).
- Where("? = ?", bun.Ident("follow_request.id"), followRequest.ID).
- Exec(ctx); err != nil {
+func (r *relationshipDB) GetAccountLocalFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) {
+ var followIDs []string
+ if err := newSelectLocalFollowers(r.conn, accountID).
+ Scan(ctx, &followIDs); err != nil {
return nil, r.conn.ProcessError(err)
}
-
- // Delete original follow request notification.
- if err := r.deleteFollowRequestNotif(ctx, originAccountID, targetAccountID); err != nil {
- return nil, err
- }
-
- // Return the now deleted follow request.
- return followRequest, nil
+ return r.GetFollowsByIDs(ctx, followIDs)
}
-func (r *relationshipDB) deleteFollowRequestNotif(ctx context.Context, originAccountID string, targetAccountID string) db.Error {
- var id string
- if err := r.conn.
- NewSelect().
- TableExpr("? AS ?", bun.Ident("notifications"), bun.Ident("notification")).
- Column("notification.id").
- Where("? = ?", bun.Ident("notification.origin_account_id"), originAccountID).
- Where("? = ?", bun.Ident("notification.target_account_id"), targetAccountID).
- Where("? = ?", bun.Ident("notification.notification_type"), gtsmodel.NotificationFollowRequest).
- Limit(1). // There should only be one!
- Scan(ctx, &id); err != nil {
- err = r.conn.ProcessError(err)
- if errors.Is(err, db.ErrNoEntries) {
- // If no entries, the notif didn't
- // exist anyway so nothing to do here.
- return nil
- }
- // Return on real error.
- return err
- }
-
- return r.state.DB.DeleteNotification(ctx, id)
+func (r *relationshipDB) CountAccountFollows(ctx context.Context, accountID string) (int, error) {
+ n, err := newSelectFollows(r.conn, accountID).Count(ctx)
+ return n, r.conn.ProcessError(err)
}
-func (r *relationshipDB) getFollow(ctx context.Context, id string) (*gtsmodel.Follow, db.Error) {
- follow := &gtsmodel.Follow{}
-
- err := r.conn.
- NewSelect().
- Model(follow).
- Where("? = ?", bun.Ident("follow.id"), id).
- Scan(ctx)
- if err != nil {
- return nil, r.conn.ProcessError(err)
- }
-
- follow.Account, err = r.state.DB.GetAccountByID(ctx, follow.AccountID)
- if err != nil {
- log.Errorf(ctx, "error getting follow account %q: %v", follow.AccountID, err)
- }
-
- follow.TargetAccount, err = r.state.DB.GetAccountByID(ctx, follow.TargetAccountID)
- if err != nil {
- log.Errorf(ctx, "error getting follow target account %q: %v", follow.TargetAccountID, err)
- }
-
- return follow, nil
+func (r *relationshipDB) CountAccountLocalFollows(ctx context.Context, accountID string) (int, error) {
+ n, err := newSelectLocalFollows(r.conn, accountID).Count(ctx)
+ return n, r.conn.ProcessError(err)
}
-func (r *relationshipDB) GetLocalFollowersIDs(ctx context.Context, targetAccountID string) ([]string, db.Error) {
- accountIDs := []string{}
-
- // Select only the account ID of each follow.
- q := r.conn.
- NewSelect().
- TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")).
- ColumnExpr("? AS ?", bun.Ident("follow.account_id"), bun.Ident("account_id")).
- Where("? = ?", bun.Ident("follow.target_account_id"), targetAccountID)
-
- // Join on accounts table to select only
- // those with NULL domain (local accounts).
- q = q.
- Join("JOIN ? AS ? ON ? = ?",
- bun.Ident("accounts"),
- bun.Ident("account"),
- bun.Ident("follow.account_id"),
- bun.Ident("account.id"),
- ).
- Where("? IS NULL", bun.Ident("account.domain"))
+func (r *relationshipDB) CountAccountFollowers(ctx context.Context, accountID string) (int, error) {
+ n, err := newSelectFollowers(r.conn, accountID).Count(ctx)
+ return n, r.conn.ProcessError(err)
+}
- // We don't *really* need to order these,
- // but it makes it more consistent to do so.
- q = q.Order("account_id DESC")
+func (r *relationshipDB) CountAccountLocalFollowers(ctx context.Context, accountID string) (int, error) {
+ n, err := newSelectLocalFollowers(r.conn, accountID).Count(ctx)
+ return n, r.conn.ProcessError(err)
+}
- if err := q.Scan(ctx, &accountIDs); err != nil {
+func (r *relationshipDB) GetAccountFollowRequests(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, error) {
+ var followReqIDs []string
+ if err := newSelectFollowRequests(r.conn, accountID).
+ Scan(ctx, &followReqIDs); err != nil {
return nil, r.conn.ProcessError(err)
}
-
- return accountIDs, nil
+ return r.GetFollowRequestsByIDs(ctx, followReqIDs)
}
-func (r *relationshipDB) GetFollows(ctx context.Context, accountID string, targetAccountID string) ([]*gtsmodel.Follow, db.Error) {
- ids := []string{}
-
- q := r.conn.
- NewSelect().
- TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")).
- Column("follow.id").
- Order("follow.updated_at DESC")
-
- if accountID != "" {
- q = q.Where("? = ?", bun.Ident("follow.account_id"), accountID)
- }
-
- if targetAccountID != "" {
- q = q.Where("? = ?", bun.Ident("follow.target_account_id"), targetAccountID)
- }
-
- if err := q.Scan(ctx, &ids); err != nil {
+func (r *relationshipDB) GetAccountFollowRequesting(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, error) {
+ var followReqIDs []string
+ if err := newSelectFollowRequesting(r.conn, accountID).
+ Scan(ctx, &followReqIDs); err != nil {
return nil, r.conn.ProcessError(err)
}
-
- follows := make([]*gtsmodel.Follow, 0, len(ids))
- for _, id := range ids {
- follow, err := r.getFollow(ctx, id)
- if err != nil {
- log.Errorf(ctx, "error getting follow %q: %v", id, err)
- continue
- }
-
- follows = append(follows, follow)
- }
-
- return follows, nil
+ return r.GetFollowRequestsByIDs(ctx, followReqIDs)
}
-func (r *relationshipDB) CountFollows(ctx context.Context, accountID string, targetAccountID string) (int, db.Error) {
- q := r.conn.
- NewSelect().
- TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")).
- Column("follow.id")
-
- if accountID != "" {
- q = q.Where("? = ?", bun.Ident("follow.account_id"), accountID)
- }
-
- if targetAccountID != "" {
- q = q.Where("? = ?", bun.Ident("follow.target_account_id"), targetAccountID)
- }
-
- return q.Count(ctx)
+func (r *relationshipDB) CountAccountFollowRequests(ctx context.Context, accountID string) (int, error) {
+ n, err := newSelectFollowRequests(r.conn, accountID).Count(ctx)
+ return n, r.conn.ProcessError(err)
}
-func (r *relationshipDB) getFollowRequest(ctx context.Context, id string) (*gtsmodel.FollowRequest, db.Error) {
- followRequest := &gtsmodel.FollowRequest{}
-
- err := r.conn.
- NewSelect().
- Model(followRequest).
- Where("? = ?", bun.Ident("follow_request.id"), id).
- Scan(ctx)
- if err != nil {
- return nil, r.conn.ProcessError(err)
- }
-
- followRequest.Account, err = r.state.DB.GetAccountByID(ctx, followRequest.AccountID)
- if err != nil {
- log.Errorf(ctx, "error getting follow request account %q: %v", followRequest.AccountID, err)
- }
-
- followRequest.TargetAccount, err = r.state.DB.GetAccountByID(ctx, followRequest.TargetAccountID)
- if err != nil {
- log.Errorf(ctx, "error getting follow request target account %q: %v", followRequest.TargetAccountID, err)
- }
-
- return followRequest, nil
+func (r *relationshipDB) CountAccountFollowRequesting(ctx context.Context, accountID string) (int, error) {
+ n, err := newSelectFollowRequesting(r.conn, accountID).Count(ctx)
+ return n, r.conn.ProcessError(err)
}
-func (r *relationshipDB) GetFollowRequests(ctx context.Context, accountID string, targetAccountID string) ([]*gtsmodel.FollowRequest, db.Error) {
- ids := []string{}
-
- q := r.conn.
- NewSelect().
- TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")).
- Column("follow_request.id")
-
- if accountID != "" {
- q = q.Where("? = ?", bun.Ident("follow_request.account_id"), accountID)
- }
-
- if targetAccountID != "" {
- q = q.Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccountID)
- }
-
- if err := q.Scan(ctx, &ids); err != nil {
- return nil, r.conn.ProcessError(err)
- }
-
- followRequests := make([]*gtsmodel.FollowRequest, 0, len(ids))
- for _, id := range ids {
- followRequest, err := r.getFollowRequest(ctx, id)
- if err != nil {
- log.Errorf(ctx, "error getting follow request %q: %v", id, err)
- continue
- }
-
- followRequests = append(followRequests, followRequest)
- }
-
- return followRequests, nil
+// newSelectFollowRequests returns a new select query for all rows in the follow_requests table with target_account_id = accountID.
+func newSelectFollowRequests(conn *DBConn, accountID string) *bun.SelectQuery {
+ return conn.NewSelect().
+ TableExpr("?", bun.Ident("follow_requests")).
+ ColumnExpr("?", bun.Ident("id")).
+ Where("? = ?", bun.Ident("target_account_id"), accountID).
+ OrderExpr("? DESC", bun.Ident("updated_at"))
}
-func (r *relationshipDB) CountFollowRequests(ctx context.Context, accountID string, targetAccountID string) (int, db.Error) {
- q := r.conn.
- NewSelect().
- TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")).
- Column("follow_request.id").
- Order("follow_request.updated_at DESC")
-
- if accountID != "" {
- q = q.Where("? = ?", bun.Ident("follow_request.account_id"), accountID)
- }
-
- if targetAccountID != "" {
- q = q.Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccountID)
- }
-
- return q.Count(ctx)
+// newSelectFollowRequesting returns a new select query for all rows in the follow_requests table with account_id = accountID.
+func newSelectFollowRequesting(conn *DBConn, accountID string) *bun.SelectQuery {
+ return conn.NewSelect().
+ TableExpr("?", bun.Ident("follow_requests")).
+ ColumnExpr("?", bun.Ident("id")).
+ Where("? = ?", bun.Ident("target_account_id"), accountID).
+ OrderExpr("? DESC", bun.Ident("updated_at"))
}
-func (r *relationshipDB) Unfollow(ctx context.Context, originAccountID string, targetAccountID string) (string, db.Error) {
- uri := new(string)
-
- _, err := r.conn.
- NewDelete().
- TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")).
- Where("? = ?", bun.Ident("follow.target_account_id"), targetAccountID).
- Where("? = ?", bun.Ident("follow.account_id"), originAccountID).
- Returning("?", bun.Ident("uri")).Exec(ctx, uri)
-
- // Only return proper errors.
- if err = r.conn.ProcessError(err); err != db.ErrNoEntries {
- return *uri, err
- }
-
- return *uri, nil
+// newSelectFollows returns a new select query for all rows in the follows table with account_id = accountID.
+func newSelectFollows(conn *DBConn, accountID string) *bun.SelectQuery {
+ return conn.NewSelect().
+ Table("follows").
+ Column("id").
+ Where("? = ?", bun.Ident("account_id"), accountID).
+ OrderExpr("? DESC", bun.Ident("updated_at"))
}
-func (r *relationshipDB) UnfollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (string, db.Error) {
- uri := new(string)
-
- _, err := r.conn.
- NewDelete().
- TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")).
- Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccountID).
- Where("? = ?", bun.Ident("follow_request.account_id"), originAccountID).
- Returning("?", bun.Ident("uri")).Exec(ctx, uri)
-
- // Only return proper errors.
- if err = r.conn.ProcessError(err); err != db.ErrNoEntries {
- return *uri, err
- }
-
- return *uri, nil
+// newSelectLocalFollows returns a new select query for all rows in the follows table with
+// account_id = accountID where the corresponding account ID has a NULL domain (i.e. is local).
+func newSelectLocalFollows(conn *DBConn, accountID string) *bun.SelectQuery {
+ return conn.NewSelect().
+ Table("follows").
+ Column("id").
+ Where("? = ? AND ? IN (?)",
+ bun.Ident("account_id"),
+ accountID,
+ bun.Ident("target_account_id"),
+ conn.NewSelect().
+ Table("accounts").
+ Column("id").
+ Where("? IS NULL", bun.Ident("domain")),
+ ).
+ OrderExpr("? DESC", bun.Ident("updated_at"))
+}
+
+// newSelectFollowers returns a new select query for all rows in the follows table with target_account_id = accountID.
+func newSelectFollowers(conn *DBConn, accountID string) *bun.SelectQuery {
+ return conn.NewSelect().
+ Table("follows").
+ Column("id").
+ Where("? = ?", bun.Ident("target_account_id"), accountID).
+ OrderExpr("? DESC", bun.Ident("updated_at"))
+}
+
+// newSelectLocalFollowers returns a new select query for all rows in the follows table with
+// target_account_id = accountID where the corresponding account ID has a NULL domain (i.e. is local).
+func newSelectLocalFollowers(conn *DBConn, accountID string) *bun.SelectQuery {
+ return conn.NewSelect().
+ Table("follows").
+ Column("id").
+ Where("? = ? AND ? IN (?)",
+ bun.Ident("target_account_id"),
+ accountID,
+ bun.Ident("account_id"),
+ conn.NewSelect().
+ Table("accounts").
+ Column("id").
+ Where("? IS NULL", bun.Ident("domain")),
+ ).
+ OrderExpr("? DESC", bun.Ident("updated_at"))
}