summaryrefslogtreecommitdiff
path: root/internal/db/bundb/account.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/db/bundb/account.go')
-rw-r--r--internal/db/bundb/account.go267
1 files changed, 189 insertions, 78 deletions
diff --git a/internal/db/bundb/account.go b/internal/db/bundb/account.go
index 45e67c10b..2b3c78aff 100644
--- a/internal/db/bundb/account.go
+++ b/internal/db/bundb/account.go
@@ -630,6 +630,13 @@ func (a *accountDB) PopulateAccount(ctx context.Context, account *gtsmodel.Accou
}
}
+ if account.Stats == nil {
+ // Get / Create stats for this account.
+ if err := a.state.DB.PopulateAccountStats(ctx, account); err != nil {
+ errs.Appendf("error populating account stats: %w", err)
+ }
+ }
+
return errs.Combine()
}
@@ -735,31 +742,6 @@ func (a *accountDB) DeleteAccount(ctx context.Context, id string) error {
})
}
-func (a *accountDB) GetAccountLastPosted(ctx context.Context, accountID string, webOnly bool) (time.Time, error) {
- createdAt := time.Time{}
-
- q := a.db.
- NewSelect().
- TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
- Column("status.created_at").
- Where("? = ?", bun.Ident("status.account_id"), accountID).
- Order("status.id DESC").
- Limit(1)
-
- if webOnly {
- q = q.
- Where("? IS NULL", bun.Ident("status.in_reply_to_uri")).
- Where("? IS NULL", bun.Ident("status.boost_of_id")).
- Where("? = ?", bun.Ident("status.visibility"), gtsmodel.VisibilityPublic).
- Where("? = ?", bun.Ident("status.federated"), true)
- }
-
- if err := q.Scan(ctx, &createdAt); err != nil {
- return time.Time{}, err
- }
- return createdAt, nil
-}
-
func (a *accountDB) SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachment *gtsmodel.MediaAttachment, accountID string) error {
if *mediaAttachment.Avatar && *mediaAttachment.Header {
return errors.New("one media attachment cannot be both header and avatar")
@@ -845,59 +827,6 @@ func (a *accountDB) GetAccountFaves(ctx context.Context, accountID string) ([]*g
return *faves, nil
}
-func (a *accountDB) CountAccountStatuses(ctx context.Context, accountID string) (int, error) {
- counts, err := a.getAccountStatusCounts(ctx, accountID)
- return counts.Statuses, err
-}
-
-func (a *accountDB) CountAccountPinned(ctx context.Context, accountID string) (int, error) {
- counts, err := a.getAccountStatusCounts(ctx, accountID)
- return counts.Pinned, err
-}
-
-func (a *accountDB) getAccountStatusCounts(ctx context.Context, accountID string) (struct {
- Statuses int
- Pinned int
-}, error) {
- // Check for an already cached copy of account status counts.
- counts, ok := a.state.Caches.GTS.AccountCounts.Get(accountID)
- if ok {
- return counts, nil
- }
-
- if err := a.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
- var err error
-
- // Scan database for account statuses.
- counts.Statuses, err = tx.NewSelect().
- Table("statuses").
- Where("? = ?", bun.Ident("account_id"), accountID).
- Count(ctx)
- if err != nil {
- return err
- }
-
- // Scan database for pinned statuses.
- counts.Pinned, err = tx.NewSelect().
- Table("statuses").
- Where("? = ?", bun.Ident("account_id"), accountID).
- Where("? IS NOT NULL", bun.Ident("pinned_at")).
- Count(ctx)
- if err != nil {
- return err
- }
-
- return nil
- }); err != nil {
- return counts, err
- }
-
- // Store this account counts result in the cache.
- a.state.Caches.GTS.AccountCounts.Set(accountID, counts)
-
- return counts, nil
-}
-
func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, limit int, excludeReplies bool, excludeReblogs bool, maxID string, minID string, mediaOnly bool, publicOnly bool) ([]*gtsmodel.Status, error) {
// Ensure reasonable
if limit < 0 {
@@ -1147,3 +1076,185 @@ func (a *accountDB) UpdateAccountSettings(
return nil
})
}
+
+func (a *accountDB) PopulateAccountStats(ctx context.Context, account *gtsmodel.Account) error {
+ // Fetch stats from db cache with loader callback.
+ stats, err := a.state.Caches.GTS.AccountStats.LoadOne(
+ "AccountID",
+ func() (*gtsmodel.AccountStats, error) {
+ // Not cached! Perform database query.
+ var stats gtsmodel.AccountStats
+ if err := a.db.
+ NewSelect().
+ Model(&stats).
+ Where("? = ?", bun.Ident("account_stats.account_id"), account.ID).
+ Scan(ctx); err != nil {
+ return nil, err
+ }
+ return &stats, nil
+ },
+ account.ID,
+ )
+
+ if err != nil && !errors.Is(err, db.ErrNoEntries) {
+ // Real error.
+ return err
+ }
+
+ if stats == nil {
+ // Don't have stats yet, generate them.
+ return a.RegenerateAccountStats(ctx, account)
+ }
+
+ // We have a stats, attach
+ // it to the account.
+ account.Stats = stats
+
+ // Check if this is a local
+ // stats by looking at the
+ // account they pertain to.
+ if account.IsRemote() {
+ // Account is remote. Updating
+ // stats for remote accounts is
+ // handled in the dereferencer.
+ //
+ // Nothing more to do!
+ return nil
+ }
+
+ // Stats account is local, check
+ // if we need to regenerate.
+ const statsFreshness = 48 * time.Hour
+ expiry := stats.RegeneratedAt.Add(statsFreshness)
+ if time.Now().After(expiry) {
+ // Stats have expired, regenerate them.
+ return a.RegenerateAccountStats(ctx, account)
+ }
+
+ // Stats are still fresh.
+ return nil
+}
+
+func (a *accountDB) RegenerateAccountStats(ctx context.Context, account *gtsmodel.Account) error {
+ // Initialize a new stats struct.
+ stats := &gtsmodel.AccountStats{
+ AccountID: account.ID,
+ RegeneratedAt: time.Now(),
+ }
+
+ // Count followers outside of transaction since
+ // it uses a cache + requires its own db calls.
+ followerIDs, err := a.state.DB.GetAccountFollowerIDs(ctx, account.ID, nil)
+ if err != nil {
+ return err
+ }
+ stats.FollowersCount = util.Ptr(len(followerIDs))
+
+ // Count following outside of transaction since
+ // it uses a cache + requires its own db calls.
+ followIDs, err := a.state.DB.GetAccountFollowIDs(ctx, account.ID, nil)
+ if err != nil {
+ return err
+ }
+ stats.FollowingCount = util.Ptr(len(followIDs))
+
+ // Count follow requests outside of transaction since
+ // it uses a cache + requires its own db calls.
+ followRequestIDs, err := a.state.DB.GetAccountFollowRequestIDs(ctx, account.ID, nil)
+ if err != nil {
+ return err
+ }
+ stats.FollowRequestsCount = util.Ptr(len(followRequestIDs))
+
+ // Populate remaining stats struct fields.
+ // This can be done inside a transaction.
+ if err := a.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
+ var err error
+
+ // Scan database for account statuses.
+ statusesCount, err := tx.NewSelect().
+ Table("statuses").
+ Where("? = ?", bun.Ident("account_id"), account.ID).
+ Count(ctx)
+ if err != nil {
+ return err
+ }
+ stats.StatusesCount = &statusesCount
+
+ // Scan database for pinned statuses.
+ statusesPinnedCount, err := tx.NewSelect().
+ Table("statuses").
+ Where("? = ?", bun.Ident("account_id"), account.ID).
+ Where("? IS NOT NULL", bun.Ident("pinned_at")).
+ Count(ctx)
+ if err != nil {
+ return err
+ }
+ stats.StatusesPinnedCount = &statusesPinnedCount
+
+ // Scan database for last status.
+ lastStatusAt := time.Time{}
+ err = tx.
+ NewSelect().
+ TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
+ Column("status.created_at").
+ Where("? = ?", bun.Ident("status.account_id"), account.ID).
+ Order("status.id DESC").
+ Limit(1).
+ Scan(ctx, &lastStatusAt)
+ if err != nil && !errors.Is(err, db.ErrNoEntries) {
+ return err
+ }
+ stats.LastStatusAt = lastStatusAt
+
+ return nil
+ }); err != nil {
+ return err
+ }
+
+ // Upsert this stats in case a race
+ // meant someone else inserted it first.
+ if err := a.state.Caches.GTS.AccountStats.Store(stats, func() error {
+ if _, err := NewUpsert(a.db).
+ Model(stats).
+ Constraint("account_id").
+ Exec(ctx); err != nil {
+ return err
+ }
+ return nil
+ }); err != nil {
+ return err
+ }
+
+ account.Stats = stats
+ return nil
+}
+
+func (a *accountDB) UpdateAccountStats(ctx context.Context, stats *gtsmodel.AccountStats, columns ...string) error {
+ return a.state.Caches.GTS.AccountStats.Store(stats, func() error {
+ if _, err := a.db.
+ NewUpdate().
+ Model(stats).
+ Column(columns...).
+ Where("? = ?", bun.Ident("account_stats.account_id"), stats.AccountID).
+ Exec(ctx); err != nil {
+ return err
+ }
+
+ return nil
+ })
+}
+
+func (a *accountDB) DeleteAccountStats(ctx context.Context, accountID string) error {
+ defer a.state.Caches.GTS.AccountStats.Invalidate("AccountID", accountID)
+
+ if _, err := a.db.
+ NewDelete().
+ Table("account_stats").
+ Where("? = ?", bun.Ident("account_id"), accountID).
+ Exec(ctx); err != nil {
+ return err
+ }
+
+ return nil
+}