diff options
| author | 2022-11-15 18:45:15 +0000 | |
|---|---|---|
| committer | 2022-11-15 18:45:15 +0000 | |
| commit | 8598dea98b872647393117704659878d9b38d4fc (patch) | |
| tree | 1940168912dc7f54af723439dbc9f6e0a42f30ae /internal/db/bundb | |
| parent | [docs] Both HTTP proxies and NAT can cause rate limiting issues (#1053) (diff) | |
| download | gotosocial-8598dea98b872647393117704659878d9b38d4fc.tar.xz | |
[chore] update database caching library (#1040)
* convert most of the caches to use result.Cache{}
* add caching of emojis
* fix issues causing failing tests
* update go-cache/v2 instances with v3
* fix getnotification
* add a note about the left-in StatusCreate comment
* update EmojiCategory db access to use new result.Cache{}
* fix possible panic in getstatusparents
* further proof that kim is not stinky
Diffstat (limited to 'internal/db/bundb')
| -rw-r--r-- | internal/db/bundb/account.go | 214 | ||||
| -rw-r--r-- | internal/db/bundb/account_test.go | 4 | ||||
| -rw-r--r-- | internal/db/bundb/admin.go | 33 | ||||
| -rw-r--r-- | internal/db/bundb/admin_test.go | 2 | ||||
| -rw-r--r-- | internal/db/bundb/bundb.go | 76 | ||||
| -rw-r--r-- | internal/db/bundb/domain.go | 94 | ||||
| -rw-r--r-- | internal/db/bundb/emoji.go | 127 | ||||
| -rw-r--r-- | internal/db/bundb/mention.go | 48 | ||||
| -rw-r--r-- | internal/db/bundb/notification.go | 54 | ||||
| -rw-r--r-- | internal/db/bundb/status.go | 251 | ||||
| -rw-r--r-- | internal/db/bundb/timeline_test.go | 26 | ||||
| -rw-r--r-- | internal/db/bundb/tombstone.go | 2 | ||||
| -rw-r--r-- | internal/db/bundb/user.go | 175 | ||||
| -rw-r--r-- | internal/db/bundb/user_test.go | 17 | 
14 files changed, 571 insertions, 552 deletions
| diff --git a/internal/db/bundb/account.go b/internal/db/bundb/account.go index 4813f4e17..1e9c390d8 100644 --- a/internal/db/bundb/account.go +++ b/internal/db/bundb/account.go @@ -24,7 +24,7 @@ import (  	"strings"  	"time" -	"github.com/superseriousbusiness/gotosocial/internal/cache" +	"codeberg.org/gruf/go-cache/v3/result"  	"github.com/superseriousbusiness/gotosocial/internal/config"  	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" @@ -35,10 +35,29 @@ import (  type accountDB struct {  	conn   *DBConn -	cache  *cache.AccountCache +	cache  *result.Cache[*gtsmodel.Account]  	status *statusDB  } +func (a *accountDB) init() { +	// Initialize account result cache +	a.cache = result.NewSized([]result.Lookup{ +		{Name: "ID"}, +		{Name: "URI"}, +		{Name: "URL"}, +		{Name: "Username.Domain"}, +		{Name: "PublicKeyURI"}, +	}, func(a1 *gtsmodel.Account) *gtsmodel.Account { +		a2 := new(gtsmodel.Account) +		*a2 = *a1 +		return a2 +	}, 1000) + +	// Set cache TTL and start sweep routine +	a.cache.SetTTL(time.Minute*5, false) +	a.cache.Start(time.Second * 10) +} +  func (a *accountDB) newAccountQ(account *gtsmodel.Account) *bun.SelectQuery {  	return a.conn.  		NewSelect(). @@ -51,45 +70,41 @@ func (a *accountDB) newAccountQ(account *gtsmodel.Account) *bun.SelectQuery {  func (a *accountDB) GetAccountByID(ctx context.Context, id string) (*gtsmodel.Account, db.Error) {  	return a.getAccount(  		ctx, -		func() (*gtsmodel.Account, bool) { -			return a.cache.GetByID(id) -		}, +		"ID",  		func(account *gtsmodel.Account) error {  			return a.newAccountQ(account).Where("? = ?", bun.Ident("account.id"), id).Scan(ctx)  		}, +		id,  	)  }  func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) {  	return a.getAccount(  		ctx, -		func() (*gtsmodel.Account, bool) { -			return a.cache.GetByURI(uri) -		}, +		"URI",  		func(account *gtsmodel.Account) error {  			return a.newAccountQ(account).Where("? = ?", bun.Ident("account.uri"), uri).Scan(ctx)  		}, +		uri,  	)  }  func (a *accountDB) GetAccountByURL(ctx context.Context, url string) (*gtsmodel.Account, db.Error) {  	return a.getAccount(  		ctx, -		func() (*gtsmodel.Account, bool) { -			return a.cache.GetByURL(url) -		}, +		"URL",  		func(account *gtsmodel.Account) error {  			return a.newAccountQ(account).Where("? = ?", bun.Ident("account.url"), url).Scan(ctx)  		}, +		url,  	)  }  func (a *accountDB) GetAccountByUsernameDomain(ctx context.Context, username string, domain string) (*gtsmodel.Account, db.Error) { +	username = strings.ToLower(username)  	return a.getAccount(  		ctx, -		func() (*gtsmodel.Account, bool) { -			return a.cache.GetByUsernameDomain(username, domain) -		}, +		"Username.Domain",  		func(account *gtsmodel.Account) error {  			q := a.newAccountQ(account) @@ -97,113 +112,117 @@ func (a *accountDB) GetAccountByUsernameDomain(ctx context.Context, username str  				q = q.Where("? = ?", bun.Ident("account.username"), username)  				q = q.Where("? = ?", bun.Ident("account.domain"), domain)  			} else { -				q = q.Where("? = ?", bun.Ident("account.username"), strings.ToLower(username)) +				q = q.Where("? = ?", bun.Ident("account.username"), username)  				q = q.Where("? IS NULL", bun.Ident("account.domain"))  			}  			return q.Scan(ctx)  		}, +		username, +		domain,  	)  }  func (a *accountDB) GetAccountByPubkeyID(ctx context.Context, id string) (*gtsmodel.Account, db.Error) {  	return a.getAccount(  		ctx, -		func() (*gtsmodel.Account, bool) { -			return a.cache.GetByPubkeyID(id) -		}, +		"PublicKeyURI",  		func(account *gtsmodel.Account) error {  			return a.newAccountQ(account).Where("? = ?", bun.Ident("account.public_key_uri"), id).Scan(ctx)  		}, +		id,  	)  } -func (a *accountDB) getAccount(ctx context.Context, cacheGet func() (*gtsmodel.Account, bool), dbQuery func(*gtsmodel.Account) error) (*gtsmodel.Account, db.Error) { -	// Attempt to fetch cached account -	account, cached := cacheGet() +func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, db.Error) { +	var username string -	if !cached { -		account = >smodel.Account{} +	if domain == "" { +		// I.e. our local instance account +		username = config.GetHost() +	} else { +		// A remote instance account +		username = domain +	} + +	return a.GetAccountByUsernameDomain(ctx, username, domain) +} + +func (a *accountDB) getAccount(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Account) error, keyParts ...any) (*gtsmodel.Account, db.Error) { +	return a.cache.Load(lookup, func() (*gtsmodel.Account, error) { +		var account gtsmodel.Account  		// Not cached! Perform database query -		err := dbQuery(account) -		if err != nil { +		if err := dbQuery(&account); err != nil {  			return nil, a.conn.ProcessError(err)  		} -		// Place in the cache -		a.cache.Put(account) -	} - -	return account, nil +		return &account, nil +	}, keyParts...)  } -func (a *accountDB) PutAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, db.Error) { -	if err := a.conn.RunInTx(ctx, func(tx bun.Tx) error { -		// create links between this account and any emojis it uses -		for _, i := range account.EmojiIDs { -			if _, err := tx.NewInsert().Model(>smodel.AccountToEmoji{ -				AccountID: account.ID, -				EmojiID:   i, -			}).Exec(ctx); err != nil { -				return err +func (a *accountDB) PutAccount(ctx context.Context, account *gtsmodel.Account) db.Error { +	return a.cache.Store(account, func() error { +		// It is safe to run this database transaction within cache.Store +		// as the cache does not attempt a mutex lock until AFTER hook. +		// +		return a.conn.RunInTx(ctx, func(tx bun.Tx) error { +			// create links between this account and any emojis it uses +			for _, i := range account.EmojiIDs { +				if _, err := tx.NewInsert().Model(>smodel.AccountToEmoji{ +					AccountID: account.ID, +					EmojiID:   i, +				}).Exec(ctx); err != nil { +					return err +				}  			} -		} -		// insert the account -		_, err := tx.NewInsert().Model(account).Exec(ctx) -		return err -	}); err != nil { -		return nil, a.conn.ProcessError(err) -	} - -	a.cache.Put(account) -	return account, nil +			// insert the account +			_, err := tx.NewInsert().Model(account).Exec(ctx) +			return err +		}) +	})  } -func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, db.Error) { +func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account) db.Error {  	// Update the account's last-updated  	account.UpdatedAt = time.Now() -	if err := a.conn.RunInTx(ctx, func(tx bun.Tx) error { -		// create links between this account and any emojis it uses -		// first clear out any old emoji links -		if _, err := tx. -			NewDelete(). -			TableExpr("? AS ?", bun.Ident("account_to_emojis"), bun.Ident("account_to_emoji")). -			Where("? = ?", bun.Ident("account_to_emoji.account_id"), account.ID). -			Exec(ctx); err != nil { -			return err -		} - -		// now populate new emoji links -		for _, i := range account.EmojiIDs { +	return a.cache.Store(account, func() error { +		// It is safe to run this database transaction within cache.Store +		// as the cache does not attempt a mutex lock until AFTER hook. +		// +		return a.conn.RunInTx(ctx, func(tx bun.Tx) error { +			// create links between this account and any emojis it uses +			// first clear out any old emoji links  			if _, err := tx. -				NewInsert(). -				Model(>smodel.AccountToEmoji{ -					AccountID: account.ID, -					EmojiID:   i, -				}).Exec(ctx); err != nil { +				NewDelete(). +				TableExpr("? AS ?", bun.Ident("account_to_emojis"), bun.Ident("account_to_emoji")). +				Where("? = ?", bun.Ident("account_to_emoji.account_id"), account.ID). +				Exec(ctx); err != nil {  				return err  			} -		} -		// update the account -		if _, err := tx. -			NewUpdate(). -			Model(account). -			Where("? = ?", bun.Ident("account.id"), account.ID). -			Exec(ctx); err != nil { -			return err -		} - -		return nil -	}); err != nil { -		return nil, a.conn.ProcessError(err) -	} +			// now populate new emoji links +			for _, i := range account.EmojiIDs { +				if _, err := tx. +					NewInsert(). +					Model(>smodel.AccountToEmoji{ +						AccountID: account.ID, +						EmojiID:   i, +					}).Exec(ctx); err != nil { +					return err +				} +			} -	a.cache.Put(account) -	return account, nil +			// update the account +			_, err := tx.NewUpdate(). +				Model(account). +				Where("? = ?", bun.Ident("account.id"), account.ID). +				Exec(ctx) +			return err +		}) +	})  }  func (a *accountDB) DeleteAccount(ctx context.Context, id string) db.Error { @@ -219,40 +238,19 @@ func (a *accountDB) DeleteAccount(ctx context.Context, id string) db.Error {  		// delete the account  		_, err := tx. -			NewUpdate(). +			NewDelete().  			TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")).  			Where("? = ?", bun.Ident("account.id"), id).  			Exec(ctx)  		return err  	}); err != nil { -		return a.conn.ProcessError(err) +		return err  	} -	a.cache.Invalidate(id) +	a.cache.Invalidate("ID", id)  	return nil  } -func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, db.Error) { -	account := new(gtsmodel.Account) - -	q := a.newAccountQ(account) - -	if domain != "" { -		q = q. -			Where("? = ?", bun.Ident("account.username"), domain). -			Where("? = ?", bun.Ident("account.domain"), domain) -	} else { -		q = q. -			Where("? = ?", bun.Ident("account.username"), config.GetHost()). -			WhereGroup(" AND ", whereEmptyOrNull("domain")) -	} - -	if err := q.Scan(ctx); err != nil { -		return nil, a.conn.ProcessError(err) -	} -	return account, nil -} -  func (a *accountDB) GetAccountLastPosted(ctx context.Context, accountID string, webOnly bool) (time.Time, db.Error) {  	createdAt := time.Time{} diff --git a/internal/db/bundb/account_test.go b/internal/db/bundb/account_test.go index 29594a740..50603623f 100644 --- a/internal/db/bundb/account_test.go +++ b/internal/db/bundb/account_test.go @@ -92,7 +92,7 @@ func (suite *AccountTestSuite) TestUpdateAccount() {  	testAccount.DisplayName = "new display name!"  	testAccount.EmojiIDs = []string{"01GD36ZKWTKY3T1JJ24JR7KY1Q", "01GD36ZV904SHBHNAYV6DX5QEF"} -	_, err := suite.db.UpdateAccount(ctx, testAccount) +	err := suite.db.UpdateAccount(ctx, testAccount)  	suite.NoError(err)  	updated, err := suite.db.GetAccountByID(ctx, testAccount.ID) @@ -127,7 +127,7 @@ func (suite *AccountTestSuite) TestUpdateAccount() {  	// update again to remove emoji associations  	testAccount.EmojiIDs = []string{} -	_, err = suite.db.UpdateAccount(ctx, testAccount) +	err = suite.db.UpdateAccount(ctx, testAccount)  	suite.NoError(err)  	updated, err = suite.db.GetAccountByID(ctx, testAccount.ID) diff --git a/internal/db/bundb/admin.go b/internal/db/bundb/admin.go index 44861a4bb..4d750581c 100644 --- a/internal/db/bundb/admin.go +++ b/internal/db/bundb/admin.go @@ -29,7 +29,6 @@ import (  	"time"  	"github.com/superseriousbusiness/gotosocial/internal/ap" -	"github.com/superseriousbusiness/gotosocial/internal/cache"  	"github.com/superseriousbusiness/gotosocial/internal/config"  	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" @@ -44,9 +43,9 @@ import (  const rsaKeyBits = 2048  type adminDB struct { -	conn         *DBConn -	userCache    *cache.UserCache -	accountCache *cache.AccountCache +	conn     *DBConn +	accounts *accountDB +	users    *userDB  }  func (a *adminDB) IsUsernameAvailable(ctx context.Context, username string) (bool, db.Error) { @@ -140,13 +139,9 @@ func (a *adminDB) NewSignup(ctx context.Context, username string, reason string,  		}  		// insert the new account! -		if _, err = a.conn. -			NewInsert(). -			Model(acct). -			Exec(ctx); err != nil { -			return nil, a.conn.ProcessError(err) +		if err := a.accounts.PutAccount(ctx, acct); err != nil { +			return nil, err  		} -		a.accountCache.Put(acct)  	}  	// we either created or already had an account by now, @@ -190,13 +185,9 @@ func (a *adminDB) NewSignup(ctx context.Context, username string, reason string,  	}  	// insert the user! -	if _, err = a.conn. -		NewInsert(). -		Model(u). -		Exec(ctx); err != nil { -		return nil, a.conn.ProcessError(err) +	if err := a.users.PutUser(ctx, u); err != nil { +		return nil, err  	} -	a.userCache.Put(u)  	return u, nil  } @@ -249,15 +240,11 @@ func (a *adminDB) CreateInstanceAccount(ctx context.Context) db.Error {  		FeaturedCollectionURI: newAccountURIs.CollectionURI,  	} -	insertQ := a.conn. -		NewInsert(). -		Model(acct) - -	if _, err := insertQ.Exec(ctx); err != nil { -		return a.conn.ProcessError(err) +	// insert the new account! +	if err := a.accounts.PutAccount(ctx, acct); err != nil { +		return err  	} -	a.accountCache.Put(acct)  	log.Infof("instance account %s CREATED with id %s", username, acct.ID)  	return nil  } diff --git a/internal/db/bundb/admin_test.go b/internal/db/bundb/admin_test.go index f0a869a9b..18e1f67e2 100644 --- a/internal/db/bundb/admin_test.go +++ b/internal/db/bundb/admin_test.go @@ -70,6 +70,8 @@ func (suite *AdminTestSuite) TestIsEmailAvailableDomainBlocked() {  }  func (suite *AdminTestSuite) TestCreateInstanceAccount() { +	// reinitialize test DB to clear caches +	suite.db = testrig.NewTestDB()  	// we need to take an empty db for this...  	testrig.StandardDBTeardown(suite.db)  	// ...with tables created but no data diff --git a/internal/db/bundb/bundb.go b/internal/db/bundb/bundb.go index cf6643f6b..de6749ca4 100644 --- a/internal/db/bundb/bundb.go +++ b/internal/db/bundb/bundb.go @@ -34,7 +34,6 @@ import (  	"github.com/google/uuid"  	"github.com/jackc/pgx/v4"  	"github.com/jackc/pgx/v4/stdlib" -	"github.com/superseriousbusiness/gotosocial/internal/cache"  	"github.com/superseriousbusiness/gotosocial/internal/config"  	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/db/bundb/migrations" @@ -46,7 +45,6 @@ import (  	"github.com/uptrace/bun/dialect/sqlitedialect"  	"github.com/uptrace/bun/migrate" -	grufcache "codeberg.org/gruf/go-cache/v2"  	"modernc.org/sqlite"  ) @@ -160,79 +158,63 @@ func NewBunDBService(ctx context.Context) (db.DB, error) {  		return nil, fmt.Errorf("db migration error: %s", err)  	} -	// Prepare caches required by more than one struct -	userCache := cache.NewUserCache() -	accountCache := cache.NewAccountCache() - -	// Prepare other caches -	// Prepare mentions cache -	// TODO: move into internal/cache -	mentionCache := grufcache.New[string, *gtsmodel.Mention]() -	mentionCache.SetTTL(time.Minute*5, false) -	mentionCache.Start(time.Second * 10) - -	// Prepare notifications cache -	// TODO: move into internal/cache -	notifCache := grufcache.New[string, *gtsmodel.Notification]() -	notifCache.SetTTL(time.Minute*5, false) -	notifCache.Start(time.Second * 10) -  	// Create DB structs that require ptrs to each other -	accounts := &accountDB{conn: conn, cache: accountCache} -	status := &statusDB{conn: conn, cache: cache.NewStatusCache()} -	emoji := &emojiDB{conn: conn, emojiCache: cache.NewEmojiCache(), categoryCache: cache.NewEmojiCategoryCache()} +	account := &accountDB{conn: conn} +	admin := &adminDB{conn: conn} +	domain := &domainDB{conn: conn} +	mention := &mentionDB{conn: conn} +	notif := ¬ificationDB{conn: conn} +	status := &statusDB{conn: conn} +	emoji := &emojiDB{conn: conn}  	timeline := &timelineDB{conn: conn}  	tombstone := &tombstoneDB{conn: conn} +	user := &userDB{conn: conn}  	// Setup DB cross-referencing -	accounts.status = status -	status.accounts = accounts +	account.status = status +	admin.users = user +	status.accounts = account  	timeline.status = status  	// Initialize db structs +	account.init() +	domain.init() +	emoji.init() +	mention.init() +	notif.init() +	status.init()  	tombstone.init() +	user.init()  	ps := &DBService{ -		Account: accounts, +		Account: account,  		Admin: &adminDB{ -			conn:         conn, -			userCache:    userCache, -			accountCache: accountCache, +			conn:     conn, +			accounts: account, +			users:    user,  		},  		Basic: &basicDB{  			conn: conn,  		}, -		Domain: &domainDB{ -			conn:  conn, -			cache: cache.NewDomainBlockCache(), -		}, -		Emoji: emoji, +		Domain: domain, +		Emoji:  emoji,  		Instance: &instanceDB{  			conn: conn,  		},  		Media: &mediaDB{  			conn: conn,  		}, -		Mention: &mentionDB{ -			conn:  conn, -			cache: mentionCache, -		}, -		Notification: ¬ificationDB{ -			conn:  conn, -			cache: notifCache, -		}, +		Mention:      mention, +		Notification: notif,  		Relationship: &relationshipDB{  			conn: conn,  		},  		Session: &sessionDB{  			conn: conn,  		}, -		Status:   status, -		Timeline: timeline, -		User: &userDB{ -			conn:  conn, -			cache: userCache, -		}, +		Status:    status, +		Timeline:  timeline, +		User:      user,  		Tombstone: tombstone,  		conn:      conn,  	} diff --git a/internal/db/bundb/domain.go b/internal/db/bundb/domain.go index 0a752d3f3..3fca8501b 100644 --- a/internal/db/bundb/domain.go +++ b/internal/db/bundb/domain.go @@ -20,11 +20,11 @@ package bundb  import (  	"context" -	"database/sql"  	"net/url"  	"strings" +	"time" -	"github.com/superseriousbusiness/gotosocial/internal/cache" +	"codeberg.org/gruf/go-cache/v3/result"  	"github.com/superseriousbusiness/gotosocial/internal/config"  	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" @@ -34,7 +34,22 @@ import (  type domainDB struct {  	conn  *DBConn -	cache *cache.DomainBlockCache +	cache *result.Cache[*gtsmodel.DomainBlock] +} + +func (d *domainDB) init() { +	// Initialize domain block result cache +	d.cache = result.NewSized([]result.Lookup{ +		{Name: "Domain"}, +	}, func(d1 *gtsmodel.DomainBlock) *gtsmodel.DomainBlock { +		d2 := new(gtsmodel.DomainBlock) +		*d2 = *d1 +		return d2 +	}, 1000) + +	// Set cache TTL and start sweep routine +	d.cache.SetTTL(time.Minute*5, false) +	d.cache.Start(time.Second * 10)  }  // normalizeDomain converts the given domain to lowercase @@ -49,76 +64,53 @@ func normalizeDomain(domain string) (out string, err error) {  }  func (d *domainDB) CreateDomainBlock(ctx context.Context, block *gtsmodel.DomainBlock) db.Error { -	domain, err := normalizeDomain(block.Domain) +	var err error + +	block.Domain, err = normalizeDomain(block.Domain)  	if err != nil {  		return err  	} -	block.Domain = domain -	// Attempt to insert new domain block -	if _, err := d.conn.NewInsert(). -		Model(block). -		Exec(ctx); err != nil { +	return d.cache.Store(block, func() error { +		_, err := d.conn.NewInsert(). +			Model(block). +			Exec(ctx)  		return d.conn.ProcessError(err) -	} - -	// Cache this domain block -	d.cache.Put(block.Domain, block) - -	return nil +	})  }  func (d *domainDB) GetDomainBlock(ctx context.Context, domain string) (*gtsmodel.DomainBlock, db.Error) {  	var err error +  	domain, err = normalizeDomain(domain)  	if err != nil {  		return nil, err  	} -	// Check for easy case, domain referencing *us* -	if domain == "" || domain == config.GetAccountDomain() { -		return nil, db.ErrNoEntries -	} - -	// Check for already cached rblock -	if block, ok := d.cache.GetByDomain(domain); ok { -		// A 'nil' return value is a sentinel value for no block -		if block == nil { +	return d.cache.Load("Domain", func() (*gtsmodel.DomainBlock, error) { +		// Check for easy case, domain referencing *us* +		if domain == "" || domain == config.GetAccountDomain() {  			return nil, db.ErrNoEntries  		} -		// Else, this block exists -		return block, nil -	} +		var block gtsmodel.DomainBlock -	block := >smodel.DomainBlock{} +		q := d.conn. +			NewSelect(). +			Model(&block). +			Where("? = ?", bun.Ident("domain_block.domain"), domain). +			Limit(1) +		if err := q.Scan(ctx); err != nil { +			return nil, d.conn.ProcessError(err) +		} -	q := d.conn. -		NewSelect(). -		Model(block). -		Where("? = ?", bun.Ident("domain_block.domain"), domain). -		Limit(1) - -	// Query database for domain block -	switch err := q.Scan(ctx); err { -	// No error, block found -	case nil: -		d.cache.Put(domain, block) -		return block, nil - -	// No error, simply not found -	case sql.ErrNoRows: -		d.cache.Put(domain, nil) -		return nil, db.ErrNoEntries - -	// Any other db error -	default: -		return nil, d.conn.ProcessError(err) -	} +		return &block, nil +	}, domain)  }  func (d *domainDB) DeleteDomainBlock(ctx context.Context, domain string) db.Error {  	var err error +  	domain, err = normalizeDomain(domain)  	if err != nil {  		return err @@ -133,7 +125,7 @@ func (d *domainDB) DeleteDomainBlock(ctx context.Context, domain string) db.Erro  	}  	// Clear domain from cache -	d.cache.InvalidateByDomain(domain) +	d.cache.Invalidate(domain)  	return nil  } diff --git a/internal/db/bundb/emoji.go b/internal/db/bundb/emoji.go index 81374ce78..55e0ee3ff 100644 --- a/internal/db/bundb/emoji.go +++ b/internal/db/bundb/emoji.go @@ -23,7 +23,7 @@ import (  	"strings"  	"time" -	"github.com/superseriousbusiness/gotosocial/internal/cache" +	"codeberg.org/gruf/go-cache/v3/result"  	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"  	"github.com/superseriousbusiness/gotosocial/internal/log" @@ -33,8 +33,40 @@ import (  type emojiDB struct {  	conn          *DBConn -	emojiCache    *cache.EmojiCache -	categoryCache *cache.EmojiCategoryCache +	emojiCache    *result.Cache[*gtsmodel.Emoji] +	categoryCache *result.Cache[*gtsmodel.EmojiCategory] +} + +func (e *emojiDB) init() { +	// Initialize emoji result cache +	e.emojiCache = result.NewSized([]result.Lookup{ +		{Name: "ID"}, +		{Name: "URI"}, +		{Name: "Shortcode.Domain"}, +		{Name: "ImageStaticURL"}, +	}, func(e1 *gtsmodel.Emoji) *gtsmodel.Emoji { +		e2 := new(gtsmodel.Emoji) +		*e2 = *e1 +		return e2 +	}, 1000) + +	// Set cache TTL and start sweep routine +	e.emojiCache.SetTTL(time.Minute*5, false) +	e.emojiCache.Start(time.Second * 10) + +	// Initialize category result cache +	e.categoryCache = result.NewSized([]result.Lookup{ +		{Name: "ID"}, +		{Name: "Name"}, +	}, func(c1 *gtsmodel.EmojiCategory) *gtsmodel.EmojiCategory { +		c2 := new(gtsmodel.EmojiCategory) +		*c2 = *c1 +		return c2 +	}, 1000) + +	// Set cache TTL and start sweep routine +	e.categoryCache.SetTTL(time.Minute*5, false) +	e.categoryCache.Start(time.Second * 10)  }  func (e *emojiDB) newEmojiQ(emoji *gtsmodel.Emoji) *bun.SelectQuery { @@ -51,12 +83,10 @@ func (e *emojiDB) newEmojiCategoryQ(emojiCategory *gtsmodel.EmojiCategory) *bun.  }  func (e *emojiDB) PutEmoji(ctx context.Context, emoji *gtsmodel.Emoji) db.Error { -	if _, err := e.conn.NewInsert().Model(emoji).Exec(ctx); err != nil { +	return e.emojiCache.Store(emoji, func() error { +		_, err := e.conn.NewInsert().Model(emoji).Exec(ctx)  		return e.conn.ProcessError(err) -	} - -	e.emojiCache.Put(emoji) -	return nil +	})  }  func (e *emojiDB) UpdateEmoji(ctx context.Context, emoji *gtsmodel.Emoji, columns ...string) (*gtsmodel.Emoji, db.Error) { @@ -72,7 +102,7 @@ func (e *emojiDB) UpdateEmoji(ctx context.Context, emoji *gtsmodel.Emoji, column  		return nil, e.conn.ProcessError(err)  	} -	e.emojiCache.Invalidate(emoji.ID) +	e.emojiCache.Invalidate("ID", emoji.ID)  	return emoji, nil  } @@ -109,7 +139,7 @@ func (e *emojiDB) DeleteEmojiByID(ctx context.Context, id string) db.Error {  		return err  	} -	e.emojiCache.Invalidate(id) +	e.emojiCache.Invalidate("ID", id)  	return nil  } @@ -252,33 +282,29 @@ func (e *emojiDB) GetUseableEmojis(ctx context.Context) ([]*gtsmodel.Emoji, db.E  func (e *emojiDB) GetEmojiByID(ctx context.Context, id string) (*gtsmodel.Emoji, db.Error) {  	return e.getEmoji(  		ctx, -		func() (*gtsmodel.Emoji, bool) { -			return e.emojiCache.GetByID(id) -		}, +		"ID",  		func(emoji *gtsmodel.Emoji) error {  			return e.newEmojiQ(emoji).Where("? = ?", bun.Ident("emoji.id"), id).Scan(ctx)  		}, +		id,  	)  }  func (e *emojiDB) GetEmojiByURI(ctx context.Context, uri string) (*gtsmodel.Emoji, db.Error) {  	return e.getEmoji(  		ctx, -		func() (*gtsmodel.Emoji, bool) { -			return e.emojiCache.GetByURI(uri) -		}, +		"URI",  		func(emoji *gtsmodel.Emoji) error {  			return e.newEmojiQ(emoji).Where("? = ?", bun.Ident("emoji.uri"), uri).Scan(ctx)  		}, +		uri,  	)  }  func (e *emojiDB) GetEmojiByShortcodeDomain(ctx context.Context, shortcode string, domain string) (*gtsmodel.Emoji, db.Error) {  	return e.getEmoji(  		ctx, -		func() (*gtsmodel.Emoji, bool) { -			return e.emojiCache.GetByShortcodeDomain(shortcode, domain) -		}, +		"Shortcode.Domain",  		func(emoji *gtsmodel.Emoji) error {  			q := e.newEmojiQ(emoji) @@ -292,31 +318,30 @@ func (e *emojiDB) GetEmojiByShortcodeDomain(ctx context.Context, shortcode strin  			return q.Scan(ctx)  		}, +		shortcode, +		domain,  	)  }  func (e *emojiDB) GetEmojiByStaticURL(ctx context.Context, imageStaticURL string) (*gtsmodel.Emoji, db.Error) {  	return e.getEmoji(  		ctx, -		func() (*gtsmodel.Emoji, bool) { -			return e.emojiCache.GetByImageStaticURL(imageStaticURL) -		}, +		"ImageStaticURL",  		func(emoji *gtsmodel.Emoji) error {  			return e.  				newEmojiQ(emoji).  				Where("? = ?", bun.Ident("emoji.image_static_url"), imageStaticURL).  				Scan(ctx)  		}, +		imageStaticURL,  	)  }  func (e *emojiDB) PutEmojiCategory(ctx context.Context, emojiCategory *gtsmodel.EmojiCategory) db.Error { -	if _, err := e.conn.NewInsert().Model(emojiCategory).Exec(ctx); err != nil { +	return e.categoryCache.Store(emojiCategory, func() error { +		_, err := e.conn.NewInsert().Model(emojiCategory).Exec(ctx)  		return e.conn.ProcessError(err) -	} - -	e.categoryCache.Put(emojiCategory) -	return nil +	})  }  func (e *emojiDB) GetEmojiCategories(ctx context.Context) ([]*gtsmodel.EmojiCategory, db.Error) { @@ -338,45 +363,36 @@ func (e *emojiDB) GetEmojiCategories(ctx context.Context) ([]*gtsmodel.EmojiCate  func (e *emojiDB) GetEmojiCategory(ctx context.Context, id string) (*gtsmodel.EmojiCategory, db.Error) {  	return e.getEmojiCategory(  		ctx, -		func() (*gtsmodel.EmojiCategory, bool) { -			return e.categoryCache.GetByID(id) -		}, +		"ID",  		func(emojiCategory *gtsmodel.EmojiCategory) error {  			return e.newEmojiCategoryQ(emojiCategory).Where("? = ?", bun.Ident("emoji_category.id"), id).Scan(ctx)  		}, +		id,  	)  }  func (e *emojiDB) GetEmojiCategoryByName(ctx context.Context, name string) (*gtsmodel.EmojiCategory, db.Error) {  	return e.getEmojiCategory(  		ctx, -		func() (*gtsmodel.EmojiCategory, bool) { -			return e.categoryCache.GetByName(name) -		}, +		"Name",  		func(emojiCategory *gtsmodel.EmojiCategory) error {  			return e.newEmojiCategoryQ(emojiCategory).Where("LOWER(?) = ?", bun.Ident("emoji_category.name"), strings.ToLower(name)).Scan(ctx)  		}, +		name,  	)  } -func (e *emojiDB) getEmoji(ctx context.Context, cacheGet func() (*gtsmodel.Emoji, bool), dbQuery func(*gtsmodel.Emoji) error) (*gtsmodel.Emoji, db.Error) { -	// Attempt to fetch cached emoji -	emoji, cached := cacheGet() - -	if !cached { -		emoji = >smodel.Emoji{} +func (e *emojiDB) getEmoji(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Emoji) error, keyParts ...any) (*gtsmodel.Emoji, db.Error) { +	return e.emojiCache.Load(lookup, func() (*gtsmodel.Emoji, error) { +		var emoji gtsmodel.Emoji  		// Not cached! Perform database query -		err := dbQuery(emoji) -		if err != nil { +		if err := dbQuery(&emoji); err != nil {  			return nil, e.conn.ProcessError(err)  		} -		// Place in the cache -		e.emojiCache.Put(emoji) -	} - -	return emoji, nil +		return &emoji, nil +	}, keyParts...)  }  func (e *emojiDB) emojisFromIDs(ctx context.Context, emojiIDs []string) ([]*gtsmodel.Emoji, db.Error) { @@ -399,24 +415,17 @@ func (e *emojiDB) emojisFromIDs(ctx context.Context, emojiIDs []string) ([]*gtsm  	return emojis, nil  } -func (e *emojiDB) getEmojiCategory(ctx context.Context, cacheGet func() (*gtsmodel.EmojiCategory, bool), dbQuery func(*gtsmodel.EmojiCategory) error) (*gtsmodel.EmojiCategory, db.Error) { -	// Attempt to fetch cached emoji categories -	emojiCategory, cached := cacheGet() - -	if !cached { -		emojiCategory = >smodel.EmojiCategory{} +func (e *emojiDB) getEmojiCategory(ctx context.Context, lookup string, dbQuery func(*gtsmodel.EmojiCategory) error, keyParts ...any) (*gtsmodel.EmojiCategory, db.Error) { +	return e.categoryCache.Load(lookup, func() (*gtsmodel.EmojiCategory, error) { +		var category gtsmodel.EmojiCategory  		// Not cached! Perform database query -		err := dbQuery(emojiCategory) -		if err != nil { +		if err := dbQuery(&category); err != nil {  			return nil, e.conn.ProcessError(err)  		} -		// Place in the cache -		e.categoryCache.Put(emojiCategory) -	} - -	return emojiCategory, nil +		return &category, nil +	}, keyParts...)  }  func (e *emojiDB) emojiCategoriesFromIDs(ctx context.Context, emojiCategoryIDs []string) ([]*gtsmodel.EmojiCategory, db.Error) { diff --git a/internal/db/bundb/mention.go b/internal/db/bundb/mention.go index 355078021..303e16484 100644 --- a/internal/db/bundb/mention.go +++ b/internal/db/bundb/mention.go @@ -20,8 +20,9 @@ package bundb  import (  	"context" +	"time" -	"codeberg.org/gruf/go-cache/v2" +	"codeberg.org/gruf/go-cache/v3/result"  	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"  	"github.com/superseriousbusiness/gotosocial/internal/log" @@ -30,7 +31,22 @@ import (  type mentionDB struct {  	conn  *DBConn -	cache cache.Cache[string, *gtsmodel.Mention] +	cache *result.Cache[*gtsmodel.Mention] +} + +func (m *mentionDB) init() { +	// Initialize notification result cache +	m.cache = result.NewSized([]result.Lookup{ +		{Name: "ID"}, +	}, func(m1 *gtsmodel.Mention) *gtsmodel.Mention { +		m2 := new(gtsmodel.Mention) +		*m2 = *m1 +		return m2 +	}, 1000) + +	// Set cache TTL and start sweep routine +	m.cache.SetTTL(time.Minute*5, false) +	m.cache.Start(time.Second * 10)  }  func (m *mentionDB) newMentionQ(i interface{}) *bun.SelectQuery { @@ -42,27 +58,19 @@ func (m *mentionDB) newMentionQ(i interface{}) *bun.SelectQuery {  		Relation("TargetAccount")  } -func (m *mentionDB) getMentionDB(ctx context.Context, id string) (*gtsmodel.Mention, db.Error) { -	mention := gtsmodel.Mention{} - -	q := m.newMentionQ(&mention). -		Where("? = ?", bun.Ident("mention.id"), id) - -	if err := q.Scan(ctx); err != nil { -		return nil, m.conn.ProcessError(err) -	} +func (m *mentionDB) GetMention(ctx context.Context, id string) (*gtsmodel.Mention, db.Error) { +	return m.cache.Load("ID", func() (*gtsmodel.Mention, error) { +		var mention gtsmodel.Mention -	copy := mention -	m.cache.Set(mention.ID, ©) +		q := m.newMentionQ(&mention). +			Where("? = ?", bun.Ident("mention.id"), id) -	return &mention, nil -} +		if err := q.Scan(ctx); err != nil { +			return nil, m.conn.ProcessError(err) +		} -func (m *mentionDB) GetMention(ctx context.Context, id string) (*gtsmodel.Mention, db.Error) { -	if mention, ok := m.cache.Get(id); ok { -		return mention, nil -	} -	return m.getMentionDB(ctx, id) +		return &mention, nil +	}, id)  }  func (m *mentionDB) GetMentions(ctx context.Context, ids []string) ([]*gtsmodel.Mention, db.Error) { diff --git a/internal/db/bundb/notification.go b/internal/db/bundb/notification.go index 69e3cf39f..1874f81ea 100644 --- a/internal/db/bundb/notification.go +++ b/internal/db/bundb/notification.go @@ -20,8 +20,9 @@ package bundb  import (  	"context" +	"time" -	"codeberg.org/gruf/go-cache/v2" +	"codeberg.org/gruf/go-cache/v3/result"  	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"  	"github.com/superseriousbusiness/gotosocial/internal/log" @@ -30,31 +31,40 @@ import (  type notificationDB struct {  	conn  *DBConn -	cache cache.Cache[string, *gtsmodel.Notification] +	cache *result.Cache[*gtsmodel.Notification]  } -func (n *notificationDB) GetNotification(ctx context.Context, id string) (*gtsmodel.Notification, db.Error) { -	if notification, ok := n.cache.Get(id); ok { -		return notification, nil -	} - -	dst := gtsmodel.Notification{ID: id} - -	q := n.conn.NewSelect(). -		Model(&dst). -		Relation("OriginAccount"). -		Relation("TargetAccount"). -		Relation("Status"). -		Where("? = ?", bun.Ident("notification.id"), id) - -	if err := q.Scan(ctx); err != nil { -		return nil, n.conn.ProcessError(err) -	} +func (n *notificationDB) init() { +	// Initialize notification result cache +	n.cache = result.NewSized([]result.Lookup{ +		{Name: "ID"}, +	}, func(n1 *gtsmodel.Notification) *gtsmodel.Notification { +		n2 := new(gtsmodel.Notification) +		*n2 = *n1 +		return n2 +	}, 1000) + +	// Set cache TTL and start sweep routine +	n.cache.SetTTL(time.Minute*5, false) +	n.cache.Start(time.Second * 10) +} -	copy := dst -	n.cache.Set(id, ©) +func (n *notificationDB) GetNotification(ctx context.Context, id string) (*gtsmodel.Notification, db.Error) { +	return n.cache.Load("ID", func() (*gtsmodel.Notification, error) { +		var notif gtsmodel.Notification + +		q := n.conn.NewSelect(). +			Model(¬if). +			Relation("OriginAccount"). +			Relation("TargetAccount"). +			Relation("Status"). +			Where("? = ?", bun.Ident("notification.id"), id) +		if err := q.Scan(ctx); err != nil { +			return nil, n.conn.ProcessError(err) +		} -	return &dst, nil +		return ¬if, nil +	}, id)  }  func (n *notificationDB) GetNotifications(ctx context.Context, accountID string, excludeTypes []string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, db.Error) { diff --git a/internal/db/bundb/status.go b/internal/db/bundb/status.go index bc72c2849..b4ae40607 100644 --- a/internal/db/bundb/status.go +++ b/internal/db/bundb/status.go @@ -25,7 +25,7 @@ import (  	"errors"  	"time" -	"github.com/superseriousbusiness/gotosocial/internal/cache" +	"codeberg.org/gruf/go-cache/v3/result"  	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"  	"github.com/superseriousbusiness/gotosocial/internal/log" @@ -33,15 +33,28 @@ import (  )  type statusDB struct { -	conn  *DBConn -	cache *cache.StatusCache - -	// TODO: keep method definitions in same place but instead have receiver -	//       all point to one single "db" type, so they can all share methods -	//       and caches where necessary +	conn     *DBConn +	cache    *result.Cache[*gtsmodel.Status]  	accounts *accountDB  } +func (s *statusDB) init() { +	// Initialize status result cache +	s.cache = result.NewSized([]result.Lookup{ +		{Name: "ID"}, +		{Name: "URI"}, +		{Name: "URL"}, +	}, func(s1 *gtsmodel.Status) *gtsmodel.Status { +		s2 := new(gtsmodel.Status) +		*s2 = *s1 +		return s2 +	}, 1000) + +	// Set cache TTL and start sweep routine +	s.cache.SetTTL(time.Minute*5, false) +	s.cache.Start(time.Second * 10) +} +  func (s *statusDB) newStatusQ(status interface{}) *bun.SelectQuery {  	return s.conn.  		NewSelect(). @@ -68,61 +81,62 @@ func (s *statusDB) newFaveQ(faves interface{}) *bun.SelectQuery {  func (s *statusDB) GetStatusByID(ctx context.Context, id string) (*gtsmodel.Status, db.Error) {  	return s.getStatus(  		ctx, -		func() (*gtsmodel.Status, bool) { -			return s.cache.GetByID(id) -		}, +		"ID",  		func(status *gtsmodel.Status) error {  			return s.newStatusQ(status).Where("? = ?", bun.Ident("status.id"), id).Scan(ctx)  		}, +		id,  	)  }  func (s *statusDB) GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.Status, db.Error) {  	return s.getStatus(  		ctx, -		func() (*gtsmodel.Status, bool) { -			return s.cache.GetByURI(uri) -		}, +		"URI",  		func(status *gtsmodel.Status) error {  			return s.newStatusQ(status).Where("? = ?", bun.Ident("status.uri"), uri).Scan(ctx)  		}, +		uri,  	)  }  func (s *statusDB) GetStatusByURL(ctx context.Context, url string) (*gtsmodel.Status, db.Error) {  	return s.getStatus(  		ctx, -		func() (*gtsmodel.Status, bool) { -			return s.cache.GetByURL(url) -		}, +		"URL",  		func(status *gtsmodel.Status) error {  			return s.newStatusQ(status).Where("? = ?", bun.Ident("status.url"), url).Scan(ctx)  		}, +		url,  	)  } -func (s *statusDB) getStatus(ctx context.Context, cacheGet func() (*gtsmodel.Status, bool), dbQuery func(*gtsmodel.Status) error) (*gtsmodel.Status, db.Error) { -	// Attempt to fetch cached status -	status, cached := cacheGet() - -	if !cached { -		status = >smodel.Status{} +func (s *statusDB) getStatus(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Status) error, keyParts ...any) (*gtsmodel.Status, db.Error) { +	// Fetch status from database cache with loader callback +	status, err := s.cache.Load(lookup, func() (*gtsmodel.Status, error) { +		var status gtsmodel.Status  		// Not cached! Perform database query -		if err := dbQuery(status); err != nil { +		if err := dbQuery(&status); err != nil {  			return nil, s.conn.ProcessError(err)  		}  		// If there is boosted, fetch from DB also  		if status.BoostOfID != "" { -			boostOf, err := s.GetStatusByID(ctx, status.BoostOfID) -			if err == nil { -				status.BoostOf = boostOf +			status.BoostOf = >smodel.Status{} +			err := s.newStatusQ(status.BoostOf). +				Where("? = ?", bun.Ident("status.id"), status.BoostOfID). +				Scan(ctx) +			if err != nil { +				return nil, s.conn.ProcessError(err)  			}  		} -		// Place in the cache -		s.cache.Put(status) +		return &status, nil +	}, keyParts...) +	if err != nil { +		// error already processed +		return nil, err  	}  	// Set the status author account @@ -137,73 +151,66 @@ func (s *statusDB) getStatus(ctx context.Context, cacheGet func() (*gtsmodel.Sta  }  func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Error { -	err := s.conn.RunInTx(ctx, func(tx bun.Tx) error { -		// create links between this status and any emojis it uses -		for _, i := range status.EmojiIDs { -			if _, err := tx. -				NewInsert(). -				Model(>smodel.StatusToEmoji{ -					StatusID: status.ID, -					EmojiID:  i, -				}).Exec(ctx); err != nil { -				err = s.conn.errProc(err) -				if !errors.Is(err, db.ErrAlreadyExists) { -					return err +	return s.cache.Store(status, func() error { +		// It is safe to run this database transaction within cache.Store +		// as the cache does not attempt a mutex lock until AFTER hook. +		// +		return s.conn.RunInTx(ctx, func(tx bun.Tx) error { +			// create links between this status and any emojis it uses +			for _, i := range status.EmojiIDs { +				if _, err := tx. +					NewInsert(). +					Model(>smodel.StatusToEmoji{ +						StatusID: status.ID, +						EmojiID:  i, +					}).Exec(ctx); err != nil { +					err = s.conn.ProcessError(err) +					if !errors.Is(err, db.ErrAlreadyExists) { +						return err +					}  				}  			} -		} -		// create links between this status and any tags it uses -		for _, i := range status.TagIDs { -			if _, err := tx. -				NewInsert(). -				Model(>smodel.StatusToTag{ -					StatusID: status.ID, -					TagID:    i, -				}).Exec(ctx); err != nil { -				err = s.conn.errProc(err) -				if !errors.Is(err, db.ErrAlreadyExists) { -					return err +			// create links between this status and any tags it uses +			for _, i := range status.TagIDs { +				if _, err := tx. +					NewInsert(). +					Model(>smodel.StatusToTag{ +						StatusID: status.ID, +						TagID:    i, +					}).Exec(ctx); err != nil { +					err = s.conn.ProcessError(err) +					if !errors.Is(err, db.ErrAlreadyExists) { +						return err +					}  				}  			} -		} -		// change the status ID of the media attachments to the new status -		for _, a := range status.Attachments { -			a.StatusID = status.ID -			a.UpdatedAt = time.Now() -			if _, err := tx. -				NewUpdate(). -				Model(a). -				Where("? = ?", bun.Ident("media_attachment.id"), a.ID). -				Exec(ctx); err != nil { -				err = s.conn.errProc(err) -				if !errors.Is(err, db.ErrAlreadyExists) { -					return err +			// change the status ID of the media attachments to the new status +			for _, a := range status.Attachments { +				a.StatusID = status.ID +				a.UpdatedAt = time.Now() +				if _, err := tx. +					NewUpdate(). +					Model(a). +					Where("? = ?", bun.Ident("media_attachment.id"), a.ID). +					Exec(ctx); err != nil { +					err = s.conn.ProcessError(err) +					if !errors.Is(err, db.ErrAlreadyExists) { +						return err +					}  				}  			} -		} -		// Finally, insert the status -		if _, err := tx. -			NewInsert(). -			Model(status). -			Exec(ctx); err != nil { +			// Finally, insert the status +			_, err := tx.NewInsert().Model(status).Exec(ctx)  			return err -		} - -		return nil +		})  	}) -	if err != nil { -		return s.conn.ProcessError(err) -	} - -	s.cache.Put(status) -	return nil  } -func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status) (*gtsmodel.Status, db.Error) { -	err := s.conn.RunInTx(ctx, func(tx bun.Tx) error { +func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status) db.Error { +	if err := s.conn.RunInTx(ctx, func(tx bun.Tx) error {  		// create links between this status and any emojis it uses  		for _, i := range status.EmojiIDs {  			if _, err := tx. @@ -212,7 +219,7 @@ func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status) (*  					StatusID: status.ID,  					EmojiID:  i,  				}).Exec(ctx); err != nil { -				err = s.conn.errProc(err) +				err = s.conn.ProcessError(err)  				if !errors.Is(err, db.ErrAlreadyExists) {  					return err  				} @@ -227,14 +234,14 @@ func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status) (*  					StatusID: status.ID,  					TagID:    i,  				}).Exec(ctx); err != nil { -				err = s.conn.errProc(err) +				err = s.conn.ProcessError(err)  				if !errors.Is(err, db.ErrAlreadyExists) {  					return err  				}  			}  		} -		// change the status ID of the media attachments to this status +		// change the status ID of the media attachments to the new status  		for _, a := range status.Attachments {  			a.StatusID = status.ID  			a.UpdatedAt = time.Now() @@ -243,31 +250,31 @@ func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status) (*  				Model(a).  				Where("? = ?", bun.Ident("media_attachment.id"), a.ID).  				Exec(ctx); err != nil { -				return err +				err = s.conn.ProcessError(err) +				if !errors.Is(err, db.ErrAlreadyExists) { +					return err +				}  			}  		} -		// Finally, update the status itself -		if _, err := tx. +		// Finally, insert the status +		_, err := tx.  			NewUpdate().  			Model(status).  			Where("? = ?", bun.Ident("status.id"), status.ID). -			Exec(ctx); err != nil { -			return err -		} - -		return nil -	}) -	if err != nil { -		return nil, s.conn.ProcessError(err) +			Exec(ctx) +		return err +	}); err != nil { +		return err  	} -	s.cache.Put(status) -	return status, nil +	// Drop any old value from cache by this ID +	s.cache.Invalidate("ID", status.ID) +	return nil  }  func (s *statusDB) DeleteStatusByID(ctx context.Context, id string) db.Error { -	err := s.conn.RunInTx(ctx, func(tx bun.Tx) error { +	if err := s.conn.RunInTx(ctx, func(tx bun.Tx) error {  		// delete links between this status and any emojis it uses  		if _, err := tx.  			NewDelete(). @@ -296,36 +303,41 @@ func (s *statusDB) DeleteStatusByID(ctx context.Context, id string) db.Error {  		}  		return nil -	}) -	if err != nil { -		return s.conn.ProcessError(err) +	}); err != nil { +		return err  	} -	s.cache.Invalidate(id) +	// Drop any old value from cache by this ID +	s.cache.Invalidate("ID", id)  	return nil  }  func (s *statusDB) GetStatusParents(ctx context.Context, status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, db.Error) { -	parents := []*gtsmodel.Status{} -	s.statusParent(ctx, status, &parents, onlyDirect) -	return parents, nil -} - -func (s *statusDB) statusParent(ctx context.Context, status *gtsmodel.Status, foundStatuses *[]*gtsmodel.Status, onlyDirect bool) { -	if status.InReplyToID == "" { -		return +	if onlyDirect { +		// Only want the direct parent, no further than first level +		parent, err := s.GetStatusByID(ctx, status.InReplyToID) +		if err != nil { +			return nil, err +		} +		return []*gtsmodel.Status{parent}, nil  	} -	parentStatus, err := s.GetStatusByID(ctx, status.InReplyToID) -	if err == nil { -		*foundStatuses = append(*foundStatuses, parentStatus) -	} +	var parents []*gtsmodel.Status -	if onlyDirect { -		return +	for id := status.InReplyToID; id != ""; { +		parent, err := s.GetStatusByID(ctx, id) +		if err != nil { +			return nil, err +		} + +		// Append parent to slice +		parents = append(parents, parent) + +		// Set the next parent ID +		id = parent.InReplyToID  	} -	s.statusParent(ctx, parentStatus, foundStatuses, false) +	return parents, nil  }  func (s *statusDB) GetStatusChildren(ctx context.Context, status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, db.Error) { @@ -350,7 +362,7 @@ func (s *statusDB) GetStatusChildren(ctx context.Context, status *gtsmodel.Statu  }  func (s *statusDB) statusChildren(ctx context.Context, status *gtsmodel.Status, foundStatuses *list.List, onlyDirect bool, minID string) { -	childIDs := []string{} +	var childIDs []string  	q := s.conn.  		NewSelect(). @@ -471,6 +483,7 @@ func (s *statusDB) GetStatusFaves(ctx context.Context, status *gtsmodel.Status)  	if err := q.Scan(ctx); err != nil {  		return nil, s.conn.ProcessError(err)  	} +  	return faves, nil  } diff --git a/internal/db/bundb/timeline_test.go b/internal/db/bundb/timeline_test.go index 9b6365621..066f55234 100644 --- a/internal/db/bundb/timeline_test.go +++ b/internal/db/bundb/timeline_test.go @@ -35,44 +35,52 @@ type TimelineTestSuite struct {  }  func (suite *TimelineTestSuite) TestGetPublicTimeline() { -	s, err := suite.db.GetPublicTimeline(context.Background(), "", "", "", 20, false) +	ctx := context.Background() + +	s, err := suite.db.GetPublicTimeline(ctx, "", "", "", 20, false)  	suite.NoError(err)  	suite.Len(s, 6)  }  func (suite *TimelineTestSuite) TestGetPublicTimelineWithFutureStatus() { +	ctx := context.Background() +  	futureStatus := getFutureStatus() -	if err := suite.db.Put(context.Background(), futureStatus); err != nil { -		suite.FailNow(err.Error()) -	} +	err := suite.db.PutStatus(ctx, futureStatus) +	suite.NoError(err) -	s, err := suite.db.GetPublicTimeline(context.Background(), "", "", "", 20, false) +	s, err := suite.db.GetPublicTimeline(ctx, "", "", "", 20, false)  	suite.NoError(err) +	suite.NotContains(s, futureStatus)  	suite.Len(s, 6)  }  func (suite *TimelineTestSuite) TestGetHomeTimeline() { +	ctx := context.Background() +  	viewingAccount := suite.testAccounts["local_account_1"] -	s, err := suite.db.GetHomeTimeline(context.Background(), viewingAccount.ID, "", "", "", 20, false) +	s, err := suite.db.GetHomeTimeline(ctx, viewingAccount.ID, "", "", "", 20, false)  	suite.NoError(err)  	suite.Len(s, 16)  }  func (suite *TimelineTestSuite) TestGetHomeTimelineWithFutureStatus() { +	ctx := context.Background() +  	viewingAccount := suite.testAccounts["local_account_1"]  	futureStatus := getFutureStatus() -	if err := suite.db.Put(context.Background(), futureStatus); err != nil { -		suite.FailNow(err.Error()) -	} +	err := suite.db.PutStatus(ctx, futureStatus) +	suite.NoError(err)  	s, err := suite.db.GetHomeTimeline(context.Background(), viewingAccount.ID, "", "", "", 20, false)  	suite.NoError(err) +	suite.NotContains(s, futureStatus)  	suite.Len(s, 16)  } diff --git a/internal/db/bundb/tombstone.go b/internal/db/bundb/tombstone.go index 7ce3327a7..309a39fd3 100644 --- a/internal/db/bundb/tombstone.go +++ b/internal/db/bundb/tombstone.go @@ -43,7 +43,7 @@ func (t *tombstoneDB) init() {  		t2 := new(gtsmodel.Tombstone)  		*t2 = *t1  		return t2 -	}, 1000) +	}, 100)  	// Set cache TTL and start sweep routine  	t.cache.SetTTL(time.Minute*5, false) diff --git a/internal/db/bundb/user.go b/internal/db/bundb/user.go index aa2f4c2c8..d9b281a6f 100644 --- a/internal/db/bundb/user.go +++ b/internal/db/bundb/user.go @@ -22,7 +22,7 @@ import (  	"context"  	"time" -	"github.com/superseriousbusiness/gotosocial/internal/cache" +	"codeberg.org/gruf/go-cache/v3/result"  	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"  	"github.com/uptrace/bun" @@ -30,111 +30,121 @@ import (  type userDB struct {  	conn  *DBConn -	cache *cache.UserCache +	cache *result.Cache[*gtsmodel.User]  } -func (u *userDB) newUserQ(user *gtsmodel.User) *bun.SelectQuery { -	return u.conn. -		NewSelect(). -		Model(user). -		Relation("Account") +func (u *userDB) init() { +	// Initialize user result cache +	u.cache = result.NewSized([]result.Lookup{ +		{Name: "ID"}, +		{Name: "AccountID"}, +		{Name: "Email"}, +		{Name: "ConfirmationToken"}, +	}, func(u1 *gtsmodel.User) *gtsmodel.User { +		u2 := new(gtsmodel.User) +		*u2 = *u1 +		return u2 +	}, 1000) + +	// Set cache TTL and start sweep routine +	u.cache.SetTTL(time.Minute*5, false) +	u.cache.Start(time.Second * 10)  } -func (u *userDB) getUser(ctx context.Context, cacheGet func() (*gtsmodel.User, bool), dbQuery func(*gtsmodel.User) error) (*gtsmodel.User, db.Error) { -	// Attempt to fetch cached user -	user, cached := cacheGet() +func (u *userDB) GetUserByID(ctx context.Context, id string) (*gtsmodel.User, db.Error) { +	return u.cache.Load("ID", func() (*gtsmodel.User, error) { +		var user gtsmodel.User -	if !cached { -		user = >smodel.User{} +		q := u.conn. +			NewSelect(). +			Model(&user). +			Relation("Account"). +			Where("? = ?", bun.Ident("user.id"), id) -		// Not cached! Perform database query -		err := dbQuery(user) -		if err != nil { +		if err := q.Scan(ctx); err != nil {  			return nil, u.conn.ProcessError(err)  		} -		// Place in the cache -		u.cache.Put(user) -	} - -	return user, nil -} - -func (u *userDB) GetUserByID(ctx context.Context, id string) (*gtsmodel.User, db.Error) { -	return u.getUser( -		ctx, -		func() (*gtsmodel.User, bool) { -			return u.cache.GetByID(id) -		}, -		func(user *gtsmodel.User) error { -			return u.newUserQ(user).Where("? = ?", bun.Ident("user.id"), id).Scan(ctx) -		}, -	) +		return &user, nil +	}, id)  }  func (u *userDB) GetUserByAccountID(ctx context.Context, accountID string) (*gtsmodel.User, db.Error) { -	return u.getUser( -		ctx, -		func() (*gtsmodel.User, bool) { -			return u.cache.GetByAccountID(accountID) -		}, -		func(user *gtsmodel.User) error { -			return u.newUserQ(user).Where("? = ?", bun.Ident("user.account_id"), accountID).Scan(ctx) -		}, -	) +	return u.cache.Load("AccountID", func() (*gtsmodel.User, error) { +		var user gtsmodel.User + +		q := u.conn. +			NewSelect(). +			Model(&user). +			Relation("Account"). +			Where("? = ?", bun.Ident("user.account_id"), accountID) + +		if err := q.Scan(ctx); err != nil { +			return nil, u.conn.ProcessError(err) +		} + +		return &user, nil +	}, accountID)  }  func (u *userDB) GetUserByEmailAddress(ctx context.Context, emailAddress string) (*gtsmodel.User, db.Error) { -	return u.getUser( -		ctx, -		func() (*gtsmodel.User, bool) { -			return u.cache.GetByEmail(emailAddress) -		}, -		func(user *gtsmodel.User) error { -			return u.newUserQ(user).Where("? = ?", bun.Ident("user.email"), emailAddress).Scan(ctx) -		}, -	) +	return u.cache.Load("Email", func() (*gtsmodel.User, error) { +		var user gtsmodel.User + +		q := u.conn. +			NewSelect(). +			Model(&user). +			Relation("Account"). +			Where("? = ?", bun.Ident("user.email"), emailAddress) + +		if err := q.Scan(ctx); err != nil { +			return nil, u.conn.ProcessError(err) +		} + +		return &user, nil +	}, emailAddress)  }  func (u *userDB) GetUserByConfirmationToken(ctx context.Context, confirmationToken string) (*gtsmodel.User, db.Error) { -	return u.getUser( -		ctx, -		func() (*gtsmodel.User, bool) { -			return u.cache.GetByConfirmationToken(confirmationToken) -		}, -		func(user *gtsmodel.User) error { -			return u.newUserQ(user).Where("? = ?", bun.Ident("user.confirmation_token"), confirmationToken).Scan(ctx) -		}, -	) -} +	return u.cache.Load("ConfirmationToken", func() (*gtsmodel.User, error) { +		var user gtsmodel.User -func (u *userDB) PutUser(ctx context.Context, user *gtsmodel.User) (*gtsmodel.User, db.Error) { -	if _, err := u.conn. -		NewInsert(). -		Model(user). -		Exec(ctx); err != nil { -		return nil, u.conn.ProcessError(err) -	} +		q := u.conn. +			NewSelect(). +			Model(&user). +			Relation("Account"). +			Where("? = ?", bun.Ident("user.confirmation_token"), confirmationToken) -	u.cache.Put(user) -	return user, nil +		if err := q.Scan(ctx); err != nil { +			return nil, u.conn.ProcessError(err) +		} + +		return &user, nil +	}, confirmationToken)  } -func (u *userDB) UpdateUser(ctx context.Context, user *gtsmodel.User, columns ...string) (*gtsmodel.User, db.Error) { +func (u *userDB) PutUser(ctx context.Context, user *gtsmodel.User) db.Error { +	return u.cache.Store(user, func() error { +		_, err := u.conn. +			NewInsert(). +			Model(user). +			Exec(ctx) +		return u.conn.ProcessError(err) +	}) +} + +func (u *userDB) UpdateUser(ctx context.Context, user *gtsmodel.User) db.Error {  	// Update the user's last-updated  	user.UpdatedAt = time.Now() -	if _, err := u.conn. -		NewUpdate(). -		Model(user). -		Where("? = ?", bun.Ident("user.id"), user.ID). -		Column(columns...). -		Exec(ctx); err != nil { -		return nil, u.conn.ProcessError(err) -	} - -	u.cache.Invalidate(user.ID) -	return user, nil +	return u.cache.Store(user, func() error { +		_, err := u.conn. +			NewUpdate(). +			Model(user). +			Where("? = ?", bun.Ident("user.id"), user.ID). +			Exec(ctx) +		return u.conn.ProcessError(err) +	})  }  func (u *userDB) DeleteUserByID(ctx context.Context, userID string) db.Error { @@ -146,6 +156,7 @@ func (u *userDB) DeleteUserByID(ctx context.Context, userID string) db.Error {  		return u.conn.ProcessError(err)  	} -	u.cache.Invalidate(userID) +	// Invalidate user from cache +	u.cache.Invalidate("ID", userID)  	return nil  } diff --git a/internal/db/bundb/user_test.go b/internal/db/bundb/user_test.go index 6ad59fc8e..18f67dde5 100644 --- a/internal/db/bundb/user_test.go +++ b/internal/db/bundb/user_test.go @@ -50,21 +50,20 @@ func (suite *UserTestSuite) TestGetUserByAccountID() {  func (suite *UserTestSuite) TestUpdateUserSelectedColumns() {  	testUser := suite.testUsers["local_account_1"] -	user := >smodel.User{ -		ID:     testUser.ID, -		Email:  "whatever", -		Locale: "es", -	} -	user, err := suite.db.UpdateUser(context.Background(), user, "email", "locale") +	updateUser := new(gtsmodel.User) +	*updateUser = *testUser +	updateUser.Email = "whatever" +	updateUser.Locale = "es" + +	err := suite.db.UpdateUser(context.Background(), updateUser)  	suite.NoError(err) -	suite.NotNil(user)  	dbUser, err := suite.db.GetUserByID(context.Background(), testUser.ID)  	suite.NoError(err)  	suite.NotNil(dbUser) -	suite.Equal("whatever", dbUser.Email) -	suite.Equal("es", dbUser.Locale) +	suite.Equal(updateUser.Email, dbUser.Email) +	suite.Equal(updateUser.Locale, dbUser.Locale)  	suite.Equal(testUser.AccountID, dbUser.AccountID)  } | 
