summaryrefslogtreecommitdiff
path: root/internal/db/bundb/account.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/db/bundb/account.go')
-rw-r--r--internal/db/bundb/account.go58
1 files changed, 47 insertions, 11 deletions
diff --git a/internal/db/bundb/account.go b/internal/db/bundb/account.go
index 4b4c78726..e0d574f62 100644
--- a/internal/db/bundb/account.go
+++ b/internal/db/bundb/account.go
@@ -532,20 +532,56 @@ func (a *accountDB) GetAccountFaves(ctx context.Context, accountID string) ([]*g
}
func (a *accountDB) CountAccountStatuses(ctx context.Context, accountID string) (int, error) {
- return a.db.
- NewSelect().
- TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
- Where("? = ?", bun.Ident("status.account_id"), accountID).
- Count(ctx)
+ counts, err := a.getAccountStatusCounts(ctx, accountID)
+ return counts.Statuses, err
}
func (a *accountDB) CountAccountPinned(ctx context.Context, accountID string) (int, error) {
- return a.db.
- NewSelect().
- TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
- Where("? = ?", bun.Ident("status.account_id"), accountID).
- Where("? IS NOT NULL", bun.Ident("status.pinned_at")).
- Count(ctx)
+ counts, err := a.getAccountStatusCounts(ctx, accountID)
+ return counts.Pinned, err
+}
+
+func (a *accountDB) getAccountStatusCounts(ctx context.Context, accountID string) (struct {
+ Statuses int
+ Pinned int
+}, error) {
+ // Check for an already cached copy of account status counts.
+ counts, ok := a.state.Caches.GTS.AccountCounts.Get(accountID)
+ if ok {
+ return counts, nil
+ }
+
+ if err := a.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
+ var err error
+
+ // Scan database for account statuses.
+ counts.Statuses, err = tx.NewSelect().
+ Table("statuses").
+ Where("? = ?", bun.Ident("account_id"), accountID).
+ Count(ctx)
+ if err != nil {
+ return err
+ }
+
+ // Scan database for pinned statuses.
+ counts.Pinned, err = tx.NewSelect().
+ Table("statuses").
+ Where("? = ?", bun.Ident("account_id"), accountID).
+ Where("? IS NOT NULL", bun.Ident("pinned_at")).
+ Count(ctx)
+ if err != nil {
+ return err
+ }
+
+ return nil
+ }); err != nil {
+ return counts, err
+ }
+
+ // Store this account counts result in the cache.
+ a.state.Caches.GTS.AccountCounts.Set(accountID, counts)
+
+ return counts, nil
}
func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, limit int, excludeReplies bool, excludeReblogs bool, maxID string, minID string, mediaOnly bool, publicOnly bool) ([]*gtsmodel.Status, error) {