diff options
author | 2022-11-15 18:45:15 +0000 | |
---|---|---|
committer | 2022-11-15 18:45:15 +0000 | |
commit | 8598dea98b872647393117704659878d9b38d4fc (patch) | |
tree | 1940168912dc7f54af723439dbc9f6e0a42f30ae /internal | |
parent | [docs] Both HTTP proxies and NAT can cause rate limiting issues (#1053) (diff) | |
download | gotosocial-8598dea98b872647393117704659878d9b38d4fc.tar.xz |
[chore] update database caching library (#1040)
* convert most of the caches to use result.Cache{}
* add caching of emojis
* fix issues causing failing tests
* update go-cache/v2 instances with v3
* fix getnotification
* add a note about the left-in StatusCreate comment
* update EmojiCategory db access to use new result.Cache{}
* fix possible panic in getstatusparents
* further proof that kim is not stinky
Diffstat (limited to 'internal')
41 files changed, 623 insertions, 1607 deletions
diff --git a/internal/api/client/auth/authorize_test.go b/internal/api/client/auth/authorize_test.go index fcc4b8caa..e3e4ce9ee 100644 --- a/internal/api/client/auth/authorize_test.go +++ b/internal/api/client/auth/authorize_test.go @@ -20,7 +20,7 @@ type AuthAuthorizeTestSuite struct { type authorizeHandlerTestCase struct { description string - mutateUserAccount func(*gtsmodel.User, *gtsmodel.Account) []string + mutateUserAccount func(*gtsmodel.User, *gtsmodel.Account) expectedStatusCode int expectedLocationHeader string } @@ -29,44 +29,40 @@ func (suite *AuthAuthorizeTestSuite) TestAccountAuthorizeHandler() { tests := []authorizeHandlerTestCase{ { description: "user has their email unconfirmed", - mutateUserAccount: func(user *gtsmodel.User, account *gtsmodel.Account) []string { + mutateUserAccount: func(user *gtsmodel.User, account *gtsmodel.Account) { // nothing to do, weed_lord420 already has their email unconfirmed - return nil }, expectedStatusCode: http.StatusSeeOther, expectedLocationHeader: auth.CheckYourEmailPath, }, { description: "user has their email confirmed but is not approved", - mutateUserAccount: func(user *gtsmodel.User, account *gtsmodel.Account) []string { + mutateUserAccount: func(user *gtsmodel.User, account *gtsmodel.Account) { user.ConfirmedAt = time.Now() user.Email = user.UnconfirmedEmail - return []string{"confirmed_at", "email"} }, expectedStatusCode: http.StatusSeeOther, expectedLocationHeader: auth.WaitForApprovalPath, }, { description: "user has their email confirmed and is approved, but User entity has been disabled", - mutateUserAccount: func(user *gtsmodel.User, account *gtsmodel.Account) []string { + mutateUserAccount: func(user *gtsmodel.User, account *gtsmodel.Account) { user.ConfirmedAt = time.Now() user.Email = user.UnconfirmedEmail user.Approved = testrig.TrueBool() user.Disabled = testrig.TrueBool() - return []string{"confirmed_at", "email", "approved", "disabled"} }, expectedStatusCode: http.StatusSeeOther, expectedLocationHeader: auth.AccountDisabledPath, }, { description: "user has their email confirmed and is approved, but Account entity has been suspended", - mutateUserAccount: func(user *gtsmodel.User, account *gtsmodel.Account) []string { + mutateUserAccount: func(user *gtsmodel.User, account *gtsmodel.Account) { user.ConfirmedAt = time.Now() user.Email = user.UnconfirmedEmail user.Approved = testrig.TrueBool() user.Disabled = testrig.FalseBool() account.SuspendedAt = time.Now() - return []string{"confirmed_at", "email", "approved", "disabled"} }, expectedStatusCode: http.StatusSeeOther, expectedLocationHeader: auth.AccountDisabledPath, @@ -81,6 +77,7 @@ func (suite *AuthAuthorizeTestSuite) TestAccountAuthorizeHandler() { *user = *suite.testUsers["unconfirmed_account"] *account = *suite.testAccounts["unconfirmed_account"] + user.SignInCount++ // cannot be 0 or fails NULL constraint testSession := sessions.Default(ctx) testSession.Set(sessionUserID, user.ID) @@ -89,14 +86,13 @@ func (suite *AuthAuthorizeTestSuite) TestAccountAuthorizeHandler() { panic(fmt.Errorf("failed on case %s: %w", testCase.description, err)) } - updatingColumns := testCase.mutateUserAccount(user, account) + testCase.mutateUserAccount(user, account) testCase.description = fmt.Sprintf("%s, %t, %s", user.Email, *user.Disabled, account.SuspendedAt) - updatingColumns = append(updatingColumns, "updated_at") - _, err := suite.db.UpdateUser(context.Background(), user, updatingColumns...) + err := suite.db.UpdateUser(context.Background(), user) suite.NoError(err) - _, err = suite.db.UpdateAccount(context.Background(), account) + err = suite.db.UpdateAccount(context.Background(), account) suite.NoError(err) // call the handler diff --git a/internal/api/client/status/statuscreate.go b/internal/api/client/status/statuscreate.go index 3b2ee1e05..c1427411d 100644 --- a/internal/api/client/status/statuscreate.go +++ b/internal/api/client/status/statuscreate.go @@ -90,6 +90,15 @@ func (m *Module) StatusCreatePOSTHandler(c *gin.Context) { return } + // DO NOT COMMIT THIS UNCOMMENTED, IT WILL CAUSE MASS CHAOS. + // this is being left in as an ode to kim's shitposting. + // + // user := authed.Account.DisplayName + // if user == "" { + // user = authed.Account.Username + // } + // form.Status += "\n\nsent from " + user + "'s iphone\n" + if err := validateCreateStatus(form); err != nil { api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) return diff --git a/internal/api/client/status/statuscreate_test.go b/internal/api/client/status/statuscreate_test.go index 78d025be1..9b570ba18 100644 --- a/internal/api/client/status/statuscreate_test.go +++ b/internal/api/client/status/statuscreate_test.go @@ -106,8 +106,9 @@ func (suite *StatusCreateTestSuite) TestPostNewStatusMarkdown() { // set default post language of account 1 to markdown testAccount := suite.testAccounts["local_account_1"] testAccount.StatusFormat = "markdown" + a := testAccount - a, err := suite.db.UpdateAccount(context.Background(), testAccount) + err := suite.db.UpdateAccount(context.Background(), a) if err != nil { suite.FailNow(err.Error()) } @@ -149,9 +150,8 @@ func (suite *StatusCreateTestSuite) TestPostNewStatusMarkdown() { func (suite *StatusCreateTestSuite) TestMentionUnknownAccount() { // first remove remote account 1 from the database so it gets looked up again remoteAccount := suite.testAccounts["remote_account_1"] - if err := suite.db.DeleteByID(context.Background(), remoteAccount.ID, >smodel.Account{}); err != nil { - panic(err) - } + err := suite.db.DeleteAccount(context.Background(), remoteAccount.ID) + suite.NoError(err) t := suite.testTokens["local_account_1"] oauthToken := oauth.DBTokenToToken(t) diff --git a/internal/cache/account.go b/internal/cache/account.go deleted file mode 100644 index c25db42ce..000000000 --- a/internal/cache/account.go +++ /dev/null @@ -1,171 +0,0 @@ -/* - GoToSocial - Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org - - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU Affero General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU Affero General Public License for more details. - - You should have received a copy of the GNU Affero General Public License - along with this program. If not, see <http://www.gnu.org/licenses/>. -*/ - -package cache - -import ( - "time" - - "codeberg.org/gruf/go-cache/v2" - "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" -) - -// AccountCache is a cache wrapper to provide URL and URI lookups for gtsmodel.Account -type AccountCache struct { - cache cache.LookupCache[string, string, *gtsmodel.Account] -} - -// NewAccountCache returns a new instantiated AccountCache object -func NewAccountCache() *AccountCache { - c := &AccountCache{} - c.cache = cache.NewLookup(cache.LookupCfg[string, string, *gtsmodel.Account]{ - RegisterLookups: func(lm *cache.LookupMap[string, string]) { - lm.RegisterLookup("uri") - lm.RegisterLookup("url") - lm.RegisterLookup("pubkeyid") - lm.RegisterLookup("usernamedomain") - }, - - AddLookups: func(lm *cache.LookupMap[string, string], acc *gtsmodel.Account) { - if uri := acc.URI; uri != "" { - lm.Set("uri", uri, acc.ID) - } - if url := acc.URL; url != "" { - lm.Set("url", url, acc.ID) - } - lm.Set("pubkeyid", acc.PublicKeyURI, acc.ID) - lm.Set("usernamedomain", usernameDomainKey(acc.Username, acc.Domain), acc.ID) - }, - - DeleteLookups: func(lm *cache.LookupMap[string, string], acc *gtsmodel.Account) { - if uri := acc.URI; uri != "" { - lm.Delete("uri", uri) - } - if url := acc.URL; url != "" { - lm.Delete("url", url) - } - lm.Delete("pubkeyid", acc.PublicKeyURI) - lm.Delete("usernamedomain", usernameDomainKey(acc.Username, acc.Domain)) - }, - }) - c.cache.SetTTL(time.Minute*5, false) - c.cache.Start(time.Second * 10) - return c -} - -// GetByID attempts to fetch a account from the cache by its ID, you will receive a copy for thread-safety -func (c *AccountCache) GetByID(id string) (*gtsmodel.Account, bool) { - return c.cache.Get(id) -} - -// GetByURL attempts to fetch a account from the cache by its URL, you will receive a copy for thread-safety -func (c *AccountCache) GetByURL(url string) (*gtsmodel.Account, bool) { - return c.cache.GetBy("url", url) -} - -// GetByURI attempts to fetch a account from the cache by its URI, you will receive a copy for thread-safety -func (c *AccountCache) GetByURI(uri string) (*gtsmodel.Account, bool) { - return c.cache.GetBy("uri", uri) -} - -// GettByUsernameDomain attempts to fetch an account from the cache by its username@domain combo (or just username), you will receive a copy for thread-safety. -func (c *AccountCache) GetByUsernameDomain(username string, domain string) (*gtsmodel.Account, bool) { - return c.cache.GetBy("usernamedomain", usernameDomainKey(username, domain)) -} - -// GetByPubkeyID attempts to fetch an account from the cache by its public key URI (ID), you will receive a copy for thread-safety. -func (c *AccountCache) GetByPubkeyID(id string) (*gtsmodel.Account, bool) { - return c.cache.GetBy("pubkeyid", id) -} - -// Put places a account in the cache, ensuring that the object place is a copy for thread-safety -func (c *AccountCache) Put(account *gtsmodel.Account) { - if account == nil || account.ID == "" { - panic("invalid account") - } - c.cache.Set(account.ID, copyAccount(account)) -} - -// Invalidate removes (invalidates) one account from the cache by its ID. -func (c *AccountCache) Invalidate(id string) { - c.cache.Invalidate(id) -} - -// copyAccount performs a surface-level copy of account, only keeping attached IDs intact, not the objects. -// due to all the data being copied being 99% primitive types or strings (which are immutable and passed by ptr) -// this should be a relatively cheap process -func copyAccount(account *gtsmodel.Account) *gtsmodel.Account { - return >smodel.Account{ - ID: account.ID, - Username: account.Username, - Domain: account.Domain, - AvatarMediaAttachmentID: account.AvatarMediaAttachmentID, - AvatarMediaAttachment: nil, - AvatarRemoteURL: account.AvatarRemoteURL, - HeaderMediaAttachmentID: account.HeaderMediaAttachmentID, - HeaderMediaAttachment: nil, - HeaderRemoteURL: account.HeaderRemoteURL, - DisplayName: account.DisplayName, - EmojiIDs: account.EmojiIDs, - Emojis: nil, - Fields: account.Fields, - Note: account.Note, - NoteRaw: account.NoteRaw, - Memorial: copyBoolPtr(account.Memorial), - MovedToAccountID: account.MovedToAccountID, - Bot: copyBoolPtr(account.Bot), - CreatedAt: account.CreatedAt, - UpdatedAt: account.UpdatedAt, - Reason: account.Reason, - Locked: copyBoolPtr(account.Locked), - Discoverable: copyBoolPtr(account.Discoverable), - Privacy: account.Privacy, - Sensitive: copyBoolPtr(account.Sensitive), - Language: account.Language, - StatusFormat: account.StatusFormat, - CustomCSS: account.CustomCSS, - URI: account.URI, - URL: account.URL, - LastWebfingeredAt: account.LastWebfingeredAt, - InboxURI: account.InboxURI, - SharedInboxURI: account.SharedInboxURI, - OutboxURI: account.OutboxURI, - FollowingURI: account.FollowingURI, - FollowersURI: account.FollowersURI, - FeaturedCollectionURI: account.FeaturedCollectionURI, - ActorType: account.ActorType, - AlsoKnownAs: account.AlsoKnownAs, - PrivateKey: account.PrivateKey, - PublicKey: account.PublicKey, - PublicKeyURI: account.PublicKeyURI, - SensitizedAt: account.SensitizedAt, - SilencedAt: account.SilencedAt, - SuspendedAt: account.SuspendedAt, - HideCollections: copyBoolPtr(account.HideCollections), - SuspensionOrigin: account.SuspensionOrigin, - EnableRSS: copyBoolPtr(account.EnableRSS), - } -} - -func usernameDomainKey(username string, domain string) string { - u := "@" + username - if domain != "" { - return u + "@" + domain - } - return u -} diff --git a/internal/cache/account_test.go b/internal/cache/account_test.go deleted file mode 100644 index d373e5f1d..000000000 --- a/internal/cache/account_test.go +++ /dev/null @@ -1,96 +0,0 @@ -/* - GoToSocial - Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org - - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU Affero General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU Affero General Public License for more details. - - You should have received a copy of the GNU Affero General Public License - along with this program. If not, see <http://www.gnu.org/licenses/>. -*/ - -package cache_test - -import ( - "fmt" - "testing" - - "github.com/stretchr/testify/suite" - "github.com/superseriousbusiness/gotosocial/internal/cache" - "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" - "github.com/superseriousbusiness/gotosocial/testrig" -) - -type AccountCacheTestSuite struct { - suite.Suite - data map[string]*gtsmodel.Account - cache *cache.AccountCache -} - -func (suite *AccountCacheTestSuite) SetupSuite() { - suite.data = testrig.NewTestAccounts() -} - -func (suite *AccountCacheTestSuite) SetupTest() { - suite.cache = cache.NewAccountCache() -} - -func (suite *AccountCacheTestSuite) TearDownTest() { - suite.data = nil - suite.cache = nil -} - -func (suite *AccountCacheTestSuite) TestAccountCache() { - for _, account := range suite.data { - // Place in the cache - suite.cache.Put(account) - } - - for _, account := range suite.data { - var ok bool - var check *gtsmodel.Account - - // Check we can retrieve - check, ok = suite.cache.GetByID(account.ID) - if !ok && !accountIs(account, check) { - suite.Fail(fmt.Sprintf("Failed to fetch expected account with ID: %s", account.ID)) - } - check, ok = suite.cache.GetByURI(account.URI) - if account.URI != "" && !ok && !accountIs(account, check) { - suite.Fail(fmt.Sprintf("Failed to fetch expected account with URI: %s", account.URI)) - } - check, ok = suite.cache.GetByURL(account.URL) - if account.URL != "" && !ok && !accountIs(account, check) { - suite.Fail(fmt.Sprintf("Failed to fetch expected account with URL: %s", account.URL)) - } - check, ok = suite.cache.GetByPubkeyID(account.PublicKeyURI) - if account.PublicKeyURI != "" && !ok && !accountIs(account, check) { - suite.Fail(fmt.Sprintf("Failed to fetch expected account with public key URI: %s", account.PublicKeyURI)) - } - check, ok = suite.cache.GetByUsernameDomain(account.Username, account.Domain) - if !ok && !accountIs(account, check) { - suite.Fail(fmt.Sprintf("Failed to fetch expected account with username/domain: %s/%s", account.Username, account.Domain)) - } - } -} - -func TestAccountCache(t *testing.T) { - suite.Run(t, &AccountCacheTestSuite{}) -} - -func accountIs(account1, account2 *gtsmodel.Account) bool { - if account1 == nil || account2 == nil { - return account1 == account2 - } - return account1.ID == account2.ID && - account1.URI == account2.URI && - account1.URL == account2.URL && - account1.PublicKeyURI == account2.PublicKeyURI -} diff --git a/internal/cache/domain.go b/internal/cache/domain.go deleted file mode 100644 index 7b5a93d39..000000000 --- a/internal/cache/domain.go +++ /dev/null @@ -1,106 +0,0 @@ -/* - GoToSocial - Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org - - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU Affero General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU Affero General Public License for more details. - - You should have received a copy of the GNU Affero General Public License - along with this program. If not, see <http://www.gnu.org/licenses/>. -*/ - -package cache - -import ( - "time" - - "codeberg.org/gruf/go-cache/v2" - "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" -) - -// DomainCache is a cache wrapper to provide URL and URI lookups for gtsmodel.Status -type DomainBlockCache struct { - cache cache.LookupCache[string, string, *gtsmodel.DomainBlock] -} - -// NewStatusCache returns a new instantiated statusCache object -func NewDomainBlockCache() *DomainBlockCache { - c := &DomainBlockCache{} - c.cache = cache.NewLookup(cache.LookupCfg[string, string, *gtsmodel.DomainBlock]{ - RegisterLookups: func(lm *cache.LookupMap[string, string]) { - lm.RegisterLookup("id") - }, - - AddLookups: func(lm *cache.LookupMap[string, string], block *gtsmodel.DomainBlock) { - // Block can be equal to nil when sentinel - if block != nil && block.ID != "" { - lm.Set("id", block.ID, block.Domain) - } - }, - - DeleteLookups: func(lm *cache.LookupMap[string, string], block *gtsmodel.DomainBlock) { - // Block can be equal to nil when sentinel - if block != nil && block.ID != "" { - lm.Delete("id", block.ID) - } - }, - }) - c.cache.SetTTL(time.Minute*5, false) - c.cache.Start(time.Second * 10) - return c -} - -// GetByID attempts to fetch a status from the cache by its ID, you will receive a copy for thread-safety -func (c *DomainBlockCache) GetByID(id string) (*gtsmodel.DomainBlock, bool) { - return c.cache.GetBy("id", id) -} - -// GetByURL attempts to fetch a status from the cache by its URL, you will receive a copy for thread-safety -func (c *DomainBlockCache) GetByDomain(domain string) (*gtsmodel.DomainBlock, bool) { - return c.cache.Get(domain) -} - -// Put places a status in the cache, ensuring that the object place is a copy for thread-safety -func (c *DomainBlockCache) Put(domain string, block *gtsmodel.DomainBlock) { - if domain == "" { - panic("invalid domain") - } - - if block == nil { - // This is a sentinel value for (no block) - c.cache.Set(domain, nil) - } else { - // This is a valid domain block - c.cache.Set(domain, copyDomainBlock(block)) - } -} - -// InvalidateByDomain will invalidate a domain block from the cache by domain name. -func (c *DomainBlockCache) InvalidateByDomain(domain string) { - c.cache.Invalidate(domain) -} - -// copyStatus performs a surface-level copy of status, only keeping attached IDs intact, not the objects. -// due to all the data being copied being 99% primitive types or strings (which are immutable and passed by ptr) -// this should be a relatively cheap process -func copyDomainBlock(block *gtsmodel.DomainBlock) *gtsmodel.DomainBlock { - return >smodel.DomainBlock{ - ID: block.ID, - CreatedAt: block.CreatedAt, - UpdatedAt: block.UpdatedAt, - Domain: block.Domain, - CreatedByAccountID: block.CreatedByAccountID, - CreatedByAccount: nil, - PrivateComment: block.PrivateComment, - PublicComment: block.PublicComment, - Obfuscate: block.Obfuscate, - SubscriptionID: block.SubscriptionID, - } -} diff --git a/internal/cache/emoji.go b/internal/cache/emoji.go deleted file mode 100644 index 117f5475e..000000000 --- a/internal/cache/emoji.go +++ /dev/null @@ -1,131 +0,0 @@ -/* - GoToSocial - Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org - - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU Affero General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU Affero General Public License for more details. - - You should have received a copy of the GNU Affero General Public License - along with this program. If not, see <http://www.gnu.org/licenses/>. -*/ - -package cache - -import ( - "time" - - "codeberg.org/gruf/go-cache/v2" - "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" -) - -// EmojiCache is a cache wrapper to provide ID and URI lookups for gtsmodel.Emoji -type EmojiCache struct { - cache cache.LookupCache[string, string, *gtsmodel.Emoji] -} - -// NewEmojiCache returns a new instantiated EmojiCache object -func NewEmojiCache() *EmojiCache { - c := &EmojiCache{} - c.cache = cache.NewLookup(cache.LookupCfg[string, string, *gtsmodel.Emoji]{ - RegisterLookups: func(lm *cache.LookupMap[string, string]) { - lm.RegisterLookup("uri") - lm.RegisterLookup("shortcodedomain") - lm.RegisterLookup("imagestaticurl") - }, - - AddLookups: func(lm *cache.LookupMap[string, string], emoji *gtsmodel.Emoji) { - lm.Set("shortcodedomain", shortcodeDomainKey(emoji.Shortcode, emoji.Domain), emoji.ID) - if uri := emoji.URI; uri != "" { - lm.Set("uri", uri, emoji.ID) - } - if imageStaticURL := emoji.ImageStaticURL; imageStaticURL != "" { - lm.Set("imagestaticurl", imageStaticURL, emoji.ID) - } - }, - - DeleteLookups: func(lm *cache.LookupMap[string, string], emoji *gtsmodel.Emoji) { - lm.Delete("shortcodedomain", shortcodeDomainKey(emoji.Shortcode, emoji.Domain)) - if uri := emoji.URI; uri != "" { - lm.Delete("uri", uri) - } - if imageStaticURL := emoji.ImageStaticURL; imageStaticURL != "" { - lm.Delete("imagestaticurl", imageStaticURL) - } - }, - }) - c.cache.SetTTL(time.Minute*5, false) - c.cache.Start(time.Second * 10) - return c -} - -// GetByID attempts to fetch an emoji from the cache by its ID, you will receive a copy for thread-safety -func (c *EmojiCache) GetByID(id string) (*gtsmodel.Emoji, bool) { - return c.cache.Get(id) -} - -// GetByURI attempts to fetch an emoji from the cache by its URI, you will receive a copy for thread-safety -func (c *EmojiCache) GetByURI(uri string) (*gtsmodel.Emoji, bool) { - return c.cache.GetBy("uri", uri) -} - -func (c *EmojiCache) GetByShortcodeDomain(shortcode string, domain string) (*gtsmodel.Emoji, bool) { - return c.cache.GetBy("shortcodedomain", shortcodeDomainKey(shortcode, domain)) -} - -func (c *EmojiCache) GetByImageStaticURL(imageStaticURL string) (*gtsmodel.Emoji, bool) { - return c.cache.GetBy("imagestaticurl", imageStaticURL) -} - -// Put places an emoji in the cache, ensuring that the object place is a copy for thread-safety -func (c *EmojiCache) Put(emoji *gtsmodel.Emoji) { - if emoji == nil || emoji.ID == "" { - panic("invalid emoji") - } - c.cache.Set(emoji.ID, copyEmoji(emoji)) -} - -func (c *EmojiCache) Invalidate(emojiID string) { - c.cache.Invalidate(emojiID) -} - -// copyEmoji performs a surface-level copy of emoji, only keeping attached IDs intact, not the objects. -// due to all the data being copied being 99% primitive types or strings (which are immutable and passed by ptr) -// this should be a relatively cheap process -func copyEmoji(emoji *gtsmodel.Emoji) *gtsmodel.Emoji { - return >smodel.Emoji{ - ID: emoji.ID, - CreatedAt: emoji.CreatedAt, - UpdatedAt: emoji.UpdatedAt, - Shortcode: emoji.Shortcode, - Domain: emoji.Domain, - ImageRemoteURL: emoji.ImageRemoteURL, - ImageStaticRemoteURL: emoji.ImageStaticRemoteURL, - ImageURL: emoji.ImageURL, - ImageStaticURL: emoji.ImageStaticURL, - ImagePath: emoji.ImagePath, - ImageStaticPath: emoji.ImageStaticPath, - ImageContentType: emoji.ImageContentType, - ImageStaticContentType: emoji.ImageStaticContentType, - ImageFileSize: emoji.ImageFileSize, - ImageStaticFileSize: emoji.ImageStaticFileSize, - ImageUpdatedAt: emoji.ImageUpdatedAt, - Disabled: copyBoolPtr(emoji.Disabled), - URI: emoji.URI, - VisibleInPicker: copyBoolPtr(emoji.VisibleInPicker), - CategoryID: emoji.CategoryID, - } -} - -func shortcodeDomainKey(shortcode string, domain string) string { - if domain != "" { - return shortcode + "@" + domain - } - return shortcode -} diff --git a/internal/cache/emojicategory.go b/internal/cache/emojicategory.go deleted file mode 100644 index 17df5591a..000000000 --- a/internal/cache/emojicategory.go +++ /dev/null @@ -1,84 +0,0 @@ -/* - GoToSocial - Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org - - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU Affero General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU Affero General Public License for more details. - - You should have received a copy of the GNU Affero General Public License - along with this program. If not, see <http://www.gnu.org/licenses/>. -*/ - -package cache - -import ( - "strings" - "time" - - "codeberg.org/gruf/go-cache/v2" - "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" -) - -// EmojiCategoryCache is a cache wrapper to provide ID lookups for gtsmodel.EmojiCategory -type EmojiCategoryCache struct { - cache cache.LookupCache[string, string, *gtsmodel.EmojiCategory] -} - -// NewEmojiCategoryCache returns a new instantiated EmojiCategoryCache object -func NewEmojiCategoryCache() *EmojiCategoryCache { - c := &EmojiCategoryCache{} - c.cache = cache.NewLookup(cache.LookupCfg[string, string, *gtsmodel.EmojiCategory]{ - RegisterLookups: func(lm *cache.LookupMap[string, string]) { - lm.RegisterLookup("name") - }, - - AddLookups: func(lm *cache.LookupMap[string, string], emojiCategory *gtsmodel.EmojiCategory) { - lm.Set(("name"), strings.ToLower(emojiCategory.Name), emojiCategory.ID) - }, - - DeleteLookups: func(lm *cache.LookupMap[string, string], emojiCategory *gtsmodel.EmojiCategory) { - lm.Delete("name", strings.ToLower(emojiCategory.Name)) - }, - }) - c.cache.SetTTL(time.Minute*5, false) - c.cache.Start(time.Second * 10) - return c -} - -// GetByID attempts to fetch an emojiCategory from the cache by its ID, you will receive a copy for thread-safety -func (c *EmojiCategoryCache) GetByID(id string) (*gtsmodel.EmojiCategory, bool) { - return c.cache.Get(id) -} - -// GetByName attempts to fetch an emojiCategory from the cache by its name, you will receive a copy for thread-safety -func (c *EmojiCategoryCache) GetByName(name string) (*gtsmodel.EmojiCategory, bool) { - return c.cache.GetBy("name", strings.ToLower(name)) -} - -// Put places an emojiCategory in the cache, ensuring that the object place is a copy for thread-safety -func (c *EmojiCategoryCache) Put(emoji *gtsmodel.EmojiCategory) { - if emoji == nil || emoji.ID == "" { - panic("invalid emoji") - } - c.cache.Set(emoji.ID, copyEmojiCategory(emoji)) -} - -func (c *EmojiCategoryCache) Invalidate(emojiID string) { - c.cache.Invalidate(emojiID) -} - -func copyEmojiCategory(emojiCategory *gtsmodel.EmojiCategory) *gtsmodel.EmojiCategory { - return >smodel.EmojiCategory{ - ID: emojiCategory.ID, - CreatedAt: emojiCategory.CreatedAt, - UpdatedAt: emojiCategory.UpdatedAt, - Name: emojiCategory.Name, - } -} diff --git a/internal/cache/status.go b/internal/cache/status.go deleted file mode 100644 index 898b50846..000000000 --- a/internal/cache/status.go +++ /dev/null @@ -1,138 +0,0 @@ -/* - GoToSocial - Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org - - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU Affero General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU Affero General Public License for more details. - - You should have received a copy of the GNU Affero General Public License - along with this program. If not, see <http://www.gnu.org/licenses/>. -*/ - -package cache - -import ( - "time" - - "codeberg.org/gruf/go-cache/v2" - "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" -) - -// StatusCache is a cache wrapper to provide URL and URI lookups for gtsmodel.Status -type StatusCache struct { - cache cache.LookupCache[string, string, *gtsmodel.Status] -} - -// NewStatusCache returns a new instantiated statusCache object -func NewStatusCache() *StatusCache { - c := &StatusCache{} - c.cache = cache.NewLookup(cache.LookupCfg[string, string, *gtsmodel.Status]{ - RegisterLookups: func(lm *cache.LookupMap[string, string]) { - lm.RegisterLookup("uri") - lm.RegisterLookup("url") - }, - - AddLookups: func(lm *cache.LookupMap[string, string], status *gtsmodel.Status) { - if uri := status.URI; uri != "" { - lm.Set("uri", uri, status.ID) - } - if url := status.URL; url != "" { - lm.Set("url", url, status.ID) - } - }, - - DeleteLookups: func(lm *cache.LookupMap[string, string], status *gtsmodel.Status) { - if uri := status.URI; uri != "" { - lm.Delete("uri", uri) - } - if url := status.URL; url != "" { - lm.Delete("url", url) - } - }, - }) - c.cache.SetTTL(time.Minute*5, false) - c.cache.Start(time.Second * 10) - return c -} - -// GetByID attempts to fetch a status from the cache by its ID, you will receive a copy for thread-safety -func (c *StatusCache) GetByID(id string) (*gtsmodel.Status, bool) { - return c.cache.Get(id) -} - -// GetByURL attempts to fetch a status from the cache by its URL, you will receive a copy for thread-safety -func (c *StatusCache) GetByURL(url string) (*gtsmodel.Status, bool) { - return c.cache.GetBy("url", url) -} - -// GetByURI attempts to fetch a status from the cache by its URI, you will receive a copy for thread-safety -func (c *StatusCache) GetByURI(uri string) (*gtsmodel.Status, bool) { - return c.cache.GetBy("uri", uri) -} - -// Put places a status in the cache, ensuring that the object place is a copy for thread-safety -func (c *StatusCache) Put(status *gtsmodel.Status) { - if status == nil || status.ID == "" { - panic("invalid status") - } - c.cache.Set(status.ID, copyStatus(status)) -} - -// Invalidate invalidates one status from the cache using the ID of the status as key. -func (c *StatusCache) Invalidate(statusID string) { - c.cache.Invalidate(statusID) -} - -// copyStatus performs a surface-level copy of status, only keeping attached IDs intact, not the objects. -// due to all the data being copied being 99% primitive types or strings (which are immutable and passed by ptr) -// this should be a relatively cheap process -func copyStatus(status *gtsmodel.Status) *gtsmodel.Status { - return >smodel.Status{ - ID: status.ID, - URI: status.URI, - URL: status.URL, - Content: status.Content, - AttachmentIDs: status.AttachmentIDs, - Attachments: nil, - TagIDs: status.TagIDs, - Tags: nil, - MentionIDs: status.MentionIDs, - Mentions: nil, - EmojiIDs: status.EmojiIDs, - Emojis: nil, - Local: copyBoolPtr(status.Local), - CreatedAt: status.CreatedAt, - UpdatedAt: status.UpdatedAt, - AccountID: status.AccountID, - Account: nil, - AccountURI: status.AccountURI, - InReplyToID: status.InReplyToID, - InReplyTo: nil, - InReplyToURI: status.InReplyToURI, - InReplyToAccountID: status.InReplyToAccountID, - InReplyToAccount: nil, - BoostOfID: status.BoostOfID, - BoostOf: nil, - BoostOfAccountID: status.BoostOfAccountID, - BoostOfAccount: nil, - ContentWarning: status.ContentWarning, - Visibility: status.Visibility, - Sensitive: copyBoolPtr(status.Sensitive), - Language: status.Language, - CreatedWithApplicationID: status.CreatedWithApplicationID, - ActivityStreamsType: status.ActivityStreamsType, - Text: status.Text, - Pinned: copyBoolPtr(status.Pinned), - Federated: copyBoolPtr(status.Federated), - Boostable: copyBoolPtr(status.Boostable), - Replyable: copyBoolPtr(status.Replyable), - Likeable: copyBoolPtr(status.Likeable), - } -} diff --git a/internal/cache/status_test.go b/internal/cache/status_test.go deleted file mode 100644 index c1c4173fb..000000000 --- a/internal/cache/status_test.go +++ /dev/null @@ -1,113 +0,0 @@ -/* - GoToSocial - Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org - - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU Affero General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU Affero General Public License for more details. - - You should have received a copy of the GNU Affero General Public License - along with this program. If not, see <http://www.gnu.org/licenses/>. -*/ - -package cache_test - -import ( - "fmt" - "testing" - - "github.com/stretchr/testify/suite" - "github.com/superseriousbusiness/gotosocial/internal/cache" - "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" - "github.com/superseriousbusiness/gotosocial/testrig" -) - -type StatusCacheTestSuite struct { - suite.Suite - data map[string]*gtsmodel.Status - cache *cache.StatusCache -} - -func (suite *StatusCacheTestSuite) SetupSuite() { - suite.data = testrig.NewTestStatuses() -} - -func (suite *StatusCacheTestSuite) SetupTest() { - suite.cache = cache.NewStatusCache() -} - -func (suite *StatusCacheTestSuite) TearDownTest() { - suite.data = nil - suite.cache = nil -} - -func (suite *StatusCacheTestSuite) TestStatusCache() { - for _, status := range suite.data { - // Place in the cache - suite.cache.Put(status) - } - - for _, status := range suite.data { - var ok bool - var check *gtsmodel.Status - - // Check we can retrieve - check, ok = suite.cache.GetByID(status.ID) - if !ok && !statusIs(status, check) { - suite.Fail(fmt.Sprintf("Failed to fetch expected account with ID: %s", status.ID)) - } - check, ok = suite.cache.GetByURI(status.URI) - if status.URI != "" && !ok && !statusIs(status, check) { - suite.Fail(fmt.Sprintf("Failed to fetch expected account with URI: %s", status.URI)) - } - check, ok = suite.cache.GetByURL(status.URL) - if status.URL != "" && !ok && !statusIs(status, check) { - suite.Fail(fmt.Sprintf("Failed to fetch expected account with URL: %s", status.URL)) - } - } -} - -func (suite *StatusCacheTestSuite) TestBoolPointerCopying() { - originalStatus := suite.data["local_account_1_status_1"] - - // mark the status as pinned + cache it - pinned := true - originalStatus.Pinned = &pinned - suite.cache.Put(originalStatus) - - // retrieve it - cachedStatus, ok := suite.cache.GetByID(originalStatus.ID) - if !ok { - suite.FailNow("status wasn't retrievable from cache") - } - - // we should be able to change the original status values + cached - // values independently since they use different pointers - suite.True(*cachedStatus.Pinned) - *originalStatus.Pinned = false - suite.False(*originalStatus.Pinned) - suite.True(*cachedStatus.Pinned) - *originalStatus.Pinned = true - *cachedStatus.Pinned = false - suite.True(*originalStatus.Pinned) - suite.False(*cachedStatus.Pinned) -} - -func TestStatusCache(t *testing.T) { - suite.Run(t, &StatusCacheTestSuite{}) -} - -func statusIs(status1, status2 *gtsmodel.Status) bool { - if status1 == nil || status2 == nil { - return status1 == status2 - } - return status1.ID == status2.ID && - status1.URI == status2.URI && - status1.URL == status2.URL -} diff --git a/internal/cache/user.go b/internal/cache/user.go deleted file mode 100644 index 23bf0b7e9..000000000 --- a/internal/cache/user.go +++ /dev/null @@ -1,141 +0,0 @@ -/* - GoToSocial - Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org - - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU Affero General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU Affero General Public License for more details. - - You should have received a copy of the GNU Affero General Public License - along with this program. If not, see <http://www.gnu.org/licenses/>. -*/ - -package cache - -import ( - "time" - - "codeberg.org/gruf/go-cache/v2" - "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" -) - -// UserCache is a cache wrapper to provide lookups for gtsmodel.User -type UserCache struct { - cache cache.LookupCache[string, string, *gtsmodel.User] -} - -// NewUserCache returns a new instantiated UserCache object -func NewUserCache() *UserCache { - c := &UserCache{} - c.cache = cache.NewLookup(cache.LookupCfg[string, string, *gtsmodel.User]{ - RegisterLookups: func(lm *cache.LookupMap[string, string]) { - lm.RegisterLookup("accountid") - lm.RegisterLookup("email") - lm.RegisterLookup("unconfirmedemail") - lm.RegisterLookup("confirmationtoken") - }, - - AddLookups: func(lm *cache.LookupMap[string, string], user *gtsmodel.User) { - lm.Set("accountid", user.AccountID, user.ID) - if email := user.Email; email != "" { - lm.Set("email", email, user.ID) - } - if unconfirmedEmail := user.UnconfirmedEmail; unconfirmedEmail != "" { - lm.Set("unconfirmedemail", unconfirmedEmail, user.ID) - } - if confirmationToken := user.ConfirmationToken; confirmationToken != "" { - lm.Set("confirmationtoken", confirmationToken, user.ID) - } - }, - - DeleteLookups: func(lm *cache.LookupMap[string, string], user *gtsmodel.User) { - lm.Delete("accountid", user.AccountID) - if email := user.Email; email != "" { - lm.Delete("email", email) - } - if unconfirmedEmail := user.UnconfirmedEmail; unconfirmedEmail != "" { - lm.Delete("unconfirmedemail", unconfirmedEmail) - } - if confirmationToken := user.ConfirmationToken; confirmationToken != "" { - lm.Delete("confirmationtoken", confirmationToken) - } - }, - }) - c.cache.SetTTL(time.Minute*5, false) - c.cache.Start(time.Second * 10) - return c -} - -// GetByID attempts to fetch a user from the cache by its ID, you will receive a copy for thread-safety -func (c *UserCache) GetByID(id string) (*gtsmodel.User, bool) { - return c.cache.Get(id) -} - -// GetByAccountID attempts to fetch a user from the cache by its account ID, you will receive a copy for thread-safety -func (c *UserCache) GetByAccountID(accountID string) (*gtsmodel.User, bool) { - return c.cache.GetBy("accountid", accountID) -} - -// GetByEmail attempts to fetch a user from the cache by its email address, you will receive a copy for thread-safety -func (c *UserCache) GetByEmail(email string) (*gtsmodel.User, bool) { - return c.cache.GetBy("email", email) -} - -// GetByUnconfirmedEmail attempts to fetch a user from the cache by its confirmation token, you will receive a copy for thread-safety -func (c *UserCache) GetByConfirmationToken(token string) (*gtsmodel.User, bool) { - return c.cache.GetBy("confirmationtoken", token) -} - -// Put places a user in the cache, ensuring that the object place is a copy for thread-safety -func (c *UserCache) Put(user *gtsmodel.User) { - if user == nil || user.ID == "" { - panic("invalid user") - } - c.cache.Set(user.ID, copyUser(user)) -} - -// Invalidate invalidates one user from the cache using the ID of the user as key. -func (c *UserCache) Invalidate(userID string) { - c.cache.Invalidate(userID) -} - -func copyUser(user *gtsmodel.User) *gtsmodel.User { - return >smodel.User{ - ID: user.ID, - CreatedAt: user.CreatedAt, - UpdatedAt: user.UpdatedAt, - Email: user.Email, - AccountID: user.AccountID, - Account: nil, - EncryptedPassword: user.EncryptedPassword, - SignUpIP: user.SignUpIP, - CurrentSignInAt: user.CurrentSignInAt, - CurrentSignInIP: user.CurrentSignInIP, - LastSignInAt: user.LastSignInAt, - LastSignInIP: user.LastSignInIP, - SignInCount: user.SignInCount, - InviteID: user.InviteID, - ChosenLanguages: user.ChosenLanguages, - FilteredLanguages: user.FilteredLanguages, - Locale: user.Locale, - CreatedByApplicationID: user.CreatedByApplicationID, - CreatedByApplication: nil, - LastEmailedAt: user.LastEmailedAt, - ConfirmationToken: user.ConfirmationToken, - ConfirmationSentAt: user.ConfirmationSentAt, - ConfirmedAt: user.ConfirmedAt, - UnconfirmedEmail: user.UnconfirmedEmail, - Moderator: copyBoolPtr(user.Moderator), - Admin: copyBoolPtr(user.Admin), - Disabled: copyBoolPtr(user.Disabled), - Approved: copyBoolPtr(user.Approved), - ResetPasswordToken: user.ResetPasswordToken, - ResetPasswordSentAt: user.ResetPasswordSentAt, - } -} diff --git a/internal/cache/util.go b/internal/cache/util.go deleted file mode 100644 index 48204b259..000000000 --- a/internal/cache/util.go +++ /dev/null @@ -1,31 +0,0 @@ -/* - GoToSocial - Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org - - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU Affero General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU Affero General Public License for more details. - - You should have received a copy of the GNU Affero General Public License - along with this program. If not, see <http://www.gnu.org/licenses/>. -*/ - -package cache - -// copyBoolPtr returns a bool pointer with the same value as the pointer passed into it. -// -// Useful when copying things from the cache to a caller. -func copyBoolPtr(in *bool) *bool { - if in == nil { - return nil - } - b := new(bool) - *b = *in - return b -} diff --git a/internal/db/account.go b/internal/db/account.go index a58aa9dd3..7e7d1de43 100644 --- a/internal/db/account.go +++ b/internal/db/account.go @@ -43,10 +43,10 @@ type Account interface { GetAccountByPubkeyID(ctx context.Context, id string) (*gtsmodel.Account, Error) // PutAccount puts one account in the database. - PutAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, Error) + PutAccount(ctx context.Context, account *gtsmodel.Account) Error // UpdateAccount updates one account by ID. - UpdateAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, Error) + UpdateAccount(ctx context.Context, account *gtsmodel.Account) Error // DeleteAccount deletes one account from the database by its ID. // DO NOT USE THIS WHEN SUSPENDING ACCOUNTS! In that case you should mark the diff --git a/internal/db/bundb/account.go b/internal/db/bundb/account.go index 4813f4e17..1e9c390d8 100644 --- a/internal/db/bundb/account.go +++ b/internal/db/bundb/account.go @@ -24,7 +24,7 @@ import ( "strings" "time" - "github.com/superseriousbusiness/gotosocial/internal/cache" + "codeberg.org/gruf/go-cache/v3/result" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" @@ -35,10 +35,29 @@ import ( type accountDB struct { conn *DBConn - cache *cache.AccountCache + cache *result.Cache[*gtsmodel.Account] status *statusDB } +func (a *accountDB) init() { + // Initialize account result cache + a.cache = result.NewSized([]result.Lookup{ + {Name: "ID"}, + {Name: "URI"}, + {Name: "URL"}, + {Name: "Username.Domain"}, + {Name: "PublicKeyURI"}, + }, func(a1 *gtsmodel.Account) *gtsmodel.Account { + a2 := new(gtsmodel.Account) + *a2 = *a1 + return a2 + }, 1000) + + // Set cache TTL and start sweep routine + a.cache.SetTTL(time.Minute*5, false) + a.cache.Start(time.Second * 10) +} + func (a *accountDB) newAccountQ(account *gtsmodel.Account) *bun.SelectQuery { return a.conn. NewSelect(). @@ -51,45 +70,41 @@ func (a *accountDB) newAccountQ(account *gtsmodel.Account) *bun.SelectQuery { func (a *accountDB) GetAccountByID(ctx context.Context, id string) (*gtsmodel.Account, db.Error) { return a.getAccount( ctx, - func() (*gtsmodel.Account, bool) { - return a.cache.GetByID(id) - }, + "ID", func(account *gtsmodel.Account) error { return a.newAccountQ(account).Where("? = ?", bun.Ident("account.id"), id).Scan(ctx) }, + id, ) } func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) { return a.getAccount( ctx, - func() (*gtsmodel.Account, bool) { - return a.cache.GetByURI(uri) - }, + "URI", func(account *gtsmodel.Account) error { return a.newAccountQ(account).Where("? = ?", bun.Ident("account.uri"), uri).Scan(ctx) }, + uri, ) } func (a *accountDB) GetAccountByURL(ctx context.Context, url string) (*gtsmodel.Account, db.Error) { return a.getAccount( ctx, - func() (*gtsmodel.Account, bool) { - return a.cache.GetByURL(url) - }, + "URL", func(account *gtsmodel.Account) error { return a.newAccountQ(account).Where("? = ?", bun.Ident("account.url"), url).Scan(ctx) }, + url, ) } func (a *accountDB) GetAccountByUsernameDomain(ctx context.Context, username string, domain string) (*gtsmodel.Account, db.Error) { + username = strings.ToLower(username) return a.getAccount( ctx, - func() (*gtsmodel.Account, bool) { - return a.cache.GetByUsernameDomain(username, domain) - }, + "Username.Domain", func(account *gtsmodel.Account) error { q := a.newAccountQ(account) @@ -97,113 +112,117 @@ func (a *accountDB) GetAccountByUsernameDomain(ctx context.Context, username str q = q.Where("? = ?", bun.Ident("account.username"), username) q = q.Where("? = ?", bun.Ident("account.domain"), domain) } else { - q = q.Where("? = ?", bun.Ident("account.username"), strings.ToLower(username)) + q = q.Where("? = ?", bun.Ident("account.username"), username) q = q.Where("? IS NULL", bun.Ident("account.domain")) } return q.Scan(ctx) }, + username, + domain, ) } func (a *accountDB) GetAccountByPubkeyID(ctx context.Context, id string) (*gtsmodel.Account, db.Error) { return a.getAccount( ctx, - func() (*gtsmodel.Account, bool) { - return a.cache.GetByPubkeyID(id) - }, + "PublicKeyURI", func(account *gtsmodel.Account) error { return a.newAccountQ(account).Where("? = ?", bun.Ident("account.public_key_uri"), id).Scan(ctx) }, + id, ) } -func (a *accountDB) getAccount(ctx context.Context, cacheGet func() (*gtsmodel.Account, bool), dbQuery func(*gtsmodel.Account) error) (*gtsmodel.Account, db.Error) { - // Attempt to fetch cached account - account, cached := cacheGet() +func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, db.Error) { + var username string - if !cached { - account = >smodel.Account{} + if domain == "" { + // I.e. our local instance account + username = config.GetHost() + } else { + // A remote instance account + username = domain + } + + return a.GetAccountByUsernameDomain(ctx, username, domain) +} + +func (a *accountDB) getAccount(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Account) error, keyParts ...any) (*gtsmodel.Account, db.Error) { + return a.cache.Load(lookup, func() (*gtsmodel.Account, error) { + var account gtsmodel.Account // Not cached! Perform database query - err := dbQuery(account) - if err != nil { + if err := dbQuery(&account); err != nil { return nil, a.conn.ProcessError(err) } - // Place in the cache - a.cache.Put(account) - } - - return account, nil + return &account, nil + }, keyParts...) } -func (a *accountDB) PutAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, db.Error) { - if err := a.conn.RunInTx(ctx, func(tx bun.Tx) error { - // create links between this account and any emojis it uses - for _, i := range account.EmojiIDs { - if _, err := tx.NewInsert().Model(>smodel.AccountToEmoji{ - AccountID: account.ID, - EmojiID: i, - }).Exec(ctx); err != nil { - return err +func (a *accountDB) PutAccount(ctx context.Context, account *gtsmodel.Account) db.Error { + return a.cache.Store(account, func() error { + // It is safe to run this database transaction within cache.Store + // as the cache does not attempt a mutex lock until AFTER hook. + // + return a.conn.RunInTx(ctx, func(tx bun.Tx) error { + // create links between this account and any emojis it uses + for _, i := range account.EmojiIDs { + if _, err := tx.NewInsert().Model(>smodel.AccountToEmoji{ + AccountID: account.ID, + EmojiID: i, + }).Exec(ctx); err != nil { + return err + } } - } - // insert the account - _, err := tx.NewInsert().Model(account).Exec(ctx) - return err - }); err != nil { - return nil, a.conn.ProcessError(err) - } - - a.cache.Put(account) - return account, nil + // insert the account + _, err := tx.NewInsert().Model(account).Exec(ctx) + return err + }) + }) } -func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, db.Error) { +func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account) db.Error { // Update the account's last-updated account.UpdatedAt = time.Now() - if err := a.conn.RunInTx(ctx, func(tx bun.Tx) error { - // create links between this account and any emojis it uses - // first clear out any old emoji links - if _, err := tx. - NewDelete(). - TableExpr("? AS ?", bun.Ident("account_to_emojis"), bun.Ident("account_to_emoji")). - Where("? = ?", bun.Ident("account_to_emoji.account_id"), account.ID). - Exec(ctx); err != nil { - return err - } - - // now populate new emoji links - for _, i := range account.EmojiIDs { + return a.cache.Store(account, func() error { + // It is safe to run this database transaction within cache.Store + // as the cache does not attempt a mutex lock until AFTER hook. + // + return a.conn.RunInTx(ctx, func(tx bun.Tx) error { + // create links between this account and any emojis it uses + // first clear out any old emoji links if _, err := tx. - NewInsert(). - Model(>smodel.AccountToEmoji{ - AccountID: account.ID, - EmojiID: i, - }).Exec(ctx); err != nil { + NewDelete(). + TableExpr("? AS ?", bun.Ident("account_to_emojis"), bun.Ident("account_to_emoji")). + Where("? = ?", bun.Ident("account_to_emoji.account_id"), account.ID). + Exec(ctx); err != nil { return err } - } - // update the account - if _, err := tx. - NewUpdate(). - Model(account). - Where("? = ?", bun.Ident("account.id"), account.ID). - Exec(ctx); err != nil { - return err - } - - return nil - }); err != nil { - return nil, a.conn.ProcessError(err) - } + // now populate new emoji links + for _, i := range account.EmojiIDs { + if _, err := tx. + NewInsert(). + Model(>smodel.AccountToEmoji{ + AccountID: account.ID, + EmojiID: i, + }).Exec(ctx); err != nil { + return err + } + } - a.cache.Put(account) - return account, nil + // update the account + _, err := tx.NewUpdate(). + Model(account). + Where("? = ?", bun.Ident("account.id"), account.ID). + Exec(ctx) + return err + }) + }) } func (a *accountDB) DeleteAccount(ctx context.Context, id string) db.Error { @@ -219,40 +238,19 @@ func (a *accountDB) DeleteAccount(ctx context.Context, id string) db.Error { // delete the account _, err := tx. - NewUpdate(). + NewDelete(). TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")). Where("? = ?", bun.Ident("account.id"), id). Exec(ctx) return err }); err != nil { - return a.conn.ProcessError(err) + return err } - a.cache.Invalidate(id) + a.cache.Invalidate("ID", id) return nil } -func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, db.Error) { - account := new(gtsmodel.Account) - - q := a.newAccountQ(account) - - if domain != "" { - q = q. - Where("? = ?", bun.Ident("account.username"), domain). - Where("? = ?", bun.Ident("account.domain"), domain) - } else { - q = q. - Where("? = ?", bun.Ident("account.username"), config.GetHost()). - WhereGroup(" AND ", whereEmptyOrNull("domain")) - } - - if err := q.Scan(ctx); err != nil { - return nil, a.conn.ProcessError(err) - } - return account, nil -} - func (a *accountDB) GetAccountLastPosted(ctx context.Context, accountID string, webOnly bool) (time.Time, db.Error) { createdAt := time.Time{} diff --git a/internal/db/bundb/account_test.go b/internal/db/bundb/account_test.go index 29594a740..50603623f 100644 --- a/internal/db/bundb/account_test.go +++ b/internal/db/bundb/account_test.go @@ -92,7 +92,7 @@ func (suite *AccountTestSuite) TestUpdateAccount() { testAccount.DisplayName = "new display name!" testAccount.EmojiIDs = []string{"01GD36ZKWTKY3T1JJ24JR7KY1Q", "01GD36ZV904SHBHNAYV6DX5QEF"} - _, err := suite.db.UpdateAccount(ctx, testAccount) + err := suite.db.UpdateAccount(ctx, testAccount) suite.NoError(err) updated, err := suite.db.GetAccountByID(ctx, testAccount.ID) @@ -127,7 +127,7 @@ func (suite *AccountTestSuite) TestUpdateAccount() { // update again to remove emoji associations testAccount.EmojiIDs = []string{} - _, err = suite.db.UpdateAccount(ctx, testAccount) + err = suite.db.UpdateAccount(ctx, testAccount) suite.NoError(err) updated, err = suite.db.GetAccountByID(ctx, testAccount.ID) diff --git a/internal/db/bundb/admin.go b/internal/db/bundb/admin.go index 44861a4bb..4d750581c 100644 --- a/internal/db/bundb/admin.go +++ b/internal/db/bundb/admin.go @@ -29,7 +29,6 @@ import ( "time" "github.com/superseriousbusiness/gotosocial/internal/ap" - "github.com/superseriousbusiness/gotosocial/internal/cache" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" @@ -44,9 +43,9 @@ import ( const rsaKeyBits = 2048 type adminDB struct { - conn *DBConn - userCache *cache.UserCache - accountCache *cache.AccountCache + conn *DBConn + accounts *accountDB + users *userDB } func (a *adminDB) IsUsernameAvailable(ctx context.Context, username string) (bool, db.Error) { @@ -140,13 +139,9 @@ func (a *adminDB) NewSignup(ctx context.Context, username string, reason string, } // insert the new account! - if _, err = a.conn. - NewInsert(). - Model(acct). - Exec(ctx); err != nil { - return nil, a.conn.ProcessError(err) + if err := a.accounts.PutAccount(ctx, acct); err != nil { + return nil, err } - a.accountCache.Put(acct) } // we either created or already had an account by now, @@ -190,13 +185,9 @@ func (a *adminDB) NewSignup(ctx context.Context, username string, reason string, } // insert the user! - if _, err = a.conn. - NewInsert(). - Model(u). - Exec(ctx); err != nil { - return nil, a.conn.ProcessError(err) + if err := a.users.PutUser(ctx, u); err != nil { + return nil, err } - a.userCache.Put(u) return u, nil } @@ -249,15 +240,11 @@ func (a *adminDB) CreateInstanceAccount(ctx context.Context) db.Error { FeaturedCollectionURI: newAccountURIs.CollectionURI, } - insertQ := a.conn. - NewInsert(). - Model(acct) - - if _, err := insertQ.Exec(ctx); err != nil { - return a.conn.ProcessError(err) + // insert the new account! + if err := a.accounts.PutAccount(ctx, acct); err != nil { + return err } - a.accountCache.Put(acct) log.Infof("instance account %s CREATED with id %s", username, acct.ID) return nil } diff --git a/internal/db/bundb/admin_test.go b/internal/db/bundb/admin_test.go index f0a869a9b..18e1f67e2 100644 --- a/internal/db/bundb/admin_test.go +++ b/internal/db/bundb/admin_test.go @@ -70,6 +70,8 @@ func (suite *AdminTestSuite) TestIsEmailAvailableDomainBlocked() { } func (suite *AdminTestSuite) TestCreateInstanceAccount() { + // reinitialize test DB to clear caches + suite.db = testrig.NewTestDB() // we need to take an empty db for this... testrig.StandardDBTeardown(suite.db) // ...with tables created but no data diff --git a/internal/db/bundb/bundb.go b/internal/db/bundb/bundb.go index cf6643f6b..de6749ca4 100644 --- a/internal/db/bundb/bundb.go +++ b/internal/db/bundb/bundb.go @@ -34,7 +34,6 @@ import ( "github.com/google/uuid" "github.com/jackc/pgx/v4" "github.com/jackc/pgx/v4/stdlib" - "github.com/superseriousbusiness/gotosocial/internal/cache" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db/bundb/migrations" @@ -46,7 +45,6 @@ import ( "github.com/uptrace/bun/dialect/sqlitedialect" "github.com/uptrace/bun/migrate" - grufcache "codeberg.org/gruf/go-cache/v2" "modernc.org/sqlite" ) @@ -160,79 +158,63 @@ func NewBunDBService(ctx context.Context) (db.DB, error) { return nil, fmt.Errorf("db migration error: %s", err) } - // Prepare caches required by more than one struct - userCache := cache.NewUserCache() - accountCache := cache.NewAccountCache() - - // Prepare other caches - // Prepare mentions cache - // TODO: move into internal/cache - mentionCache := grufcache.New[string, *gtsmodel.Mention]() - mentionCache.SetTTL(time.Minute*5, false) - mentionCache.Start(time.Second * 10) - - // Prepare notifications cache - // TODO: move into internal/cache - notifCache := grufcache.New[string, *gtsmodel.Notification]() - notifCache.SetTTL(time.Minute*5, false) - notifCache.Start(time.Second * 10) - // Create DB structs that require ptrs to each other - accounts := &accountDB{conn: conn, cache: accountCache} - status := &statusDB{conn: conn, cache: cache.NewStatusCache()} - emoji := &emojiDB{conn: conn, emojiCache: cache.NewEmojiCache(), categoryCache: cache.NewEmojiCategoryCache()} + account := &accountDB{conn: conn} + admin := &adminDB{conn: conn} + domain := &domainDB{conn: conn} + mention := &mentionDB{conn: conn} + notif := ¬ificationDB{conn: conn} + status := &statusDB{conn: conn} + emoji := &emojiDB{conn: conn} timeline := &timelineDB{conn: conn} tombstone := &tombstoneDB{conn: conn} + user := &userDB{conn: conn} // Setup DB cross-referencing - accounts.status = status - status.accounts = accounts + account.status = status + admin.users = user + status.accounts = account timeline.status = status // Initialize db structs + account.init() + domain.init() + emoji.init() + mention.init() + notif.init() + status.init() tombstone.init() + user.init() ps := &DBService{ - Account: accounts, + Account: account, Admin: &adminDB{ - conn: conn, - userCache: userCache, - accountCache: accountCache, + conn: conn, + accounts: account, + users: user, }, Basic: &basicDB{ conn: conn, }, - Domain: &domainDB{ - conn: conn, - cache: cache.NewDomainBlockCache(), - }, - Emoji: emoji, + Domain: domain, + Emoji: emoji, Instance: &instanceDB{ conn: conn, }, Media: &mediaDB{ conn: conn, }, - Mention: &mentionDB{ - conn: conn, - cache: mentionCache, - }, - Notification: ¬ificationDB{ - conn: conn, - cache: notifCache, - }, + Mention: mention, + Notification: notif, Relationship: &relationshipDB{ conn: conn, }, Session: &sessionDB{ conn: conn, }, - Status: status, - Timeline: timeline, - User: &userDB{ - conn: conn, - cache: userCache, - }, + Status: status, + Timeline: timeline, + User: user, Tombstone: tombstone, conn: conn, } diff --git a/internal/db/bundb/domain.go b/internal/db/bundb/domain.go index 0a752d3f3..3fca8501b 100644 --- a/internal/db/bundb/domain.go +++ b/internal/db/bundb/domain.go @@ -20,11 +20,11 @@ package bundb import ( "context" - "database/sql" "net/url" "strings" + "time" - "github.com/superseriousbusiness/gotosocial/internal/cache" + "codeberg.org/gruf/go-cache/v3/result" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" @@ -34,7 +34,22 @@ import ( type domainDB struct { conn *DBConn - cache *cache.DomainBlockCache + cache *result.Cache[*gtsmodel.DomainBlock] +} + +func (d *domainDB) init() { + // Initialize domain block result cache + d.cache = result.NewSized([]result.Lookup{ + {Name: "Domain"}, + }, func(d1 *gtsmodel.DomainBlock) *gtsmodel.DomainBlock { + d2 := new(gtsmodel.DomainBlock) + *d2 = *d1 + return d2 + }, 1000) + + // Set cache TTL and start sweep routine + d.cache.SetTTL(time.Minute*5, false) + d.cache.Start(time.Second * 10) } // normalizeDomain converts the given domain to lowercase @@ -49,76 +64,53 @@ func normalizeDomain(domain string) (out string, err error) { } func (d *domainDB) CreateDomainBlock(ctx context.Context, block *gtsmodel.DomainBlock) db.Error { - domain, err := normalizeDomain(block.Domain) + var err error + + block.Domain, err = normalizeDomain(block.Domain) if err != nil { return err } - block.Domain = domain - // Attempt to insert new domain block - if _, err := d.conn.NewInsert(). - Model(block). - Exec(ctx); err != nil { + return d.cache.Store(block, func() error { + _, err := d.conn.NewInsert(). + Model(block). + Exec(ctx) return d.conn.ProcessError(err) - } - - // Cache this domain block - d.cache.Put(block.Domain, block) - - return nil + }) } func (d *domainDB) GetDomainBlock(ctx context.Context, domain string) (*gtsmodel.DomainBlock, db.Error) { var err error + domain, err = normalizeDomain(domain) if err != nil { return nil, err } - // Check for easy case, domain referencing *us* - if domain == "" || domain == config.GetAccountDomain() { - return nil, db.ErrNoEntries - } - - // Check for already cached rblock - if block, ok := d.cache.GetByDomain(domain); ok { - // A 'nil' return value is a sentinel value for no block - if block == nil { + return d.cache.Load("Domain", func() (*gtsmodel.DomainBlock, error) { + // Check for easy case, domain referencing *us* + if domain == "" || domain == config.GetAccountDomain() { return nil, db.ErrNoEntries } - // Else, this block exists - return block, nil - } + var block gtsmodel.DomainBlock - block := >smodel.DomainBlock{} + q := d.conn. + NewSelect(). + Model(&block). + Where("? = ?", bun.Ident("domain_block.domain"), domain). + Limit(1) + if err := q.Scan(ctx); err != nil { + return nil, d.conn.ProcessError(err) + } - q := d.conn. - NewSelect(). - Model(block). - Where("? = ?", bun.Ident("domain_block.domain"), domain). - Limit(1) - - // Query database for domain block - switch err := q.Scan(ctx); err { - // No error, block found - case nil: - d.cache.Put(domain, block) - return block, nil - - // No error, simply not found - case sql.ErrNoRows: - d.cache.Put(domain, nil) - return nil, db.ErrNoEntries - - // Any other db error - default: - return nil, d.conn.ProcessError(err) - } + return &block, nil + }, domain) } func (d *domainDB) DeleteDomainBlock(ctx context.Context, domain string) db.Error { var err error + domain, err = normalizeDomain(domain) if err != nil { return err @@ -133,7 +125,7 @@ func (d *domainDB) DeleteDomainBlock(ctx context.Context, domain string) db.Erro } // Clear domain from cache - d.cache.InvalidateByDomain(domain) + d.cache.Invalidate(domain) return nil } diff --git a/internal/db/bundb/emoji.go b/internal/db/bundb/emoji.go index 81374ce78..55e0ee3ff 100644 --- a/internal/db/bundb/emoji.go +++ b/internal/db/bundb/emoji.go @@ -23,7 +23,7 @@ import ( "strings" "time" - "github.com/superseriousbusiness/gotosocial/internal/cache" + "codeberg.org/gruf/go-cache/v3/result" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/log" @@ -33,8 +33,40 @@ import ( type emojiDB struct { conn *DBConn - emojiCache *cache.EmojiCache - categoryCache *cache.EmojiCategoryCache + emojiCache *result.Cache[*gtsmodel.Emoji] + categoryCache *result.Cache[*gtsmodel.EmojiCategory] +} + +func (e *emojiDB) init() { + // Initialize emoji result cache + e.emojiCache = result.NewSized([]result.Lookup{ + {Name: "ID"}, + {Name: "URI"}, + {Name: "Shortcode.Domain"}, + {Name: "ImageStaticURL"}, + }, func(e1 *gtsmodel.Emoji) *gtsmodel.Emoji { + e2 := new(gtsmodel.Emoji) + *e2 = *e1 + return e2 + }, 1000) + + // Set cache TTL and start sweep routine + e.emojiCache.SetTTL(time.Minute*5, false) + e.emojiCache.Start(time.Second * 10) + + // Initialize category result cache + e.categoryCache = result.NewSized([]result.Lookup{ + {Name: "ID"}, + {Name: "Name"}, + }, func(c1 *gtsmodel.EmojiCategory) *gtsmodel.EmojiCategory { + c2 := new(gtsmodel.EmojiCategory) + *c2 = *c1 + return c2 + }, 1000) + + // Set cache TTL and start sweep routine + e.categoryCache.SetTTL(time.Minute*5, false) + e.categoryCache.Start(time.Second * 10) } func (e *emojiDB) newEmojiQ(emoji *gtsmodel.Emoji) *bun.SelectQuery { @@ -51,12 +83,10 @@ func (e *emojiDB) newEmojiCategoryQ(emojiCategory *gtsmodel.EmojiCategory) *bun. } func (e *emojiDB) PutEmoji(ctx context.Context, emoji *gtsmodel.Emoji) db.Error { - if _, err := e.conn.NewInsert().Model(emoji).Exec(ctx); err != nil { + return e.emojiCache.Store(emoji, func() error { + _, err := e.conn.NewInsert().Model(emoji).Exec(ctx) return e.conn.ProcessError(err) - } - - e.emojiCache.Put(emoji) - return nil + }) } func (e *emojiDB) UpdateEmoji(ctx context.Context, emoji *gtsmodel.Emoji, columns ...string) (*gtsmodel.Emoji, db.Error) { @@ -72,7 +102,7 @@ func (e *emojiDB) UpdateEmoji(ctx context.Context, emoji *gtsmodel.Emoji, column return nil, e.conn.ProcessError(err) } - e.emojiCache.Invalidate(emoji.ID) + e.emojiCache.Invalidate("ID", emoji.ID) return emoji, nil } @@ -109,7 +139,7 @@ func (e *emojiDB) DeleteEmojiByID(ctx context.Context, id string) db.Error { return err } - e.emojiCache.Invalidate(id) + e.emojiCache.Invalidate("ID", id) return nil } @@ -252,33 +282,29 @@ func (e *emojiDB) GetUseableEmojis(ctx context.Context) ([]*gtsmodel.Emoji, db.E func (e *emojiDB) GetEmojiByID(ctx context.Context, id string) (*gtsmodel.Emoji, db.Error) { return e.getEmoji( ctx, - func() (*gtsmodel.Emoji, bool) { - return e.emojiCache.GetByID(id) - }, + "ID", func(emoji *gtsmodel.Emoji) error { return e.newEmojiQ(emoji).Where("? = ?", bun.Ident("emoji.id"), id).Scan(ctx) }, + id, ) } func (e *emojiDB) GetEmojiByURI(ctx context.Context, uri string) (*gtsmodel.Emoji, db.Error) { return e.getEmoji( ctx, - func() (*gtsmodel.Emoji, bool) { - return e.emojiCache.GetByURI(uri) - }, + "URI", func(emoji *gtsmodel.Emoji) error { return e.newEmojiQ(emoji).Where("? = ?", bun.Ident("emoji.uri"), uri).Scan(ctx) }, + uri, ) } func (e *emojiDB) GetEmojiByShortcodeDomain(ctx context.Context, shortcode string, domain string) (*gtsmodel.Emoji, db.Error) { return e.getEmoji( ctx, - func() (*gtsmodel.Emoji, bool) { - return e.emojiCache.GetByShortcodeDomain(shortcode, domain) - }, + "Shortcode.Domain", func(emoji *gtsmodel.Emoji) error { q := e.newEmojiQ(emoji) @@ -292,31 +318,30 @@ func (e *emojiDB) GetEmojiByShortcodeDomain(ctx context.Context, shortcode strin return q.Scan(ctx) }, + shortcode, + domain, ) } func (e *emojiDB) GetEmojiByStaticURL(ctx context.Context, imageStaticURL string) (*gtsmodel.Emoji, db.Error) { return e.getEmoji( ctx, - func() (*gtsmodel.Emoji, bool) { - return e.emojiCache.GetByImageStaticURL(imageStaticURL) - }, + "ImageStaticURL", func(emoji *gtsmodel.Emoji) error { return e. newEmojiQ(emoji). Where("? = ?", bun.Ident("emoji.image_static_url"), imageStaticURL). Scan(ctx) }, + imageStaticURL, ) } func (e *emojiDB) PutEmojiCategory(ctx context.Context, emojiCategory *gtsmodel.EmojiCategory) db.Error { - if _, err := e.conn.NewInsert().Model(emojiCategory).Exec(ctx); err != nil { + return e.categoryCache.Store(emojiCategory, func() error { + _, err := e.conn.NewInsert().Model(emojiCategory).Exec(ctx) return e.conn.ProcessError(err) - } - - e.categoryCache.Put(emojiCategory) - return nil + }) } func (e *emojiDB) GetEmojiCategories(ctx context.Context) ([]*gtsmodel.EmojiCategory, db.Error) { @@ -338,45 +363,36 @@ func (e *emojiDB) GetEmojiCategories(ctx context.Context) ([]*gtsmodel.EmojiCate func (e *emojiDB) GetEmojiCategory(ctx context.Context, id string) (*gtsmodel.EmojiCategory, db.Error) { return e.getEmojiCategory( ctx, - func() (*gtsmodel.EmojiCategory, bool) { - return e.categoryCache.GetByID(id) - }, + "ID", func(emojiCategory *gtsmodel.EmojiCategory) error { return e.newEmojiCategoryQ(emojiCategory).Where("? = ?", bun.Ident("emoji_category.id"), id).Scan(ctx) }, + id, ) } func (e *emojiDB) GetEmojiCategoryByName(ctx context.Context, name string) (*gtsmodel.EmojiCategory, db.Error) { return e.getEmojiCategory( ctx, - func() (*gtsmodel.EmojiCategory, bool) { - return e.categoryCache.GetByName(name) - }, + "Name", func(emojiCategory *gtsmodel.EmojiCategory) error { return e.newEmojiCategoryQ(emojiCategory).Where("LOWER(?) = ?", bun.Ident("emoji_category.name"), strings.ToLower(name)).Scan(ctx) }, + name, ) } -func (e *emojiDB) getEmoji(ctx context.Context, cacheGet func() (*gtsmodel.Emoji, bool), dbQuery func(*gtsmodel.Emoji) error) (*gtsmodel.Emoji, db.Error) { - // Attempt to fetch cached emoji - emoji, cached := cacheGet() - - if !cached { - emoji = >smodel.Emoji{} +func (e *emojiDB) getEmoji(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Emoji) error, keyParts ...any) (*gtsmodel.Emoji, db.Error) { + return e.emojiCache.Load(lookup, func() (*gtsmodel.Emoji, error) { + var emoji gtsmodel.Emoji // Not cached! Perform database query - err := dbQuery(emoji) - if err != nil { + if err := dbQuery(&emoji); err != nil { return nil, e.conn.ProcessError(err) } - // Place in the cache - e.emojiCache.Put(emoji) - } - - return emoji, nil + return &emoji, nil + }, keyParts...) } func (e *emojiDB) emojisFromIDs(ctx context.Context, emojiIDs []string) ([]*gtsmodel.Emoji, db.Error) { @@ -399,24 +415,17 @@ func (e *emojiDB) emojisFromIDs(ctx context.Context, emojiIDs []string) ([]*gtsm return emojis, nil } -func (e *emojiDB) getEmojiCategory(ctx context.Context, cacheGet func() (*gtsmodel.EmojiCategory, bool), dbQuery func(*gtsmodel.EmojiCategory) error) (*gtsmodel.EmojiCategory, db.Error) { - // Attempt to fetch cached emoji categories - emojiCategory, cached := cacheGet() - - if !cached { - emojiCategory = >smodel.EmojiCategory{} +func (e *emojiDB) getEmojiCategory(ctx context.Context, lookup string, dbQuery func(*gtsmodel.EmojiCategory) error, keyParts ...any) (*gtsmodel.EmojiCategory, db.Error) { + return e.categoryCache.Load(lookup, func() (*gtsmodel.EmojiCategory, error) { + var category gtsmodel.EmojiCategory // Not cached! Perform database query - err := dbQuery(emojiCategory) - if err != nil { + if err := dbQuery(&category); err != nil { return nil, e.conn.ProcessError(err) } - // Place in the cache - e.categoryCache.Put(emojiCategory) - } - - return emojiCategory, nil + return &category, nil + }, keyParts...) } func (e *emojiDB) emojiCategoriesFromIDs(ctx context.Context, emojiCategoryIDs []string) ([]*gtsmodel.EmojiCategory, db.Error) { diff --git a/internal/db/bundb/mention.go b/internal/db/bundb/mention.go index 355078021..303e16484 100644 --- a/internal/db/bundb/mention.go +++ b/internal/db/bundb/mention.go @@ -20,8 +20,9 @@ package bundb import ( "context" + "time" - "codeberg.org/gruf/go-cache/v2" + "codeberg.org/gruf/go-cache/v3/result" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/log" @@ -30,7 +31,22 @@ import ( type mentionDB struct { conn *DBConn - cache cache.Cache[string, *gtsmodel.Mention] + cache *result.Cache[*gtsmodel.Mention] +} + +func (m *mentionDB) init() { + // Initialize notification result cache + m.cache = result.NewSized([]result.Lookup{ + {Name: "ID"}, + }, func(m1 *gtsmodel.Mention) *gtsmodel.Mention { + m2 := new(gtsmodel.Mention) + *m2 = *m1 + return m2 + }, 1000) + + // Set cache TTL and start sweep routine + m.cache.SetTTL(time.Minute*5, false) + m.cache.Start(time.Second * 10) } func (m *mentionDB) newMentionQ(i interface{}) *bun.SelectQuery { @@ -42,27 +58,19 @@ func (m *mentionDB) newMentionQ(i interface{}) *bun.SelectQuery { Relation("TargetAccount") } -func (m *mentionDB) getMentionDB(ctx context.Context, id string) (*gtsmodel.Mention, db.Error) { - mention := gtsmodel.Mention{} - - q := m.newMentionQ(&mention). - Where("? = ?", bun.Ident("mention.id"), id) - - if err := q.Scan(ctx); err != nil { - return nil, m.conn.ProcessError(err) - } +func (m *mentionDB) GetMention(ctx context.Context, id string) (*gtsmodel.Mention, db.Error) { + return m.cache.Load("ID", func() (*gtsmodel.Mention, error) { + var mention gtsmodel.Mention - copy := mention - m.cache.Set(mention.ID, ©) + q := m.newMentionQ(&mention). + Where("? = ?", bun.Ident("mention.id"), id) - return &mention, nil -} + if err := q.Scan(ctx); err != nil { + return nil, m.conn.ProcessError(err) + } -func (m *mentionDB) GetMention(ctx context.Context, id string) (*gtsmodel.Mention, db.Error) { - if mention, ok := m.cache.Get(id); ok { - return mention, nil - } - return m.getMentionDB(ctx, id) + return &mention, nil + }, id) } func (m *mentionDB) GetMentions(ctx context.Context, ids []string) ([]*gtsmodel.Mention, db.Error) { diff --git a/internal/db/bundb/notification.go b/internal/db/bundb/notification.go index 69e3cf39f..1874f81ea 100644 --- a/internal/db/bundb/notification.go +++ b/internal/db/bundb/notification.go @@ -20,8 +20,9 @@ package bundb import ( "context" + "time" - "codeberg.org/gruf/go-cache/v2" + "codeberg.org/gruf/go-cache/v3/result" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/log" @@ -30,31 +31,40 @@ import ( type notificationDB struct { conn *DBConn - cache cache.Cache[string, *gtsmodel.Notification] + cache *result.Cache[*gtsmodel.Notification] } -func (n *notificationDB) GetNotification(ctx context.Context, id string) (*gtsmodel.Notification, db.Error) { - if notification, ok := n.cache.Get(id); ok { - return notification, nil - } - - dst := gtsmodel.Notification{ID: id} - - q := n.conn.NewSelect(). - Model(&dst). - Relation("OriginAccount"). - Relation("TargetAccount"). - Relation("Status"). - Where("? = ?", bun.Ident("notification.id"), id) - - if err := q.Scan(ctx); err != nil { - return nil, n.conn.ProcessError(err) - } +func (n *notificationDB) init() { + // Initialize notification result cache + n.cache = result.NewSized([]result.Lookup{ + {Name: "ID"}, + }, func(n1 *gtsmodel.Notification) *gtsmodel.Notification { + n2 := new(gtsmodel.Notification) + *n2 = *n1 + return n2 + }, 1000) + + // Set cache TTL and start sweep routine + n.cache.SetTTL(time.Minute*5, false) + n.cache.Start(time.Second * 10) +} - copy := dst - n.cache.Set(id, ©) +func (n *notificationDB) GetNotification(ctx context.Context, id string) (*gtsmodel.Notification, db.Error) { + return n.cache.Load("ID", func() (*gtsmodel.Notification, error) { + var notif gtsmodel.Notification + + q := n.conn.NewSelect(). + Model(¬if). + Relation("OriginAccount"). + Relation("TargetAccount"). + Relation("Status"). + Where("? = ?", bun.Ident("notification.id"), id) + if err := q.Scan(ctx); err != nil { + return nil, n.conn.ProcessError(err) + } - return &dst, nil + return ¬if, nil + }, id) } func (n *notificationDB) GetNotifications(ctx context.Context, accountID string, excludeTypes []string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, db.Error) { diff --git a/internal/db/bundb/status.go b/internal/db/bundb/status.go index bc72c2849..b4ae40607 100644 --- a/internal/db/bundb/status.go +++ b/internal/db/bundb/status.go @@ -25,7 +25,7 @@ import ( "errors" "time" - "github.com/superseriousbusiness/gotosocial/internal/cache" + "codeberg.org/gruf/go-cache/v3/result" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/log" @@ -33,15 +33,28 @@ import ( ) type statusDB struct { - conn *DBConn - cache *cache.StatusCache - - // TODO: keep method definitions in same place but instead have receiver - // all point to one single "db" type, so they can all share methods - // and caches where necessary + conn *DBConn + cache *result.Cache[*gtsmodel.Status] accounts *accountDB } +func (s *statusDB) init() { + // Initialize status result cache + s.cache = result.NewSized([]result.Lookup{ + {Name: "ID"}, + {Name: "URI"}, + {Name: "URL"}, + }, func(s1 *gtsmodel.Status) *gtsmodel.Status { + s2 := new(gtsmodel.Status) + *s2 = *s1 + return s2 + }, 1000) + + // Set cache TTL and start sweep routine + s.cache.SetTTL(time.Minute*5, false) + s.cache.Start(time.Second * 10) +} + func (s *statusDB) newStatusQ(status interface{}) *bun.SelectQuery { return s.conn. NewSelect(). @@ -68,61 +81,62 @@ func (s *statusDB) newFaveQ(faves interface{}) *bun.SelectQuery { func (s *statusDB) GetStatusByID(ctx context.Context, id string) (*gtsmodel.Status, db.Error) { return s.getStatus( ctx, - func() (*gtsmodel.Status, bool) { - return s.cache.GetByID(id) - }, + "ID", func(status *gtsmodel.Status) error { return s.newStatusQ(status).Where("? = ?", bun.Ident("status.id"), id).Scan(ctx) }, + id, ) } func (s *statusDB) GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.Status, db.Error) { return s.getStatus( ctx, - func() (*gtsmodel.Status, bool) { - return s.cache.GetByURI(uri) - }, + "URI", func(status *gtsmodel.Status) error { return s.newStatusQ(status).Where("? = ?", bun.Ident("status.uri"), uri).Scan(ctx) }, + uri, ) } func (s *statusDB) GetStatusByURL(ctx context.Context, url string) (*gtsmodel.Status, db.Error) { return s.getStatus( ctx, - func() (*gtsmodel.Status, bool) { - return s.cache.GetByURL(url) - }, + "URL", func(status *gtsmodel.Status) error { return s.newStatusQ(status).Where("? = ?", bun.Ident("status.url"), url).Scan(ctx) }, + url, ) } -func (s *statusDB) getStatus(ctx context.Context, cacheGet func() (*gtsmodel.Status, bool), dbQuery func(*gtsmodel.Status) error) (*gtsmodel.Status, db.Error) { - // Attempt to fetch cached status - status, cached := cacheGet() - - if !cached { - status = >smodel.Status{} +func (s *statusDB) getStatus(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Status) error, keyParts ...any) (*gtsmodel.Status, db.Error) { + // Fetch status from database cache with loader callback + status, err := s.cache.Load(lookup, func() (*gtsmodel.Status, error) { + var status gtsmodel.Status // Not cached! Perform database query - if err := dbQuery(status); err != nil { + if err := dbQuery(&status); err != nil { return nil, s.conn.ProcessError(err) } // If there is boosted, fetch from DB also if status.BoostOfID != "" { - boostOf, err := s.GetStatusByID(ctx, status.BoostOfID) - if err == nil { - status.BoostOf = boostOf + status.BoostOf = >smodel.Status{} + err := s.newStatusQ(status.BoostOf). + Where("? = ?", bun.Ident("status.id"), status.BoostOfID). + Scan(ctx) + if err != nil { + return nil, s.conn.ProcessError(err) } } - // Place in the cache - s.cache.Put(status) + return &status, nil + }, keyParts...) + if err != nil { + // error already processed + return nil, err } // Set the status author account @@ -137,73 +151,66 @@ func (s *statusDB) getStatus(ctx context.Context, cacheGet func() (*gtsmodel.Sta } func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Error { - err := s.conn.RunInTx(ctx, func(tx bun.Tx) error { - // create links between this status and any emojis it uses - for _, i := range status.EmojiIDs { - if _, err := tx. - NewInsert(). - Model(>smodel.StatusToEmoji{ - StatusID: status.ID, - EmojiID: i, - }).Exec(ctx); err != nil { - err = s.conn.errProc(err) - if !errors.Is(err, db.ErrAlreadyExists) { - return err + return s.cache.Store(status, func() error { + // It is safe to run this database transaction within cache.Store + // as the cache does not attempt a mutex lock until AFTER hook. + // + return s.conn.RunInTx(ctx, func(tx bun.Tx) error { + // create links between this status and any emojis it uses + for _, i := range status.EmojiIDs { + if _, err := tx. + NewInsert(). + Model(>smodel.StatusToEmoji{ + StatusID: status.ID, + EmojiID: i, + }).Exec(ctx); err != nil { + err = s.conn.ProcessError(err) + if !errors.Is(err, db.ErrAlreadyExists) { + return err + } } } - } - // create links between this status and any tags it uses - for _, i := range status.TagIDs { - if _, err := tx. - NewInsert(). - Model(>smodel.StatusToTag{ - StatusID: status.ID, - TagID: i, - }).Exec(ctx); err != nil { - err = s.conn.errProc(err) - if !errors.Is(err, db.ErrAlreadyExists) { - return err + // create links between this status and any tags it uses + for _, i := range status.TagIDs { + if _, err := tx. + NewInsert(). + Model(>smodel.StatusToTag{ + StatusID: status.ID, + TagID: i, + }).Exec(ctx); err != nil { + err = s.conn.ProcessError(err) + if !errors.Is(err, db.ErrAlreadyExists) { + return err + } } } - } - // change the status ID of the media attachments to the new status - for _, a := range status.Attachments { - a.StatusID = status.ID - a.UpdatedAt = time.Now() - if _, err := tx. - NewUpdate(). - Model(a). - Where("? = ?", bun.Ident("media_attachment.id"), a.ID). - Exec(ctx); err != nil { - err = s.conn.errProc(err) - if !errors.Is(err, db.ErrAlreadyExists) { - return err + // change the status ID of the media attachments to the new status + for _, a := range status.Attachments { + a.StatusID = status.ID + a.UpdatedAt = time.Now() + if _, err := tx. + NewUpdate(). + Model(a). + Where("? = ?", bun.Ident("media_attachment.id"), a.ID). + Exec(ctx); err != nil { + err = s.conn.ProcessError(err) + if !errors.Is(err, db.ErrAlreadyExists) { + return err + } } } - } - // Finally, insert the status - if _, err := tx. - NewInsert(). - Model(status). - Exec(ctx); err != nil { + // Finally, insert the status + _, err := tx.NewInsert().Model(status).Exec(ctx) return err - } - - return nil + }) }) - if err != nil { - return s.conn.ProcessError(err) - } - - s.cache.Put(status) - return nil } -func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status) (*gtsmodel.Status, db.Error) { - err := s.conn.RunInTx(ctx, func(tx bun.Tx) error { +func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status) db.Error { + if err := s.conn.RunInTx(ctx, func(tx bun.Tx) error { // create links between this status and any emojis it uses for _, i := range status.EmojiIDs { if _, err := tx. @@ -212,7 +219,7 @@ func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status) (* StatusID: status.ID, EmojiID: i, }).Exec(ctx); err != nil { - err = s.conn.errProc(err) + err = s.conn.ProcessError(err) if !errors.Is(err, db.ErrAlreadyExists) { return err } @@ -227,14 +234,14 @@ func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status) (* StatusID: status.ID, TagID: i, }).Exec(ctx); err != nil { - err = s.conn.errProc(err) + err = s.conn.ProcessError(err) if !errors.Is(err, db.ErrAlreadyExists) { return err } } } - // change the status ID of the media attachments to this status + // change the status ID of the media attachments to the new status for _, a := range status.Attachments { a.StatusID = status.ID a.UpdatedAt = time.Now() @@ -243,31 +250,31 @@ func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status) (* Model(a). Where("? = ?", bun.Ident("media_attachment.id"), a.ID). Exec(ctx); err != nil { - return err + err = s.conn.ProcessError(err) + if !errors.Is(err, db.ErrAlreadyExists) { + return err + } } } - // Finally, update the status itself - if _, err := tx. + // Finally, insert the status + _, err := tx. NewUpdate(). Model(status). Where("? = ?", bun.Ident("status.id"), status.ID). - Exec(ctx); err != nil { - return err - } - - return nil - }) - if err != nil { - return nil, s.conn.ProcessError(err) + Exec(ctx) + return err + }); err != nil { + return err } - s.cache.Put(status) - return status, nil + // Drop any old value from cache by this ID + s.cache.Invalidate("ID", status.ID) + return nil } func (s *statusDB) DeleteStatusByID(ctx context.Context, id string) db.Error { - err := s.conn.RunInTx(ctx, func(tx bun.Tx) error { + if err := s.conn.RunInTx(ctx, func(tx bun.Tx) error { // delete links between this status and any emojis it uses if _, err := tx. NewDelete(). @@ -296,36 +303,41 @@ func (s *statusDB) DeleteStatusByID(ctx context.Context, id string) db.Error { } return nil - }) - if err != nil { - return s.conn.ProcessError(err) + }); err != nil { + return err } - s.cache.Invalidate(id) + // Drop any old value from cache by this ID + s.cache.Invalidate("ID", id) return nil } func (s *statusDB) GetStatusParents(ctx context.Context, status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, db.Error) { - parents := []*gtsmodel.Status{} - s.statusParent(ctx, status, &parents, onlyDirect) - return parents, nil -} - -func (s *statusDB) statusParent(ctx context.Context, status *gtsmodel.Status, foundStatuses *[]*gtsmodel.Status, onlyDirect bool) { - if status.InReplyToID == "" { - return + if onlyDirect { + // Only want the direct parent, no further than first level + parent, err := s.GetStatusByID(ctx, status.InReplyToID) + if err != nil { + return nil, err + } + return []*gtsmodel.Status{parent}, nil } - parentStatus, err := s.GetStatusByID(ctx, status.InReplyToID) - if err == nil { - *foundStatuses = append(*foundStatuses, parentStatus) - } + var parents []*gtsmodel.Status - if onlyDirect { - return + for id := status.InReplyToID; id != ""; { + parent, err := s.GetStatusByID(ctx, id) + if err != nil { + return nil, err + } + + // Append parent to slice + parents = append(parents, parent) + + // Set the next parent ID + id = parent.InReplyToID } - s.statusParent(ctx, parentStatus, foundStatuses, false) + return parents, nil } func (s *statusDB) GetStatusChildren(ctx context.Context, status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, db.Error) { @@ -350,7 +362,7 @@ func (s *statusDB) GetStatusChildren(ctx context.Context, status *gtsmodel.Statu } func (s *statusDB) statusChildren(ctx context.Context, status *gtsmodel.Status, foundStatuses *list.List, onlyDirect bool, minID string) { - childIDs := []string{} + var childIDs []string q := s.conn. NewSelect(). @@ -471,6 +483,7 @@ func (s *statusDB) GetStatusFaves(ctx context.Context, status *gtsmodel.Status) if err := q.Scan(ctx); err != nil { return nil, s.conn.ProcessError(err) } + return faves, nil } diff --git a/internal/db/bundb/timeline_test.go b/internal/db/bundb/timeline_test.go index 9b6365621..066f55234 100644 --- a/internal/db/bundb/timeline_test.go +++ b/internal/db/bundb/timeline_test.go @@ -35,44 +35,52 @@ type TimelineTestSuite struct { } func (suite *TimelineTestSuite) TestGetPublicTimeline() { - s, err := suite.db.GetPublicTimeline(context.Background(), "", "", "", 20, false) + ctx := context.Background() + + s, err := suite.db.GetPublicTimeline(ctx, "", "", "", 20, false) suite.NoError(err) suite.Len(s, 6) } func (suite *TimelineTestSuite) TestGetPublicTimelineWithFutureStatus() { + ctx := context.Background() + futureStatus := getFutureStatus() - if err := suite.db.Put(context.Background(), futureStatus); err != nil { - suite.FailNow(err.Error()) - } + err := suite.db.PutStatus(ctx, futureStatus) + suite.NoError(err) - s, err := suite.db.GetPublicTimeline(context.Background(), "", "", "", 20, false) + s, err := suite.db.GetPublicTimeline(ctx, "", "", "", 20, false) suite.NoError(err) + suite.NotContains(s, futureStatus) suite.Len(s, 6) } func (suite *TimelineTestSuite) TestGetHomeTimeline() { + ctx := context.Background() + viewingAccount := suite.testAccounts["local_account_1"] - s, err := suite.db.GetHomeTimeline(context.Background(), viewingAccount.ID, "", "", "", 20, false) + s, err := suite.db.GetHomeTimeline(ctx, viewingAccount.ID, "", "", "", 20, false) suite.NoError(err) suite.Len(s, 16) } func (suite *TimelineTestSuite) TestGetHomeTimelineWithFutureStatus() { + ctx := context.Background() + viewingAccount := suite.testAccounts["local_account_1"] futureStatus := getFutureStatus() - if err := suite.db.Put(context.Background(), futureStatus); err != nil { - suite.FailNow(err.Error()) - } + err := suite.db.PutStatus(ctx, futureStatus) + suite.NoError(err) s, err := suite.db.GetHomeTimeline(context.Background(), viewingAccount.ID, "", "", "", 20, false) suite.NoError(err) + suite.NotContains(s, futureStatus) suite.Len(s, 16) } diff --git a/internal/db/bundb/tombstone.go b/internal/db/bundb/tombstone.go index 7ce3327a7..309a39fd3 100644 --- a/internal/db/bundb/tombstone.go +++ b/internal/db/bundb/tombstone.go @@ -43,7 +43,7 @@ func (t *tombstoneDB) init() { t2 := new(gtsmodel.Tombstone) *t2 = *t1 return t2 - }, 1000) + }, 100) // Set cache TTL and start sweep routine t.cache.SetTTL(time.Minute*5, false) diff --git a/internal/db/bundb/user.go b/internal/db/bundb/user.go index aa2f4c2c8..d9b281a6f 100644 --- a/internal/db/bundb/user.go +++ b/internal/db/bundb/user.go @@ -22,7 +22,7 @@ import ( "context" "time" - "github.com/superseriousbusiness/gotosocial/internal/cache" + "codeberg.org/gruf/go-cache/v3/result" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/uptrace/bun" @@ -30,111 +30,121 @@ import ( type userDB struct { conn *DBConn - cache *cache.UserCache + cache *result.Cache[*gtsmodel.User] } -func (u *userDB) newUserQ(user *gtsmodel.User) *bun.SelectQuery { - return u.conn. - NewSelect(). - Model(user). - Relation("Account") +func (u *userDB) init() { + // Initialize user result cache + u.cache = result.NewSized([]result.Lookup{ + {Name: "ID"}, + {Name: "AccountID"}, + {Name: "Email"}, + {Name: "ConfirmationToken"}, + }, func(u1 *gtsmodel.User) *gtsmodel.User { + u2 := new(gtsmodel.User) + *u2 = *u1 + return u2 + }, 1000) + + // Set cache TTL and start sweep routine + u.cache.SetTTL(time.Minute*5, false) + u.cache.Start(time.Second * 10) } -func (u *userDB) getUser(ctx context.Context, cacheGet func() (*gtsmodel.User, bool), dbQuery func(*gtsmodel.User) error) (*gtsmodel.User, db.Error) { - // Attempt to fetch cached user - user, cached := cacheGet() +func (u *userDB) GetUserByID(ctx context.Context, id string) (*gtsmodel.User, db.Error) { + return u.cache.Load("ID", func() (*gtsmodel.User, error) { + var user gtsmodel.User - if !cached { - user = >smodel.User{} + q := u.conn. + NewSelect(). + Model(&user). + Relation("Account"). + Where("? = ?", bun.Ident("user.id"), id) - // Not cached! Perform database query - err := dbQuery(user) - if err != nil { + if err := q.Scan(ctx); err != nil { return nil, u.conn.ProcessError(err) } - // Place in the cache - u.cache.Put(user) - } - - return user, nil -} - -func (u *userDB) GetUserByID(ctx context.Context, id string) (*gtsmodel.User, db.Error) { - return u.getUser( - ctx, - func() (*gtsmodel.User, bool) { - return u.cache.GetByID(id) - }, - func(user *gtsmodel.User) error { - return u.newUserQ(user).Where("? = ?", bun.Ident("user.id"), id).Scan(ctx) - }, - ) + return &user, nil + }, id) } func (u *userDB) GetUserByAccountID(ctx context.Context, accountID string) (*gtsmodel.User, db.Error) { - return u.getUser( - ctx, - func() (*gtsmodel.User, bool) { - return u.cache.GetByAccountID(accountID) - }, - func(user *gtsmodel.User) error { - return u.newUserQ(user).Where("? = ?", bun.Ident("user.account_id"), accountID).Scan(ctx) - }, - ) + return u.cache.Load("AccountID", func() (*gtsmodel.User, error) { + var user gtsmodel.User + + q := u.conn. + NewSelect(). + Model(&user). + Relation("Account"). + Where("? = ?", bun.Ident("user.account_id"), accountID) + + if err := q.Scan(ctx); err != nil { + return nil, u.conn.ProcessError(err) + } + + return &user, nil + }, accountID) } func (u *userDB) GetUserByEmailAddress(ctx context.Context, emailAddress string) (*gtsmodel.User, db.Error) { - return u.getUser( - ctx, - func() (*gtsmodel.User, bool) { - return u.cache.GetByEmail(emailAddress) - }, - func(user *gtsmodel.User) error { - return u.newUserQ(user).Where("? = ?", bun.Ident("user.email"), emailAddress).Scan(ctx) - }, - ) + return u.cache.Load("Email", func() (*gtsmodel.User, error) { + var user gtsmodel.User + + q := u.conn. + NewSelect(). + Model(&user). + Relation("Account"). + Where("? = ?", bun.Ident("user.email"), emailAddress) + + if err := q.Scan(ctx); err != nil { + return nil, u.conn.ProcessError(err) + } + + return &user, nil + }, emailAddress) } func (u *userDB) GetUserByConfirmationToken(ctx context.Context, confirmationToken string) (*gtsmodel.User, db.Error) { - return u.getUser( - ctx, - func() (*gtsmodel.User, bool) { - return u.cache.GetByConfirmationToken(confirmationToken) - }, - func(user *gtsmodel.User) error { - return u.newUserQ(user).Where("? = ?", bun.Ident("user.confirmation_token"), confirmationToken).Scan(ctx) - }, - ) -} + return u.cache.Load("ConfirmationToken", func() (*gtsmodel.User, error) { + var user gtsmodel.User -func (u *userDB) PutUser(ctx context.Context, user *gtsmodel.User) (*gtsmodel.User, db.Error) { - if _, err := u.conn. - NewInsert(). - Model(user). - Exec(ctx); err != nil { - return nil, u.conn.ProcessError(err) - } + q := u.conn. + NewSelect(). + Model(&user). + Relation("Account"). + Where("? = ?", bun.Ident("user.confirmation_token"), confirmationToken) - u.cache.Put(user) - return user, nil + if err := q.Scan(ctx); err != nil { + return nil, u.conn.ProcessError(err) + } + + return &user, nil + }, confirmationToken) } -func (u *userDB) UpdateUser(ctx context.Context, user *gtsmodel.User, columns ...string) (*gtsmodel.User, db.Error) { +func (u *userDB) PutUser(ctx context.Context, user *gtsmodel.User) db.Error { + return u.cache.Store(user, func() error { + _, err := u.conn. + NewInsert(). + Model(user). + Exec(ctx) + return u.conn.ProcessError(err) + }) +} + +func (u *userDB) UpdateUser(ctx context.Context, user *gtsmodel.User) db.Error { // Update the user's last-updated user.UpdatedAt = time.Now() - if _, err := u.conn. - NewUpdate(). - Model(user). - Where("? = ?", bun.Ident("user.id"), user.ID). - Column(columns...). - Exec(ctx); err != nil { - return nil, u.conn.ProcessError(err) - } - - u.cache.Invalidate(user.ID) - return user, nil + return u.cache.Store(user, func() error { + _, err := u.conn. + NewUpdate(). + Model(user). + Where("? = ?", bun.Ident("user.id"), user.ID). + Exec(ctx) + return u.conn.ProcessError(err) + }) } func (u *userDB) DeleteUserByID(ctx context.Context, userID string) db.Error { @@ -146,6 +156,7 @@ func (u *userDB) DeleteUserByID(ctx context.Context, userID string) db.Error { return u.conn.ProcessError(err) } - u.cache.Invalidate(userID) + // Invalidate user from cache + u.cache.Invalidate("ID", userID) return nil } diff --git a/internal/db/bundb/user_test.go b/internal/db/bundb/user_test.go index 6ad59fc8e..18f67dde5 100644 --- a/internal/db/bundb/user_test.go +++ b/internal/db/bundb/user_test.go @@ -50,21 +50,20 @@ func (suite *UserTestSuite) TestGetUserByAccountID() { func (suite *UserTestSuite) TestUpdateUserSelectedColumns() { testUser := suite.testUsers["local_account_1"] - user := >smodel.User{ - ID: testUser.ID, - Email: "whatever", - Locale: "es", - } - user, err := suite.db.UpdateUser(context.Background(), user, "email", "locale") + updateUser := new(gtsmodel.User) + *updateUser = *testUser + updateUser.Email = "whatever" + updateUser.Locale = "es" + + err := suite.db.UpdateUser(context.Background(), updateUser) suite.NoError(err) - suite.NotNil(user) dbUser, err := suite.db.GetUserByID(context.Background(), testUser.ID) suite.NoError(err) suite.NotNil(dbUser) - suite.Equal("whatever", dbUser.Email) - suite.Equal("es", dbUser.Locale) + suite.Equal(updateUser.Email, dbUser.Email) + suite.Equal(updateUser.Locale, dbUser.Locale) suite.Equal(testUser.AccountID, dbUser.AccountID) } diff --git a/internal/db/status.go b/internal/db/status.go index 55cec5beb..d0983122b 100644 --- a/internal/db/status.go +++ b/internal/db/status.go @@ -39,7 +39,7 @@ type Status interface { PutStatus(ctx context.Context, status *gtsmodel.Status) Error // UpdateStatus updates one status in the database and returns it to the caller. - UpdateStatus(ctx context.Context, status *gtsmodel.Status) (*gtsmodel.Status, Error) + UpdateStatus(ctx context.Context, status *gtsmodel.Status) Error // DeleteStatusByID deletes one status from the database. DeleteStatusByID(ctx context.Context, id string) Error diff --git a/internal/db/user.go b/internal/db/user.go index a4d48db56..d01a8862a 100644 --- a/internal/db/user.go +++ b/internal/db/user.go @@ -34,9 +34,10 @@ type User interface { GetUserByEmailAddress(ctx context.Context, emailAddress string) (*gtsmodel.User, Error) // GetUserByConfirmationToken returns one user by its confirmation token, or an error if something goes wrong. GetUserByConfirmationToken(ctx context.Context, confirmationToken string) (*gtsmodel.User, Error) - // UpdateUser updates one user by its primary key. If columns is set, only given columns - // will be updated. If not set, all columns will be updated. - UpdateUser(ctx context.Context, user *gtsmodel.User, columns ...string) (*gtsmodel.User, Error) + // PutUser will attempt to place user in the database + PutUser(ctx context.Context, user *gtsmodel.User) Error + // UpdateUser updates one user by its primary key. + UpdateUser(ctx context.Context, user *gtsmodel.User) Error // DeleteUserByID deletes one user by its ID. DeleteUserByID(ctx context.Context, userID string) Error } diff --git a/internal/federation/dereferencing/account.go b/internal/federation/dereferencing/account.go index 0e7bc1cc9..6e107a11d 100644 --- a/internal/federation/dereferencing/account.go +++ b/internal/federation/dereferencing/account.go @@ -276,7 +276,7 @@ func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountPar foundAccount.LastWebfingeredAt = fingered foundAccount.UpdatedAt = time.Now() - foundAccount, err = d.db.PutAccount(ctx, foundAccount) + err = d.db.PutAccount(ctx, foundAccount) if err != nil { err = fmt.Errorf("GetRemoteAccount: error putting new account: %s", err) return @@ -338,7 +338,7 @@ func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountPar } if accountDomainChanged || sharedInboxChanged || fieldsChanged || fingeredChanged { - foundAccount, err = d.db.UpdateAccount(ctx, foundAccount) + err = d.db.UpdateAccount(ctx, foundAccount) if err != nil { return nil, fmt.Errorf("GetRemoteAccount: error updating remoteAccount: %s", err) } diff --git a/internal/federation/dereferencing/account_test.go b/internal/federation/dereferencing/account_test.go index ddd9456e8..38dc615d5 100644 --- a/internal/federation/dereferencing/account_test.go +++ b/internal/federation/dereferencing/account_test.go @@ -107,7 +107,7 @@ func (suite *AccountTestSuite) TestDereferenceLocalAccountAsRemoteURLNoSharedInb targetAccount := suite.testAccounts["local_account_2"] targetAccount.SharedInboxURI = nil - if _, err := suite.db.UpdateAccount(context.Background(), targetAccount); err != nil { + if err := suite.db.UpdateAccount(context.Background(), targetAccount); err != nil { suite.FailNow(err.Error()) } diff --git a/internal/federation/dereferencing/status.go b/internal/federation/dereferencing/status.go index bfbc790d8..001fe53f4 100644 --- a/internal/federation/dereferencing/status.go +++ b/internal/federation/dereferencing/status.go @@ -45,8 +45,10 @@ func (d *deref) EnrichRemoteStatus(ctx context.Context, username string, status if err := d.populateStatusFields(ctx, status, username, includeParent); err != nil { return nil, err } - - return d.db.UpdateStatus(ctx, status) + if err := d.db.UpdateStatus(ctx, status); err != nil { + return nil, err + } + return status, nil } // GetRemoteStatus completely dereferences a remote status, converts it to a GtS model status, diff --git a/internal/federation/federatingdb/inbox_test.go b/internal/federation/federatingdb/inbox_test.go index dbf9d3c53..f0ce38af6 100644 --- a/internal/federation/federatingdb/inbox_test.go +++ b/internal/federation/federatingdb/inbox_test.go @@ -68,7 +68,7 @@ func (suite *InboxTestSuite) TestInboxesForAccountIRIWithSharedInbox() { testAccount := suite.testAccounts["local_account_1"] sharedInbox := "http://some-inbox-iri/weeeeeeeeeeeee" testAccount.SharedInboxURI = &sharedInbox - if _, err := suite.db.UpdateAccount(ctx, testAccount); err != nil { + if err := suite.db.UpdateAccount(ctx, testAccount); err != nil { suite.FailNow("error updating account") } diff --git a/internal/processing/account/delete.go b/internal/processing/account/delete.go index 3758a4000..3cc3bb143 100644 --- a/internal/processing/account/delete.go +++ b/internal/processing/account/delete.go @@ -273,7 +273,7 @@ selectStatusesLoop: account.SuspendedAt = time.Now() account.SuspensionOrigin = origin - account, err := p.db.UpdateAccount(ctx, account) + err := p.db.UpdateAccount(ctx, account) if err != nil { return gtserror.NewErrorInternalError(err) } diff --git a/internal/processing/account/update.go b/internal/processing/account/update.go index bce82d6ca..bc4570c76 100644 --- a/internal/processing/account/update.go +++ b/internal/processing/account/update.go @@ -164,7 +164,7 @@ func (p *processor) Update(ctx context.Context, account *gtsmodel.Account, form account.EnableRSS = form.EnableRSS } - updatedAccount, err := p.db.UpdateAccount(ctx, account) + err := p.db.UpdateAccount(ctx, account) if err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("could not update account %s: %s", account.ID, err)) } @@ -172,11 +172,11 @@ func (p *processor) Update(ctx context.Context, account *gtsmodel.Account, form p.clientWorker.Queue(messages.FromClientAPI{ APObjectType: ap.ObjectProfile, APActivityType: ap.ActivityUpdate, - GTSModel: updatedAccount, - OriginAccount: updatedAccount, + GTSModel: account, + OriginAccount: account, }) - acctSensitive, err := p.tc.AccountToAPIAccountSensitive(ctx, updatedAccount) + acctSensitive, err := p.tc.AccountToAPIAccountSensitive(ctx, account) if err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("could not convert account into apisensitive account: %s", err)) } diff --git a/internal/processing/fromclientapi_test.go b/internal/processing/fromclientapi_test.go index 0e620c9e9..c4e06ea62 100644 --- a/internal/processing/fromclientapi_test.go +++ b/internal/processing/fromclientapi_test.go @@ -129,7 +129,7 @@ func (suite *FromClientAPITestSuite) TestProcessStatusDelete() { suite.NoError(errWithCode) // delete the status from the db first, to mimic what would have already happened earlier up the flow - err := suite.db.DeleteByID(ctx, deletedStatus.ID, >smodel.Status{}) + err := suite.db.DeleteStatusByID(ctx, deletedStatus.ID) suite.NoError(err) // process the status delete diff --git a/internal/processing/instance.go b/internal/processing/instance.go index 1d7bdb377..14ff1de5a 100644 --- a/internal/processing/instance.go +++ b/internal/processing/instance.go @@ -235,7 +235,7 @@ func (p *processor) InstancePatch(ctx context.Context, form *apimodel.InstanceSe if updateInstanceAccount { // if either avatar or header is updated, we need // to update the instance account that stores them - if _, err := p.db.UpdateAccount(ctx, ia); err != nil { + if err := p.db.UpdateAccount(ctx, ia); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error updating instance account: %s", err)) } } diff --git a/internal/transport/controller.go b/internal/transport/controller.go index 4db8ee00e..e7a07016f 100644 --- a/internal/transport/controller.go +++ b/internal/transport/controller.go @@ -28,7 +28,7 @@ import ( "time" "codeberg.org/gruf/go-byteutil" - "codeberg.org/gruf/go-cache/v2" + "codeberg.org/gruf/go-cache/v3" "github.com/superseriousbusiness/activity/pub" "github.com/superseriousbusiness/activity/streams" "github.com/superseriousbusiness/gotosocial/internal/config" @@ -67,8 +67,8 @@ func NewController(db db.DB, federatingDB federatingdb.DB, clock pub.Clock, clie fedDB: federatingDB, clock: clock, client: client, - trspCache: cache.New[string, *transport](), - badHosts: cache.New[string, struct{}](), + trspCache: cache.New[string, *transport](0, 100, 0), + badHosts: cache.New[string, struct{}](0, 1000, 0), userAgent: fmt.Sprintf("%s; %s (gofed/activity gotosocial-%s)", applicationName, host, version), } @@ -110,7 +110,7 @@ func (c *controller) NewTransport(pubKeyID string, privkey *rsa.PrivateKey) (Tra } // Cache this transport under pubkey - if !c.trspCache.Put(pubStr, transp) { + if !c.trspCache.Add(pubStr, transp) { var cached *transport cached, ok = c.trspCache.Get(pubStr) diff --git a/internal/web/etag.go b/internal/web/etag.go index 37c1cb423..4fe3f7cac 100644 --- a/internal/web/etag.go +++ b/internal/web/etag.go @@ -27,11 +27,11 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/log" - "codeberg.org/gruf/go-cache/v2" + "codeberg.org/gruf/go-cache/v3" ) func newETagCache() cache.Cache[string, eTagCacheEntry] { - eTagCache := cache.New[string, eTagCacheEntry]() + eTagCache := cache.New[string, eTagCacheEntry](0, 1000, 0) eTagCache.SetTTL(time.Hour, false) if !eTagCache.Start(time.Minute) { log.Panic("could not start eTagCache") diff --git a/internal/web/rss.go b/internal/web/rss.go index 64be7685c..827a19e87 100644 --- a/internal/web/rss.go +++ b/internal/web/rss.go @@ -123,7 +123,7 @@ func (m *Module) rssFeedGETHandler(c *gin.Context) { cacheEntry.lastModified = accountLastPostedPublic cacheEntry.eTag = eTag - m.eTagCache.Put(cacheKey, cacheEntry) + m.eTagCache.Set(cacheKey, cacheEntry) } c.Header(eTagHeader, cacheEntry.eTag) diff --git a/internal/web/web.go b/internal/web/web.go index cdcf7422f..16be8a71d 100644 --- a/internal/web/web.go +++ b/internal/web/web.go @@ -22,7 +22,7 @@ import ( "errors" "net/http" - "codeberg.org/gruf/go-cache/v2" + "codeberg.org/gruf/go-cache/v3" "github.com/gin-gonic/gin" "github.com/superseriousbusiness/gotosocial/internal/api" "github.com/superseriousbusiness/gotosocial/internal/gtserror" |