summaryrefslogtreecommitdiff
path: root/internal/db/bundb/util.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/db/bundb/util.go')
-rw-r--r--internal/db/bundb/util.go27
1 files changed, 27 insertions, 0 deletions
diff --git a/internal/db/bundb/util.go b/internal/db/bundb/util.go
index 39849ba73..ac58dd6f4 100644
--- a/internal/db/bundb/util.go
+++ b/internal/db/bundb/util.go
@@ -25,6 +25,8 @@ import (
"code.superseriousbusiness.org/gotosocial/internal/cache"
"code.superseriousbusiness.org/gotosocial/internal/db"
+ "code.superseriousbusiness.org/gotosocial/internal/gtserror"
+ "code.superseriousbusiness.org/gotosocial/internal/gtsmodel"
"code.superseriousbusiness.org/gotosocial/internal/log"
"code.superseriousbusiness.org/gotosocial/internal/paging"
"github.com/uptrace/bun"
@@ -151,6 +153,7 @@ func notExists(ctx context.Context, query *bun.SelectQuery) (bool, error) {
// loadPagedIDs loads a page of IDs from given SliceCache by `key`, resorting to `loadDESC` if required. Uses `page` to sort + page resulting IDs.
// NOTE: IDs returned from `cache` / `loadDESC` MUST be in descending order, otherwise paging will not work correctly / return things out of order.
func loadPagedIDs(cache *cache.SliceCache[string], key string, page *paging.Page, loadDESC func() ([]string, error)) ([]string, error) {
+
// Check cache for IDs, else load.
ids, err := cache.Load(key, loadDESC)
if err != nil {
@@ -171,6 +174,30 @@ func loadPagedIDs(cache *cache.SliceCache[string], key string, page *paging.Page
return ids, nil
}
+// incrementAccountStats will increment the given column in the `account_stats` table matching `account_id`.
+func incrementAccountStats(ctx context.Context, tx bun.Tx, col bun.Ident, accountID string) error {
+ if _, err := tx.NewUpdate().
+ Model((*gtsmodel.AccountStats)(nil)).
+ Where("? = ?", bun.Ident("account_id"), accountID).
+ Set("? = (? + 1)", col, col).
+ Exec(ctx); err != nil {
+ return gtserror.Newf("error updating %s: %w", col, err)
+ }
+ return nil
+}
+
+// decrementAccountStats will decrement the given column in the `account_stats` table matching `account_id`.
+func decrementAccountStats(ctx context.Context, tx bun.Tx, col bun.Ident, accountID string) error {
+ if _, err := tx.NewUpdate().
+ Model((*gtsmodel.AccountStats)(nil)).
+ Where("? = ?", bun.Ident("account_id"), accountID).
+ Set("? = (? - 1)", col, col).
+ Exec(ctx); err != nil {
+ return gtserror.Newf("error updating %s: %w", col, err)
+ }
+ return nil
+}
+
// updateWhere parses []db.Where and adds it to the given update query.
func updateWhere(q *bun.UpdateQuery, where []db.Where) {
for _, w := range where {