summaryrefslogtreecommitdiff
path: root/internal/db
diff options
context:
space:
mode:
Diffstat (limited to 'internal/db')
-rw-r--r--internal/db/bundb/relationship.go61
-rw-r--r--internal/db/relationship.go18
2 files changed, 65 insertions, 14 deletions
diff --git a/internal/db/bundb/relationship.go b/internal/db/bundb/relationship.go
index 867282376..687e29f81 100644
--- a/internal/db/bundb/relationship.go
+++ b/internal/db/bundb/relationship.go
@@ -20,6 +20,7 @@ package bundb
import (
"context"
"errors"
+ "slices"
"time"
"code.superseriousbusiness.org/gotosocial/internal/db"
@@ -170,16 +171,24 @@ func (r *relationshipDB) GetAccountFollowRequesting(ctx context.Context, account
return r.GetFollowRequestsByIDs(ctx, followReqIDs)
}
-func (r *relationshipDB) GetAccountBlocks(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.Block, error) {
- blockIDs, err := r.GetAccountBlockIDs(ctx, accountID, page)
+func (r *relationshipDB) GetAccountBlocking(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.Block, error) {
+ blockIDs, err := r.GetAccountBlockingIDs(ctx, accountID, page)
if err != nil {
return nil, err
}
return r.GetBlocksByIDs(ctx, blockIDs)
}
-func (r *relationshipDB) CountAccountBlocks(ctx context.Context, accountID string) (int, error) {
- blockIDs, err := r.GetAccountBlockIDs(ctx, accountID, nil)
+func (r *relationshipDB) GetAccountBlockedBy(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.Block, error) {
+ blockIDs, err := r.GetAccountBlockedByIDs(ctx, accountID, page)
+ if err != nil {
+ return nil, err
+ }
+ return r.GetBlocksByIDs(ctx, blockIDs)
+}
+
+func (r *relationshipDB) CountAccountBlocking(ctx context.Context, accountID string) (int, error) {
+ blockIDs, err := r.GetAccountBlockingIDs(ctx, accountID, nil)
return len(blockIDs), err
}
@@ -273,12 +282,12 @@ func (r *relationshipDB) GetAccountFollowRequestingIDs(ctx context.Context, acco
})
}
-func (r *relationshipDB) GetAccountBlockIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) {
+func (r *relationshipDB) GetAccountBlockingIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) {
return loadPagedIDs(&r.state.Caches.DB.BlockIDs, accountID, page, func() ([]string, error) {
var blockIDs []string
// Block IDs not in cache, perform DB query!
- q := newSelectBlocks(r.db, accountID)
+ q := newSelectBlocking(r.db, accountID)
if _, err := q.Exec(ctx, &blockIDs); // nocollapse
err != nil && !errors.Is(err, db.ErrNoEntries) {
return nil, err
@@ -288,6 +297,33 @@ func (r *relationshipDB) GetAccountBlockIDs(ctx context.Context, accountID strin
})
}
+func (r *relationshipDB) GetAccountBlockedByIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) {
+ var blockIDs []string
+
+ // NOTE that we are specifically not using
+ // any caching here, as this is only called
+ // when deleting an account, i.e. pointless!
+
+ // Block IDs not in cache, perform DB query!
+ q := newSelectBlockedBy(r.db, accountID)
+ if _, err := q.Exec(ctx, &blockIDs); // nocollapse
+ err != nil && !errors.Is(err, db.ErrNoEntries) {
+ return nil, err
+ }
+
+ // Our selected IDs are ALWAYS fetched
+ // from `loadDESC` in descending order.
+ // Depending on the paging requested
+ // this may be an unexpected order.
+ if page.GetOrder().Ascending() {
+ slices.Reverse(blockIDs)
+ }
+
+ // Page the resulting block IDs.
+ blockIDs = page.Page(blockIDs)
+ return blockIDs, nil
+}
+
// newSelectFollowRequests returns a new select query for all rows in the follow_requests table with target_account_id = accountID.
func newSelectFollowRequests(db *bun.DB, accountID string) *bun.SelectQuery {
return db.NewSelect().
@@ -360,11 +396,20 @@ func newSelectLocalFollowers(db *bun.DB, accountID string) *bun.SelectQuery {
OrderExpr("? DESC", bun.Ident("created_at"))
}
-// newSelectBlocks returns a new select query for all rows in the blocks table with account_id = accountID.
-func newSelectBlocks(db *bun.DB, accountID string) *bun.SelectQuery {
+// newSelectBlocking returns a new select query for all rows in the blocks table with account_id = accountID.
+func newSelectBlocking(db *bun.DB, accountID string) *bun.SelectQuery {
return db.NewSelect().
TableExpr("?", bun.Ident("blocks")).
ColumnExpr("?", bun.Ident("id")).
Where("? = ?", bun.Ident("account_id"), accountID).
OrderExpr("? DESC", bun.Ident("id"))
}
+
+// newSelectBlocking returns a new select query for all rows in the blocks table with target_account_id = accountID.
+func newSelectBlockedBy(db *bun.DB, accountID string) *bun.SelectQuery {
+ return db.NewSelect().
+ TableExpr("?", bun.Ident("blocks")).
+ ColumnExpr("?", bun.Ident("id")).
+ Where("? = ?", bun.Ident("target_account_id"), accountID).
+ OrderExpr("? DESC", bun.Ident("id"))
+}
diff --git a/internal/db/relationship.go b/internal/db/relationship.go
index b63e911e6..c00b8f233 100644
--- a/internal/db/relationship.go
+++ b/internal/db/relationship.go
@@ -176,14 +176,20 @@ type Relationship interface {
// GetAccountFollowRequestingIDs is like GetAccountFollowRequesting, but returns just IDs.
GetAccountFollowRequestingIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error)
- // GetAccountBlocks returns all blocks originating from the given account, with given optional paging parameters.
- GetAccountBlocks(ctx context.Context, accountID string, paging *paging.Page) ([]*gtsmodel.Block, error)
+ // GetAccountBlocking returns all blocks originating from the given account, with given optional paging parameters.
+ GetAccountBlocking(ctx context.Context, accountID string, paging *paging.Page) ([]*gtsmodel.Block, error)
- // GetAccountBlockIDs is like GetAccountBlocks, but returns just IDs.
- GetAccountBlockIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error)
+ // GetAccountBlockingIDs is like GetAccountBlocking, but returns just IDs.
+ GetAccountBlockingIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error)
- // CountAccountBlocks counts the number of blocks owned by the given account.
- CountAccountBlocks(ctx context.Context, accountID string) (int, error)
+ // GetAccountBlockedBy returns all blocks targeting the given account, with optional paging parameters.
+ GetAccountBlockedBy(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.Block, error)
+
+ // GetAccountBlockedByIDs is like GetAccountBlockedBy, but returns just IDs.
+ GetAccountBlockedByIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error)
+
+ // CountAccountBlocking counts the number of blocks owned by the given account.
+ CountAccountBlocking(ctx context.Context, accountID string) (int, error)
// GetNote gets a private note from a source account on a target account, if it exists.
GetNote(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.AccountNote, error)