summaryrefslogtreecommitdiff
path: root/internal/db/bundb/relationship.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/db/bundb/relationship.go')
-rw-r--r--internal/db/bundb/relationship.go280
1 files changed, 158 insertions, 122 deletions
diff --git a/internal/db/bundb/relationship.go b/internal/db/bundb/relationship.go
index ba72a053a..66e48e441 100644
--- a/internal/db/bundb/relationship.go
+++ b/internal/db/bundb/relationship.go
@@ -51,26 +51,25 @@ func (r *relationshipDB) newFollowQ(follow interface{}) *bun.SelectQuery {
func (r *relationshipDB) IsBlocked(ctx context.Context, account1 string, account2 string, eitherDirection bool) (bool, db.Error) {
q := r.conn.
NewSelect().
- Model(&gtsmodel.Block{}).
- ExcludeColumn("id", "created_at", "updated_at", "uri").
- Limit(1)
+ TableExpr("? AS ?", bun.Ident("blocks"), bun.Ident("block")).
+ Column("block.id")
if eitherDirection {
q = q.
WhereGroup(" OR ", func(inner *bun.SelectQuery) *bun.SelectQuery {
return inner.
- Where("account_id = ?", account1).
- Where("target_account_id = ?", account2)
+ Where("? = ?", bun.Ident("block.account_id"), account1).
+ Where("? = ?", bun.Ident("block.target_account_id"), account2)
}).
WhereGroup(" OR ", func(inner *bun.SelectQuery) *bun.SelectQuery {
return inner.
- Where("account_id = ?", account2).
- Where("target_account_id = ?", account1)
+ Where("? = ?", bun.Ident("block.account_id"), account2).
+ Where("? = ?", bun.Ident("block.target_account_id"), account1)
})
} else {
q = q.
- Where("account_id = ?", account1).
- Where("target_account_id = ?", account2)
+ Where("? = ?", bun.Ident("block.account_id"), account1).
+ Where("? = ?", bun.Ident("block.target_account_id"), account2)
}
return r.conn.Exists(ctx, q)
@@ -80,8 +79,8 @@ func (r *relationshipDB) GetBlock(ctx context.Context, account1 string, account2
block := &gtsmodel.Block{}
q := r.newBlockQ(block).
- Where("block.account_id = ?", account1).
- Where("block.target_account_id = ?", account2)
+ Where("? = ?", bun.Ident("block.account_id"), account1).
+ Where("? = ?", bun.Ident("block.target_account_id"), account2)
if err := q.Scan(ctx); err != nil {
return nil, r.conn.ProcessError(err)
@@ -99,13 +98,13 @@ func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount
if err := r.conn.
NewSelect().
Model(follow).
- Where("account_id = ?", requestingAccount).
- Where("target_account_id = ?", targetAccount).
+ Column("follow.show_reblogs", "follow.notify").
+ Where("? = ?", bun.Ident("follow.account_id"), requestingAccount).
+ Where("? = ?", bun.Ident("follow.target_account_id"), targetAccount).
Limit(1).
Scan(ctx); err != nil {
- if err != sql.ErrNoRows {
- // a proper error
- return nil, fmt.Errorf("getrelationship: error checking follow existence: %s", err)
+ if err := r.conn.ProcessError(err); err != db.ErrNoEntries {
+ return nil, fmt.Errorf("GetRelationship: error fetching follow: %s", err)
}
// no follow exists so these are all false
rel.Following = false
@@ -119,55 +118,56 @@ func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount
}
// check if the target account follows the requesting account
- count, err := r.conn.
+ followedByQ := r.conn.
NewSelect().
- Model(&gtsmodel.Follow{}).
- Where("account_id = ?", targetAccount).
- Where("target_account_id = ?", requestingAccount).
- Limit(1).
- Count(ctx)
+ TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")).
+ Column("follow.id").
+ Where("? = ?", bun.Ident("follow.account_id"), targetAccount).
+ Where("? = ?", bun.Ident("follow.target_account_id"), requestingAccount)
+ followedBy, err := r.conn.Exists(ctx, followedByQ)
if err != nil {
- return nil, fmt.Errorf("getrelationship: error checking followed_by existence: %s", err)
+ return nil, fmt.Errorf("GetRelationship: error checking followedBy: %s", err)
}
- rel.FollowedBy = count > 0
+ rel.FollowedBy = followedBy
- // check if the requesting account blocks the target account
- count, err = r.conn.NewSelect().
- Model(&gtsmodel.Block{}).
- Where("account_id = ?", requestingAccount).
- Where("target_account_id = ?", targetAccount).
- Limit(1).
- Count(ctx)
+ // check if there's a pending following request from requesting account to target account
+ requestedQ := r.conn.
+ NewSelect().
+ TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")).
+ Column("follow_request.id").
+ Where("? = ?", bun.Ident("follow_request.account_id"), requestingAccount).
+ Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccount)
+ requested, err := r.conn.Exists(ctx, requestedQ)
if err != nil {
- return nil, fmt.Errorf("getrelationship: error checking blocking existence: %s", err)
+ return nil, fmt.Errorf("GetRelationship: error checking requested: %s", err)
}
- rel.Blocking = count > 0
+ rel.Requested = requested
- // check if the target account blocks the requesting account
- count, err = r.conn.
+ // check if the requesting account is blocking the target account
+ blockingQ := r.conn.
NewSelect().
- Model(&gtsmodel.Block{}).
- Where("account_id = ?", targetAccount).
- Where("target_account_id = ?", requestingAccount).
- Limit(1).
- Count(ctx)
+ TableExpr("? AS ?", bun.Ident("blocks"), bun.Ident("block")).
+ Column("block.id").
+ Where("? = ?", bun.Ident("block.account_id"), requestingAccount).
+ Where("? = ?", bun.Ident("block.target_account_id"), targetAccount)
+ blocking, err := r.conn.Exists(ctx, blockingQ)
if err != nil {
- return nil, fmt.Errorf("getrelationship: error checking blocked existence: %s", err)
+ return nil, fmt.Errorf("GetRelationship: error checking blocking: %s", err)
}
- rel.BlockedBy = count > 0
+ rel.Blocking = blocking
- // check if there's a pending following request from requesting account to target account
- count, err = r.conn.
+ // check if the requesting account is blocked by the target account
+ blockedByQ := r.conn.
NewSelect().
- Model(&gtsmodel.FollowRequest{}).
- Where("account_id = ?", requestingAccount).
- Where("target_account_id = ?", targetAccount).
- Limit(1).
- Count(ctx)
+ TableExpr("? AS ?", bun.Ident("blocks"), bun.Ident("block")).
+ Column("block.id").
+ Where("? = ?", bun.Ident("block.account_id"), targetAccount).
+ Where("? = ?", bun.Ident("block.target_account_id"), requestingAccount)
+ blockedBy, err := r.conn.Exists(ctx, blockedByQ)
if err != nil {
- return nil, fmt.Errorf("getrelationship: error checking blocked existence: %s", err)
+ return nil, fmt.Errorf("GetRelationship: error checking blockedBy: %s", err)
}
- rel.Requested = count > 0
+ rel.BlockedBy = blockedBy
return rel, nil
}
@@ -179,10 +179,10 @@ func (r *relationshipDB) IsFollowing(ctx context.Context, sourceAccount *gtsmode
q := r.conn.
NewSelect().
- Model(&gtsmodel.Follow{}).
- Where("account_id = ?", sourceAccount.ID).
- Where("target_account_id = ?", targetAccount.ID).
- Limit(1)
+ TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")).
+ Column("follow.id").
+ Where("? = ?", bun.Ident("follow.account_id"), sourceAccount.ID).
+ Where("? = ?", bun.Ident("follow.target_account_id"), targetAccount.ID)
return r.conn.Exists(ctx, q)
}
@@ -194,9 +194,10 @@ func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccount *g
q := r.conn.
NewSelect().
- Model(&gtsmodel.FollowRequest{}).
- Where("account_id = ?", sourceAccount.ID).
- Where("target_account_id = ?", targetAccount.ID)
+ TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")).
+ Column("follow_request.id").
+ Where("? = ?", bun.Ident("follow_request.account_id"), sourceAccount.ID).
+ Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccount.ID)
return r.conn.Exists(ctx, q)
}
@@ -222,82 +223,98 @@ func (r *relationshipDB) IsMutualFollowing(ctx context.Context, account1 *gtsmod
}
func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.Follow, db.Error) {
- // make sure the original follow request exists
- fr := &gtsmodel.FollowRequest{}
- if err := r.conn.
- NewSelect().
- Model(fr).
- Where("account_id = ?", originAccountID).
- Where("target_account_id = ?", targetAccountID).
- Scan(ctx); err != nil {
- return nil, r.conn.ProcessError(err)
- }
+ var follow *gtsmodel.Follow
+
+ if err := r.conn.RunInTx(ctx, func(tx bun.Tx) error {
+ // get original follow request
+ followRequest := &gtsmodel.FollowRequest{}
+ if err := tx.
+ NewSelect().
+ Model(followRequest).
+ Where("? = ?", bun.Ident("follow_request.account_id"), originAccountID).
+ Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccountID).
+ Scan(ctx); err != nil {
+ return err
+ }
- // create a new follow to 'replace' the request with
- follow := &gtsmodel.Follow{
- ID: fr.ID,
- AccountID: originAccountID,
- TargetAccountID: targetAccountID,
- URI: fr.URI,
- }
+ // create a new follow to 'replace' the request with
+ follow = &gtsmodel.Follow{
+ ID: followRequest.ID,
+ AccountID: originAccountID,
+ TargetAccountID: targetAccountID,
+ URI: followRequest.URI,
+ }
- // if the follow already exists, just update the URI -- we don't need to do anything else
- if _, err := r.conn.
- NewInsert().
- Model(follow).
- On("CONFLICT (account_id,target_account_id) DO UPDATE set uri = ?", follow.URI).
- Exec(ctx); err != nil {
- return nil, r.conn.ProcessError(err)
- }
+ // if the follow already exists, just update the URI -- we don't need to do anything else
+ if _, err := tx.
+ NewInsert().
+ Model(follow).
+ On("CONFLICT (?,?) DO UPDATE set ? = ?", bun.Ident("account_id"), bun.Ident("target_account_id"), bun.Ident("uri"), follow.URI).
+ Exec(ctx); err != nil {
+ return err
+ }
+
+ // now remove the follow request
+ if _, err := tx.
+ NewDelete().
+ TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")).
+ Where("? = ?", bun.Ident("follow_request.id"), followRequest.ID).
+ Exec(ctx); err != nil {
+ return err
+ }
- // now remove the follow request
- if _, err := r.conn.
- NewDelete().
- Model(&gtsmodel.FollowRequest{}).
- Where("account_id = ?", originAccountID).
- Where("target_account_id = ?", targetAccountID).
- Exec(ctx); err != nil {
+ return nil
+ }); err != nil {
return nil, r.conn.ProcessError(err)
}
+ // return the new follow
return follow, nil
}
func (r *relationshipDB) RejectFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.FollowRequest, db.Error) {
- // first get the follow request out of the database
- fr := &gtsmodel.FollowRequest{}
- if err := r.conn.
- NewSelect().
- Model(fr).
- Where("account_id = ?", originAccountID).
- Where("target_account_id = ?", targetAccountID).
- Scan(ctx); err != nil {
- return nil, r.conn.ProcessError(err)
- }
+ followRequest := &gtsmodel.FollowRequest{}
+
+ if err := r.conn.RunInTx(ctx, func(tx bun.Tx) error {
+ // get original follow request
+ if err := tx.
+ NewSelect().
+ Model(followRequest).
+ Where("? = ?", bun.Ident("follow_request.account_id"), originAccountID).
+ Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccountID).
+ Scan(ctx); err != nil {
+ return err
+ }
- // now delete it from the database by ID
- if _, err := r.conn.
- NewDelete().
- Model(&gtsmodel.FollowRequest{ID: fr.ID}).
- WherePK().
- Exec(ctx); err != nil {
+ // now delete it from the database by ID
+ if _, err := tx.
+ NewDelete().
+ TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")).
+ Where("? = ?", bun.Ident("follow_request.id"), followRequest.ID).
+ Exec(ctx); err != nil {
+ return err
+ }
+
+ return nil
+ }); err != nil {
return nil, r.conn.ProcessError(err)
}
// return the deleted follow request
- return fr, nil
+ return followRequest, nil
}
func (r *relationshipDB) GetAccountFollowRequests(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, db.Error) {
followRequests := []*gtsmodel.FollowRequest{}
q := r.newFollowQ(&followRequests).
- Where("target_account_id = ?", accountID).
+ Where("? = ?", bun.Ident("follow_request.target_account_id"), accountID).
Order("follow_request.updated_at DESC")
if err := q.Scan(ctx); err != nil {
return nil, r.conn.ProcessError(err)
}
+
return followRequests, nil
}
@@ -305,21 +322,31 @@ func (r *relationshipDB) GetAccountFollows(ctx context.Context, accountID string
follows := []*gtsmodel.Follow{}
q := r.newFollowQ(&follows).
- Where("account_id = ?", accountID).
+ Where("? = ?", bun.Ident("follow.account_id"), accountID).
Order("follow.updated_at DESC")
if err := q.Scan(ctx); err != nil {
return nil, r.conn.ProcessError(err)
}
+
return follows, nil
}
func (r *relationshipDB) CountAccountFollows(ctx context.Context, accountID string, localOnly bool) (int, db.Error) {
- return r.conn.
+ q := r.conn.
NewSelect().
- Model(&[]*gtsmodel.Follow{}).
- Where("account_id = ?", accountID).
- Count(ctx)
+ TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow"))
+
+ if localOnly {
+ q = q.
+ Join("JOIN ? AS ? ON ? = ?", bun.Ident("accounts"), bun.Ident("account"), bun.Ident("follow.target_account_id"), bun.Ident("account.id")).
+ Where("? = ?", bun.Ident("follow.account_id"), accountID).
+ Where("? IS NULL", bun.Ident("account.domain"))
+ } else {
+ q = q.Where("? = ?", bun.Ident("follow.account_id"), accountID)
+ }
+
+ return q.Count(ctx)
}
func (r *relationshipDB) GetAccountFollowedBy(ctx context.Context, accountID string, localOnly bool) ([]*gtsmodel.Follow, db.Error) {
@@ -331,12 +358,12 @@ func (r *relationshipDB) GetAccountFollowedBy(ctx context.Context, accountID str
Order("follow.updated_at DESC")
if localOnly {
- q = q.ColumnExpr("follow.*").
- Join("JOIN accounts AS a ON follow.account_id = CAST(a.id as TEXT)").
- Where("follow.target_account_id = ?", accountID).
- WhereGroup(" AND ", whereEmptyOrNull("a.domain"))
+ q = q.
+ Join("JOIN ? AS ? ON ? = ?", bun.Ident("accounts"), bun.Ident("account"), bun.Ident("follow.account_id"), bun.Ident("account.id")).
+ Where("? = ?", bun.Ident("follow.target_account_id"), accountID).
+ Where("? IS NULL", bun.Ident("account.domain"))
} else {
- q = q.Where("target_account_id = ?", accountID)
+ q = q.Where("? = ?", bun.Ident("follow.target_account_id"), accountID)
}
err := q.Scan(ctx)
@@ -347,9 +374,18 @@ func (r *relationshipDB) GetAccountFollowedBy(ctx context.Context, accountID str
}
func (r *relationshipDB) CountAccountFollowedBy(ctx context.Context, accountID string, localOnly bool) (int, db.Error) {
- return r.conn.
+ q := r.conn.
NewSelect().
- Model(&[]*gtsmodel.Follow{}).
- Where("target_account_id = ?", accountID).
- Count(ctx)
+ TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow"))
+
+ if localOnly {
+ q = q.
+ Join("JOIN ? AS ? ON ? = ?", bun.Ident("accounts"), bun.Ident("account"), bun.Ident("follow.account_id"), bun.Ident("account.id")).
+ Where("? = ?", bun.Ident("follow.target_account_id"), accountID).
+ Where("? IS NULL", bun.Ident("account.domain"))
+ } else {
+ q = q.Where("? = ?", bun.Ident("follow.target_account_id"), accountID)
+ }
+
+ return q.Count(ctx)
}