summaryrefslogtreecommitdiff
path: root/internal/db/bundb/relationship_follow_req.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/db/bundb/relationship_follow_req.go')
-rw-r--r--internal/db/bundb/relationship_follow_req.go269
1 files changed, 152 insertions, 117 deletions
diff --git a/internal/db/bundb/relationship_follow_req.go b/internal/db/bundb/relationship_follow_req.go
index 513f5abb0..6a9fbea09 100644
--- a/internal/db/bundb/relationship_follow_req.go
+++ b/internal/db/bundb/relationship_follow_req.go
@@ -167,7 +167,7 @@ func (r *relationshipDB) getFollowRequest(ctx context.Context, lookup string, db
func (r *relationshipDB) PopulateFollowRequest(ctx context.Context, follow *gtsmodel.FollowRequest) error {
var (
err error
- errs = gtserror.NewMultiError(2)
+ errs gtserror.MultiError
)
if follow.Account == nil {
@@ -196,8 +196,10 @@ func (r *relationshipDB) PopulateFollowRequest(ctx context.Context, follow *gtsm
}
func (r *relationshipDB) PutFollowRequest(ctx context.Context, follow *gtsmodel.FollowRequest) error {
- return r.state.Caches.DB.FollowRequest.Store(follow, func() error {
- _, err := r.db.NewInsert().Model(follow).Exec(ctx)
+ return r.insertFollowRequest(ctx, follow, func(tx bun.Tx) error {
+ _, err := tx.NewInsert().
+ Model(follow).
+ Exec(ctx)
return err
})
}
@@ -208,22 +210,17 @@ func (r *relationshipDB) UpdateFollowRequest(ctx context.Context, followRequest
// If we're updating by column, ensure "updated_at" is included.
columns = append(columns, "updated_at")
}
-
return r.state.Caches.DB.FollowRequest.Store(followRequest, func() error {
- if _, err := r.db.NewUpdate().
+ _, err := r.db.NewUpdate().
Model(followRequest).
Where("? = ?", bun.Ident("follow_request.id"), followRequest.ID).
Column(columns...).
- Exec(ctx); err != nil {
- return err
- }
-
- return nil
+ Exec(ctx)
+ return err
})
}
func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.Follow, error) {
- // Get original follow request.
followReq, err := r.GetFollowRequest(ctx, sourceAccountID, targetAccountID)
if err != nil {
return nil, err
@@ -242,11 +239,9 @@ func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, sourceAccountI
Notify: followReq.Notify,
}
- if err := r.state.Caches.DB.Follow.Store(follow, func() error {
- // If the follow already exists, just
- // replace the URI with the new one.
- _, err := r.db.
- NewInsert().
+ // Insert the new follow modelled after request into database.
+ if err := r.insertFollow(ctx, follow, func(tx bun.Tx) error {
+ _, err := tx.NewInsert().
Model(follow).
On("CONFLICT (?,?) DO UPDATE set ? = ?", bun.Ident("account_id"), bun.Ident("target_account_id"), bun.Ident("uri"), follow.URI).
Exec(ctx)
@@ -255,12 +250,8 @@ func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, sourceAccountI
return nil, err
}
- // Delete original follow request.
- if _, err := r.db.
- NewDelete().
- Table("follow_requests").
- Where("? = ?", bun.Ident("id"), followReq.ID).
- Exec(ctx); err != nil {
+ // Delete the follow request now that it's accepted and not needed.
+ if err := r.DeleteFollowRequestByID(ctx, followReq.ID); err != nil {
return nil, err
}
@@ -275,12 +266,9 @@ func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, sourceAccountI
}
func (r *relationshipDB) RejectFollowRequest(ctx context.Context, sourceAccountID string, targetAccountID string) error {
- // Delete follow request first.
if err := r.DeleteFollowRequest(ctx, sourceAccountID, targetAccountID); err != nil {
return err
}
-
- // Delete follow request notification
return r.state.DB.DeleteNotifications(ctx, []gtsmodel.NotificationType{
gtsmodel.NotificationFollowRequest,
}, targetAccountID, sourceAccountID)
@@ -291,89 +279,63 @@ func (r *relationshipDB) DeleteFollowRequest(
sourceAccountID string,
targetAccountID string,
) error {
+ return r.deleteFollowRequest(ctx, func(tx bun.Tx) (*gtsmodel.FollowRequest, error) {
+ var deleted gtsmodel.FollowRequest
+ deleted.AccountID = sourceAccountID
+ deleted.TargetAccountID = targetAccountID
+
+ if _, err := tx.NewDelete().
+ Model(&deleted).
+ Where("? = ?", bun.Ident("account_id"), sourceAccountID).
+ Where("? = ?", bun.Ident("target_account_id"), targetAccountID).
+ Returning("?", bun.Ident("id")).
+ Exec(ctx); err != nil {
+ return nil, err
+ }
- // Gather necessary fields from
- // deleted for cache invaliation.
- var deleted gtsmodel.FollowRequest
- deleted.AccountID = sourceAccountID
- deleted.TargetAccountID = targetAccountID
-
- // Delete all follow reqs either
- // from account, or targeting account,
- // returning the deleted models.
- if _, err := r.db.NewDelete().
- Model(&deleted).
- Where("? = ?", bun.Ident("account_id"), sourceAccountID).
- Where("? = ?", bun.Ident("target_account_id"), targetAccountID).
- Returning("?", bun.Ident("id")).
- Exec(ctx); err != nil &&
- !errors.Is(err, db.ErrNoEntries) {
- return err
- }
-
- // Invalidate cached follow with source / target account IDs,
- // manually calling invalidate hook in case it isn't cached.
- r.state.Caches.DB.FollowRequest.Invalidate("AccountID,TargetAccountID",
- sourceAccountID, targetAccountID)
- r.state.Caches.OnInvalidateFollowRequest(&deleted)
-
- return nil
+ return &deleted, nil
+ })
}
func (r *relationshipDB) DeleteFollowRequestByID(ctx context.Context, id string) error {
- // Gather necessary fields from
- // deleted for cache invaliation.
- var deleted gtsmodel.FollowRequest
- deleted.ID = id
-
- // Delete follow with given URI,
- // returning the deleted models.
- if _, err := r.db.NewDelete().
- Model(&deleted).
- Where("? = ?", bun.Ident("id"), id).
- Returning("?, ?",
- bun.Ident("account_id"),
- bun.Ident("target_account_id"),
- ).
- Exec(ctx); err != nil &&
- !errors.Is(err, db.ErrNoEntries) {
- return err
- }
-
- // Invalidate cached follow with URI, manually
- // call invalidate hook in case not cached.
- r.state.Caches.DB.FollowRequest.Invalidate("ID", id)
- r.state.Caches.OnInvalidateFollowRequest(&deleted)
+ return r.deleteFollowRequest(ctx, func(tx bun.Tx) (*gtsmodel.FollowRequest, error) {
+ var deleted gtsmodel.FollowRequest
+ deleted.ID = id
+
+ if _, err := tx.NewDelete().
+ Model(&deleted).
+ Where("? = ?", bun.Ident("id"), id).
+ Returning("?, ?",
+ bun.Ident("account_id"),
+ bun.Ident("target_account_id"),
+ ).
+ Exec(ctx); err != nil {
+ return nil, err
+ }
- return nil
+ return &deleted, nil
+ })
}
func (r *relationshipDB) DeleteFollowRequestByURI(ctx context.Context, uri string) error {
- // Gather necessary fields from
- // deleted for cache invaliation.
- var deleted gtsmodel.FollowRequest
-
- // Delete follow with given URI,
- // returning the deleted models.
- if _, err := r.db.NewDelete().
- Model(&deleted).
- Where("? = ?", bun.Ident("uri"), uri).
- Returning("?, ?, ?",
- bun.Ident("id"),
- bun.Ident("account_id"),
- bun.Ident("target_account_id"),
- ).
- Exec(ctx); err != nil &&
- !errors.Is(err, db.ErrNoEntries) {
- return err
- }
-
- // Invalidate cached follow with URI, manually
- // call invalidate hook in case not cached.
- r.state.Caches.DB.FollowRequest.Invalidate("URI", uri)
- r.state.Caches.OnInvalidateFollowRequest(&deleted)
+ return r.deleteFollowRequest(ctx, func(tx bun.Tx) (*gtsmodel.FollowRequest, error) {
+ var deleted gtsmodel.FollowRequest
+ deleted.URI = uri
+
+ if _, err := tx.NewDelete().
+ Model(&deleted).
+ Where("? = ?", bun.Ident("uri"), uri).
+ Returning("?, ?, ?",
+ bun.Ident("id"),
+ bun.Ident("account_id"),
+ bun.Ident("target_account_id"),
+ ).
+ Exec(ctx); err != nil {
+ return nil, err
+ }
- return nil
+ return &deleted, nil
+ })
}
func (r *relationshipDB) DeleteAccountFollowRequests(ctx context.Context, accountID string) error {
@@ -381,24 +343,44 @@ func (r *relationshipDB) DeleteAccountFollowRequests(ctx context.Context, accoun
// deleted for cache invaliation.
var deleted []*gtsmodel.FollowRequest
- // Delete all follows either from
- // account, or targeting account,
- // returning the deleted models.
- if _, err := r.db.NewDelete().
- Model(&deleted).
- WhereOr("? = ? OR ? = ?",
- bun.Ident("account_id"),
- accountID,
- bun.Ident("target_account_id"),
- accountID,
- ).
- Returning("?, ?, ?",
- bun.Ident("id"),
- bun.Ident("account_id"),
- bun.Ident("target_account_id"),
- ).
- Exec(ctx); err != nil &&
- !errors.Is(err, db.ErrNoEntries) {
+ if err := r.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
+ // Delete all follows either from
+ // account, or targeting account,
+ // returning the deleted models.
+ if _, err := tx.NewDelete().
+ Model(&deleted).
+ WhereOr("? = ? OR ? = ?",
+ bun.Ident("account_id"),
+ accountID,
+ bun.Ident("target_account_id"),
+ accountID,
+ ).
+ Returning("?, ?, ?",
+ bun.Ident("id"),
+ bun.Ident("account_id"),
+ bun.Ident("target_account_id"),
+ ).
+ Exec(ctx); err != nil {
+
+ // the RETURNING here will cause an ErrNoRows
+ // to be returned on DELETE, which is caught
+ // outside this RunInTx() func, and ensures we
+ // return early here to *not* update statistics.
+ return err
+ }
+
+ for _, follow := range deleted {
+ // Decrement target follow requests count.
+ if err := decrementAccountStats(ctx, tx,
+ "follow_requests_count",
+ follow.TargetAccountID,
+ ); err != nil {
+ return err
+ }
+ }
+
+ return nil
+ }); err != nil && !errors.Is(err, db.ErrNoEntries) {
return err
}
@@ -414,3 +396,56 @@ func (r *relationshipDB) DeleteAccountFollowRequests(ctx context.Context, accoun
return nil
}
+
+func (r *relationshipDB) insertFollowRequest(ctx context.Context, follow *gtsmodel.FollowRequest, insert func(bun.Tx) error) error {
+ return r.state.Caches.DB.FollowRequest.Store(follow, func() error {
+ return r.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
+ // Perform the insert operation.
+ if err := insert(tx); err != nil {
+ return gtserror.Newf("error inserting follow request: %w", err)
+ }
+
+ // Increment target follow requests count.
+ return incrementAccountStats(ctx, tx,
+ "follow_requests_count",
+ follow.TargetAccountID,
+ )
+ })
+ })
+}
+
+func (r *relationshipDB) deleteFollowRequest(ctx context.Context, delete func(bun.Tx) (*gtsmodel.FollowRequest, error)) error {
+ var follow *gtsmodel.FollowRequest
+
+ if err := r.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) (err error) {
+ // Perform delete operation.
+ follow, err = delete(tx)
+ if err != nil {
+
+ // the RETURNING here will cause an ErrNoRows
+ // to be returned on DELETE, which is caught
+ // outside this RunInTx() func, and ensures we
+ // return early here to *not* update statistics.
+ return err
+ }
+
+ // Decrement target follow requests count.
+ return decrementAccountStats(ctx, tx,
+ "follow_requests_count",
+ follow.TargetAccountID,
+ )
+ }); err != nil && !errors.Is(err, db.ErrNoEntries) {
+ return err
+ }
+
+ if follow == nil {
+ return nil
+ }
+
+ // Invalidate cached follow with ID, manually
+ // call invalidate hook in case not cached.
+ r.state.Caches.DB.FollowRequest.Invalidate("ID", follow.ID)
+ r.state.Caches.OnInvalidateFollowRequest(follow)
+
+ return nil
+}