diff options
Diffstat (limited to 'internal/db/bundb/account.go')
-rw-r--r-- | internal/db/bundb/account.go | 58 |
1 files changed, 47 insertions, 11 deletions
diff --git a/internal/db/bundb/account.go b/internal/db/bundb/account.go index 4b4c78726..e0d574f62 100644 --- a/internal/db/bundb/account.go +++ b/internal/db/bundb/account.go @@ -532,20 +532,56 @@ func (a *accountDB) GetAccountFaves(ctx context.Context, accountID string) ([]*g } func (a *accountDB) CountAccountStatuses(ctx context.Context, accountID string) (int, error) { - return a.db. - NewSelect(). - TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). - Where("? = ?", bun.Ident("status.account_id"), accountID). - Count(ctx) + counts, err := a.getAccountStatusCounts(ctx, accountID) + return counts.Statuses, err } func (a *accountDB) CountAccountPinned(ctx context.Context, accountID string) (int, error) { - return a.db. - NewSelect(). - TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). - Where("? = ?", bun.Ident("status.account_id"), accountID). - Where("? IS NOT NULL", bun.Ident("status.pinned_at")). - Count(ctx) + counts, err := a.getAccountStatusCounts(ctx, accountID) + return counts.Pinned, err +} + +func (a *accountDB) getAccountStatusCounts(ctx context.Context, accountID string) (struct { + Statuses int + Pinned int +}, error) { + // Check for an already cached copy of account status counts. + counts, ok := a.state.Caches.GTS.AccountCounts.Get(accountID) + if ok { + return counts, nil + } + + if err := a.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + var err error + + // Scan database for account statuses. + counts.Statuses, err = tx.NewSelect(). + Table("statuses"). + Where("? = ?", bun.Ident("account_id"), accountID). + Count(ctx) + if err != nil { + return err + } + + // Scan database for pinned statuses. + counts.Pinned, err = tx.NewSelect(). + Table("statuses"). + Where("? = ?", bun.Ident("account_id"), accountID). + Where("? IS NOT NULL", bun.Ident("pinned_at")). + Count(ctx) + if err != nil { + return err + } + + return nil + }); err != nil { + return counts, err + } + + // Store this account counts result in the cache. + a.state.Caches.GTS.AccountCounts.Set(accountID, counts) + + return counts, nil } func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, limit int, excludeReplies bool, excludeReblogs bool, maxID string, minID string, mediaOnly bool, publicOnly bool) ([]*gtsmodel.Status, error) { |