summaryrefslogtreecommitdiff
path: root/internal/db
diff options
context:
space:
mode:
authorLibravatar kim <89579420+NyaaaWhatsUpDoc@users.noreply.github.com>2022-11-15 18:45:15 +0000
committerLibravatar GitHub <noreply@github.com>2022-11-15 18:45:15 +0000
commit8598dea98b872647393117704659878d9b38d4fc (patch)
tree1940168912dc7f54af723439dbc9f6e0a42f30ae /internal/db
parent[docs] Both HTTP proxies and NAT can cause rate limiting issues (#1053) (diff)
downloadgotosocial-8598dea98b872647393117704659878d9b38d4fc.tar.xz
[chore] update database caching library (#1040)
* convert most of the caches to use result.Cache{} * add caching of emojis * fix issues causing failing tests * update go-cache/v2 instances with v3 * fix getnotification * add a note about the left-in StatusCreate comment * update EmojiCategory db access to use new result.Cache{} * fix possible panic in getstatusparents * further proof that kim is not stinky
Diffstat (limited to 'internal/db')
-rw-r--r--internal/db/account.go4
-rw-r--r--internal/db/bundb/account.go214
-rw-r--r--internal/db/bundb/account_test.go4
-rw-r--r--internal/db/bundb/admin.go33
-rw-r--r--internal/db/bundb/admin_test.go2
-rw-r--r--internal/db/bundb/bundb.go76
-rw-r--r--internal/db/bundb/domain.go94
-rw-r--r--internal/db/bundb/emoji.go127
-rw-r--r--internal/db/bundb/mention.go48
-rw-r--r--internal/db/bundb/notification.go54
-rw-r--r--internal/db/bundb/status.go251
-rw-r--r--internal/db/bundb/timeline_test.go26
-rw-r--r--internal/db/bundb/tombstone.go2
-rw-r--r--internal/db/bundb/user.go175
-rw-r--r--internal/db/bundb/user_test.go17
-rw-r--r--internal/db/status.go2
-rw-r--r--internal/db/user.go7
17 files changed, 578 insertions, 558 deletions
diff --git a/internal/db/account.go b/internal/db/account.go
index a58aa9dd3..7e7d1de43 100644
--- a/internal/db/account.go
+++ b/internal/db/account.go
@@ -43,10 +43,10 @@ type Account interface {
GetAccountByPubkeyID(ctx context.Context, id string) (*gtsmodel.Account, Error)
// PutAccount puts one account in the database.
- PutAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, Error)
+ PutAccount(ctx context.Context, account *gtsmodel.Account) Error
// UpdateAccount updates one account by ID.
- UpdateAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, Error)
+ UpdateAccount(ctx context.Context, account *gtsmodel.Account) Error
// DeleteAccount deletes one account from the database by its ID.
// DO NOT USE THIS WHEN SUSPENDING ACCOUNTS! In that case you should mark the
diff --git a/internal/db/bundb/account.go b/internal/db/bundb/account.go
index 4813f4e17..1e9c390d8 100644
--- a/internal/db/bundb/account.go
+++ b/internal/db/bundb/account.go
@@ -24,7 +24,7 @@ import (
"strings"
"time"
- "github.com/superseriousbusiness/gotosocial/internal/cache"
+ "codeberg.org/gruf/go-cache/v3/result"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
@@ -35,10 +35,29 @@ import (
type accountDB struct {
conn *DBConn
- cache *cache.AccountCache
+ cache *result.Cache[*gtsmodel.Account]
status *statusDB
}
+func (a *accountDB) init() {
+ // Initialize account result cache
+ a.cache = result.NewSized([]result.Lookup{
+ {Name: "ID"},
+ {Name: "URI"},
+ {Name: "URL"},
+ {Name: "Username.Domain"},
+ {Name: "PublicKeyURI"},
+ }, func(a1 *gtsmodel.Account) *gtsmodel.Account {
+ a2 := new(gtsmodel.Account)
+ *a2 = *a1
+ return a2
+ }, 1000)
+
+ // Set cache TTL and start sweep routine
+ a.cache.SetTTL(time.Minute*5, false)
+ a.cache.Start(time.Second * 10)
+}
+
func (a *accountDB) newAccountQ(account *gtsmodel.Account) *bun.SelectQuery {
return a.conn.
NewSelect().
@@ -51,45 +70,41 @@ func (a *accountDB) newAccountQ(account *gtsmodel.Account) *bun.SelectQuery {
func (a *accountDB) GetAccountByID(ctx context.Context, id string) (*gtsmodel.Account, db.Error) {
return a.getAccount(
ctx,
- func() (*gtsmodel.Account, bool) {
- return a.cache.GetByID(id)
- },
+ "ID",
func(account *gtsmodel.Account) error {
return a.newAccountQ(account).Where("? = ?", bun.Ident("account.id"), id).Scan(ctx)
},
+ id,
)
}
func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) {
return a.getAccount(
ctx,
- func() (*gtsmodel.Account, bool) {
- return a.cache.GetByURI(uri)
- },
+ "URI",
func(account *gtsmodel.Account) error {
return a.newAccountQ(account).Where("? = ?", bun.Ident("account.uri"), uri).Scan(ctx)
},
+ uri,
)
}
func (a *accountDB) GetAccountByURL(ctx context.Context, url string) (*gtsmodel.Account, db.Error) {
return a.getAccount(
ctx,
- func() (*gtsmodel.Account, bool) {
- return a.cache.GetByURL(url)
- },
+ "URL",
func(account *gtsmodel.Account) error {
return a.newAccountQ(account).Where("? = ?", bun.Ident("account.url"), url).Scan(ctx)
},
+ url,
)
}
func (a *accountDB) GetAccountByUsernameDomain(ctx context.Context, username string, domain string) (*gtsmodel.Account, db.Error) {
+ username = strings.ToLower(username)
return a.getAccount(
ctx,
- func() (*gtsmodel.Account, bool) {
- return a.cache.GetByUsernameDomain(username, domain)
- },
+ "Username.Domain",
func(account *gtsmodel.Account) error {
q := a.newAccountQ(account)
@@ -97,113 +112,117 @@ func (a *accountDB) GetAccountByUsernameDomain(ctx context.Context, username str
q = q.Where("? = ?", bun.Ident("account.username"), username)
q = q.Where("? = ?", bun.Ident("account.domain"), domain)
} else {
- q = q.Where("? = ?", bun.Ident("account.username"), strings.ToLower(username))
+ q = q.Where("? = ?", bun.Ident("account.username"), username)
q = q.Where("? IS NULL", bun.Ident("account.domain"))
}
return q.Scan(ctx)
},
+ username,
+ domain,
)
}
func (a *accountDB) GetAccountByPubkeyID(ctx context.Context, id string) (*gtsmodel.Account, db.Error) {
return a.getAccount(
ctx,
- func() (*gtsmodel.Account, bool) {
- return a.cache.GetByPubkeyID(id)
- },
+ "PublicKeyURI",
func(account *gtsmodel.Account) error {
return a.newAccountQ(account).Where("? = ?", bun.Ident("account.public_key_uri"), id).Scan(ctx)
},
+ id,
)
}
-func (a *accountDB) getAccount(ctx context.Context, cacheGet func() (*gtsmodel.Account, bool), dbQuery func(*gtsmodel.Account) error) (*gtsmodel.Account, db.Error) {
- // Attempt to fetch cached account
- account, cached := cacheGet()
+func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, db.Error) {
+ var username string
- if !cached {
- account = &gtsmodel.Account{}
+ if domain == "" {
+ // I.e. our local instance account
+ username = config.GetHost()
+ } else {
+ // A remote instance account
+ username = domain
+ }
+
+ return a.GetAccountByUsernameDomain(ctx, username, domain)
+}
+
+func (a *accountDB) getAccount(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Account) error, keyParts ...any) (*gtsmodel.Account, db.Error) {
+ return a.cache.Load(lookup, func() (*gtsmodel.Account, error) {
+ var account gtsmodel.Account
// Not cached! Perform database query
- err := dbQuery(account)
- if err != nil {
+ if err := dbQuery(&account); err != nil {
return nil, a.conn.ProcessError(err)
}
- // Place in the cache
- a.cache.Put(account)
- }
-
- return account, nil
+ return &account, nil
+ }, keyParts...)
}
-func (a *accountDB) PutAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, db.Error) {
- if err := a.conn.RunInTx(ctx, func(tx bun.Tx) error {
- // create links between this account and any emojis it uses
- for _, i := range account.EmojiIDs {
- if _, err := tx.NewInsert().Model(&gtsmodel.AccountToEmoji{
- AccountID: account.ID,
- EmojiID: i,
- }).Exec(ctx); err != nil {
- return err
+func (a *accountDB) PutAccount(ctx context.Context, account *gtsmodel.Account) db.Error {
+ return a.cache.Store(account, func() error {
+ // It is safe to run this database transaction within cache.Store
+ // as the cache does not attempt a mutex lock until AFTER hook.
+ //
+ return a.conn.RunInTx(ctx, func(tx bun.Tx) error {
+ // create links between this account and any emojis it uses
+ for _, i := range account.EmojiIDs {
+ if _, err := tx.NewInsert().Model(&gtsmodel.AccountToEmoji{
+ AccountID: account.ID,
+ EmojiID: i,
+ }).Exec(ctx); err != nil {
+ return err
+ }
}
- }
- // insert the account
- _, err := tx.NewInsert().Model(account).Exec(ctx)
- return err
- }); err != nil {
- return nil, a.conn.ProcessError(err)
- }
-
- a.cache.Put(account)
- return account, nil
+ // insert the account
+ _, err := tx.NewInsert().Model(account).Exec(ctx)
+ return err
+ })
+ })
}
-func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, db.Error) {
+func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account) db.Error {
// Update the account's last-updated
account.UpdatedAt = time.Now()
- if err := a.conn.RunInTx(ctx, func(tx bun.Tx) error {
- // create links between this account and any emojis it uses
- // first clear out any old emoji links
- if _, err := tx.
- NewDelete().
- TableExpr("? AS ?", bun.Ident("account_to_emojis"), bun.Ident("account_to_emoji")).
- Where("? = ?", bun.Ident("account_to_emoji.account_id"), account.ID).
- Exec(ctx); err != nil {
- return err
- }
-
- // now populate new emoji links
- for _, i := range account.EmojiIDs {
+ return a.cache.Store(account, func() error {
+ // It is safe to run this database transaction within cache.Store
+ // as the cache does not attempt a mutex lock until AFTER hook.
+ //
+ return a.conn.RunInTx(ctx, func(tx bun.Tx) error {
+ // create links between this account and any emojis it uses
+ // first clear out any old emoji links
if _, err := tx.
- NewInsert().
- Model(&gtsmodel.AccountToEmoji{
- AccountID: account.ID,
- EmojiID: i,
- }).Exec(ctx); err != nil {
+ NewDelete().
+ TableExpr("? AS ?", bun.Ident("account_to_emojis"), bun.Ident("account_to_emoji")).
+ Where("? = ?", bun.Ident("account_to_emoji.account_id"), account.ID).
+ Exec(ctx); err != nil {
return err
}
- }
- // update the account
- if _, err := tx.
- NewUpdate().
- Model(account).
- Where("? = ?", bun.Ident("account.id"), account.ID).
- Exec(ctx); err != nil {
- return err
- }
-
- return nil
- }); err != nil {
- return nil, a.conn.ProcessError(err)
- }
+ // now populate new emoji links
+ for _, i := range account.EmojiIDs {
+ if _, err := tx.
+ NewInsert().
+ Model(&gtsmodel.AccountToEmoji{
+ AccountID: account.ID,
+ EmojiID: i,
+ }).Exec(ctx); err != nil {
+ return err
+ }
+ }
- a.cache.Put(account)
- return account, nil
+ // update the account
+ _, err := tx.NewUpdate().
+ Model(account).
+ Where("? = ?", bun.Ident("account.id"), account.ID).
+ Exec(ctx)
+ return err
+ })
+ })
}
func (a *accountDB) DeleteAccount(ctx context.Context, id string) db.Error {
@@ -219,40 +238,19 @@ func (a *accountDB) DeleteAccount(ctx context.Context, id string) db.Error {
// delete the account
_, err := tx.
- NewUpdate().
+ NewDelete().
TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")).
Where("? = ?", bun.Ident("account.id"), id).
Exec(ctx)
return err
}); err != nil {
- return a.conn.ProcessError(err)
+ return err
}
- a.cache.Invalidate(id)
+ a.cache.Invalidate("ID", id)
return nil
}
-func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, db.Error) {
- account := new(gtsmodel.Account)
-
- q := a.newAccountQ(account)
-
- if domain != "" {
- q = q.
- Where("? = ?", bun.Ident("account.username"), domain).
- Where("? = ?", bun.Ident("account.domain"), domain)
- } else {
- q = q.
- Where("? = ?", bun.Ident("account.username"), config.GetHost()).
- WhereGroup(" AND ", whereEmptyOrNull("domain"))
- }
-
- if err := q.Scan(ctx); err != nil {
- return nil, a.conn.ProcessError(err)
- }
- return account, nil
-}
-
func (a *accountDB) GetAccountLastPosted(ctx context.Context, accountID string, webOnly bool) (time.Time, db.Error) {
createdAt := time.Time{}
diff --git a/internal/db/bundb/account_test.go b/internal/db/bundb/account_test.go
index 29594a740..50603623f 100644
--- a/internal/db/bundb/account_test.go
+++ b/internal/db/bundb/account_test.go
@@ -92,7 +92,7 @@ func (suite *AccountTestSuite) TestUpdateAccount() {
testAccount.DisplayName = "new display name!"
testAccount.EmojiIDs = []string{"01GD36ZKWTKY3T1JJ24JR7KY1Q", "01GD36ZV904SHBHNAYV6DX5QEF"}
- _, err := suite.db.UpdateAccount(ctx, testAccount)
+ err := suite.db.UpdateAccount(ctx, testAccount)
suite.NoError(err)
updated, err := suite.db.GetAccountByID(ctx, testAccount.ID)
@@ -127,7 +127,7 @@ func (suite *AccountTestSuite) TestUpdateAccount() {
// update again to remove emoji associations
testAccount.EmojiIDs = []string{}
- _, err = suite.db.UpdateAccount(ctx, testAccount)
+ err = suite.db.UpdateAccount(ctx, testAccount)
suite.NoError(err)
updated, err = suite.db.GetAccountByID(ctx, testAccount.ID)
diff --git a/internal/db/bundb/admin.go b/internal/db/bundb/admin.go
index 44861a4bb..4d750581c 100644
--- a/internal/db/bundb/admin.go
+++ b/internal/db/bundb/admin.go
@@ -29,7 +29,6 @@ import (
"time"
"github.com/superseriousbusiness/gotosocial/internal/ap"
- "github.com/superseriousbusiness/gotosocial/internal/cache"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
@@ -44,9 +43,9 @@ import (
const rsaKeyBits = 2048
type adminDB struct {
- conn *DBConn
- userCache *cache.UserCache
- accountCache *cache.AccountCache
+ conn *DBConn
+ accounts *accountDB
+ users *userDB
}
func (a *adminDB) IsUsernameAvailable(ctx context.Context, username string) (bool, db.Error) {
@@ -140,13 +139,9 @@ func (a *adminDB) NewSignup(ctx context.Context, username string, reason string,
}
// insert the new account!
- if _, err = a.conn.
- NewInsert().
- Model(acct).
- Exec(ctx); err != nil {
- return nil, a.conn.ProcessError(err)
+ if err := a.accounts.PutAccount(ctx, acct); err != nil {
+ return nil, err
}
- a.accountCache.Put(acct)
}
// we either created or already had an account by now,
@@ -190,13 +185,9 @@ func (a *adminDB) NewSignup(ctx context.Context, username string, reason string,
}
// insert the user!
- if _, err = a.conn.
- NewInsert().
- Model(u).
- Exec(ctx); err != nil {
- return nil, a.conn.ProcessError(err)
+ if err := a.users.PutUser(ctx, u); err != nil {
+ return nil, err
}
- a.userCache.Put(u)
return u, nil
}
@@ -249,15 +240,11 @@ func (a *adminDB) CreateInstanceAccount(ctx context.Context) db.Error {
FeaturedCollectionURI: newAccountURIs.CollectionURI,
}
- insertQ := a.conn.
- NewInsert().
- Model(acct)
-
- if _, err := insertQ.Exec(ctx); err != nil {
- return a.conn.ProcessError(err)
+ // insert the new account!
+ if err := a.accounts.PutAccount(ctx, acct); err != nil {
+ return err
}
- a.accountCache.Put(acct)
log.Infof("instance account %s CREATED with id %s", username, acct.ID)
return nil
}
diff --git a/internal/db/bundb/admin_test.go b/internal/db/bundb/admin_test.go
index f0a869a9b..18e1f67e2 100644
--- a/internal/db/bundb/admin_test.go
+++ b/internal/db/bundb/admin_test.go
@@ -70,6 +70,8 @@ func (suite *AdminTestSuite) TestIsEmailAvailableDomainBlocked() {
}
func (suite *AdminTestSuite) TestCreateInstanceAccount() {
+ // reinitialize test DB to clear caches
+ suite.db = testrig.NewTestDB()
// we need to take an empty db for this...
testrig.StandardDBTeardown(suite.db)
// ...with tables created but no data
diff --git a/internal/db/bundb/bundb.go b/internal/db/bundb/bundb.go
index cf6643f6b..de6749ca4 100644
--- a/internal/db/bundb/bundb.go
+++ b/internal/db/bundb/bundb.go
@@ -34,7 +34,6 @@ import (
"github.com/google/uuid"
"github.com/jackc/pgx/v4"
"github.com/jackc/pgx/v4/stdlib"
- "github.com/superseriousbusiness/gotosocial/internal/cache"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/db/bundb/migrations"
@@ -46,7 +45,6 @@ import (
"github.com/uptrace/bun/dialect/sqlitedialect"
"github.com/uptrace/bun/migrate"
- grufcache "codeberg.org/gruf/go-cache/v2"
"modernc.org/sqlite"
)
@@ -160,79 +158,63 @@ func NewBunDBService(ctx context.Context) (db.DB, error) {
return nil, fmt.Errorf("db migration error: %s", err)
}
- // Prepare caches required by more than one struct
- userCache := cache.NewUserCache()
- accountCache := cache.NewAccountCache()
-
- // Prepare other caches
- // Prepare mentions cache
- // TODO: move into internal/cache
- mentionCache := grufcache.New[string, *gtsmodel.Mention]()
- mentionCache.SetTTL(time.Minute*5, false)
- mentionCache.Start(time.Second * 10)
-
- // Prepare notifications cache
- // TODO: move into internal/cache
- notifCache := grufcache.New[string, *gtsmodel.Notification]()
- notifCache.SetTTL(time.Minute*5, false)
- notifCache.Start(time.Second * 10)
-
// Create DB structs that require ptrs to each other
- accounts := &accountDB{conn: conn, cache: accountCache}
- status := &statusDB{conn: conn, cache: cache.NewStatusCache()}
- emoji := &emojiDB{conn: conn, emojiCache: cache.NewEmojiCache(), categoryCache: cache.NewEmojiCategoryCache()}
+ account := &accountDB{conn: conn}
+ admin := &adminDB{conn: conn}
+ domain := &domainDB{conn: conn}
+ mention := &mentionDB{conn: conn}
+ notif := &notificationDB{conn: conn}
+ status := &statusDB{conn: conn}
+ emoji := &emojiDB{conn: conn}
timeline := &timelineDB{conn: conn}
tombstone := &tombstoneDB{conn: conn}
+ user := &userDB{conn: conn}
// Setup DB cross-referencing
- accounts.status = status
- status.accounts = accounts
+ account.status = status
+ admin.users = user
+ status.accounts = account
timeline.status = status
// Initialize db structs
+ account.init()
+ domain.init()
+ emoji.init()
+ mention.init()
+ notif.init()
+ status.init()
tombstone.init()
+ user.init()
ps := &DBService{
- Account: accounts,
+ Account: account,
Admin: &adminDB{
- conn: conn,
- userCache: userCache,
- accountCache: accountCache,
+ conn: conn,
+ accounts: account,
+ users: user,
},
Basic: &basicDB{
conn: conn,
},
- Domain: &domainDB{
- conn: conn,
- cache: cache.NewDomainBlockCache(),
- },
- Emoji: emoji,
+ Domain: domain,
+ Emoji: emoji,
Instance: &instanceDB{
conn: conn,
},
Media: &mediaDB{
conn: conn,
},
- Mention: &mentionDB{
- conn: conn,
- cache: mentionCache,
- },
- Notification: &notificationDB{
- conn: conn,
- cache: notifCache,
- },
+ Mention: mention,
+ Notification: notif,
Relationship: &relationshipDB{
conn: conn,
},
Session: &sessionDB{
conn: conn,
},
- Status: status,
- Timeline: timeline,
- User: &userDB{
- conn: conn,
- cache: userCache,
- },
+ Status: status,
+ Timeline: timeline,
+ User: user,
Tombstone: tombstone,
conn: conn,
}
diff --git a/internal/db/bundb/domain.go b/internal/db/bundb/domain.go
index 0a752d3f3..3fca8501b 100644
--- a/internal/db/bundb/domain.go
+++ b/internal/db/bundb/domain.go
@@ -20,11 +20,11 @@ package bundb
import (
"context"
- "database/sql"
"net/url"
"strings"
+ "time"
- "github.com/superseriousbusiness/gotosocial/internal/cache"
+ "codeberg.org/gruf/go-cache/v3/result"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
@@ -34,7 +34,22 @@ import (
type domainDB struct {
conn *DBConn
- cache *cache.DomainBlockCache
+ cache *result.Cache[*gtsmodel.DomainBlock]
+}
+
+func (d *domainDB) init() {
+ // Initialize domain block result cache
+ d.cache = result.NewSized([]result.Lookup{
+ {Name: "Domain"},
+ }, func(d1 *gtsmodel.DomainBlock) *gtsmodel.DomainBlock {
+ d2 := new(gtsmodel.DomainBlock)
+ *d2 = *d1
+ return d2
+ }, 1000)
+
+ // Set cache TTL and start sweep routine
+ d.cache.SetTTL(time.Minute*5, false)
+ d.cache.Start(time.Second * 10)
}
// normalizeDomain converts the given domain to lowercase
@@ -49,76 +64,53 @@ func normalizeDomain(domain string) (out string, err error) {
}
func (d *domainDB) CreateDomainBlock(ctx context.Context, block *gtsmodel.DomainBlock) db.Error {
- domain, err := normalizeDomain(block.Domain)
+ var err error
+
+ block.Domain, err = normalizeDomain(block.Domain)
if err != nil {
return err
}
- block.Domain = domain
- // Attempt to insert new domain block
- if _, err := d.conn.NewInsert().
- Model(block).
- Exec(ctx); err != nil {
+ return d.cache.Store(block, func() error {
+ _, err := d.conn.NewInsert().
+ Model(block).
+ Exec(ctx)
return d.conn.ProcessError(err)
- }
-
- // Cache this domain block
- d.cache.Put(block.Domain, block)
-
- return nil
+ })
}
func (d *domainDB) GetDomainBlock(ctx context.Context, domain string) (*gtsmodel.DomainBlock, db.Error) {
var err error
+
domain, err = normalizeDomain(domain)
if err != nil {
return nil, err
}
- // Check for easy case, domain referencing *us*
- if domain == "" || domain == config.GetAccountDomain() {
- return nil, db.ErrNoEntries
- }
-
- // Check for already cached rblock
- if block, ok := d.cache.GetByDomain(domain); ok {
- // A 'nil' return value is a sentinel value for no block
- if block == nil {
+ return d.cache.Load("Domain", func() (*gtsmodel.DomainBlock, error) {
+ // Check for easy case, domain referencing *us*
+ if domain == "" || domain == config.GetAccountDomain() {
return nil, db.ErrNoEntries
}
- // Else, this block exists
- return block, nil
- }
+ var block gtsmodel.DomainBlock
- block := &gtsmodel.DomainBlock{}
+ q := d.conn.
+ NewSelect().
+ Model(&block).
+ Where("? = ?", bun.Ident("domain_block.domain"), domain).
+ Limit(1)
+ if err := q.Scan(ctx); err != nil {
+ return nil, d.conn.ProcessError(err)
+ }
- q := d.conn.
- NewSelect().
- Model(block).
- Where("? = ?", bun.Ident("domain_block.domain"), domain).
- Limit(1)
-
- // Query database for domain block
- switch err := q.Scan(ctx); err {
- // No error, block found
- case nil:
- d.cache.Put(domain, block)
- return block, nil
-
- // No error, simply not found
- case sql.ErrNoRows:
- d.cache.Put(domain, nil)
- return nil, db.ErrNoEntries
-
- // Any other db error
- default:
- return nil, d.conn.ProcessError(err)
- }
+ return &block, nil
+ }, domain)
}
func (d *domainDB) DeleteDomainBlock(ctx context.Context, domain string) db.Error {
var err error
+
domain, err = normalizeDomain(domain)
if err != nil {
return err
@@ -133,7 +125,7 @@ func (d *domainDB) DeleteDomainBlock(ctx context.Context, domain string) db.Erro
}
// Clear domain from cache
- d.cache.InvalidateByDomain(domain)
+ d.cache.Invalidate(domain)
return nil
}
diff --git a/internal/db/bundb/emoji.go b/internal/db/bundb/emoji.go
index 81374ce78..55e0ee3ff 100644
--- a/internal/db/bundb/emoji.go
+++ b/internal/db/bundb/emoji.go
@@ -23,7 +23,7 @@ import (
"strings"
"time"
- "github.com/superseriousbusiness/gotosocial/internal/cache"
+ "codeberg.org/gruf/go-cache/v3/result"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
@@ -33,8 +33,40 @@ import (
type emojiDB struct {
conn *DBConn
- emojiCache *cache.EmojiCache
- categoryCache *cache.EmojiCategoryCache
+ emojiCache *result.Cache[*gtsmodel.Emoji]
+ categoryCache *result.Cache[*gtsmodel.EmojiCategory]
+}
+
+func (e *emojiDB) init() {
+ // Initialize emoji result cache
+ e.emojiCache = result.NewSized([]result.Lookup{
+ {Name: "ID"},
+ {Name: "URI"},
+ {Name: "Shortcode.Domain"},
+ {Name: "ImageStaticURL"},
+ }, func(e1 *gtsmodel.Emoji) *gtsmodel.Emoji {
+ e2 := new(gtsmodel.Emoji)
+ *e2 = *e1
+ return e2
+ }, 1000)
+
+ // Set cache TTL and start sweep routine
+ e.emojiCache.SetTTL(time.Minute*5, false)
+ e.emojiCache.Start(time.Second * 10)
+
+ // Initialize category result cache
+ e.categoryCache = result.NewSized([]result.Lookup{
+ {Name: "ID"},
+ {Name: "Name"},
+ }, func(c1 *gtsmodel.EmojiCategory) *gtsmodel.EmojiCategory {
+ c2 := new(gtsmodel.EmojiCategory)
+ *c2 = *c1
+ return c2
+ }, 1000)
+
+ // Set cache TTL and start sweep routine
+ e.categoryCache.SetTTL(time.Minute*5, false)
+ e.categoryCache.Start(time.Second * 10)
}
func (e *emojiDB) newEmojiQ(emoji *gtsmodel.Emoji) *bun.SelectQuery {
@@ -51,12 +83,10 @@ func (e *emojiDB) newEmojiCategoryQ(emojiCategory *gtsmodel.EmojiCategory) *bun.
}
func (e *emojiDB) PutEmoji(ctx context.Context, emoji *gtsmodel.Emoji) db.Error {
- if _, err := e.conn.NewInsert().Model(emoji).Exec(ctx); err != nil {
+ return e.emojiCache.Store(emoji, func() error {
+ _, err := e.conn.NewInsert().Model(emoji).Exec(ctx)
return e.conn.ProcessError(err)
- }
-
- e.emojiCache.Put(emoji)
- return nil
+ })
}
func (e *emojiDB) UpdateEmoji(ctx context.Context, emoji *gtsmodel.Emoji, columns ...string) (*gtsmodel.Emoji, db.Error) {
@@ -72,7 +102,7 @@ func (e *emojiDB) UpdateEmoji(ctx context.Context, emoji *gtsmodel.Emoji, column
return nil, e.conn.ProcessError(err)
}
- e.emojiCache.Invalidate(emoji.ID)
+ e.emojiCache.Invalidate("ID", emoji.ID)
return emoji, nil
}
@@ -109,7 +139,7 @@ func (e *emojiDB) DeleteEmojiByID(ctx context.Context, id string) db.Error {
return err
}
- e.emojiCache.Invalidate(id)
+ e.emojiCache.Invalidate("ID", id)
return nil
}
@@ -252,33 +282,29 @@ func (e *emojiDB) GetUseableEmojis(ctx context.Context) ([]*gtsmodel.Emoji, db.E
func (e *emojiDB) GetEmojiByID(ctx context.Context, id string) (*gtsmodel.Emoji, db.Error) {
return e.getEmoji(
ctx,
- func() (*gtsmodel.Emoji, bool) {
- return e.emojiCache.GetByID(id)
- },
+ "ID",
func(emoji *gtsmodel.Emoji) error {
return e.newEmojiQ(emoji).Where("? = ?", bun.Ident("emoji.id"), id).Scan(ctx)
},
+ id,
)
}
func (e *emojiDB) GetEmojiByURI(ctx context.Context, uri string) (*gtsmodel.Emoji, db.Error) {
return e.getEmoji(
ctx,
- func() (*gtsmodel.Emoji, bool) {
- return e.emojiCache.GetByURI(uri)
- },
+ "URI",
func(emoji *gtsmodel.Emoji) error {
return e.newEmojiQ(emoji).Where("? = ?", bun.Ident("emoji.uri"), uri).Scan(ctx)
},
+ uri,
)
}
func (e *emojiDB) GetEmojiByShortcodeDomain(ctx context.Context, shortcode string, domain string) (*gtsmodel.Emoji, db.Error) {
return e.getEmoji(
ctx,
- func() (*gtsmodel.Emoji, bool) {
- return e.emojiCache.GetByShortcodeDomain(shortcode, domain)
- },
+ "Shortcode.Domain",
func(emoji *gtsmodel.Emoji) error {
q := e.newEmojiQ(emoji)
@@ -292,31 +318,30 @@ func (e *emojiDB) GetEmojiByShortcodeDomain(ctx context.Context, shortcode strin
return q.Scan(ctx)
},
+ shortcode,
+ domain,
)
}
func (e *emojiDB) GetEmojiByStaticURL(ctx context.Context, imageStaticURL string) (*gtsmodel.Emoji, db.Error) {
return e.getEmoji(
ctx,
- func() (*gtsmodel.Emoji, bool) {
- return e.emojiCache.GetByImageStaticURL(imageStaticURL)
- },
+ "ImageStaticURL",
func(emoji *gtsmodel.Emoji) error {
return e.
newEmojiQ(emoji).
Where("? = ?", bun.Ident("emoji.image_static_url"), imageStaticURL).
Scan(ctx)
},
+ imageStaticURL,
)
}
func (e *emojiDB) PutEmojiCategory(ctx context.Context, emojiCategory *gtsmodel.EmojiCategory) db.Error {
- if _, err := e.conn.NewInsert().Model(emojiCategory).Exec(ctx); err != nil {
+ return e.categoryCache.Store(emojiCategory, func() error {
+ _, err := e.conn.NewInsert().Model(emojiCategory).Exec(ctx)
return e.conn.ProcessError(err)
- }
-
- e.categoryCache.Put(emojiCategory)
- return nil
+ })
}
func (e *emojiDB) GetEmojiCategories(ctx context.Context) ([]*gtsmodel.EmojiCategory, db.Error) {
@@ -338,45 +363,36 @@ func (e *emojiDB) GetEmojiCategories(ctx context.Context) ([]*gtsmodel.EmojiCate
func (e *emojiDB) GetEmojiCategory(ctx context.Context, id string) (*gtsmodel.EmojiCategory, db.Error) {
return e.getEmojiCategory(
ctx,
- func() (*gtsmodel.EmojiCategory, bool) {
- return e.categoryCache.GetByID(id)
- },
+ "ID",
func(emojiCategory *gtsmodel.EmojiCategory) error {
return e.newEmojiCategoryQ(emojiCategory).Where("? = ?", bun.Ident("emoji_category.id"), id).Scan(ctx)
},
+ id,
)
}
func (e *emojiDB) GetEmojiCategoryByName(ctx context.Context, name string) (*gtsmodel.EmojiCategory, db.Error) {
return e.getEmojiCategory(
ctx,
- func() (*gtsmodel.EmojiCategory, bool) {
- return e.categoryCache.GetByName(name)
- },
+ "Name",
func(emojiCategory *gtsmodel.EmojiCategory) error {
return e.newEmojiCategoryQ(emojiCategory).Where("LOWER(?) = ?", bun.Ident("emoji_category.name"), strings.ToLower(name)).Scan(ctx)
},
+ name,
)
}
-func (e *emojiDB) getEmoji(ctx context.Context, cacheGet func() (*gtsmodel.Emoji, bool), dbQuery func(*gtsmodel.Emoji) error) (*gtsmodel.Emoji, db.Error) {
- // Attempt to fetch cached emoji
- emoji, cached := cacheGet()
-
- if !cached {
- emoji = &gtsmodel.Emoji{}
+func (e *emojiDB) getEmoji(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Emoji) error, keyParts ...any) (*gtsmodel.Emoji, db.Error) {
+ return e.emojiCache.Load(lookup, func() (*gtsmodel.Emoji, error) {
+ var emoji gtsmodel.Emoji
// Not cached! Perform database query
- err := dbQuery(emoji)
- if err != nil {
+ if err := dbQuery(&emoji); err != nil {
return nil, e.conn.ProcessError(err)
}
- // Place in the cache
- e.emojiCache.Put(emoji)
- }
-
- return emoji, nil
+ return &emoji, nil
+ }, keyParts...)
}
func (e *emojiDB) emojisFromIDs(ctx context.Context, emojiIDs []string) ([]*gtsmodel.Emoji, db.Error) {
@@ -399,24 +415,17 @@ func (e *emojiDB) emojisFromIDs(ctx context.Context, emojiIDs []string) ([]*gtsm
return emojis, nil
}
-func (e *emojiDB) getEmojiCategory(ctx context.Context, cacheGet func() (*gtsmodel.EmojiCategory, bool), dbQuery func(*gtsmodel.EmojiCategory) error) (*gtsmodel.EmojiCategory, db.Error) {
- // Attempt to fetch cached emoji categories
- emojiCategory, cached := cacheGet()
-
- if !cached {
- emojiCategory = &gtsmodel.EmojiCategory{}
+func (e *emojiDB) getEmojiCategory(ctx context.Context, lookup string, dbQuery func(*gtsmodel.EmojiCategory) error, keyParts ...any) (*gtsmodel.EmojiCategory, db.Error) {
+ return e.categoryCache.Load(lookup, func() (*gtsmodel.EmojiCategory, error) {
+ var category gtsmodel.EmojiCategory
// Not cached! Perform database query
- err := dbQuery(emojiCategory)
- if err != nil {
+ if err := dbQuery(&category); err != nil {
return nil, e.conn.ProcessError(err)
}
- // Place in the cache
- e.categoryCache.Put(emojiCategory)
- }
-
- return emojiCategory, nil
+ return &category, nil
+ }, keyParts...)
}
func (e *emojiDB) emojiCategoriesFromIDs(ctx context.Context, emojiCategoryIDs []string) ([]*gtsmodel.EmojiCategory, db.Error) {
diff --git a/internal/db/bundb/mention.go b/internal/db/bundb/mention.go
index 355078021..303e16484 100644
--- a/internal/db/bundb/mention.go
+++ b/internal/db/bundb/mention.go
@@ -20,8 +20,9 @@ package bundb
import (
"context"
+ "time"
- "codeberg.org/gruf/go-cache/v2"
+ "codeberg.org/gruf/go-cache/v3/result"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
@@ -30,7 +31,22 @@ import (
type mentionDB struct {
conn *DBConn
- cache cache.Cache[string, *gtsmodel.Mention]
+ cache *result.Cache[*gtsmodel.Mention]
+}
+
+func (m *mentionDB) init() {
+ // Initialize notification result cache
+ m.cache = result.NewSized([]result.Lookup{
+ {Name: "ID"},
+ }, func(m1 *gtsmodel.Mention) *gtsmodel.Mention {
+ m2 := new(gtsmodel.Mention)
+ *m2 = *m1
+ return m2
+ }, 1000)
+
+ // Set cache TTL and start sweep routine
+ m.cache.SetTTL(time.Minute*5, false)
+ m.cache.Start(time.Second * 10)
}
func (m *mentionDB) newMentionQ(i interface{}) *bun.SelectQuery {
@@ -42,27 +58,19 @@ func (m *mentionDB) newMentionQ(i interface{}) *bun.SelectQuery {
Relation("TargetAccount")
}
-func (m *mentionDB) getMentionDB(ctx context.Context, id string) (*gtsmodel.Mention, db.Error) {
- mention := gtsmodel.Mention{}
-
- q := m.newMentionQ(&mention).
- Where("? = ?", bun.Ident("mention.id"), id)
-
- if err := q.Scan(ctx); err != nil {
- return nil, m.conn.ProcessError(err)
- }
+func (m *mentionDB) GetMention(ctx context.Context, id string) (*gtsmodel.Mention, db.Error) {
+ return m.cache.Load("ID", func() (*gtsmodel.Mention, error) {
+ var mention gtsmodel.Mention
- copy := mention
- m.cache.Set(mention.ID, &copy)
+ q := m.newMentionQ(&mention).
+ Where("? = ?", bun.Ident("mention.id"), id)
- return &mention, nil
-}
+ if err := q.Scan(ctx); err != nil {
+ return nil, m.conn.ProcessError(err)
+ }
-func (m *mentionDB) GetMention(ctx context.Context, id string) (*gtsmodel.Mention, db.Error) {
- if mention, ok := m.cache.Get(id); ok {
- return mention, nil
- }
- return m.getMentionDB(ctx, id)
+ return &mention, nil
+ }, id)
}
func (m *mentionDB) GetMentions(ctx context.Context, ids []string) ([]*gtsmodel.Mention, db.Error) {
diff --git a/internal/db/bundb/notification.go b/internal/db/bundb/notification.go
index 69e3cf39f..1874f81ea 100644
--- a/internal/db/bundb/notification.go
+++ b/internal/db/bundb/notification.go
@@ -20,8 +20,9 @@ package bundb
import (
"context"
+ "time"
- "codeberg.org/gruf/go-cache/v2"
+ "codeberg.org/gruf/go-cache/v3/result"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
@@ -30,31 +31,40 @@ import (
type notificationDB struct {
conn *DBConn
- cache cache.Cache[string, *gtsmodel.Notification]
+ cache *result.Cache[*gtsmodel.Notification]
}
-func (n *notificationDB) GetNotification(ctx context.Context, id string) (*gtsmodel.Notification, db.Error) {
- if notification, ok := n.cache.Get(id); ok {
- return notification, nil
- }
-
- dst := gtsmodel.Notification{ID: id}
-
- q := n.conn.NewSelect().
- Model(&dst).
- Relation("OriginAccount").
- Relation("TargetAccount").
- Relation("Status").
- Where("? = ?", bun.Ident("notification.id"), id)
-
- if err := q.Scan(ctx); err != nil {
- return nil, n.conn.ProcessError(err)
- }
+func (n *notificationDB) init() {
+ // Initialize notification result cache
+ n.cache = result.NewSized([]result.Lookup{
+ {Name: "ID"},
+ }, func(n1 *gtsmodel.Notification) *gtsmodel.Notification {
+ n2 := new(gtsmodel.Notification)
+ *n2 = *n1
+ return n2
+ }, 1000)
+
+ // Set cache TTL and start sweep routine
+ n.cache.SetTTL(time.Minute*5, false)
+ n.cache.Start(time.Second * 10)
+}
- copy := dst
- n.cache.Set(id, &copy)
+func (n *notificationDB) GetNotification(ctx context.Context, id string) (*gtsmodel.Notification, db.Error) {
+ return n.cache.Load("ID", func() (*gtsmodel.Notification, error) {
+ var notif gtsmodel.Notification
+
+ q := n.conn.NewSelect().
+ Model(&notif).
+ Relation("OriginAccount").
+ Relation("TargetAccount").
+ Relation("Status").
+ Where("? = ?", bun.Ident("notification.id"), id)
+ if err := q.Scan(ctx); err != nil {
+ return nil, n.conn.ProcessError(err)
+ }
- return &dst, nil
+ return &notif, nil
+ }, id)
}
func (n *notificationDB) GetNotifications(ctx context.Context, accountID string, excludeTypes []string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, db.Error) {
diff --git a/internal/db/bundb/status.go b/internal/db/bundb/status.go
index bc72c2849..b4ae40607 100644
--- a/internal/db/bundb/status.go
+++ b/internal/db/bundb/status.go
@@ -25,7 +25,7 @@ import (
"errors"
"time"
- "github.com/superseriousbusiness/gotosocial/internal/cache"
+ "codeberg.org/gruf/go-cache/v3/result"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
@@ -33,15 +33,28 @@ import (
)
type statusDB struct {
- conn *DBConn
- cache *cache.StatusCache
-
- // TODO: keep method definitions in same place but instead have receiver
- // all point to one single "db" type, so they can all share methods
- // and caches where necessary
+ conn *DBConn
+ cache *result.Cache[*gtsmodel.Status]
accounts *accountDB
}
+func (s *statusDB) init() {
+ // Initialize status result cache
+ s.cache = result.NewSized([]result.Lookup{
+ {Name: "ID"},
+ {Name: "URI"},
+ {Name: "URL"},
+ }, func(s1 *gtsmodel.Status) *gtsmodel.Status {
+ s2 := new(gtsmodel.Status)
+ *s2 = *s1
+ return s2
+ }, 1000)
+
+ // Set cache TTL and start sweep routine
+ s.cache.SetTTL(time.Minute*5, false)
+ s.cache.Start(time.Second * 10)
+}
+
func (s *statusDB) newStatusQ(status interface{}) *bun.SelectQuery {
return s.conn.
NewSelect().
@@ -68,61 +81,62 @@ func (s *statusDB) newFaveQ(faves interface{}) *bun.SelectQuery {
func (s *statusDB) GetStatusByID(ctx context.Context, id string) (*gtsmodel.Status, db.Error) {
return s.getStatus(
ctx,
- func() (*gtsmodel.Status, bool) {
- return s.cache.GetByID(id)
- },
+ "ID",
func(status *gtsmodel.Status) error {
return s.newStatusQ(status).Where("? = ?", bun.Ident("status.id"), id).Scan(ctx)
},
+ id,
)
}
func (s *statusDB) GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.Status, db.Error) {
return s.getStatus(
ctx,
- func() (*gtsmodel.Status, bool) {
- return s.cache.GetByURI(uri)
- },
+ "URI",
func(status *gtsmodel.Status) error {
return s.newStatusQ(status).Where("? = ?", bun.Ident("status.uri"), uri).Scan(ctx)
},
+ uri,
)
}
func (s *statusDB) GetStatusByURL(ctx context.Context, url string) (*gtsmodel.Status, db.Error) {
return s.getStatus(
ctx,
- func() (*gtsmodel.Status, bool) {
- return s.cache.GetByURL(url)
- },
+ "URL",
func(status *gtsmodel.Status) error {
return s.newStatusQ(status).Where("? = ?", bun.Ident("status.url"), url).Scan(ctx)
},
+ url,
)
}
-func (s *statusDB) getStatus(ctx context.Context, cacheGet func() (*gtsmodel.Status, bool), dbQuery func(*gtsmodel.Status) error) (*gtsmodel.Status, db.Error) {
- // Attempt to fetch cached status
- status, cached := cacheGet()
-
- if !cached {
- status = &gtsmodel.Status{}
+func (s *statusDB) getStatus(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Status) error, keyParts ...any) (*gtsmodel.Status, db.Error) {
+ // Fetch status from database cache with loader callback
+ status, err := s.cache.Load(lookup, func() (*gtsmodel.Status, error) {
+ var status gtsmodel.Status
// Not cached! Perform database query
- if err := dbQuery(status); err != nil {
+ if err := dbQuery(&status); err != nil {
return nil, s.conn.ProcessError(err)
}
// If there is boosted, fetch from DB also
if status.BoostOfID != "" {
- boostOf, err := s.GetStatusByID(ctx, status.BoostOfID)
- if err == nil {
- status.BoostOf = boostOf
+ status.BoostOf = &gtsmodel.Status{}
+ err := s.newStatusQ(status.BoostOf).
+ Where("? = ?", bun.Ident("status.id"), status.BoostOfID).
+ Scan(ctx)
+ if err != nil {
+ return nil, s.conn.ProcessError(err)
}
}
- // Place in the cache
- s.cache.Put(status)
+ return &status, nil
+ }, keyParts...)
+ if err != nil {
+ // error already processed
+ return nil, err
}
// Set the status author account
@@ -137,73 +151,66 @@ func (s *statusDB) getStatus(ctx context.Context, cacheGet func() (*gtsmodel.Sta
}
func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Error {
- err := s.conn.RunInTx(ctx, func(tx bun.Tx) error {
- // create links between this status and any emojis it uses
- for _, i := range status.EmojiIDs {
- if _, err := tx.
- NewInsert().
- Model(&gtsmodel.StatusToEmoji{
- StatusID: status.ID,
- EmojiID: i,
- }).Exec(ctx); err != nil {
- err = s.conn.errProc(err)
- if !errors.Is(err, db.ErrAlreadyExists) {
- return err
+ return s.cache.Store(status, func() error {
+ // It is safe to run this database transaction within cache.Store
+ // as the cache does not attempt a mutex lock until AFTER hook.
+ //
+ return s.conn.RunInTx(ctx, func(tx bun.Tx) error {
+ // create links between this status and any emojis it uses
+ for _, i := range status.EmojiIDs {
+ if _, err := tx.
+ NewInsert().
+ Model(&gtsmodel.StatusToEmoji{
+ StatusID: status.ID,
+ EmojiID: i,
+ }).Exec(ctx); err != nil {
+ err = s.conn.ProcessError(err)
+ if !errors.Is(err, db.ErrAlreadyExists) {
+ return err
+ }
}
}
- }
- // create links between this status and any tags it uses
- for _, i := range status.TagIDs {
- if _, err := tx.
- NewInsert().
- Model(&gtsmodel.StatusToTag{
- StatusID: status.ID,
- TagID: i,
- }).Exec(ctx); err != nil {
- err = s.conn.errProc(err)
- if !errors.Is(err, db.ErrAlreadyExists) {
- return err
+ // create links between this status and any tags it uses
+ for _, i := range status.TagIDs {
+ if _, err := tx.
+ NewInsert().
+ Model(&gtsmodel.StatusToTag{
+ StatusID: status.ID,
+ TagID: i,
+ }).Exec(ctx); err != nil {
+ err = s.conn.ProcessError(err)
+ if !errors.Is(err, db.ErrAlreadyExists) {
+ return err
+ }
}
}
- }
- // change the status ID of the media attachments to the new status
- for _, a := range status.Attachments {
- a.StatusID = status.ID
- a.UpdatedAt = time.Now()
- if _, err := tx.
- NewUpdate().
- Model(a).
- Where("? = ?", bun.Ident("media_attachment.id"), a.ID).
- Exec(ctx); err != nil {
- err = s.conn.errProc(err)
- if !errors.Is(err, db.ErrAlreadyExists) {
- return err
+ // change the status ID of the media attachments to the new status
+ for _, a := range status.Attachments {
+ a.StatusID = status.ID
+ a.UpdatedAt = time.Now()
+ if _, err := tx.
+ NewUpdate().
+ Model(a).
+ Where("? = ?", bun.Ident("media_attachment.id"), a.ID).
+ Exec(ctx); err != nil {
+ err = s.conn.ProcessError(err)
+ if !errors.Is(err, db.ErrAlreadyExists) {
+ return err
+ }
}
}
- }
- // Finally, insert the status
- if _, err := tx.
- NewInsert().
- Model(status).
- Exec(ctx); err != nil {
+ // Finally, insert the status
+ _, err := tx.NewInsert().Model(status).Exec(ctx)
return err
- }
-
- return nil
+ })
})
- if err != nil {
- return s.conn.ProcessError(err)
- }
-
- s.cache.Put(status)
- return nil
}
-func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status) (*gtsmodel.Status, db.Error) {
- err := s.conn.RunInTx(ctx, func(tx bun.Tx) error {
+func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status) db.Error {
+ if err := s.conn.RunInTx(ctx, func(tx bun.Tx) error {
// create links between this status and any emojis it uses
for _, i := range status.EmojiIDs {
if _, err := tx.
@@ -212,7 +219,7 @@ func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status) (*
StatusID: status.ID,
EmojiID: i,
}).Exec(ctx); err != nil {
- err = s.conn.errProc(err)
+ err = s.conn.ProcessError(err)
if !errors.Is(err, db.ErrAlreadyExists) {
return err
}
@@ -227,14 +234,14 @@ func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status) (*
StatusID: status.ID,
TagID: i,
}).Exec(ctx); err != nil {
- err = s.conn.errProc(err)
+ err = s.conn.ProcessError(err)
if !errors.Is(err, db.ErrAlreadyExists) {
return err
}
}
}
- // change the status ID of the media attachments to this status
+ // change the status ID of the media attachments to the new status
for _, a := range status.Attachments {
a.StatusID = status.ID
a.UpdatedAt = time.Now()
@@ -243,31 +250,31 @@ func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status) (*
Model(a).
Where("? = ?", bun.Ident("media_attachment.id"), a.ID).
Exec(ctx); err != nil {
- return err
+ err = s.conn.ProcessError(err)
+ if !errors.Is(err, db.ErrAlreadyExists) {
+ return err
+ }
}
}
- // Finally, update the status itself
- if _, err := tx.
+ // Finally, insert the status
+ _, err := tx.
NewUpdate().
Model(status).
Where("? = ?", bun.Ident("status.id"), status.ID).
- Exec(ctx); err != nil {
- return err
- }
-
- return nil
- })
- if err != nil {
- return nil, s.conn.ProcessError(err)
+ Exec(ctx)
+ return err
+ }); err != nil {
+ return err
}
- s.cache.Put(status)
- return status, nil
+ // Drop any old value from cache by this ID
+ s.cache.Invalidate("ID", status.ID)
+ return nil
}
func (s *statusDB) DeleteStatusByID(ctx context.Context, id string) db.Error {
- err := s.conn.RunInTx(ctx, func(tx bun.Tx) error {
+ if err := s.conn.RunInTx(ctx, func(tx bun.Tx) error {
// delete links between this status and any emojis it uses
if _, err := tx.
NewDelete().
@@ -296,36 +303,41 @@ func (s *statusDB) DeleteStatusByID(ctx context.Context, id string) db.Error {
}
return nil
- })
- if err != nil {
- return s.conn.ProcessError(err)
+ }); err != nil {
+ return err
}
- s.cache.Invalidate(id)
+ // Drop any old value from cache by this ID
+ s.cache.Invalidate("ID", id)
return nil
}
func (s *statusDB) GetStatusParents(ctx context.Context, status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, db.Error) {
- parents := []*gtsmodel.Status{}
- s.statusParent(ctx, status, &parents, onlyDirect)
- return parents, nil
-}
-
-func (s *statusDB) statusParent(ctx context.Context, status *gtsmodel.Status, foundStatuses *[]*gtsmodel.Status, onlyDirect bool) {
- if status.InReplyToID == "" {
- return
+ if onlyDirect {
+ // Only want the direct parent, no further than first level
+ parent, err := s.GetStatusByID(ctx, status.InReplyToID)
+ if err != nil {
+ return nil, err
+ }
+ return []*gtsmodel.Status{parent}, nil
}
- parentStatus, err := s.GetStatusByID(ctx, status.InReplyToID)
- if err == nil {
- *foundStatuses = append(*foundStatuses, parentStatus)
- }
+ var parents []*gtsmodel.Status
- if onlyDirect {
- return
+ for id := status.InReplyToID; id != ""; {
+ parent, err := s.GetStatusByID(ctx, id)
+ if err != nil {
+ return nil, err
+ }
+
+ // Append parent to slice
+ parents = append(parents, parent)
+
+ // Set the next parent ID
+ id = parent.InReplyToID
}
- s.statusParent(ctx, parentStatus, foundStatuses, false)
+ return parents, nil
}
func (s *statusDB) GetStatusChildren(ctx context.Context, status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, db.Error) {
@@ -350,7 +362,7 @@ func (s *statusDB) GetStatusChildren(ctx context.Context, status *gtsmodel.Statu
}
func (s *statusDB) statusChildren(ctx context.Context, status *gtsmodel.Status, foundStatuses *list.List, onlyDirect bool, minID string) {
- childIDs := []string{}
+ var childIDs []string
q := s.conn.
NewSelect().
@@ -471,6 +483,7 @@ func (s *statusDB) GetStatusFaves(ctx context.Context, status *gtsmodel.Status)
if err := q.Scan(ctx); err != nil {
return nil, s.conn.ProcessError(err)
}
+
return faves, nil
}
diff --git a/internal/db/bundb/timeline_test.go b/internal/db/bundb/timeline_test.go
index 9b6365621..066f55234 100644
--- a/internal/db/bundb/timeline_test.go
+++ b/internal/db/bundb/timeline_test.go
@@ -35,44 +35,52 @@ type TimelineTestSuite struct {
}
func (suite *TimelineTestSuite) TestGetPublicTimeline() {
- s, err := suite.db.GetPublicTimeline(context.Background(), "", "", "", 20, false)
+ ctx := context.Background()
+
+ s, err := suite.db.GetPublicTimeline(ctx, "", "", "", 20, false)
suite.NoError(err)
suite.Len(s, 6)
}
func (suite *TimelineTestSuite) TestGetPublicTimelineWithFutureStatus() {
+ ctx := context.Background()
+
futureStatus := getFutureStatus()
- if err := suite.db.Put(context.Background(), futureStatus); err != nil {
- suite.FailNow(err.Error())
- }
+ err := suite.db.PutStatus(ctx, futureStatus)
+ suite.NoError(err)
- s, err := suite.db.GetPublicTimeline(context.Background(), "", "", "", 20, false)
+ s, err := suite.db.GetPublicTimeline(ctx, "", "", "", 20, false)
suite.NoError(err)
+ suite.NotContains(s, futureStatus)
suite.Len(s, 6)
}
func (suite *TimelineTestSuite) TestGetHomeTimeline() {
+ ctx := context.Background()
+
viewingAccount := suite.testAccounts["local_account_1"]
- s, err := suite.db.GetHomeTimeline(context.Background(), viewingAccount.ID, "", "", "", 20, false)
+ s, err := suite.db.GetHomeTimeline(ctx, viewingAccount.ID, "", "", "", 20, false)
suite.NoError(err)
suite.Len(s, 16)
}
func (suite *TimelineTestSuite) TestGetHomeTimelineWithFutureStatus() {
+ ctx := context.Background()
+
viewingAccount := suite.testAccounts["local_account_1"]
futureStatus := getFutureStatus()
- if err := suite.db.Put(context.Background(), futureStatus); err != nil {
- suite.FailNow(err.Error())
- }
+ err := suite.db.PutStatus(ctx, futureStatus)
+ suite.NoError(err)
s, err := suite.db.GetHomeTimeline(context.Background(), viewingAccount.ID, "", "", "", 20, false)
suite.NoError(err)
+ suite.NotContains(s, futureStatus)
suite.Len(s, 16)
}
diff --git a/internal/db/bundb/tombstone.go b/internal/db/bundb/tombstone.go
index 7ce3327a7..309a39fd3 100644
--- a/internal/db/bundb/tombstone.go
+++ b/internal/db/bundb/tombstone.go
@@ -43,7 +43,7 @@ func (t *tombstoneDB) init() {
t2 := new(gtsmodel.Tombstone)
*t2 = *t1
return t2
- }, 1000)
+ }, 100)
// Set cache TTL and start sweep routine
t.cache.SetTTL(time.Minute*5, false)
diff --git a/internal/db/bundb/user.go b/internal/db/bundb/user.go
index aa2f4c2c8..d9b281a6f 100644
--- a/internal/db/bundb/user.go
+++ b/internal/db/bundb/user.go
@@ -22,7 +22,7 @@ import (
"context"
"time"
- "github.com/superseriousbusiness/gotosocial/internal/cache"
+ "codeberg.org/gruf/go-cache/v3/result"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/uptrace/bun"
@@ -30,111 +30,121 @@ import (
type userDB struct {
conn *DBConn
- cache *cache.UserCache
+ cache *result.Cache[*gtsmodel.User]
}
-func (u *userDB) newUserQ(user *gtsmodel.User) *bun.SelectQuery {
- return u.conn.
- NewSelect().
- Model(user).
- Relation("Account")
+func (u *userDB) init() {
+ // Initialize user result cache
+ u.cache = result.NewSized([]result.Lookup{
+ {Name: "ID"},
+ {Name: "AccountID"},
+ {Name: "Email"},
+ {Name: "ConfirmationToken"},
+ }, func(u1 *gtsmodel.User) *gtsmodel.User {
+ u2 := new(gtsmodel.User)
+ *u2 = *u1
+ return u2
+ }, 1000)
+
+ // Set cache TTL and start sweep routine
+ u.cache.SetTTL(time.Minute*5, false)
+ u.cache.Start(time.Second * 10)
}
-func (u *userDB) getUser(ctx context.Context, cacheGet func() (*gtsmodel.User, bool), dbQuery func(*gtsmodel.User) error) (*gtsmodel.User, db.Error) {
- // Attempt to fetch cached user
- user, cached := cacheGet()
+func (u *userDB) GetUserByID(ctx context.Context, id string) (*gtsmodel.User, db.Error) {
+ return u.cache.Load("ID", func() (*gtsmodel.User, error) {
+ var user gtsmodel.User
- if !cached {
- user = &gtsmodel.User{}
+ q := u.conn.
+ NewSelect().
+ Model(&user).
+ Relation("Account").
+ Where("? = ?", bun.Ident("user.id"), id)
- // Not cached! Perform database query
- err := dbQuery(user)
- if err != nil {
+ if err := q.Scan(ctx); err != nil {
return nil, u.conn.ProcessError(err)
}
- // Place in the cache
- u.cache.Put(user)
- }
-
- return user, nil
-}
-
-func (u *userDB) GetUserByID(ctx context.Context, id string) (*gtsmodel.User, db.Error) {
- return u.getUser(
- ctx,
- func() (*gtsmodel.User, bool) {
- return u.cache.GetByID(id)
- },
- func(user *gtsmodel.User) error {
- return u.newUserQ(user).Where("? = ?", bun.Ident("user.id"), id).Scan(ctx)
- },
- )
+ return &user, nil
+ }, id)
}
func (u *userDB) GetUserByAccountID(ctx context.Context, accountID string) (*gtsmodel.User, db.Error) {
- return u.getUser(
- ctx,
- func() (*gtsmodel.User, bool) {
- return u.cache.GetByAccountID(accountID)
- },
- func(user *gtsmodel.User) error {
- return u.newUserQ(user).Where("? = ?", bun.Ident("user.account_id"), accountID).Scan(ctx)
- },
- )
+ return u.cache.Load("AccountID", func() (*gtsmodel.User, error) {
+ var user gtsmodel.User
+
+ q := u.conn.
+ NewSelect().
+ Model(&user).
+ Relation("Account").
+ Where("? = ?", bun.Ident("user.account_id"), accountID)
+
+ if err := q.Scan(ctx); err != nil {
+ return nil, u.conn.ProcessError(err)
+ }
+
+ return &user, nil
+ }, accountID)
}
func (u *userDB) GetUserByEmailAddress(ctx context.Context, emailAddress string) (*gtsmodel.User, db.Error) {
- return u.getUser(
- ctx,
- func() (*gtsmodel.User, bool) {
- return u.cache.GetByEmail(emailAddress)
- },
- func(user *gtsmodel.User) error {
- return u.newUserQ(user).Where("? = ?", bun.Ident("user.email"), emailAddress).Scan(ctx)
- },
- )
+ return u.cache.Load("Email", func() (*gtsmodel.User, error) {
+ var user gtsmodel.User
+
+ q := u.conn.
+ NewSelect().
+ Model(&user).
+ Relation("Account").
+ Where("? = ?", bun.Ident("user.email"), emailAddress)
+
+ if err := q.Scan(ctx); err != nil {
+ return nil, u.conn.ProcessError(err)
+ }
+
+ return &user, nil
+ }, emailAddress)
}
func (u *userDB) GetUserByConfirmationToken(ctx context.Context, confirmationToken string) (*gtsmodel.User, db.Error) {
- return u.getUser(
- ctx,
- func() (*gtsmodel.User, bool) {
- return u.cache.GetByConfirmationToken(confirmationToken)
- },
- func(user *gtsmodel.User) error {
- return u.newUserQ(user).Where("? = ?", bun.Ident("user.confirmation_token"), confirmationToken).Scan(ctx)
- },
- )
-}
+ return u.cache.Load("ConfirmationToken", func() (*gtsmodel.User, error) {
+ var user gtsmodel.User
-func (u *userDB) PutUser(ctx context.Context, user *gtsmodel.User) (*gtsmodel.User, db.Error) {
- if _, err := u.conn.
- NewInsert().
- Model(user).
- Exec(ctx); err != nil {
- return nil, u.conn.ProcessError(err)
- }
+ q := u.conn.
+ NewSelect().
+ Model(&user).
+ Relation("Account").
+ Where("? = ?", bun.Ident("user.confirmation_token"), confirmationToken)
- u.cache.Put(user)
- return user, nil
+ if err := q.Scan(ctx); err != nil {
+ return nil, u.conn.ProcessError(err)
+ }
+
+ return &user, nil
+ }, confirmationToken)
}
-func (u *userDB) UpdateUser(ctx context.Context, user *gtsmodel.User, columns ...string) (*gtsmodel.User, db.Error) {
+func (u *userDB) PutUser(ctx context.Context, user *gtsmodel.User) db.Error {
+ return u.cache.Store(user, func() error {
+ _, err := u.conn.
+ NewInsert().
+ Model(user).
+ Exec(ctx)
+ return u.conn.ProcessError(err)
+ })
+}
+
+func (u *userDB) UpdateUser(ctx context.Context, user *gtsmodel.User) db.Error {
// Update the user's last-updated
user.UpdatedAt = time.Now()
- if _, err := u.conn.
- NewUpdate().
- Model(user).
- Where("? = ?", bun.Ident("user.id"), user.ID).
- Column(columns...).
- Exec(ctx); err != nil {
- return nil, u.conn.ProcessError(err)
- }
-
- u.cache.Invalidate(user.ID)
- return user, nil
+ return u.cache.Store(user, func() error {
+ _, err := u.conn.
+ NewUpdate().
+ Model(user).
+ Where("? = ?", bun.Ident("user.id"), user.ID).
+ Exec(ctx)
+ return u.conn.ProcessError(err)
+ })
}
func (u *userDB) DeleteUserByID(ctx context.Context, userID string) db.Error {
@@ -146,6 +156,7 @@ func (u *userDB) DeleteUserByID(ctx context.Context, userID string) db.Error {
return u.conn.ProcessError(err)
}
- u.cache.Invalidate(userID)
+ // Invalidate user from cache
+ u.cache.Invalidate("ID", userID)
return nil
}
diff --git a/internal/db/bundb/user_test.go b/internal/db/bundb/user_test.go
index 6ad59fc8e..18f67dde5 100644
--- a/internal/db/bundb/user_test.go
+++ b/internal/db/bundb/user_test.go
@@ -50,21 +50,20 @@ func (suite *UserTestSuite) TestGetUserByAccountID() {
func (suite *UserTestSuite) TestUpdateUserSelectedColumns() {
testUser := suite.testUsers["local_account_1"]
- user := &gtsmodel.User{
- ID: testUser.ID,
- Email: "whatever",
- Locale: "es",
- }
- user, err := suite.db.UpdateUser(context.Background(), user, "email", "locale")
+ updateUser := new(gtsmodel.User)
+ *updateUser = *testUser
+ updateUser.Email = "whatever"
+ updateUser.Locale = "es"
+
+ err := suite.db.UpdateUser(context.Background(), updateUser)
suite.NoError(err)
- suite.NotNil(user)
dbUser, err := suite.db.GetUserByID(context.Background(), testUser.ID)
suite.NoError(err)
suite.NotNil(dbUser)
- suite.Equal("whatever", dbUser.Email)
- suite.Equal("es", dbUser.Locale)
+ suite.Equal(updateUser.Email, dbUser.Email)
+ suite.Equal(updateUser.Locale, dbUser.Locale)
suite.Equal(testUser.AccountID, dbUser.AccountID)
}
diff --git a/internal/db/status.go b/internal/db/status.go
index 55cec5beb..d0983122b 100644
--- a/internal/db/status.go
+++ b/internal/db/status.go
@@ -39,7 +39,7 @@ type Status interface {
PutStatus(ctx context.Context, status *gtsmodel.Status) Error
// UpdateStatus updates one status in the database and returns it to the caller.
- UpdateStatus(ctx context.Context, status *gtsmodel.Status) (*gtsmodel.Status, Error)
+ UpdateStatus(ctx context.Context, status *gtsmodel.Status) Error
// DeleteStatusByID deletes one status from the database.
DeleteStatusByID(ctx context.Context, id string) Error
diff --git a/internal/db/user.go b/internal/db/user.go
index a4d48db56..d01a8862a 100644
--- a/internal/db/user.go
+++ b/internal/db/user.go
@@ -34,9 +34,10 @@ type User interface {
GetUserByEmailAddress(ctx context.Context, emailAddress string) (*gtsmodel.User, Error)
// GetUserByConfirmationToken returns one user by its confirmation token, or an error if something goes wrong.
GetUserByConfirmationToken(ctx context.Context, confirmationToken string) (*gtsmodel.User, Error)
- // UpdateUser updates one user by its primary key. If columns is set, only given columns
- // will be updated. If not set, all columns will be updated.
- UpdateUser(ctx context.Context, user *gtsmodel.User, columns ...string) (*gtsmodel.User, Error)
+ // PutUser will attempt to place user in the database
+ PutUser(ctx context.Context, user *gtsmodel.User) Error
+ // UpdateUser updates one user by its primary key.
+ UpdateUser(ctx context.Context, user *gtsmodel.User) Error
// DeleteUserByID deletes one user by its ID.
DeleteUserByID(ctx context.Context, userID string) Error
}