diff options
Diffstat (limited to 'internal/db/bundb')
| -rw-r--r-- | internal/db/bundb/relationship.go | 61 |
1 files changed, 53 insertions, 8 deletions
diff --git a/internal/db/bundb/relationship.go b/internal/db/bundb/relationship.go index 867282376..687e29f81 100644 --- a/internal/db/bundb/relationship.go +++ b/internal/db/bundb/relationship.go @@ -20,6 +20,7 @@ package bundb import ( "context" "errors" + "slices" "time" "code.superseriousbusiness.org/gotosocial/internal/db" @@ -170,16 +171,24 @@ func (r *relationshipDB) GetAccountFollowRequesting(ctx context.Context, account return r.GetFollowRequestsByIDs(ctx, followReqIDs) } -func (r *relationshipDB) GetAccountBlocks(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.Block, error) { - blockIDs, err := r.GetAccountBlockIDs(ctx, accountID, page) +func (r *relationshipDB) GetAccountBlocking(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.Block, error) { + blockIDs, err := r.GetAccountBlockingIDs(ctx, accountID, page) if err != nil { return nil, err } return r.GetBlocksByIDs(ctx, blockIDs) } -func (r *relationshipDB) CountAccountBlocks(ctx context.Context, accountID string) (int, error) { - blockIDs, err := r.GetAccountBlockIDs(ctx, accountID, nil) +func (r *relationshipDB) GetAccountBlockedBy(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.Block, error) { + blockIDs, err := r.GetAccountBlockedByIDs(ctx, accountID, page) + if err != nil { + return nil, err + } + return r.GetBlocksByIDs(ctx, blockIDs) +} + +func (r *relationshipDB) CountAccountBlocking(ctx context.Context, accountID string) (int, error) { + blockIDs, err := r.GetAccountBlockingIDs(ctx, accountID, nil) return len(blockIDs), err } @@ -273,12 +282,12 @@ func (r *relationshipDB) GetAccountFollowRequestingIDs(ctx context.Context, acco }) } -func (r *relationshipDB) GetAccountBlockIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) { +func (r *relationshipDB) GetAccountBlockingIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) { return loadPagedIDs(&r.state.Caches.DB.BlockIDs, accountID, page, func() ([]string, error) { var blockIDs []string // Block IDs not in cache, perform DB query! - q := newSelectBlocks(r.db, accountID) + q := newSelectBlocking(r.db, accountID) if _, err := q.Exec(ctx, &blockIDs); // nocollapse err != nil && !errors.Is(err, db.ErrNoEntries) { return nil, err @@ -288,6 +297,33 @@ func (r *relationshipDB) GetAccountBlockIDs(ctx context.Context, accountID strin }) } +func (r *relationshipDB) GetAccountBlockedByIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) { + var blockIDs []string + + // NOTE that we are specifically not using + // any caching here, as this is only called + // when deleting an account, i.e. pointless! + + // Block IDs not in cache, perform DB query! + q := newSelectBlockedBy(r.db, accountID) + if _, err := q.Exec(ctx, &blockIDs); // nocollapse + err != nil && !errors.Is(err, db.ErrNoEntries) { + return nil, err + } + + // Our selected IDs are ALWAYS fetched + // from `loadDESC` in descending order. + // Depending on the paging requested + // this may be an unexpected order. + if page.GetOrder().Ascending() { + slices.Reverse(blockIDs) + } + + // Page the resulting block IDs. + blockIDs = page.Page(blockIDs) + return blockIDs, nil +} + // newSelectFollowRequests returns a new select query for all rows in the follow_requests table with target_account_id = accountID. func newSelectFollowRequests(db *bun.DB, accountID string) *bun.SelectQuery { return db.NewSelect(). @@ -360,11 +396,20 @@ func newSelectLocalFollowers(db *bun.DB, accountID string) *bun.SelectQuery { OrderExpr("? DESC", bun.Ident("created_at")) } -// newSelectBlocks returns a new select query for all rows in the blocks table with account_id = accountID. -func newSelectBlocks(db *bun.DB, accountID string) *bun.SelectQuery { +// newSelectBlocking returns a new select query for all rows in the blocks table with account_id = accountID. +func newSelectBlocking(db *bun.DB, accountID string) *bun.SelectQuery { return db.NewSelect(). TableExpr("?", bun.Ident("blocks")). ColumnExpr("?", bun.Ident("id")). Where("? = ?", bun.Ident("account_id"), accountID). OrderExpr("? DESC", bun.Ident("id")) } + +// newSelectBlocking returns a new select query for all rows in the blocks table with target_account_id = accountID. +func newSelectBlockedBy(db *bun.DB, accountID string) *bun.SelectQuery { + return db.NewSelect(). + TableExpr("?", bun.Ident("blocks")). + ColumnExpr("?", bun.Ident("id")). + Where("? = ?", bun.Ident("target_account_id"), accountID). + OrderExpr("? DESC", bun.Ident("id")) +} |
