diff options
Diffstat (limited to 'internal/db/bundb/relationship.go')
-rw-r--r-- | internal/db/bundb/relationship.go | 280 |
1 files changed, 158 insertions, 122 deletions
diff --git a/internal/db/bundb/relationship.go b/internal/db/bundb/relationship.go index ba72a053a..66e48e441 100644 --- a/internal/db/bundb/relationship.go +++ b/internal/db/bundb/relationship.go @@ -51,26 +51,25 @@ func (r *relationshipDB) newFollowQ(follow interface{}) *bun.SelectQuery { func (r *relationshipDB) IsBlocked(ctx context.Context, account1 string, account2 string, eitherDirection bool) (bool, db.Error) { q := r.conn. NewSelect(). - Model(>smodel.Block{}). - ExcludeColumn("id", "created_at", "updated_at", "uri"). - Limit(1) + TableExpr("? AS ?", bun.Ident("blocks"), bun.Ident("block")). + Column("block.id") if eitherDirection { q = q. WhereGroup(" OR ", func(inner *bun.SelectQuery) *bun.SelectQuery { return inner. - Where("account_id = ?", account1). - Where("target_account_id = ?", account2) + Where("? = ?", bun.Ident("block.account_id"), account1). + Where("? = ?", bun.Ident("block.target_account_id"), account2) }). WhereGroup(" OR ", func(inner *bun.SelectQuery) *bun.SelectQuery { return inner. - Where("account_id = ?", account2). - Where("target_account_id = ?", account1) + Where("? = ?", bun.Ident("block.account_id"), account2). + Where("? = ?", bun.Ident("block.target_account_id"), account1) }) } else { q = q. - Where("account_id = ?", account1). - Where("target_account_id = ?", account2) + Where("? = ?", bun.Ident("block.account_id"), account1). + Where("? = ?", bun.Ident("block.target_account_id"), account2) } return r.conn.Exists(ctx, q) @@ -80,8 +79,8 @@ func (r *relationshipDB) GetBlock(ctx context.Context, account1 string, account2 block := >smodel.Block{} q := r.newBlockQ(block). - Where("block.account_id = ?", account1). - Where("block.target_account_id = ?", account2) + Where("? = ?", bun.Ident("block.account_id"), account1). + Where("? = ?", bun.Ident("block.target_account_id"), account2) if err := q.Scan(ctx); err != nil { return nil, r.conn.ProcessError(err) @@ -99,13 +98,13 @@ func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount if err := r.conn. NewSelect(). Model(follow). - Where("account_id = ?", requestingAccount). - Where("target_account_id = ?", targetAccount). + Column("follow.show_reblogs", "follow.notify"). + Where("? = ?", bun.Ident("follow.account_id"), requestingAccount). + Where("? = ?", bun.Ident("follow.target_account_id"), targetAccount). Limit(1). Scan(ctx); err != nil { - if err != sql.ErrNoRows { - // a proper error - return nil, fmt.Errorf("getrelationship: error checking follow existence: %s", err) + if err := r.conn.ProcessError(err); err != db.ErrNoEntries { + return nil, fmt.Errorf("GetRelationship: error fetching follow: %s", err) } // no follow exists so these are all false rel.Following = false @@ -119,55 +118,56 @@ func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount } // check if the target account follows the requesting account - count, err := r.conn. + followedByQ := r.conn. NewSelect(). - Model(>smodel.Follow{}). - Where("account_id = ?", targetAccount). - Where("target_account_id = ?", requestingAccount). - Limit(1). - Count(ctx) + TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")). + Column("follow.id"). + Where("? = ?", bun.Ident("follow.account_id"), targetAccount). + Where("? = ?", bun.Ident("follow.target_account_id"), requestingAccount) + followedBy, err := r.conn.Exists(ctx, followedByQ) if err != nil { - return nil, fmt.Errorf("getrelationship: error checking followed_by existence: %s", err) + return nil, fmt.Errorf("GetRelationship: error checking followedBy: %s", err) } - rel.FollowedBy = count > 0 + rel.FollowedBy = followedBy - // check if the requesting account blocks the target account - count, err = r.conn.NewSelect(). - Model(>smodel.Block{}). - Where("account_id = ?", requestingAccount). - Where("target_account_id = ?", targetAccount). - Limit(1). - Count(ctx) + // check if there's a pending following request from requesting account to target account + requestedQ := r.conn. + NewSelect(). + TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")). + Column("follow_request.id"). + Where("? = ?", bun.Ident("follow_request.account_id"), requestingAccount). + Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccount) + requested, err := r.conn.Exists(ctx, requestedQ) if err != nil { - return nil, fmt.Errorf("getrelationship: error checking blocking existence: %s", err) + return nil, fmt.Errorf("GetRelationship: error checking requested: %s", err) } - rel.Blocking = count > 0 + rel.Requested = requested - // check if the target account blocks the requesting account - count, err = r.conn. + // check if the requesting account is blocking the target account + blockingQ := r.conn. NewSelect(). - Model(>smodel.Block{}). - Where("account_id = ?", targetAccount). - Where("target_account_id = ?", requestingAccount). - Limit(1). - Count(ctx) + TableExpr("? AS ?", bun.Ident("blocks"), bun.Ident("block")). + Column("block.id"). + Where("? = ?", bun.Ident("block.account_id"), requestingAccount). + Where("? = ?", bun.Ident("block.target_account_id"), targetAccount) + blocking, err := r.conn.Exists(ctx, blockingQ) if err != nil { - return nil, fmt.Errorf("getrelationship: error checking blocked existence: %s", err) + return nil, fmt.Errorf("GetRelationship: error checking blocking: %s", err) } - rel.BlockedBy = count > 0 + rel.Blocking = blocking - // check if there's a pending following request from requesting account to target account - count, err = r.conn. + // check if the requesting account is blocked by the target account + blockedByQ := r.conn. NewSelect(). - Model(>smodel.FollowRequest{}). - Where("account_id = ?", requestingAccount). - Where("target_account_id = ?", targetAccount). - Limit(1). - Count(ctx) + TableExpr("? AS ?", bun.Ident("blocks"), bun.Ident("block")). + Column("block.id"). + Where("? = ?", bun.Ident("block.account_id"), targetAccount). + Where("? = ?", bun.Ident("block.target_account_id"), requestingAccount) + blockedBy, err := r.conn.Exists(ctx, blockedByQ) if err != nil { - return nil, fmt.Errorf("getrelationship: error checking blocked existence: %s", err) + return nil, fmt.Errorf("GetRelationship: error checking blockedBy: %s", err) } - rel.Requested = count > 0 + rel.BlockedBy = blockedBy return rel, nil } @@ -179,10 +179,10 @@ func (r *relationshipDB) IsFollowing(ctx context.Context, sourceAccount *gtsmode q := r.conn. NewSelect(). - Model(>smodel.Follow{}). - Where("account_id = ?", sourceAccount.ID). - Where("target_account_id = ?", targetAccount.ID). - Limit(1) + TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")). + Column("follow.id"). + Where("? = ?", bun.Ident("follow.account_id"), sourceAccount.ID). + Where("? = ?", bun.Ident("follow.target_account_id"), targetAccount.ID) return r.conn.Exists(ctx, q) } @@ -194,9 +194,10 @@ func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccount *g q := r.conn. NewSelect(). - Model(>smodel.FollowRequest{}). - Where("account_id = ?", sourceAccount.ID). - Where("target_account_id = ?", targetAccount.ID) + TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")). + Column("follow_request.id"). + Where("? = ?", bun.Ident("follow_request.account_id"), sourceAccount.ID). + Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccount.ID) return r.conn.Exists(ctx, q) } @@ -222,82 +223,98 @@ func (r *relationshipDB) IsMutualFollowing(ctx context.Context, account1 *gtsmod } func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.Follow, db.Error) { - // make sure the original follow request exists - fr := >smodel.FollowRequest{} - if err := r.conn. - NewSelect(). - Model(fr). - Where("account_id = ?", originAccountID). - Where("target_account_id = ?", targetAccountID). - Scan(ctx); err != nil { - return nil, r.conn.ProcessError(err) - } + var follow *gtsmodel.Follow + + if err := r.conn.RunInTx(ctx, func(tx bun.Tx) error { + // get original follow request + followRequest := >smodel.FollowRequest{} + if err := tx. + NewSelect(). + Model(followRequest). + Where("? = ?", bun.Ident("follow_request.account_id"), originAccountID). + Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccountID). + Scan(ctx); err != nil { + return err + } - // create a new follow to 'replace' the request with - follow := >smodel.Follow{ - ID: fr.ID, - AccountID: originAccountID, - TargetAccountID: targetAccountID, - URI: fr.URI, - } + // create a new follow to 'replace' the request with + follow = >smodel.Follow{ + ID: followRequest.ID, + AccountID: originAccountID, + TargetAccountID: targetAccountID, + URI: followRequest.URI, + } - // if the follow already exists, just update the URI -- we don't need to do anything else - if _, err := r.conn. - NewInsert(). - Model(follow). - On("CONFLICT (account_id,target_account_id) DO UPDATE set uri = ?", follow.URI). - Exec(ctx); err != nil { - return nil, r.conn.ProcessError(err) - } + // if the follow already exists, just update the URI -- we don't need to do anything else + if _, err := tx. + NewInsert(). + Model(follow). + On("CONFLICT (?,?) DO UPDATE set ? = ?", bun.Ident("account_id"), bun.Ident("target_account_id"), bun.Ident("uri"), follow.URI). + Exec(ctx); err != nil { + return err + } + + // now remove the follow request + if _, err := tx. + NewDelete(). + TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")). + Where("? = ?", bun.Ident("follow_request.id"), followRequest.ID). + Exec(ctx); err != nil { + return err + } - // now remove the follow request - if _, err := r.conn. - NewDelete(). - Model(>smodel.FollowRequest{}). - Where("account_id = ?", originAccountID). - Where("target_account_id = ?", targetAccountID). - Exec(ctx); err != nil { + return nil + }); err != nil { return nil, r.conn.ProcessError(err) } + // return the new follow return follow, nil } func (r *relationshipDB) RejectFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.FollowRequest, db.Error) { - // first get the follow request out of the database - fr := >smodel.FollowRequest{} - if err := r.conn. - NewSelect(). - Model(fr). - Where("account_id = ?", originAccountID). - Where("target_account_id = ?", targetAccountID). - Scan(ctx); err != nil { - return nil, r.conn.ProcessError(err) - } + followRequest := >smodel.FollowRequest{} + + if err := r.conn.RunInTx(ctx, func(tx bun.Tx) error { + // get original follow request + if err := tx. + NewSelect(). + Model(followRequest). + Where("? = ?", bun.Ident("follow_request.account_id"), originAccountID). + Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccountID). + Scan(ctx); err != nil { + return err + } - // now delete it from the database by ID - if _, err := r.conn. - NewDelete(). - Model(>smodel.FollowRequest{ID: fr.ID}). - WherePK(). - Exec(ctx); err != nil { + // now delete it from the database by ID + if _, err := tx. + NewDelete(). + TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")). + Where("? = ?", bun.Ident("follow_request.id"), followRequest.ID). + Exec(ctx); err != nil { + return err + } + + return nil + }); err != nil { return nil, r.conn.ProcessError(err) } // return the deleted follow request - return fr, nil + return followRequest, nil } func (r *relationshipDB) GetAccountFollowRequests(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, db.Error) { followRequests := []*gtsmodel.FollowRequest{} q := r.newFollowQ(&followRequests). - Where("target_account_id = ?", accountID). + Where("? = ?", bun.Ident("follow_request.target_account_id"), accountID). Order("follow_request.updated_at DESC") if err := q.Scan(ctx); err != nil { return nil, r.conn.ProcessError(err) } + return followRequests, nil } @@ -305,21 +322,31 @@ func (r *relationshipDB) GetAccountFollows(ctx context.Context, accountID string follows := []*gtsmodel.Follow{} q := r.newFollowQ(&follows). - Where("account_id = ?", accountID). + Where("? = ?", bun.Ident("follow.account_id"), accountID). Order("follow.updated_at DESC") if err := q.Scan(ctx); err != nil { return nil, r.conn.ProcessError(err) } + return follows, nil } func (r *relationshipDB) CountAccountFollows(ctx context.Context, accountID string, localOnly bool) (int, db.Error) { - return r.conn. + q := r.conn. NewSelect(). - Model(&[]*gtsmodel.Follow{}). - Where("account_id = ?", accountID). - Count(ctx) + TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")) + + if localOnly { + q = q. + Join("JOIN ? AS ? ON ? = ?", bun.Ident("accounts"), bun.Ident("account"), bun.Ident("follow.target_account_id"), bun.Ident("account.id")). + Where("? = ?", bun.Ident("follow.account_id"), accountID). + Where("? IS NULL", bun.Ident("account.domain")) + } else { + q = q.Where("? = ?", bun.Ident("follow.account_id"), accountID) + } + + return q.Count(ctx) } func (r *relationshipDB) GetAccountFollowedBy(ctx context.Context, accountID string, localOnly bool) ([]*gtsmodel.Follow, db.Error) { @@ -331,12 +358,12 @@ func (r *relationshipDB) GetAccountFollowedBy(ctx context.Context, accountID str Order("follow.updated_at DESC") if localOnly { - q = q.ColumnExpr("follow.*"). - Join("JOIN accounts AS a ON follow.account_id = CAST(a.id as TEXT)"). - Where("follow.target_account_id = ?", accountID). - WhereGroup(" AND ", whereEmptyOrNull("a.domain")) + q = q. + Join("JOIN ? AS ? ON ? = ?", bun.Ident("accounts"), bun.Ident("account"), bun.Ident("follow.account_id"), bun.Ident("account.id")). + Where("? = ?", bun.Ident("follow.target_account_id"), accountID). + Where("? IS NULL", bun.Ident("account.domain")) } else { - q = q.Where("target_account_id = ?", accountID) + q = q.Where("? = ?", bun.Ident("follow.target_account_id"), accountID) } err := q.Scan(ctx) @@ -347,9 +374,18 @@ func (r *relationshipDB) GetAccountFollowedBy(ctx context.Context, accountID str } func (r *relationshipDB) CountAccountFollowedBy(ctx context.Context, accountID string, localOnly bool) (int, db.Error) { - return r.conn. + q := r.conn. NewSelect(). - Model(&[]*gtsmodel.Follow{}). - Where("target_account_id = ?", accountID). - Count(ctx) + TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")) + + if localOnly { + q = q. + Join("JOIN ? AS ? ON ? = ?", bun.Ident("accounts"), bun.Ident("account"), bun.Ident("follow.account_id"), bun.Ident("account.id")). + Where("? = ?", bun.Ident("follow.target_account_id"), accountID). + Where("? IS NULL", bun.Ident("account.domain")) + } else { + q = q.Where("? = ?", bun.Ident("follow.target_account_id"), accountID) + } + + return q.Count(ctx) } |