diff options
Diffstat (limited to 'internal/db/bundb/account.go')
-rw-r--r-- | internal/db/bundb/account.go | 214 |
1 files changed, 106 insertions, 108 deletions
diff --git a/internal/db/bundb/account.go b/internal/db/bundb/account.go index 4813f4e17..1e9c390d8 100644 --- a/internal/db/bundb/account.go +++ b/internal/db/bundb/account.go @@ -24,7 +24,7 @@ import ( "strings" "time" - "github.com/superseriousbusiness/gotosocial/internal/cache" + "codeberg.org/gruf/go-cache/v3/result" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" @@ -35,10 +35,29 @@ import ( type accountDB struct { conn *DBConn - cache *cache.AccountCache + cache *result.Cache[*gtsmodel.Account] status *statusDB } +func (a *accountDB) init() { + // Initialize account result cache + a.cache = result.NewSized([]result.Lookup{ + {Name: "ID"}, + {Name: "URI"}, + {Name: "URL"}, + {Name: "Username.Domain"}, + {Name: "PublicKeyURI"}, + }, func(a1 *gtsmodel.Account) *gtsmodel.Account { + a2 := new(gtsmodel.Account) + *a2 = *a1 + return a2 + }, 1000) + + // Set cache TTL and start sweep routine + a.cache.SetTTL(time.Minute*5, false) + a.cache.Start(time.Second * 10) +} + func (a *accountDB) newAccountQ(account *gtsmodel.Account) *bun.SelectQuery { return a.conn. NewSelect(). @@ -51,45 +70,41 @@ func (a *accountDB) newAccountQ(account *gtsmodel.Account) *bun.SelectQuery { func (a *accountDB) GetAccountByID(ctx context.Context, id string) (*gtsmodel.Account, db.Error) { return a.getAccount( ctx, - func() (*gtsmodel.Account, bool) { - return a.cache.GetByID(id) - }, + "ID", func(account *gtsmodel.Account) error { return a.newAccountQ(account).Where("? = ?", bun.Ident("account.id"), id).Scan(ctx) }, + id, ) } func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) { return a.getAccount( ctx, - func() (*gtsmodel.Account, bool) { - return a.cache.GetByURI(uri) - }, + "URI", func(account *gtsmodel.Account) error { return a.newAccountQ(account).Where("? = ?", bun.Ident("account.uri"), uri).Scan(ctx) }, + uri, ) } func (a *accountDB) GetAccountByURL(ctx context.Context, url string) (*gtsmodel.Account, db.Error) { return a.getAccount( ctx, - func() (*gtsmodel.Account, bool) { - return a.cache.GetByURL(url) - }, + "URL", func(account *gtsmodel.Account) error { return a.newAccountQ(account).Where("? = ?", bun.Ident("account.url"), url).Scan(ctx) }, + url, ) } func (a *accountDB) GetAccountByUsernameDomain(ctx context.Context, username string, domain string) (*gtsmodel.Account, db.Error) { + username = strings.ToLower(username) return a.getAccount( ctx, - func() (*gtsmodel.Account, bool) { - return a.cache.GetByUsernameDomain(username, domain) - }, + "Username.Domain", func(account *gtsmodel.Account) error { q := a.newAccountQ(account) @@ -97,113 +112,117 @@ func (a *accountDB) GetAccountByUsernameDomain(ctx context.Context, username str q = q.Where("? = ?", bun.Ident("account.username"), username) q = q.Where("? = ?", bun.Ident("account.domain"), domain) } else { - q = q.Where("? = ?", bun.Ident("account.username"), strings.ToLower(username)) + q = q.Where("? = ?", bun.Ident("account.username"), username) q = q.Where("? IS NULL", bun.Ident("account.domain")) } return q.Scan(ctx) }, + username, + domain, ) } func (a *accountDB) GetAccountByPubkeyID(ctx context.Context, id string) (*gtsmodel.Account, db.Error) { return a.getAccount( ctx, - func() (*gtsmodel.Account, bool) { - return a.cache.GetByPubkeyID(id) - }, + "PublicKeyURI", func(account *gtsmodel.Account) error { return a.newAccountQ(account).Where("? = ?", bun.Ident("account.public_key_uri"), id).Scan(ctx) }, + id, ) } -func (a *accountDB) getAccount(ctx context.Context, cacheGet func() (*gtsmodel.Account, bool), dbQuery func(*gtsmodel.Account) error) (*gtsmodel.Account, db.Error) { - // Attempt to fetch cached account - account, cached := cacheGet() +func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, db.Error) { + var username string - if !cached { - account = >smodel.Account{} + if domain == "" { + // I.e. our local instance account + username = config.GetHost() + } else { + // A remote instance account + username = domain + } + + return a.GetAccountByUsernameDomain(ctx, username, domain) +} + +func (a *accountDB) getAccount(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Account) error, keyParts ...any) (*gtsmodel.Account, db.Error) { + return a.cache.Load(lookup, func() (*gtsmodel.Account, error) { + var account gtsmodel.Account // Not cached! Perform database query - err := dbQuery(account) - if err != nil { + if err := dbQuery(&account); err != nil { return nil, a.conn.ProcessError(err) } - // Place in the cache - a.cache.Put(account) - } - - return account, nil + return &account, nil + }, keyParts...) } -func (a *accountDB) PutAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, db.Error) { - if err := a.conn.RunInTx(ctx, func(tx bun.Tx) error { - // create links between this account and any emojis it uses - for _, i := range account.EmojiIDs { - if _, err := tx.NewInsert().Model(>smodel.AccountToEmoji{ - AccountID: account.ID, - EmojiID: i, - }).Exec(ctx); err != nil { - return err +func (a *accountDB) PutAccount(ctx context.Context, account *gtsmodel.Account) db.Error { + return a.cache.Store(account, func() error { + // It is safe to run this database transaction within cache.Store + // as the cache does not attempt a mutex lock until AFTER hook. + // + return a.conn.RunInTx(ctx, func(tx bun.Tx) error { + // create links between this account and any emojis it uses + for _, i := range account.EmojiIDs { + if _, err := tx.NewInsert().Model(>smodel.AccountToEmoji{ + AccountID: account.ID, + EmojiID: i, + }).Exec(ctx); err != nil { + return err + } } - } - // insert the account - _, err := tx.NewInsert().Model(account).Exec(ctx) - return err - }); err != nil { - return nil, a.conn.ProcessError(err) - } - - a.cache.Put(account) - return account, nil + // insert the account + _, err := tx.NewInsert().Model(account).Exec(ctx) + return err + }) + }) } -func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, db.Error) { +func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account) db.Error { // Update the account's last-updated account.UpdatedAt = time.Now() - if err := a.conn.RunInTx(ctx, func(tx bun.Tx) error { - // create links between this account and any emojis it uses - // first clear out any old emoji links - if _, err := tx. - NewDelete(). - TableExpr("? AS ?", bun.Ident("account_to_emojis"), bun.Ident("account_to_emoji")). - Where("? = ?", bun.Ident("account_to_emoji.account_id"), account.ID). - Exec(ctx); err != nil { - return err - } - - // now populate new emoji links - for _, i := range account.EmojiIDs { + return a.cache.Store(account, func() error { + // It is safe to run this database transaction within cache.Store + // as the cache does not attempt a mutex lock until AFTER hook. + // + return a.conn.RunInTx(ctx, func(tx bun.Tx) error { + // create links between this account and any emojis it uses + // first clear out any old emoji links if _, err := tx. - NewInsert(). - Model(>smodel.AccountToEmoji{ - AccountID: account.ID, - EmojiID: i, - }).Exec(ctx); err != nil { + NewDelete(). + TableExpr("? AS ?", bun.Ident("account_to_emojis"), bun.Ident("account_to_emoji")). + Where("? = ?", bun.Ident("account_to_emoji.account_id"), account.ID). + Exec(ctx); err != nil { return err } - } - // update the account - if _, err := tx. - NewUpdate(). - Model(account). - Where("? = ?", bun.Ident("account.id"), account.ID). - Exec(ctx); err != nil { - return err - } - - return nil - }); err != nil { - return nil, a.conn.ProcessError(err) - } + // now populate new emoji links + for _, i := range account.EmojiIDs { + if _, err := tx. + NewInsert(). + Model(>smodel.AccountToEmoji{ + AccountID: account.ID, + EmojiID: i, + }).Exec(ctx); err != nil { + return err + } + } - a.cache.Put(account) - return account, nil + // update the account + _, err := tx.NewUpdate(). + Model(account). + Where("? = ?", bun.Ident("account.id"), account.ID). + Exec(ctx) + return err + }) + }) } func (a *accountDB) DeleteAccount(ctx context.Context, id string) db.Error { @@ -219,40 +238,19 @@ func (a *accountDB) DeleteAccount(ctx context.Context, id string) db.Error { // delete the account _, err := tx. - NewUpdate(). + NewDelete(). TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")). Where("? = ?", bun.Ident("account.id"), id). Exec(ctx) return err }); err != nil { - return a.conn.ProcessError(err) + return err } - a.cache.Invalidate(id) + a.cache.Invalidate("ID", id) return nil } -func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, db.Error) { - account := new(gtsmodel.Account) - - q := a.newAccountQ(account) - - if domain != "" { - q = q. - Where("? = ?", bun.Ident("account.username"), domain). - Where("? = ?", bun.Ident("account.domain"), domain) - } else { - q = q. - Where("? = ?", bun.Ident("account.username"), config.GetHost()). - WhereGroup(" AND ", whereEmptyOrNull("domain")) - } - - if err := q.Scan(ctx); err != nil { - return nil, a.conn.ProcessError(err) - } - return account, nil -} - func (a *accountDB) GetAccountLastPosted(ctx context.Context, accountID string, webOnly bool) (time.Time, db.Error) { createdAt := time.Time{} |