diff options
Diffstat (limited to 'internal/db/bundb/account.go')
-rw-r--r-- | internal/db/bundb/account.go | 86 |
1 files changed, 54 insertions, 32 deletions
diff --git a/internal/db/bundb/account.go b/internal/db/bundb/account.go index d7d45a739..32a70f7cd 100644 --- a/internal/db/bundb/account.go +++ b/internal/db/bundb/account.go @@ -25,6 +25,7 @@ import ( "strings" "time" + "github.com/superseriousbusiness/gotosocial/internal/cache" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" @@ -34,6 +35,7 @@ import ( type accountDB struct { config *config.Config conn *DBConn + cache *cache.AccountCache } func (a *accountDB) newAccountQ(account *gtsmodel.Account) *bun.SelectQuery { @@ -45,60 +47,80 @@ func (a *accountDB) newAccountQ(account *gtsmodel.Account) *bun.SelectQuery { } func (a *accountDB) GetAccountByID(ctx context.Context, id string) (*gtsmodel.Account, db.Error) { - account := new(gtsmodel.Account) - - q := a.newAccountQ(account). - Where("account.id = ?", id) - - err := q.Scan(ctx) - if err != nil { - return nil, a.conn.ProcessError(err) - } - return account, nil + return a.getAccount( + ctx, + func() (*gtsmodel.Account, bool) { + return a.cache.GetByID(id) + }, + func(account *gtsmodel.Account) error { + return a.newAccountQ(account).Where("account.id = ?", id).Scan(ctx) + }, + ) } func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) { - account := new(gtsmodel.Account) - - q := a.newAccountQ(account). - Where("account.uri = ?", uri) + return a.getAccount( + ctx, + func() (*gtsmodel.Account, bool) { + return a.cache.GetByURI(uri) + }, + func(account *gtsmodel.Account) error { + return a.newAccountQ(account).Where("account.uri = ?", uri).Scan(ctx) + }, + ) +} - err := q.Scan(ctx) - if err != nil { - return nil, a.conn.ProcessError(err) - } - return account, nil +func (a *accountDB) GetAccountByURL(ctx context.Context, url string) (*gtsmodel.Account, db.Error) { + return a.getAccount( + ctx, + func() (*gtsmodel.Account, bool) { + return a.cache.GetByURL(url) + }, + func(account *gtsmodel.Account) error { + return a.newAccountQ(account).Where("account.url = ?", url).Scan(ctx) + }, + ) } -func (a *accountDB) GetAccountByURL(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) { - account := new(gtsmodel.Account) +func (a *accountDB) getAccount(ctx context.Context, cacheGet func() (*gtsmodel.Account, bool), dbQuery func(*gtsmodel.Account) error) (*gtsmodel.Account, db.Error) { + // Attempt to fetch cached account + account, cached := cacheGet() - q := a.newAccountQ(account). - Where("account.url = ?", uri) + if !cached { + account = >smodel.Account{} - err := q.Scan(ctx) - if err != nil { - return nil, a.conn.ProcessError(err) + // Not cached! Perform database query + err := dbQuery(account) + if err != nil { + return nil, a.conn.ProcessError(err) + } + + // Place in the cache + a.cache.Put(account) } + return account, nil } func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, db.Error) { if strings.TrimSpace(account.ID) == "" { + // TODO: we should not need this check here return nil, errors.New("account had no ID") } + // Update the account's last-used account.UpdatedAt = time.Now() - q := a.conn. - NewUpdate(). - Model(account). - WherePK() - - _, err := q.Exec(ctx) + // Update the account model in the DB + _, err := a.conn.NewUpdate().Model(account).WherePK().Exec(ctx) if err != nil { return nil, a.conn.ProcessError(err) } + + // Place updated account in cache + // (this will replace existing, i.e. invalidating) + a.cache.Put(account) + return account, nil } |