diff options
author | 2022-11-15 18:45:15 +0000 | |
---|---|---|
committer | 2022-11-15 18:45:15 +0000 | |
commit | 8598dea98b872647393117704659878d9b38d4fc (patch) | |
tree | 1940168912dc7f54af723439dbc9f6e0a42f30ae /internal/db | |
parent | [docs] Both HTTP proxies and NAT can cause rate limiting issues (#1053) (diff) | |
download | gotosocial-8598dea98b872647393117704659878d9b38d4fc.tar.xz |
[chore] update database caching library (#1040)
* convert most of the caches to use result.Cache{}
* add caching of emojis
* fix issues causing failing tests
* update go-cache/v2 instances with v3
* fix getnotification
* add a note about the left-in StatusCreate comment
* update EmojiCategory db access to use new result.Cache{}
* fix possible panic in getstatusparents
* further proof that kim is not stinky
Diffstat (limited to 'internal/db')
-rw-r--r-- | internal/db/account.go | 4 | ||||
-rw-r--r-- | internal/db/bundb/account.go | 214 | ||||
-rw-r--r-- | internal/db/bundb/account_test.go | 4 | ||||
-rw-r--r-- | internal/db/bundb/admin.go | 33 | ||||
-rw-r--r-- | internal/db/bundb/admin_test.go | 2 | ||||
-rw-r--r-- | internal/db/bundb/bundb.go | 76 | ||||
-rw-r--r-- | internal/db/bundb/domain.go | 94 | ||||
-rw-r--r-- | internal/db/bundb/emoji.go | 127 | ||||
-rw-r--r-- | internal/db/bundb/mention.go | 48 | ||||
-rw-r--r-- | internal/db/bundb/notification.go | 54 | ||||
-rw-r--r-- | internal/db/bundb/status.go | 251 | ||||
-rw-r--r-- | internal/db/bundb/timeline_test.go | 26 | ||||
-rw-r--r-- | internal/db/bundb/tombstone.go | 2 | ||||
-rw-r--r-- | internal/db/bundb/user.go | 175 | ||||
-rw-r--r-- | internal/db/bundb/user_test.go | 17 | ||||
-rw-r--r-- | internal/db/status.go | 2 | ||||
-rw-r--r-- | internal/db/user.go | 7 |
17 files changed, 578 insertions, 558 deletions
diff --git a/internal/db/account.go b/internal/db/account.go index a58aa9dd3..7e7d1de43 100644 --- a/internal/db/account.go +++ b/internal/db/account.go @@ -43,10 +43,10 @@ type Account interface { GetAccountByPubkeyID(ctx context.Context, id string) (*gtsmodel.Account, Error) // PutAccount puts one account in the database. - PutAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, Error) + PutAccount(ctx context.Context, account *gtsmodel.Account) Error // UpdateAccount updates one account by ID. - UpdateAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, Error) + UpdateAccount(ctx context.Context, account *gtsmodel.Account) Error // DeleteAccount deletes one account from the database by its ID. // DO NOT USE THIS WHEN SUSPENDING ACCOUNTS! In that case you should mark the diff --git a/internal/db/bundb/account.go b/internal/db/bundb/account.go index 4813f4e17..1e9c390d8 100644 --- a/internal/db/bundb/account.go +++ b/internal/db/bundb/account.go @@ -24,7 +24,7 @@ import ( "strings" "time" - "github.com/superseriousbusiness/gotosocial/internal/cache" + "codeberg.org/gruf/go-cache/v3/result" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" @@ -35,10 +35,29 @@ import ( type accountDB struct { conn *DBConn - cache *cache.AccountCache + cache *result.Cache[*gtsmodel.Account] status *statusDB } +func (a *accountDB) init() { + // Initialize account result cache + a.cache = result.NewSized([]result.Lookup{ + {Name: "ID"}, + {Name: "URI"}, + {Name: "URL"}, + {Name: "Username.Domain"}, + {Name: "PublicKeyURI"}, + }, func(a1 *gtsmodel.Account) *gtsmodel.Account { + a2 := new(gtsmodel.Account) + *a2 = *a1 + return a2 + }, 1000) + + // Set cache TTL and start sweep routine + a.cache.SetTTL(time.Minute*5, false) + a.cache.Start(time.Second * 10) +} + func (a *accountDB) newAccountQ(account *gtsmodel.Account) *bun.SelectQuery { return a.conn. NewSelect(). @@ -51,45 +70,41 @@ func (a *accountDB) newAccountQ(account *gtsmodel.Account) *bun.SelectQuery { func (a *accountDB) GetAccountByID(ctx context.Context, id string) (*gtsmodel.Account, db.Error) { return a.getAccount( ctx, - func() (*gtsmodel.Account, bool) { - return a.cache.GetByID(id) - }, + "ID", func(account *gtsmodel.Account) error { return a.newAccountQ(account).Where("? = ?", bun.Ident("account.id"), id).Scan(ctx) }, + id, ) } func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) { return a.getAccount( ctx, - func() (*gtsmodel.Account, bool) { - return a.cache.GetByURI(uri) - }, + "URI", func(account *gtsmodel.Account) error { return a.newAccountQ(account).Where("? = ?", bun.Ident("account.uri"), uri).Scan(ctx) }, + uri, ) } func (a *accountDB) GetAccountByURL(ctx context.Context, url string) (*gtsmodel.Account, db.Error) { return a.getAccount( ctx, - func() (*gtsmodel.Account, bool) { - return a.cache.GetByURL(url) - }, + "URL", func(account *gtsmodel.Account) error { return a.newAccountQ(account).Where("? = ?", bun.Ident("account.url"), url).Scan(ctx) }, + url, ) } func (a *accountDB) GetAccountByUsernameDomain(ctx context.Context, username string, domain string) (*gtsmodel.Account, db.Error) { + username = strings.ToLower(username) return a.getAccount( ctx, - func() (*gtsmodel.Account, bool) { - return a.cache.GetByUsernameDomain(username, domain) - }, + "Username.Domain", func(account *gtsmodel.Account) error { q := a.newAccountQ(account) @@ -97,113 +112,117 @@ func (a *accountDB) GetAccountByUsernameDomain(ctx context.Context, username str q = q.Where("? = ?", bun.Ident("account.username"), username) q = q.Where("? = ?", bun.Ident("account.domain"), domain) } else { - q = q.Where("? = ?", bun.Ident("account.username"), strings.ToLower(username)) + q = q.Where("? = ?", bun.Ident("account.username"), username) q = q.Where("? IS NULL", bun.Ident("account.domain")) } return q.Scan(ctx) }, + username, + domain, ) } func (a *accountDB) GetAccountByPubkeyID(ctx context.Context, id string) (*gtsmodel.Account, db.Error) { return a.getAccount( ctx, - func() (*gtsmodel.Account, bool) { - return a.cache.GetByPubkeyID(id) - }, + "PublicKeyURI", func(account *gtsmodel.Account) error { return a.newAccountQ(account).Where("? = ?", bun.Ident("account.public_key_uri"), id).Scan(ctx) }, + id, ) } -func (a *accountDB) getAccount(ctx context.Context, cacheGet func() (*gtsmodel.Account, bool), dbQuery func(*gtsmodel.Account) error) (*gtsmodel.Account, db.Error) { - // Attempt to fetch cached account - account, cached := cacheGet() +func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, db.Error) { + var username string - if !cached { - account = >smodel.Account{} + if domain == "" { + // I.e. our local instance account + username = config.GetHost() + } else { + // A remote instance account + username = domain + } + + return a.GetAccountByUsernameDomain(ctx, username, domain) +} + +func (a *accountDB) getAccount(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Account) error, keyParts ...any) (*gtsmodel.Account, db.Error) { + return a.cache.Load(lookup, func() (*gtsmodel.Account, error) { + var account gtsmodel.Account // Not cached! Perform database query - err := dbQuery(account) - if err != nil { + if err := dbQuery(&account); err != nil { return nil, a.conn.ProcessError(err) } - // Place in the cache - a.cache.Put(account) - } - - return account, nil + return &account, nil + }, keyParts...) } -func (a *accountDB) PutAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, db.Error) { - if err := a.conn.RunInTx(ctx, func(tx bun.Tx) error { - // create links between this account and any emojis it uses - for _, i := range account.EmojiIDs { - if _, err := tx.NewInsert().Model(>smodel.AccountToEmoji{ - AccountID: account.ID, - EmojiID: i, - }).Exec(ctx); err != nil { - return err +func (a *accountDB) PutAccount(ctx context.Context, account *gtsmodel.Account) db.Error { + return a.cache.Store(account, func() error { + // It is safe to run this database transaction within cache.Store + // as the cache does not attempt a mutex lock until AFTER hook. + // + return a.conn.RunInTx(ctx, func(tx bun.Tx) error { + // create links between this account and any emojis it uses + for _, i := range account.EmojiIDs { + if _, err := tx.NewInsert().Model(>smodel.AccountToEmoji{ + AccountID: account.ID, + EmojiID: i, + }).Exec(ctx); err != nil { + return err + } } - } - // insert the account - _, err := tx.NewInsert().Model(account).Exec(ctx) - return err - }); err != nil { - return nil, a.conn.ProcessError(err) - } - - a.cache.Put(account) - return account, nil + // insert the account + _, err := tx.NewInsert().Model(account).Exec(ctx) + return err + }) + }) } -func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, db.Error) { +func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account) db.Error { // Update the account's last-updated account.UpdatedAt = time.Now() - if err := a.conn.RunInTx(ctx, func(tx bun.Tx) error { - // create links between this account and any emojis it uses - // first clear out any old emoji links - if _, err := tx. - NewDelete(). - TableExpr("? AS ?", bun.Ident("account_to_emojis"), bun.Ident("account_to_emoji")). - Where("? = ?", bun.Ident("account_to_emoji.account_id"), account.ID). - Exec(ctx); err != nil { - return err - } - - // now populate new emoji links - for _, i := range account.EmojiIDs { + return a.cache.Store(account, func() error { + // It is safe to run this database transaction within cache.Store + // as the cache does not attempt a mutex lock until AFTER hook. + // + return a.conn.RunInTx(ctx, func(tx bun.Tx) error { + // create links between this account and any emojis it uses + // first clear out any old emoji links if _, err := tx. - NewInsert(). - Model(>smodel.AccountToEmoji{ - AccountID: account.ID, - EmojiID: i, - }).Exec(ctx); err != nil { + NewDelete(). + TableExpr("? AS ?", bun.Ident("account_to_emojis"), bun.Ident("account_to_emoji")). + Where("? = ?", bun.Ident("account_to_emoji.account_id"), account.ID). + Exec(ctx); err != nil { return err } - } - // update the account - if _, err := tx. - NewUpdate(). - Model(account). - Where("? = ?", bun.Ident("account.id"), account.ID). - Exec(ctx); err != nil { - return err - } - - return nil - }); err != nil { - return nil, a.conn.ProcessError(err) - } + // now populate new emoji links + for _, i := range account.EmojiIDs { + if _, err := tx. + NewInsert(). + Model(>smodel.AccountToEmoji{ + AccountID: account.ID, + EmojiID: i, + }).Exec(ctx); err != nil { + return err + } + } - a.cache.Put(account) - return account, nil + // update the account + _, err := tx.NewUpdate(). + Model(account). + Where("? = ?", bun.Ident("account.id"), account.ID). + Exec(ctx) + return err + }) + }) } func (a *accountDB) DeleteAccount(ctx context.Context, id string) db.Error { @@ -219,40 +238,19 @@ func (a *accountDB) DeleteAccount(ctx context.Context, id string) db.Error { // delete the account _, err := tx. - NewUpdate(). + NewDelete(). TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")). Where("? = ?", bun.Ident("account.id"), id). Exec(ctx) return err }); err != nil { - return a.conn.ProcessError(err) + return err } - a.cache.Invalidate(id) + a.cache.Invalidate("ID", id) return nil } -func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, db.Error) { - account := new(gtsmodel.Account) - - q := a.newAccountQ(account) - - if domain != "" { - q = q. - Where("? = ?", bun.Ident("account.username"), domain). - Where("? = ?", bun.Ident("account.domain"), domain) - } else { - q = q. - Where("? = ?", bun.Ident("account.username"), config.GetHost()). - WhereGroup(" AND ", whereEmptyOrNull("domain")) - } - - if err := q.Scan(ctx); err != nil { - return nil, a.conn.ProcessError(err) - } - return account, nil -} - func (a *accountDB) GetAccountLastPosted(ctx context.Context, accountID string, webOnly bool) (time.Time, db.Error) { createdAt := time.Time{} diff --git a/internal/db/bundb/account_test.go b/internal/db/bundb/account_test.go index 29594a740..50603623f 100644 --- a/internal/db/bundb/account_test.go +++ b/internal/db/bundb/account_test.go @@ -92,7 +92,7 @@ func (suite *AccountTestSuite) TestUpdateAccount() { testAccount.DisplayName = "new display name!" testAccount.EmojiIDs = []string{"01GD36ZKWTKY3T1JJ24JR7KY1Q", "01GD36ZV904SHBHNAYV6DX5QEF"} - _, err := suite.db.UpdateAccount(ctx, testAccount) + err := suite.db.UpdateAccount(ctx, testAccount) suite.NoError(err) updated, err := suite.db.GetAccountByID(ctx, testAccount.ID) @@ -127,7 +127,7 @@ func (suite *AccountTestSuite) TestUpdateAccount() { // update again to remove emoji associations testAccount.EmojiIDs = []string{} - _, err = suite.db.UpdateAccount(ctx, testAccount) + err = suite.db.UpdateAccount(ctx, testAccount) suite.NoError(err) updated, err = suite.db.GetAccountByID(ctx, testAccount.ID) diff --git a/internal/db/bundb/admin.go b/internal/db/bundb/admin.go index 44861a4bb..4d750581c 100644 --- a/internal/db/bundb/admin.go +++ b/internal/db/bundb/admin.go @@ -29,7 +29,6 @@ import ( "time" "github.com/superseriousbusiness/gotosocial/internal/ap" - "github.com/superseriousbusiness/gotosocial/internal/cache" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" @@ -44,9 +43,9 @@ import ( const rsaKeyBits = 2048 type adminDB struct { - conn *DBConn - userCache *cache.UserCache - accountCache *cache.AccountCache + conn *DBConn + accounts *accountDB + users *userDB } func (a *adminDB) IsUsernameAvailable(ctx context.Context, username string) (bool, db.Error) { @@ -140,13 +139,9 @@ func (a *adminDB) NewSignup(ctx context.Context, username string, reason string, } // insert the new account! - if _, err = a.conn. - NewInsert(). - Model(acct). - Exec(ctx); err != nil { - return nil, a.conn.ProcessError(err) + if err := a.accounts.PutAccount(ctx, acct); err != nil { + return nil, err } - a.accountCache.Put(acct) } // we either created or already had an account by now, @@ -190,13 +185,9 @@ func (a *adminDB) NewSignup(ctx context.Context, username string, reason string, } // insert the user! - if _, err = a.conn. - NewInsert(). - Model(u). - Exec(ctx); err != nil { - return nil, a.conn.ProcessError(err) + if err := a.users.PutUser(ctx, u); err != nil { + return nil, err } - a.userCache.Put(u) return u, nil } @@ -249,15 +240,11 @@ func (a *adminDB) CreateInstanceAccount(ctx context.Context) db.Error { FeaturedCollectionURI: newAccountURIs.CollectionURI, } - insertQ := a.conn. - NewInsert(). - Model(acct) - - if _, err := insertQ.Exec(ctx); err != nil { - return a.conn.ProcessError(err) + // insert the new account! + if err := a.accounts.PutAccount(ctx, acct); err != nil { + return err } - a.accountCache.Put(acct) log.Infof("instance account %s CREATED with id %s", username, acct.ID) return nil } diff --git a/internal/db/bundb/admin_test.go b/internal/db/bundb/admin_test.go index f0a869a9b..18e1f67e2 100644 --- a/internal/db/bundb/admin_test.go +++ b/internal/db/bundb/admin_test.go @@ -70,6 +70,8 @@ func (suite *AdminTestSuite) TestIsEmailAvailableDomainBlocked() { } func (suite *AdminTestSuite) TestCreateInstanceAccount() { + // reinitialize test DB to clear caches + suite.db = testrig.NewTestDB() // we need to take an empty db for this... testrig.StandardDBTeardown(suite.db) // ...with tables created but no data diff --git a/internal/db/bundb/bundb.go b/internal/db/bundb/bundb.go index cf6643f6b..de6749ca4 100644 --- a/internal/db/bundb/bundb.go +++ b/internal/db/bundb/bundb.go @@ -34,7 +34,6 @@ import ( "github.com/google/uuid" "github.com/jackc/pgx/v4" "github.com/jackc/pgx/v4/stdlib" - "github.com/superseriousbusiness/gotosocial/internal/cache" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db/bundb/migrations" @@ -46,7 +45,6 @@ import ( "github.com/uptrace/bun/dialect/sqlitedialect" "github.com/uptrace/bun/migrate" - grufcache "codeberg.org/gruf/go-cache/v2" "modernc.org/sqlite" ) @@ -160,79 +158,63 @@ func NewBunDBService(ctx context.Context) (db.DB, error) { return nil, fmt.Errorf("db migration error: %s", err) } - // Prepare caches required by more than one struct - userCache := cache.NewUserCache() - accountCache := cache.NewAccountCache() - - // Prepare other caches - // Prepare mentions cache - // TODO: move into internal/cache - mentionCache := grufcache.New[string, *gtsmodel.Mention]() - mentionCache.SetTTL(time.Minute*5, false) - mentionCache.Start(time.Second * 10) - - // Prepare notifications cache - // TODO: move into internal/cache - notifCache := grufcache.New[string, *gtsmodel.Notification]() - notifCache.SetTTL(time.Minute*5, false) - notifCache.Start(time.Second * 10) - // Create DB structs that require ptrs to each other - accounts := &accountDB{conn: conn, cache: accountCache} - status := &statusDB{conn: conn, cache: cache.NewStatusCache()} - emoji := &emojiDB{conn: conn, emojiCache: cache.NewEmojiCache(), categoryCache: cache.NewEmojiCategoryCache()} + account := &accountDB{conn: conn} + admin := &adminDB{conn: conn} + domain := &domainDB{conn: conn} + mention := &mentionDB{conn: conn} + notif := ¬ificationDB{conn: conn} + status := &statusDB{conn: conn} + emoji := &emojiDB{conn: conn} timeline := &timelineDB{conn: conn} tombstone := &tombstoneDB{conn: conn} + user := &userDB{conn: conn} // Setup DB cross-referencing - accounts.status = status - status.accounts = accounts + account.status = status + admin.users = user + status.accounts = account timeline.status = status // Initialize db structs + account.init() + domain.init() + emoji.init() + mention.init() + notif.init() + status.init() tombstone.init() + user.init() ps := &DBService{ - Account: accounts, + Account: account, Admin: &adminDB{ - conn: conn, - userCache: userCache, - accountCache: accountCache, + conn: conn, + accounts: account, + users: user, }, Basic: &basicDB{ conn: conn, }, - Domain: &domainDB{ - conn: conn, - cache: cache.NewDomainBlockCache(), - }, - Emoji: emoji, + Domain: domain, + Emoji: emoji, Instance: &instanceDB{ conn: conn, }, Media: &mediaDB{ conn: conn, }, - Mention: &mentionDB{ - conn: conn, - cache: mentionCache, - }, - Notification: ¬ificationDB{ - conn: conn, - cache: notifCache, - }, + Mention: mention, + Notification: notif, Relationship: &relationshipDB{ conn: conn, }, Session: &sessionDB{ conn: conn, }, - Status: status, - Timeline: timeline, - User: &userDB{ - conn: conn, - cache: userCache, - }, + Status: status, + Timeline: timeline, + User: user, Tombstone: tombstone, conn: conn, } diff --git a/internal/db/bundb/domain.go b/internal/db/bundb/domain.go index 0a752d3f3..3fca8501b 100644 --- a/internal/db/bundb/domain.go +++ b/internal/db/bundb/domain.go @@ -20,11 +20,11 @@ package bundb import ( "context" - "database/sql" "net/url" "strings" + "time" - "github.com/superseriousbusiness/gotosocial/internal/cache" + "codeberg.org/gruf/go-cache/v3/result" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" @@ -34,7 +34,22 @@ import ( type domainDB struct { conn *DBConn - cache *cache.DomainBlockCache + cache *result.Cache[*gtsmodel.DomainBlock] +} + +func (d *domainDB) init() { + // Initialize domain block result cache + d.cache = result.NewSized([]result.Lookup{ + {Name: "Domain"}, + }, func(d1 *gtsmodel.DomainBlock) *gtsmodel.DomainBlock { + d2 := new(gtsmodel.DomainBlock) + *d2 = *d1 + return d2 + }, 1000) + + // Set cache TTL and start sweep routine + d.cache.SetTTL(time.Minute*5, false) + d.cache.Start(time.Second * 10) } // normalizeDomain converts the given domain to lowercase @@ -49,76 +64,53 @@ func normalizeDomain(domain string) (out string, err error) { } func (d *domainDB) CreateDomainBlock(ctx context.Context, block *gtsmodel.DomainBlock) db.Error { - domain, err := normalizeDomain(block.Domain) + var err error + + block.Domain, err = normalizeDomain(block.Domain) if err != nil { return err } - block.Domain = domain - // Attempt to insert new domain block - if _, err := d.conn.NewInsert(). - Model(block). - Exec(ctx); err != nil { + return d.cache.Store(block, func() error { + _, err := d.conn.NewInsert(). + Model(block). + Exec(ctx) return d.conn.ProcessError(err) - } - - // Cache this domain block - d.cache.Put(block.Domain, block) - - return nil + }) } func (d *domainDB) GetDomainBlock(ctx context.Context, domain string) (*gtsmodel.DomainBlock, db.Error) { var err error + domain, err = normalizeDomain(domain) if err != nil { return nil, err } - // Check for easy case, domain referencing *us* - if domain == "" || domain == config.GetAccountDomain() { - return nil, db.ErrNoEntries - } - - // Check for already cached rblock - if block, ok := d.cache.GetByDomain(domain); ok { - // A 'nil' return value is a sentinel value for no block - if block == nil { + return d.cache.Load("Domain", func() (*gtsmodel.DomainBlock, error) { + // Check for easy case, domain referencing *us* + if domain == "" || domain == config.GetAccountDomain() { return nil, db.ErrNoEntries } - // Else, this block exists - return block, nil - } + var block gtsmodel.DomainBlock - block := >smodel.DomainBlock{} + q := d.conn. + NewSelect(). + Model(&block). + Where("? = ?", bun.Ident("domain_block.domain"), domain). + Limit(1) + if err := q.Scan(ctx); err != nil { + return nil, d.conn.ProcessError(err) + } - q := d.conn. - NewSelect(). - Model(block). - Where("? = ?", bun.Ident("domain_block.domain"), domain). - Limit(1) - - // Query database for domain block - switch err := q.Scan(ctx); err { - // No error, block found - case nil: - d.cache.Put(domain, block) - return block, nil - - // No error, simply not found - case sql.ErrNoRows: - d.cache.Put(domain, nil) - return nil, db.ErrNoEntries - - // Any other db error - default: - return nil, d.conn.ProcessError(err) - } + return &block, nil + }, domain) } func (d *domainDB) DeleteDomainBlock(ctx context.Context, domain string) db.Error { var err error + domain, err = normalizeDomain(domain) if err != nil { return err @@ -133,7 +125,7 @@ func (d *domainDB) DeleteDomainBlock(ctx context.Context, domain string) db.Erro } // Clear domain from cache - d.cache.InvalidateByDomain(domain) + d.cache.Invalidate(domain) return nil } diff --git a/internal/db/bundb/emoji.go b/internal/db/bundb/emoji.go index 81374ce78..55e0ee3ff 100644 --- a/internal/db/bundb/emoji.go +++ b/internal/db/bundb/emoji.go @@ -23,7 +23,7 @@ import ( "strings" "time" - "github.com/superseriousbusiness/gotosocial/internal/cache" + "codeberg.org/gruf/go-cache/v3/result" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/log" @@ -33,8 +33,40 @@ import ( type emojiDB struct { conn *DBConn - emojiCache *cache.EmojiCache - categoryCache *cache.EmojiCategoryCache + emojiCache *result.Cache[*gtsmodel.Emoji] + categoryCache *result.Cache[*gtsmodel.EmojiCategory] +} + +func (e *emojiDB) init() { + // Initialize emoji result cache + e.emojiCache = result.NewSized([]result.Lookup{ + {Name: "ID"}, + {Name: "URI"}, + {Name: "Shortcode.Domain"}, + {Name: "ImageStaticURL"}, + }, func(e1 *gtsmodel.Emoji) *gtsmodel.Emoji { + e2 := new(gtsmodel.Emoji) + *e2 = *e1 + return e2 + }, 1000) + + // Set cache TTL and start sweep routine + e.emojiCache.SetTTL(time.Minute*5, false) + e.emojiCache.Start(time.Second * 10) + + // Initialize category result cache + e.categoryCache = result.NewSized([]result.Lookup{ + {Name: "ID"}, + {Name: "Name"}, + }, func(c1 *gtsmodel.EmojiCategory) *gtsmodel.EmojiCategory { + c2 := new(gtsmodel.EmojiCategory) + *c2 = *c1 + return c2 + }, 1000) + + // Set cache TTL and start sweep routine + e.categoryCache.SetTTL(time.Minute*5, false) + e.categoryCache.Start(time.Second * 10) } func (e *emojiDB) newEmojiQ(emoji *gtsmodel.Emoji) *bun.SelectQuery { @@ -51,12 +83,10 @@ func (e *emojiDB) newEmojiCategoryQ(emojiCategory *gtsmodel.EmojiCategory) *bun. } func (e *emojiDB) PutEmoji(ctx context.Context, emoji *gtsmodel.Emoji) db.Error { - if _, err := e.conn.NewInsert().Model(emoji).Exec(ctx); err != nil { + return e.emojiCache.Store(emoji, func() error { + _, err := e.conn.NewInsert().Model(emoji).Exec(ctx) return e.conn.ProcessError(err) - } - - e.emojiCache.Put(emoji) - return nil + }) } func (e *emojiDB) UpdateEmoji(ctx context.Context, emoji *gtsmodel.Emoji, columns ...string) (*gtsmodel.Emoji, db.Error) { @@ -72,7 +102,7 @@ func (e *emojiDB) UpdateEmoji(ctx context.Context, emoji *gtsmodel.Emoji, column return nil, e.conn.ProcessError(err) } - e.emojiCache.Invalidate(emoji.ID) + e.emojiCache.Invalidate("ID", emoji.ID) return emoji, nil } @@ -109,7 +139,7 @@ func (e *emojiDB) DeleteEmojiByID(ctx context.Context, id string) db.Error { return err } - e.emojiCache.Invalidate(id) + e.emojiCache.Invalidate("ID", id) return nil } @@ -252,33 +282,29 @@ func (e *emojiDB) GetUseableEmojis(ctx context.Context) ([]*gtsmodel.Emoji, db.E func (e *emojiDB) GetEmojiByID(ctx context.Context, id string) (*gtsmodel.Emoji, db.Error) { return e.getEmoji( ctx, - func() (*gtsmodel.Emoji, bool) { - return e.emojiCache.GetByID(id) - }, + "ID", func(emoji *gtsmodel.Emoji) error { return e.newEmojiQ(emoji).Where("? = ?", bun.Ident("emoji.id"), id).Scan(ctx) }, + id, ) } func (e *emojiDB) GetEmojiByURI(ctx context.Context, uri string) (*gtsmodel.Emoji, db.Error) { return e.getEmoji( ctx, - func() (*gtsmodel.Emoji, bool) { - return e.emojiCache.GetByURI(uri) - }, + "URI", func(emoji *gtsmodel.Emoji) error { return e.newEmojiQ(emoji).Where("? = ?", bun.Ident("emoji.uri"), uri).Scan(ctx) }, + uri, ) } func (e *emojiDB) GetEmojiByShortcodeDomain(ctx context.Context, shortcode string, domain string) (*gtsmodel.Emoji, db.Error) { return e.getEmoji( ctx, - func() (*gtsmodel.Emoji, bool) { - return e.emojiCache.GetByShortcodeDomain(shortcode, domain) - }, + "Shortcode.Domain", func(emoji *gtsmodel.Emoji) error { q := e.newEmojiQ(emoji) @@ -292,31 +318,30 @@ func (e *emojiDB) GetEmojiByShortcodeDomain(ctx context.Context, shortcode strin return q.Scan(ctx) }, + shortcode, + domain, ) } func (e *emojiDB) GetEmojiByStaticURL(ctx context.Context, imageStaticURL string) (*gtsmodel.Emoji, db.Error) { return e.getEmoji( ctx, - func() (*gtsmodel.Emoji, bool) { - return e.emojiCache.GetByImageStaticURL(imageStaticURL) - }, + "ImageStaticURL", func(emoji *gtsmodel.Emoji) error { return e. newEmojiQ(emoji). Where("? = ?", bun.Ident("emoji.image_static_url"), imageStaticURL). Scan(ctx) }, + imageStaticURL, ) } func (e *emojiDB) PutEmojiCategory(ctx context.Context, emojiCategory *gtsmodel.EmojiCategory) db.Error { - if _, err := e.conn.NewInsert().Model(emojiCategory).Exec(ctx); err != nil { + return e.categoryCache.Store(emojiCategory, func() error { + _, err := e.conn.NewInsert().Model(emojiCategory).Exec(ctx) return e.conn.ProcessError(err) - } - - e.categoryCache.Put(emojiCategory) - return nil + }) } func (e *emojiDB) GetEmojiCategories(ctx context.Context) ([]*gtsmodel.EmojiCategory, db.Error) { @@ -338,45 +363,36 @@ func (e *emojiDB) GetEmojiCategories(ctx context.Context) ([]*gtsmodel.EmojiCate func (e *emojiDB) GetEmojiCategory(ctx context.Context, id string) (*gtsmodel.EmojiCategory, db.Error) { return e.getEmojiCategory( ctx, - func() (*gtsmodel.EmojiCategory, bool) { - return e.categoryCache.GetByID(id) - }, + "ID", func(emojiCategory *gtsmodel.EmojiCategory) error { return e.newEmojiCategoryQ(emojiCategory).Where("? = ?", bun.Ident("emoji_category.id"), id).Scan(ctx) }, + id, ) } func (e *emojiDB) GetEmojiCategoryByName(ctx context.Context, name string) (*gtsmodel.EmojiCategory, db.Error) { return e.getEmojiCategory( ctx, - func() (*gtsmodel.EmojiCategory, bool) { - return e.categoryCache.GetByName(name) - }, + "Name", func(emojiCategory *gtsmodel.EmojiCategory) error { return e.newEmojiCategoryQ(emojiCategory).Where("LOWER(?) = ?", bun.Ident("emoji_category.name"), strings.ToLower(name)).Scan(ctx) }, + name, ) } -func (e *emojiDB) getEmoji(ctx context.Context, cacheGet func() (*gtsmodel.Emoji, bool), dbQuery func(*gtsmodel.Emoji) error) (*gtsmodel.Emoji, db.Error) { - // Attempt to fetch cached emoji - emoji, cached := cacheGet() - - if !cached { - emoji = >smodel.Emoji{} +func (e *emojiDB) getEmoji(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Emoji) error, keyParts ...any) (*gtsmodel.Emoji, db.Error) { + return e.emojiCache.Load(lookup, func() (*gtsmodel.Emoji, error) { + var emoji gtsmodel.Emoji // Not cached! Perform database query - err := dbQuery(emoji) - if err != nil { + if err := dbQuery(&emoji); err != nil { return nil, e.conn.ProcessError(err) } - // Place in the cache - e.emojiCache.Put(emoji) - } - - return emoji, nil + return &emoji, nil + }, keyParts...) } func (e *emojiDB) emojisFromIDs(ctx context.Context, emojiIDs []string) ([]*gtsmodel.Emoji, db.Error) { @@ -399,24 +415,17 @@ func (e *emojiDB) emojisFromIDs(ctx context.Context, emojiIDs []string) ([]*gtsm return emojis, nil } -func (e *emojiDB) getEmojiCategory(ctx context.Context, cacheGet func() (*gtsmodel.EmojiCategory, bool), dbQuery func(*gtsmodel.EmojiCategory) error) (*gtsmodel.EmojiCategory, db.Error) { - // Attempt to fetch cached emoji categories - emojiCategory, cached := cacheGet() - - if !cached { - emojiCategory = >smodel.EmojiCategory{} +func (e *emojiDB) getEmojiCategory(ctx context.Context, lookup string, dbQuery func(*gtsmodel.EmojiCategory) error, keyParts ...any) (*gtsmodel.EmojiCategory, db.Error) { + return e.categoryCache.Load(lookup, func() (*gtsmodel.EmojiCategory, error) { + var category gtsmodel.EmojiCategory // Not cached! Perform database query - err := dbQuery(emojiCategory) - if err != nil { + if err := dbQuery(&category); err != nil { return nil, e.conn.ProcessError(err) } - // Place in the cache - e.categoryCache.Put(emojiCategory) - } - - return emojiCategory, nil + return &category, nil + }, keyParts...) } func (e *emojiDB) emojiCategoriesFromIDs(ctx context.Context, emojiCategoryIDs []string) ([]*gtsmodel.EmojiCategory, db.Error) { diff --git a/internal/db/bundb/mention.go b/internal/db/bundb/mention.go index 355078021..303e16484 100644 --- a/internal/db/bundb/mention.go +++ b/internal/db/bundb/mention.go @@ -20,8 +20,9 @@ package bundb import ( "context" + "time" - "codeberg.org/gruf/go-cache/v2" + "codeberg.org/gruf/go-cache/v3/result" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/log" @@ -30,7 +31,22 @@ import ( type mentionDB struct { conn *DBConn - cache cache.Cache[string, *gtsmodel.Mention] + cache *result.Cache[*gtsmodel.Mention] +} + +func (m *mentionDB) init() { + // Initialize notification result cache + m.cache = result.NewSized([]result.Lookup{ + {Name: "ID"}, + }, func(m1 *gtsmodel.Mention) *gtsmodel.Mention { + m2 := new(gtsmodel.Mention) + *m2 = *m1 + return m2 + }, 1000) + + // Set cache TTL and start sweep routine + m.cache.SetTTL(time.Minute*5, false) + m.cache.Start(time.Second * 10) } func (m *mentionDB) newMentionQ(i interface{}) *bun.SelectQuery { @@ -42,27 +58,19 @@ func (m *mentionDB) newMentionQ(i interface{}) *bun.SelectQuery { Relation("TargetAccount") } -func (m *mentionDB) getMentionDB(ctx context.Context, id string) (*gtsmodel.Mention, db.Error) { - mention := gtsmodel.Mention{} - - q := m.newMentionQ(&mention). - Where("? = ?", bun.Ident("mention.id"), id) - - if err := q.Scan(ctx); err != nil { - return nil, m.conn.ProcessError(err) - } +func (m *mentionDB) GetMention(ctx context.Context, id string) (*gtsmodel.Mention, db.Error) { + return m.cache.Load("ID", func() (*gtsmodel.Mention, error) { + var mention gtsmodel.Mention - copy := mention - m.cache.Set(mention.ID, ©) + q := m.newMentionQ(&mention). + Where("? = ?", bun.Ident("mention.id"), id) - return &mention, nil -} + if err := q.Scan(ctx); err != nil { + return nil, m.conn.ProcessError(err) + } -func (m *mentionDB) GetMention(ctx context.Context, id string) (*gtsmodel.Mention, db.Error) { - if mention, ok := m.cache.Get(id); ok { - return mention, nil - } - return m.getMentionDB(ctx, id) + return &mention, nil + }, id) } func (m *mentionDB) GetMentions(ctx context.Context, ids []string) ([]*gtsmodel.Mention, db.Error) { diff --git a/internal/db/bundb/notification.go b/internal/db/bundb/notification.go index 69e3cf39f..1874f81ea 100644 --- a/internal/db/bundb/notification.go +++ b/internal/db/bundb/notification.go @@ -20,8 +20,9 @@ package bundb import ( "context" + "time" - "codeberg.org/gruf/go-cache/v2" + "codeberg.org/gruf/go-cache/v3/result" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/log" @@ -30,31 +31,40 @@ import ( type notificationDB struct { conn *DBConn - cache cache.Cache[string, *gtsmodel.Notification] + cache *result.Cache[*gtsmodel.Notification] } -func (n *notificationDB) GetNotification(ctx context.Context, id string) (*gtsmodel.Notification, db.Error) { - if notification, ok := n.cache.Get(id); ok { - return notification, nil - } - - dst := gtsmodel.Notification{ID: id} - - q := n.conn.NewSelect(). - Model(&dst). - Relation("OriginAccount"). - Relation("TargetAccount"). - Relation("Status"). - Where("? = ?", bun.Ident("notification.id"), id) - - if err := q.Scan(ctx); err != nil { - return nil, n.conn.ProcessError(err) - } +func (n *notificationDB) init() { + // Initialize notification result cache + n.cache = result.NewSized([]result.Lookup{ + {Name: "ID"}, + }, func(n1 *gtsmodel.Notification) *gtsmodel.Notification { + n2 := new(gtsmodel.Notification) + *n2 = *n1 + return n2 + }, 1000) + + // Set cache TTL and start sweep routine + n.cache.SetTTL(time.Minute*5, false) + n.cache.Start(time.Second * 10) +} - copy := dst - n.cache.Set(id, ©) +func (n *notificationDB) GetNotification(ctx context.Context, id string) (*gtsmodel.Notification, db.Error) { + return n.cache.Load("ID", func() (*gtsmodel.Notification, error) { + var notif gtsmodel.Notification + + q := n.conn.NewSelect(). + Model(¬if). + Relation("OriginAccount"). + Relation("TargetAccount"). + Relation("Status"). + Where("? = ?", bun.Ident("notification.id"), id) + if err := q.Scan(ctx); err != nil { + return nil, n.conn.ProcessError(err) + } - return &dst, nil + return ¬if, nil + }, id) } func (n *notificationDB) GetNotifications(ctx context.Context, accountID string, excludeTypes []string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, db.Error) { diff --git a/internal/db/bundb/status.go b/internal/db/bundb/status.go index bc72c2849..b4ae40607 100644 --- a/internal/db/bundb/status.go +++ b/internal/db/bundb/status.go @@ -25,7 +25,7 @@ import ( "errors" "time" - "github.com/superseriousbusiness/gotosocial/internal/cache" + "codeberg.org/gruf/go-cache/v3/result" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/log" @@ -33,15 +33,28 @@ import ( ) type statusDB struct { - conn *DBConn - cache *cache.StatusCache - - // TODO: keep method definitions in same place but instead have receiver - // all point to one single "db" type, so they can all share methods - // and caches where necessary + conn *DBConn + cache *result.Cache[*gtsmodel.Status] accounts *accountDB } +func (s *statusDB) init() { + // Initialize status result cache + s.cache = result.NewSized([]result.Lookup{ + {Name: "ID"}, + {Name: "URI"}, + {Name: "URL"}, + }, func(s1 *gtsmodel.Status) *gtsmodel.Status { + s2 := new(gtsmodel.Status) + *s2 = *s1 + return s2 + }, 1000) + + // Set cache TTL and start sweep routine + s.cache.SetTTL(time.Minute*5, false) + s.cache.Start(time.Second * 10) +} + func (s *statusDB) newStatusQ(status interface{}) *bun.SelectQuery { return s.conn. NewSelect(). @@ -68,61 +81,62 @@ func (s *statusDB) newFaveQ(faves interface{}) *bun.SelectQuery { func (s *statusDB) GetStatusByID(ctx context.Context, id string) (*gtsmodel.Status, db.Error) { return s.getStatus( ctx, - func() (*gtsmodel.Status, bool) { - return s.cache.GetByID(id) - }, + "ID", func(status *gtsmodel.Status) error { return s.newStatusQ(status).Where("? = ?", bun.Ident("status.id"), id).Scan(ctx) }, + id, ) } func (s *statusDB) GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.Status, db.Error) { return s.getStatus( ctx, - func() (*gtsmodel.Status, bool) { - return s.cache.GetByURI(uri) - }, + "URI", func(status *gtsmodel.Status) error { return s.newStatusQ(status).Where("? = ?", bun.Ident("status.uri"), uri).Scan(ctx) }, + uri, ) } func (s *statusDB) GetStatusByURL(ctx context.Context, url string) (*gtsmodel.Status, db.Error) { return s.getStatus( ctx, - func() (*gtsmodel.Status, bool) { - return s.cache.GetByURL(url) - }, + "URL", func(status *gtsmodel.Status) error { return s.newStatusQ(status).Where("? = ?", bun.Ident("status.url"), url).Scan(ctx) }, + url, ) } -func (s *statusDB) getStatus(ctx context.Context, cacheGet func() (*gtsmodel.Status, bool), dbQuery func(*gtsmodel.Status) error) (*gtsmodel.Status, db.Error) { - // Attempt to fetch cached status - status, cached := cacheGet() - - if !cached { - status = >smodel.Status{} +func (s *statusDB) getStatus(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Status) error, keyParts ...any) (*gtsmodel.Status, db.Error) { + // Fetch status from database cache with loader callback + status, err := s.cache.Load(lookup, func() (*gtsmodel.Status, error) { + var status gtsmodel.Status // Not cached! Perform database query - if err := dbQuery(status); err != nil { + if err := dbQuery(&status); err != nil { return nil, s.conn.ProcessError(err) } // If there is boosted, fetch from DB also if status.BoostOfID != "" { - boostOf, err := s.GetStatusByID(ctx, status.BoostOfID) - if err == nil { - status.BoostOf = boostOf + status.BoostOf = >smodel.Status{} + err := s.newStatusQ(status.BoostOf). + Where("? = ?", bun.Ident("status.id"), status.BoostOfID). + Scan(ctx) + if err != nil { + return nil, s.conn.ProcessError(err) } } - // Place in the cache - s.cache.Put(status) + return &status, nil + }, keyParts...) + if err != nil { + // error already processed + return nil, err } // Set the status author account @@ -137,73 +151,66 @@ func (s *statusDB) getStatus(ctx context.Context, cacheGet func() (*gtsmodel.Sta } func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Error { - err := s.conn.RunInTx(ctx, func(tx bun.Tx) error { - // create links between this status and any emojis it uses - for _, i := range status.EmojiIDs { - if _, err := tx. - NewInsert(). - Model(>smodel.StatusToEmoji{ - StatusID: status.ID, - EmojiID: i, - }).Exec(ctx); err != nil { - err = s.conn.errProc(err) - if !errors.Is(err, db.ErrAlreadyExists) { - return err + return s.cache.Store(status, func() error { + // It is safe to run this database transaction within cache.Store + // as the cache does not attempt a mutex lock until AFTER hook. + // + return s.conn.RunInTx(ctx, func(tx bun.Tx) error { + // create links between this status and any emojis it uses + for _, i := range status.EmojiIDs { + if _, err := tx. + NewInsert(). + Model(>smodel.StatusToEmoji{ + StatusID: status.ID, + EmojiID: i, + }).Exec(ctx); err != nil { + err = s.conn.ProcessError(err) + if !errors.Is(err, db.ErrAlreadyExists) { + return err + } } } - } - // create links between this status and any tags it uses - for _, i := range status.TagIDs { - if _, err := tx. - NewInsert(). - Model(>smodel.StatusToTag{ - StatusID: status.ID, - TagID: i, - }).Exec(ctx); err != nil { - err = s.conn.errProc(err) - if !errors.Is(err, db.ErrAlreadyExists) { - return err + // create links between this status and any tags it uses + for _, i := range status.TagIDs { + if _, err := tx. + NewInsert(). + Model(>smodel.StatusToTag{ + StatusID: status.ID, + TagID: i, + }).Exec(ctx); err != nil { + err = s.conn.ProcessError(err) + if !errors.Is(err, db.ErrAlreadyExists) { + return err + } } } - } - // change the status ID of the media attachments to the new status - for _, a := range status.Attachments { - a.StatusID = status.ID - a.UpdatedAt = time.Now() - if _, err := tx. - NewUpdate(). - Model(a). - Where("? = ?", bun.Ident("media_attachment.id"), a.ID). - Exec(ctx); err != nil { - err = s.conn.errProc(err) - if !errors.Is(err, db.ErrAlreadyExists) { - return err + // change the status ID of the media attachments to the new status + for _, a := range status.Attachments { + a.StatusID = status.ID + a.UpdatedAt = time.Now() + if _, err := tx. + NewUpdate(). + Model(a). + Where("? = ?", bun.Ident("media_attachment.id"), a.ID). + Exec(ctx); err != nil { + err = s.conn.ProcessError(err) + if !errors.Is(err, db.ErrAlreadyExists) { + return err + } } } - } - // Finally, insert the status - if _, err := tx. - NewInsert(). - Model(status). - Exec(ctx); err != nil { + // Finally, insert the status + _, err := tx.NewInsert().Model(status).Exec(ctx) return err - } - - return nil + }) }) - if err != nil { - return s.conn.ProcessError(err) - } - - s.cache.Put(status) - return nil } -func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status) (*gtsmodel.Status, db.Error) { - err := s.conn.RunInTx(ctx, func(tx bun.Tx) error { +func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status) db.Error { + if err := s.conn.RunInTx(ctx, func(tx bun.Tx) error { // create links between this status and any emojis it uses for _, i := range status.EmojiIDs { if _, err := tx. @@ -212,7 +219,7 @@ func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status) (* StatusID: status.ID, EmojiID: i, }).Exec(ctx); err != nil { - err = s.conn.errProc(err) + err = s.conn.ProcessError(err) if !errors.Is(err, db.ErrAlreadyExists) { return err } @@ -227,14 +234,14 @@ func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status) (* StatusID: status.ID, TagID: i, }).Exec(ctx); err != nil { - err = s.conn.errProc(err) + err = s.conn.ProcessError(err) if !errors.Is(err, db.ErrAlreadyExists) { return err } } } - // change the status ID of the media attachments to this status + // change the status ID of the media attachments to the new status for _, a := range status.Attachments { a.StatusID = status.ID a.UpdatedAt = time.Now() @@ -243,31 +250,31 @@ func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status) (* Model(a). Where("? = ?", bun.Ident("media_attachment.id"), a.ID). Exec(ctx); err != nil { - return err + err = s.conn.ProcessError(err) + if !errors.Is(err, db.ErrAlreadyExists) { + return err + } } } - // Finally, update the status itself - if _, err := tx. + // Finally, insert the status + _, err := tx. NewUpdate(). Model(status). Where("? = ?", bun.Ident("status.id"), status.ID). - Exec(ctx); err != nil { - return err - } - - return nil - }) - if err != nil { - return nil, s.conn.ProcessError(err) + Exec(ctx) + return err + }); err != nil { + return err } - s.cache.Put(status) - return status, nil + // Drop any old value from cache by this ID + s.cache.Invalidate("ID", status.ID) + return nil } func (s *statusDB) DeleteStatusByID(ctx context.Context, id string) db.Error { - err := s.conn.RunInTx(ctx, func(tx bun.Tx) error { + if err := s.conn.RunInTx(ctx, func(tx bun.Tx) error { // delete links between this status and any emojis it uses if _, err := tx. NewDelete(). @@ -296,36 +303,41 @@ func (s *statusDB) DeleteStatusByID(ctx context.Context, id string) db.Error { } return nil - }) - if err != nil { - return s.conn.ProcessError(err) + }); err != nil { + return err } - s.cache.Invalidate(id) + // Drop any old value from cache by this ID + s.cache.Invalidate("ID", id) return nil } func (s *statusDB) GetStatusParents(ctx context.Context, status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, db.Error) { - parents := []*gtsmodel.Status{} - s.statusParent(ctx, status, &parents, onlyDirect) - return parents, nil -} - -func (s *statusDB) statusParent(ctx context.Context, status *gtsmodel.Status, foundStatuses *[]*gtsmodel.Status, onlyDirect bool) { - if status.InReplyToID == "" { - return + if onlyDirect { + // Only want the direct parent, no further than first level + parent, err := s.GetStatusByID(ctx, status.InReplyToID) + if err != nil { + return nil, err + } + return []*gtsmodel.Status{parent}, nil } - parentStatus, err := s.GetStatusByID(ctx, status.InReplyToID) - if err == nil { - *foundStatuses = append(*foundStatuses, parentStatus) - } + var parents []*gtsmodel.Status - if onlyDirect { - return + for id := status.InReplyToID; id != ""; { + parent, err := s.GetStatusByID(ctx, id) + if err != nil { + return nil, err + } + + // Append parent to slice + parents = append(parents, parent) + + // Set the next parent ID + id = parent.InReplyToID } - s.statusParent(ctx, parentStatus, foundStatuses, false) + return parents, nil } func (s *statusDB) GetStatusChildren(ctx context.Context, status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, db.Error) { @@ -350,7 +362,7 @@ func (s *statusDB) GetStatusChildren(ctx context.Context, status *gtsmodel.Statu } func (s *statusDB) statusChildren(ctx context.Context, status *gtsmodel.Status, foundStatuses *list.List, onlyDirect bool, minID string) { - childIDs := []string{} + var childIDs []string q := s.conn. NewSelect(). @@ -471,6 +483,7 @@ func (s *statusDB) GetStatusFaves(ctx context.Context, status *gtsmodel.Status) if err := q.Scan(ctx); err != nil { return nil, s.conn.ProcessError(err) } + return faves, nil } diff --git a/internal/db/bundb/timeline_test.go b/internal/db/bundb/timeline_test.go index 9b6365621..066f55234 100644 --- a/internal/db/bundb/timeline_test.go +++ b/internal/db/bundb/timeline_test.go @@ -35,44 +35,52 @@ type TimelineTestSuite struct { } func (suite *TimelineTestSuite) TestGetPublicTimeline() { - s, err := suite.db.GetPublicTimeline(context.Background(), "", "", "", 20, false) + ctx := context.Background() + + s, err := suite.db.GetPublicTimeline(ctx, "", "", "", 20, false) suite.NoError(err) suite.Len(s, 6) } func (suite *TimelineTestSuite) TestGetPublicTimelineWithFutureStatus() { + ctx := context.Background() + futureStatus := getFutureStatus() - if err := suite.db.Put(context.Background(), futureStatus); err != nil { - suite.FailNow(err.Error()) - } + err := suite.db.PutStatus(ctx, futureStatus) + suite.NoError(err) - s, err := suite.db.GetPublicTimeline(context.Background(), "", "", "", 20, false) + s, err := suite.db.GetPublicTimeline(ctx, "", "", "", 20, false) suite.NoError(err) + suite.NotContains(s, futureStatus) suite.Len(s, 6) } func (suite *TimelineTestSuite) TestGetHomeTimeline() { + ctx := context.Background() + viewingAccount := suite.testAccounts["local_account_1"] - s, err := suite.db.GetHomeTimeline(context.Background(), viewingAccount.ID, "", "", "", 20, false) + s, err := suite.db.GetHomeTimeline(ctx, viewingAccount.ID, "", "", "", 20, false) suite.NoError(err) suite.Len(s, 16) } func (suite *TimelineTestSuite) TestGetHomeTimelineWithFutureStatus() { + ctx := context.Background() + viewingAccount := suite.testAccounts["local_account_1"] futureStatus := getFutureStatus() - if err := suite.db.Put(context.Background(), futureStatus); err != nil { - suite.FailNow(err.Error()) - } + err := suite.db.PutStatus(ctx, futureStatus) + suite.NoError(err) s, err := suite.db.GetHomeTimeline(context.Background(), viewingAccount.ID, "", "", "", 20, false) suite.NoError(err) + suite.NotContains(s, futureStatus) suite.Len(s, 16) } diff --git a/internal/db/bundb/tombstone.go b/internal/db/bundb/tombstone.go index 7ce3327a7..309a39fd3 100644 --- a/internal/db/bundb/tombstone.go +++ b/internal/db/bundb/tombstone.go @@ -43,7 +43,7 @@ func (t *tombstoneDB) init() { t2 := new(gtsmodel.Tombstone) *t2 = *t1 return t2 - }, 1000) + }, 100) // Set cache TTL and start sweep routine t.cache.SetTTL(time.Minute*5, false) diff --git a/internal/db/bundb/user.go b/internal/db/bundb/user.go index aa2f4c2c8..d9b281a6f 100644 --- a/internal/db/bundb/user.go +++ b/internal/db/bundb/user.go @@ -22,7 +22,7 @@ import ( "context" "time" - "github.com/superseriousbusiness/gotosocial/internal/cache" + "codeberg.org/gruf/go-cache/v3/result" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/uptrace/bun" @@ -30,111 +30,121 @@ import ( type userDB struct { conn *DBConn - cache *cache.UserCache + cache *result.Cache[*gtsmodel.User] } -func (u *userDB) newUserQ(user *gtsmodel.User) *bun.SelectQuery { - return u.conn. - NewSelect(). - Model(user). - Relation("Account") +func (u *userDB) init() { + // Initialize user result cache + u.cache = result.NewSized([]result.Lookup{ + {Name: "ID"}, + {Name: "AccountID"}, + {Name: "Email"}, + {Name: "ConfirmationToken"}, + }, func(u1 *gtsmodel.User) *gtsmodel.User { + u2 := new(gtsmodel.User) + *u2 = *u1 + return u2 + }, 1000) + + // Set cache TTL and start sweep routine + u.cache.SetTTL(time.Minute*5, false) + u.cache.Start(time.Second * 10) } -func (u *userDB) getUser(ctx context.Context, cacheGet func() (*gtsmodel.User, bool), dbQuery func(*gtsmodel.User) error) (*gtsmodel.User, db.Error) { - // Attempt to fetch cached user - user, cached := cacheGet() +func (u *userDB) GetUserByID(ctx context.Context, id string) (*gtsmodel.User, db.Error) { + return u.cache.Load("ID", func() (*gtsmodel.User, error) { + var user gtsmodel.User - if !cached { - user = >smodel.User{} + q := u.conn. + NewSelect(). + Model(&user). + Relation("Account"). + Where("? = ?", bun.Ident("user.id"), id) - // Not cached! Perform database query - err := dbQuery(user) - if err != nil { + if err := q.Scan(ctx); err != nil { return nil, u.conn.ProcessError(err) } - // Place in the cache - u.cache.Put(user) - } - - return user, nil -} - -func (u *userDB) GetUserByID(ctx context.Context, id string) (*gtsmodel.User, db.Error) { - return u.getUser( - ctx, - func() (*gtsmodel.User, bool) { - return u.cache.GetByID(id) - }, - func(user *gtsmodel.User) error { - return u.newUserQ(user).Where("? = ?", bun.Ident("user.id"), id).Scan(ctx) - }, - ) + return &user, nil + }, id) } func (u *userDB) GetUserByAccountID(ctx context.Context, accountID string) (*gtsmodel.User, db.Error) { - return u.getUser( - ctx, - func() (*gtsmodel.User, bool) { - return u.cache.GetByAccountID(accountID) - }, - func(user *gtsmodel.User) error { - return u.newUserQ(user).Where("? = ?", bun.Ident("user.account_id"), accountID).Scan(ctx) - }, - ) + return u.cache.Load("AccountID", func() (*gtsmodel.User, error) { + var user gtsmodel.User + + q := u.conn. + NewSelect(). + Model(&user). + Relation("Account"). + Where("? = ?", bun.Ident("user.account_id"), accountID) + + if err := q.Scan(ctx); err != nil { + return nil, u.conn.ProcessError(err) + } + + return &user, nil + }, accountID) } func (u *userDB) GetUserByEmailAddress(ctx context.Context, emailAddress string) (*gtsmodel.User, db.Error) { - return u.getUser( - ctx, - func() (*gtsmodel.User, bool) { - return u.cache.GetByEmail(emailAddress) - }, - func(user *gtsmodel.User) error { - return u.newUserQ(user).Where("? = ?", bun.Ident("user.email"), emailAddress).Scan(ctx) - }, - ) + return u.cache.Load("Email", func() (*gtsmodel.User, error) { + var user gtsmodel.User + + q := u.conn. + NewSelect(). + Model(&user). + Relation("Account"). + Where("? = ?", bun.Ident("user.email"), emailAddress) + + if err := q.Scan(ctx); err != nil { + return nil, u.conn.ProcessError(err) + } + + return &user, nil + }, emailAddress) } func (u *userDB) GetUserByConfirmationToken(ctx context.Context, confirmationToken string) (*gtsmodel.User, db.Error) { - return u.getUser( - ctx, - func() (*gtsmodel.User, bool) { - return u.cache.GetByConfirmationToken(confirmationToken) - }, - func(user *gtsmodel.User) error { - return u.newUserQ(user).Where("? = ?", bun.Ident("user.confirmation_token"), confirmationToken).Scan(ctx) - }, - ) -} + return u.cache.Load("ConfirmationToken", func() (*gtsmodel.User, error) { + var user gtsmodel.User -func (u *userDB) PutUser(ctx context.Context, user *gtsmodel.User) (*gtsmodel.User, db.Error) { - if _, err := u.conn. - NewInsert(). - Model(user). - Exec(ctx); err != nil { - return nil, u.conn.ProcessError(err) - } + q := u.conn. + NewSelect(). + Model(&user). + Relation("Account"). + Where("? = ?", bun.Ident("user.confirmation_token"), confirmationToken) - u.cache.Put(user) - return user, nil + if err := q.Scan(ctx); err != nil { + return nil, u.conn.ProcessError(err) + } + + return &user, nil + }, confirmationToken) } -func (u *userDB) UpdateUser(ctx context.Context, user *gtsmodel.User, columns ...string) (*gtsmodel.User, db.Error) { +func (u *userDB) PutUser(ctx context.Context, user *gtsmodel.User) db.Error { + return u.cache.Store(user, func() error { + _, err := u.conn. + NewInsert(). + Model(user). + Exec(ctx) + return u.conn.ProcessError(err) + }) +} + +func (u *userDB) UpdateUser(ctx context.Context, user *gtsmodel.User) db.Error { // Update the user's last-updated user.UpdatedAt = time.Now() - if _, err := u.conn. - NewUpdate(). - Model(user). - Where("? = ?", bun.Ident("user.id"), user.ID). - Column(columns...). - Exec(ctx); err != nil { - return nil, u.conn.ProcessError(err) - } - - u.cache.Invalidate(user.ID) - return user, nil + return u.cache.Store(user, func() error { + _, err := u.conn. + NewUpdate(). + Model(user). + Where("? = ?", bun.Ident("user.id"), user.ID). + Exec(ctx) + return u.conn.ProcessError(err) + }) } func (u *userDB) DeleteUserByID(ctx context.Context, userID string) db.Error { @@ -146,6 +156,7 @@ func (u *userDB) DeleteUserByID(ctx context.Context, userID string) db.Error { return u.conn.ProcessError(err) } - u.cache.Invalidate(userID) + // Invalidate user from cache + u.cache.Invalidate("ID", userID) return nil } diff --git a/internal/db/bundb/user_test.go b/internal/db/bundb/user_test.go index 6ad59fc8e..18f67dde5 100644 --- a/internal/db/bundb/user_test.go +++ b/internal/db/bundb/user_test.go @@ -50,21 +50,20 @@ func (suite *UserTestSuite) TestGetUserByAccountID() { func (suite *UserTestSuite) TestUpdateUserSelectedColumns() { testUser := suite.testUsers["local_account_1"] - user := >smodel.User{ - ID: testUser.ID, - Email: "whatever", - Locale: "es", - } - user, err := suite.db.UpdateUser(context.Background(), user, "email", "locale") + updateUser := new(gtsmodel.User) + *updateUser = *testUser + updateUser.Email = "whatever" + updateUser.Locale = "es" + + err := suite.db.UpdateUser(context.Background(), updateUser) suite.NoError(err) - suite.NotNil(user) dbUser, err := suite.db.GetUserByID(context.Background(), testUser.ID) suite.NoError(err) suite.NotNil(dbUser) - suite.Equal("whatever", dbUser.Email) - suite.Equal("es", dbUser.Locale) + suite.Equal(updateUser.Email, dbUser.Email) + suite.Equal(updateUser.Locale, dbUser.Locale) suite.Equal(testUser.AccountID, dbUser.AccountID) } diff --git a/internal/db/status.go b/internal/db/status.go index 55cec5beb..d0983122b 100644 --- a/internal/db/status.go +++ b/internal/db/status.go @@ -39,7 +39,7 @@ type Status interface { PutStatus(ctx context.Context, status *gtsmodel.Status) Error // UpdateStatus updates one status in the database and returns it to the caller. - UpdateStatus(ctx context.Context, status *gtsmodel.Status) (*gtsmodel.Status, Error) + UpdateStatus(ctx context.Context, status *gtsmodel.Status) Error // DeleteStatusByID deletes one status from the database. DeleteStatusByID(ctx context.Context, id string) Error diff --git a/internal/db/user.go b/internal/db/user.go index a4d48db56..d01a8862a 100644 --- a/internal/db/user.go +++ b/internal/db/user.go @@ -34,9 +34,10 @@ type User interface { GetUserByEmailAddress(ctx context.Context, emailAddress string) (*gtsmodel.User, Error) // GetUserByConfirmationToken returns one user by its confirmation token, or an error if something goes wrong. GetUserByConfirmationToken(ctx context.Context, confirmationToken string) (*gtsmodel.User, Error) - // UpdateUser updates one user by its primary key. If columns is set, only given columns - // will be updated. If not set, all columns will be updated. - UpdateUser(ctx context.Context, user *gtsmodel.User, columns ...string) (*gtsmodel.User, Error) + // PutUser will attempt to place user in the database + PutUser(ctx context.Context, user *gtsmodel.User) Error + // UpdateUser updates one user by its primary key. + UpdateUser(ctx context.Context, user *gtsmodel.User) Error // DeleteUserByID deletes one user by its ID. DeleteUserByID(ctx context.Context, userID string) Error } |