summaryrefslogtreecommitdiff
path: root/internal/db/bundb/account.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/db/bundb/account.go')
-rw-r--r--internal/db/bundb/account.go214
1 files changed, 106 insertions, 108 deletions
diff --git a/internal/db/bundb/account.go b/internal/db/bundb/account.go
index 4813f4e17..1e9c390d8 100644
--- a/internal/db/bundb/account.go
+++ b/internal/db/bundb/account.go
@@ -24,7 +24,7 @@ import (
"strings"
"time"
- "github.com/superseriousbusiness/gotosocial/internal/cache"
+ "codeberg.org/gruf/go-cache/v3/result"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
@@ -35,10 +35,29 @@ import (
type accountDB struct {
conn *DBConn
- cache *cache.AccountCache
+ cache *result.Cache[*gtsmodel.Account]
status *statusDB
}
+func (a *accountDB) init() {
+ // Initialize account result cache
+ a.cache = result.NewSized([]result.Lookup{
+ {Name: "ID"},
+ {Name: "URI"},
+ {Name: "URL"},
+ {Name: "Username.Domain"},
+ {Name: "PublicKeyURI"},
+ }, func(a1 *gtsmodel.Account) *gtsmodel.Account {
+ a2 := new(gtsmodel.Account)
+ *a2 = *a1
+ return a2
+ }, 1000)
+
+ // Set cache TTL and start sweep routine
+ a.cache.SetTTL(time.Minute*5, false)
+ a.cache.Start(time.Second * 10)
+}
+
func (a *accountDB) newAccountQ(account *gtsmodel.Account) *bun.SelectQuery {
return a.conn.
NewSelect().
@@ -51,45 +70,41 @@ func (a *accountDB) newAccountQ(account *gtsmodel.Account) *bun.SelectQuery {
func (a *accountDB) GetAccountByID(ctx context.Context, id string) (*gtsmodel.Account, db.Error) {
return a.getAccount(
ctx,
- func() (*gtsmodel.Account, bool) {
- return a.cache.GetByID(id)
- },
+ "ID",
func(account *gtsmodel.Account) error {
return a.newAccountQ(account).Where("? = ?", bun.Ident("account.id"), id).Scan(ctx)
},
+ id,
)
}
func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) {
return a.getAccount(
ctx,
- func() (*gtsmodel.Account, bool) {
- return a.cache.GetByURI(uri)
- },
+ "URI",
func(account *gtsmodel.Account) error {
return a.newAccountQ(account).Where("? = ?", bun.Ident("account.uri"), uri).Scan(ctx)
},
+ uri,
)
}
func (a *accountDB) GetAccountByURL(ctx context.Context, url string) (*gtsmodel.Account, db.Error) {
return a.getAccount(
ctx,
- func() (*gtsmodel.Account, bool) {
- return a.cache.GetByURL(url)
- },
+ "URL",
func(account *gtsmodel.Account) error {
return a.newAccountQ(account).Where("? = ?", bun.Ident("account.url"), url).Scan(ctx)
},
+ url,
)
}
func (a *accountDB) GetAccountByUsernameDomain(ctx context.Context, username string, domain string) (*gtsmodel.Account, db.Error) {
+ username = strings.ToLower(username)
return a.getAccount(
ctx,
- func() (*gtsmodel.Account, bool) {
- return a.cache.GetByUsernameDomain(username, domain)
- },
+ "Username.Domain",
func(account *gtsmodel.Account) error {
q := a.newAccountQ(account)
@@ -97,113 +112,117 @@ func (a *accountDB) GetAccountByUsernameDomain(ctx context.Context, username str
q = q.Where("? = ?", bun.Ident("account.username"), username)
q = q.Where("? = ?", bun.Ident("account.domain"), domain)
} else {
- q = q.Where("? = ?", bun.Ident("account.username"), strings.ToLower(username))
+ q = q.Where("? = ?", bun.Ident("account.username"), username)
q = q.Where("? IS NULL", bun.Ident("account.domain"))
}
return q.Scan(ctx)
},
+ username,
+ domain,
)
}
func (a *accountDB) GetAccountByPubkeyID(ctx context.Context, id string) (*gtsmodel.Account, db.Error) {
return a.getAccount(
ctx,
- func() (*gtsmodel.Account, bool) {
- return a.cache.GetByPubkeyID(id)
- },
+ "PublicKeyURI",
func(account *gtsmodel.Account) error {
return a.newAccountQ(account).Where("? = ?", bun.Ident("account.public_key_uri"), id).Scan(ctx)
},
+ id,
)
}
-func (a *accountDB) getAccount(ctx context.Context, cacheGet func() (*gtsmodel.Account, bool), dbQuery func(*gtsmodel.Account) error) (*gtsmodel.Account, db.Error) {
- // Attempt to fetch cached account
- account, cached := cacheGet()
+func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, db.Error) {
+ var username string
- if !cached {
- account = &gtsmodel.Account{}
+ if domain == "" {
+ // I.e. our local instance account
+ username = config.GetHost()
+ } else {
+ // A remote instance account
+ username = domain
+ }
+
+ return a.GetAccountByUsernameDomain(ctx, username, domain)
+}
+
+func (a *accountDB) getAccount(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Account) error, keyParts ...any) (*gtsmodel.Account, db.Error) {
+ return a.cache.Load(lookup, func() (*gtsmodel.Account, error) {
+ var account gtsmodel.Account
// Not cached! Perform database query
- err := dbQuery(account)
- if err != nil {
+ if err := dbQuery(&account); err != nil {
return nil, a.conn.ProcessError(err)
}
- // Place in the cache
- a.cache.Put(account)
- }
-
- return account, nil
+ return &account, nil
+ }, keyParts...)
}
-func (a *accountDB) PutAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, db.Error) {
- if err := a.conn.RunInTx(ctx, func(tx bun.Tx) error {
- // create links between this account and any emojis it uses
- for _, i := range account.EmojiIDs {
- if _, err := tx.NewInsert().Model(&gtsmodel.AccountToEmoji{
- AccountID: account.ID,
- EmojiID: i,
- }).Exec(ctx); err != nil {
- return err
+func (a *accountDB) PutAccount(ctx context.Context, account *gtsmodel.Account) db.Error {
+ return a.cache.Store(account, func() error {
+ // It is safe to run this database transaction within cache.Store
+ // as the cache does not attempt a mutex lock until AFTER hook.
+ //
+ return a.conn.RunInTx(ctx, func(tx bun.Tx) error {
+ // create links between this account and any emojis it uses
+ for _, i := range account.EmojiIDs {
+ if _, err := tx.NewInsert().Model(&gtsmodel.AccountToEmoji{
+ AccountID: account.ID,
+ EmojiID: i,
+ }).Exec(ctx); err != nil {
+ return err
+ }
}
- }
- // insert the account
- _, err := tx.NewInsert().Model(account).Exec(ctx)
- return err
- }); err != nil {
- return nil, a.conn.ProcessError(err)
- }
-
- a.cache.Put(account)
- return account, nil
+ // insert the account
+ _, err := tx.NewInsert().Model(account).Exec(ctx)
+ return err
+ })
+ })
}
-func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, db.Error) {
+func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account) db.Error {
// Update the account's last-updated
account.UpdatedAt = time.Now()
- if err := a.conn.RunInTx(ctx, func(tx bun.Tx) error {
- // create links between this account and any emojis it uses
- // first clear out any old emoji links
- if _, err := tx.
- NewDelete().
- TableExpr("? AS ?", bun.Ident("account_to_emojis"), bun.Ident("account_to_emoji")).
- Where("? = ?", bun.Ident("account_to_emoji.account_id"), account.ID).
- Exec(ctx); err != nil {
- return err
- }
-
- // now populate new emoji links
- for _, i := range account.EmojiIDs {
+ return a.cache.Store(account, func() error {
+ // It is safe to run this database transaction within cache.Store
+ // as the cache does not attempt a mutex lock until AFTER hook.
+ //
+ return a.conn.RunInTx(ctx, func(tx bun.Tx) error {
+ // create links between this account and any emojis it uses
+ // first clear out any old emoji links
if _, err := tx.
- NewInsert().
- Model(&gtsmodel.AccountToEmoji{
- AccountID: account.ID,
- EmojiID: i,
- }).Exec(ctx); err != nil {
+ NewDelete().
+ TableExpr("? AS ?", bun.Ident("account_to_emojis"), bun.Ident("account_to_emoji")).
+ Where("? = ?", bun.Ident("account_to_emoji.account_id"), account.ID).
+ Exec(ctx); err != nil {
return err
}
- }
- // update the account
- if _, err := tx.
- NewUpdate().
- Model(account).
- Where("? = ?", bun.Ident("account.id"), account.ID).
- Exec(ctx); err != nil {
- return err
- }
-
- return nil
- }); err != nil {
- return nil, a.conn.ProcessError(err)
- }
+ // now populate new emoji links
+ for _, i := range account.EmojiIDs {
+ if _, err := tx.
+ NewInsert().
+ Model(&gtsmodel.AccountToEmoji{
+ AccountID: account.ID,
+ EmojiID: i,
+ }).Exec(ctx); err != nil {
+ return err
+ }
+ }
- a.cache.Put(account)
- return account, nil
+ // update the account
+ _, err := tx.NewUpdate().
+ Model(account).
+ Where("? = ?", bun.Ident("account.id"), account.ID).
+ Exec(ctx)
+ return err
+ })
+ })
}
func (a *accountDB) DeleteAccount(ctx context.Context, id string) db.Error {
@@ -219,40 +238,19 @@ func (a *accountDB) DeleteAccount(ctx context.Context, id string) db.Error {
// delete the account
_, err := tx.
- NewUpdate().
+ NewDelete().
TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")).
Where("? = ?", bun.Ident("account.id"), id).
Exec(ctx)
return err
}); err != nil {
- return a.conn.ProcessError(err)
+ return err
}
- a.cache.Invalidate(id)
+ a.cache.Invalidate("ID", id)
return nil
}
-func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, db.Error) {
- account := new(gtsmodel.Account)
-
- q := a.newAccountQ(account)
-
- if domain != "" {
- q = q.
- Where("? = ?", bun.Ident("account.username"), domain).
- Where("? = ?", bun.Ident("account.domain"), domain)
- } else {
- q = q.
- Where("? = ?", bun.Ident("account.username"), config.GetHost()).
- WhereGroup(" AND ", whereEmptyOrNull("domain"))
- }
-
- if err := q.Scan(ctx); err != nil {
- return nil, a.conn.ProcessError(err)
- }
- return account, nil
-}
-
func (a *accountDB) GetAccountLastPosted(ctx context.Context, accountID string, webOnly bool) (time.Time, db.Error) {
createdAt := time.Time{}