diff options
author | 2024-01-19 12:57:29 +0000 | |
---|---|---|
committer | 2024-01-19 12:57:29 +0000 | |
commit | 7ec1e1332e7d04e74451acef18b41f389722b698 (patch) | |
tree | 9c69eca7fc664ab5564279a2e065dfd5c2ddd17b /internal | |
parent | [chore] chore rationalise http return codes for activitypub handlers (#2540) (diff) | |
download | gotosocial-7ec1e1332e7d04e74451acef18b41f389722b698.tar.xz |
[performance] overhaul struct (+ result) caching library for simplicity, performance and multiple-result lookups (#2535)
* rewrite cache library as codeberg.org/gruf/go-structr, implement in gotosocial
* use actual go-structr release version (not just commit hash)
* revert go toolchain changes (damn you go for auto changing this)
* fix go mod woes
* ensure %w is used in calls to errs.Appendf()
* fix error checking
* fix possible panic
* remove unnecessary start/stop functions, move to main Cache{} struct, add note regarding which caches require start/stop
* fix copy-paste artifact... :innocent:
* fix all comment copy-paste artifacts
* remove dropID() function, now we can just use slices.DeleteFunc()
* use util.Deduplicate() instead of collate(), move collate to util
* move orderByIDs() to util package and "generify"
* add a util.DeleteIf() function, use this to delete entries on failed population
* use slices.DeleteFunc() instead of util.DeleteIf() (i had the logic mixed up in my head somehow lol)
* add note about how collate differs from deduplicate
Diffstat (limited to 'internal')
49 files changed, 2496 insertions, 1961 deletions
diff --git a/internal/cache/ap.go b/internal/cache/ap.go deleted file mode 100644 index 6498d7991..000000000 --- a/internal/cache/ap.go +++ /dev/null @@ -1,30 +0,0 @@ -// GoToSocial -// Copyright (C) GoToSocial Authors admin@gotosocial.org -// SPDX-License-Identifier: AGPL-3.0-or-later -// -// This program is free software: you can redistribute it and/or modify -// it under the terms of the GNU Affero General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// This program is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Affero General Public License for more details. -// -// You should have received a copy of the GNU Affero General Public License -// along with this program. If not, see <http://www.gnu.org/licenses/>. - -package cache - -type APCaches struct{} - -// Init will initialize all the ActivityPub caches in this collection. -// NOTE: the cache MUST NOT be in use anywhere, this is not thread-safe. -func (c *APCaches) Init() {} - -// Start will attempt to start all of the ActivityPub caches, or panic. -func (c *APCaches) Start() {} - -// Stop will attempt to stop all of the ActivityPub caches, or panic. -func (c *APCaches) Stop() {} diff --git a/internal/cache/cache.go b/internal/cache/cache.go index 73e3ad6f0..a278336ae 100644 --- a/internal/cache/cache.go +++ b/internal/cache/cache.go @@ -18,8 +18,9 @@ package cache import ( + "time" + "github.com/superseriousbusiness/gotosocial/internal/cache/headerfilter" - "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/log" ) @@ -49,198 +50,59 @@ type Caches struct { func (c *Caches) Init() { log.Infof(nil, "init: %p", c) - c.GTS.Init() - c.Visibility.Init() - - // Setup cache invalidate hooks. - // !! READ THE METHOD COMMENT - c.setuphooks() + c.initAccount() + c.initAccountNote() + c.initApplication() + c.initBlock() + c.initBlockIDs() + c.initBoostOfIDs() + c.initDomainAllow() + c.initDomainBlock() + c.initEmoji() + c.initEmojiCategory() + c.initFollow() + c.initFollowIDs() + c.initFollowRequest() + c.initFollowRequestIDs() + c.initInReplyToIDs() + c.initInstance() + c.initList() + c.initListEntry() + c.initMarker() + c.initMedia() + c.initMention() + c.initNotification() + c.initPoll() + c.initPollVote() + c.initPollVoteIDs() + c.initReport() + c.initStatus() + c.initStatusFave() + c.initTag() + c.initThreadMute() + c.initStatusFaveIDs() + c.initTombstone() + c.initUser() + c.initWebfinger() + c.initVisibility() } -// Start will start both the GTS and AP cache collections. +// Start will start any caches that require a background +// routine, which usually means any kind of TTL caches. func (c *Caches) Start() { log.Infof(nil, "start: %p", c) - c.GTS.Start() - c.Visibility.Start() + tryUntil("starting *gtsmodel.Webfinger cache", 5, func() bool { + return c.GTS.Webfinger.Start(5 * time.Minute) + }) } -// Stop will stop both the GTS and AP cache collections. +// Stop will stop any caches that require a background +// routine, which usually means any kind of TTL caches. func (c *Caches) Stop() { log.Infof(nil, "stop: %p", c) - c.GTS.Stop() - c.Visibility.Stop() -} - -// setuphooks sets necessary cache invalidation hooks between caches, -// as an invalidation indicates a database INSERT / UPDATE / DELETE. -// NOTE THEY ARE ONLY CALLED WHEN THE ITEM IS IN THE CACHE, SO FOR -// HOOKS TO BE CALLED ON DELETE YOU MUST FIRST POPULATE IT IN THE CACHE. -func (c *Caches) setuphooks() { - c.GTS.Account().SetInvalidateCallback(func(account *gtsmodel.Account) { - // Invalidate account ID cached visibility. - c.Visibility.Invalidate("ItemID", account.ID) - c.Visibility.Invalidate("RequesterID", account.ID) - - // Invalidate this account's - // following / follower lists. - // (see FollowIDs() comment for details). - c.GTS.FollowIDs().InvalidateAll( - ">"+account.ID, - "l>"+account.ID, - "<"+account.ID, - "l<"+account.ID, - ) - - // Invalidate this account's - // follow requesting / request lists. - // (see FollowRequestIDs() comment for details). - c.GTS.FollowRequestIDs().InvalidateAll( - ">"+account.ID, - "<"+account.ID, - ) - - // Invalidate this account's block lists. - c.GTS.BlockIDs().Invalidate(account.ID) - }) - - c.GTS.Block().SetInvalidateCallback(func(block *gtsmodel.Block) { - // Invalidate block origin account ID cached visibility. - c.Visibility.Invalidate("ItemID", block.AccountID) - c.Visibility.Invalidate("RequesterID", block.AccountID) - - // Invalidate block target account ID cached visibility. - c.Visibility.Invalidate("ItemID", block.TargetAccountID) - c.Visibility.Invalidate("RequesterID", block.TargetAccountID) - - // Invalidate source account's block lists. - c.GTS.BlockIDs().Invalidate(block.AccountID) - }) - - c.GTS.EmojiCategory().SetInvalidateCallback(func(category *gtsmodel.EmojiCategory) { - // Invalidate any emoji in this category. - c.GTS.Emoji().Invalidate("CategoryID", category.ID) - }) - - c.GTS.Follow().SetInvalidateCallback(func(follow *gtsmodel.Follow) { - // Invalidate follow request with this same ID. - c.GTS.FollowRequest().Invalidate("ID", follow.ID) - - // Invalidate any related list entries. - c.GTS.ListEntry().Invalidate("FollowID", follow.ID) - - // Invalidate follow origin account ID cached visibility. - c.Visibility.Invalidate("ItemID", follow.AccountID) - c.Visibility.Invalidate("RequesterID", follow.AccountID) - - // Invalidate follow target account ID cached visibility. - c.Visibility.Invalidate("ItemID", follow.TargetAccountID) - c.Visibility.Invalidate("RequesterID", follow.TargetAccountID) - - // Invalidate source account's following - // lists, and destination's follwer lists. - // (see FollowIDs() comment for details). - c.GTS.FollowIDs().InvalidateAll( - ">"+follow.AccountID, - "l>"+follow.AccountID, - "<"+follow.AccountID, - "l<"+follow.AccountID, - "<"+follow.TargetAccountID, - "l<"+follow.TargetAccountID, - ">"+follow.TargetAccountID, - "l>"+follow.TargetAccountID, - ) - }) - - c.GTS.FollowRequest().SetInvalidateCallback(func(followReq *gtsmodel.FollowRequest) { - // Invalidate follow with this same ID. - c.GTS.Follow().Invalidate("ID", followReq.ID) - - // Invalidate source account's followreq - // lists, and destinations follow req lists. - // (see FollowRequestIDs() comment for details). - c.GTS.FollowRequestIDs().InvalidateAll( - ">"+followReq.AccountID, - "<"+followReq.AccountID, - ">"+followReq.TargetAccountID, - "<"+followReq.TargetAccountID, - ) - }) - - c.GTS.List().SetInvalidateCallback(func(list *gtsmodel.List) { - // Invalidate all cached entries of this list. - c.GTS.ListEntry().Invalidate("ListID", list.ID) - }) - - c.GTS.Media().SetInvalidateCallback(func(media *gtsmodel.MediaAttachment) { - if *media.Avatar || *media.Header { - // Invalidate cache of attaching account. - c.GTS.Account().Invalidate("ID", media.AccountID) - } - - if media.StatusID != "" { - // Invalidate cache of attaching status. - c.GTS.Status().Invalidate("ID", media.StatusID) - } - }) - - c.GTS.Poll().SetInvalidateCallback(func(poll *gtsmodel.Poll) { - // Invalidate all cached votes of this poll. - c.GTS.PollVote().Invalidate("PollID", poll.ID) - - // Invalidate cache of poll vote IDs. - c.GTS.PollVoteIDs().Invalidate(poll.ID) - }) - - c.GTS.PollVote().SetInvalidateCallback(func(vote *gtsmodel.PollVote) { - // Invalidate cached poll (contains no. votes). - c.GTS.Poll().Invalidate("ID", vote.PollID) - - // Invalidate cache of poll vote IDs. - c.GTS.PollVoteIDs().Invalidate(vote.PollID) - }) - - c.GTS.Status().SetInvalidateCallback(func(status *gtsmodel.Status) { - // Invalidate status ID cached visibility. - c.Visibility.Invalidate("ItemID", status.ID) - - for _, id := range status.AttachmentIDs { - // Invalidate each media by the IDs we're aware of. - // This must be done as the status table is aware of - // the media IDs in use before the media table is - // aware of the status ID they are linked to. - // - // c.GTS.Media().Invalidate("StatusID") will not work. - c.GTS.Media().Invalidate("ID", id) - } - - if status.BoostOfID != "" { - // Invalidate boost ID list of the original status. - c.GTS.BoostOfIDs().Invalidate(status.BoostOfID) - } - - if status.InReplyToID != "" { - // Invalidate in reply to ID list of original status. - c.GTS.InReplyToIDs().Invalidate(status.InReplyToID) - } - - if status.PollID != "" { - // Invalidate cache of attached poll ID. - c.GTS.Poll().Invalidate("ID", status.PollID) - } - }) - - c.GTS.StatusFave().SetInvalidateCallback(func(fave *gtsmodel.StatusFave) { - // Invalidate status fave ID list for this status. - c.GTS.StatusFaveIDs().Invalidate(fave.StatusID) - }) - - c.GTS.User().SetInvalidateCallback(func(user *gtsmodel.User) { - // Invalidate local account ID cached visibility. - c.Visibility.Invalidate("ItemID", user.AccountID) - c.Visibility.Invalidate("RequesterID", user.AccountID) - }) + tryUntil("stopping *gtsmodel.Webfinger cache", 5, c.GTS.Webfinger.Stop) } // Sweep will sweep all the available caches to ensure none @@ -250,30 +112,30 @@ func (c *Caches) setuphooks() { // require an eviction on every single write, which adds // significant overhead to all cache writes. func (c *Caches) Sweep(threshold float64) { - c.GTS.Account().Trim(threshold) - c.GTS.AccountNote().Trim(threshold) - c.GTS.Block().Trim(threshold) - c.GTS.BlockIDs().Trim(threshold) - c.GTS.Emoji().Trim(threshold) - c.GTS.EmojiCategory().Trim(threshold) - c.GTS.Follow().Trim(threshold) - c.GTS.FollowIDs().Trim(threshold) - c.GTS.FollowRequest().Trim(threshold) - c.GTS.FollowRequestIDs().Trim(threshold) - c.GTS.Instance().Trim(threshold) - c.GTS.List().Trim(threshold) - c.GTS.ListEntry().Trim(threshold) - c.GTS.Marker().Trim(threshold) - c.GTS.Media().Trim(threshold) - c.GTS.Mention().Trim(threshold) - c.GTS.Notification().Trim(threshold) - c.GTS.Poll().Trim(threshold) - c.GTS.Report().Trim(threshold) - c.GTS.Status().Trim(threshold) - c.GTS.StatusFave().Trim(threshold) - c.GTS.Tag().Trim(threshold) - c.GTS.ThreadMute().Trim(threshold) - c.GTS.Tombstone().Trim(threshold) - c.GTS.User().Trim(threshold) + c.GTS.Account.Trim(threshold) + c.GTS.AccountNote.Trim(threshold) + c.GTS.Block.Trim(threshold) + c.GTS.BlockIDs.Trim(threshold) + c.GTS.Emoji.Trim(threshold) + c.GTS.EmojiCategory.Trim(threshold) + c.GTS.Follow.Trim(threshold) + c.GTS.FollowIDs.Trim(threshold) + c.GTS.FollowRequest.Trim(threshold) + c.GTS.FollowRequestIDs.Trim(threshold) + c.GTS.Instance.Trim(threshold) + c.GTS.List.Trim(threshold) + c.GTS.ListEntry.Trim(threshold) + c.GTS.Marker.Trim(threshold) + c.GTS.Media.Trim(threshold) + c.GTS.Mention.Trim(threshold) + c.GTS.Notification.Trim(threshold) + c.GTS.Poll.Trim(threshold) + c.GTS.Report.Trim(threshold) + c.GTS.Status.Trim(threshold) + c.GTS.StatusFave.Trim(threshold) + c.GTS.Tag.Trim(threshold) + c.GTS.ThreadMute.Trim(threshold) + c.GTS.Tombstone.Trim(threshold) + c.GTS.User.Trim(threshold) c.Visibility.Trim(threshold) } diff --git a/internal/cache/db.go b/internal/cache/db.go new file mode 100644 index 000000000..894d74109 --- /dev/null +++ b/internal/cache/db.go @@ -0,0 +1,1071 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package cache + +import ( + "time" + + "codeberg.org/gruf/go-cache/v3/simple" + "codeberg.org/gruf/go-cache/v3/ttl" + "codeberg.org/gruf/go-structr" + "github.com/superseriousbusiness/gotosocial/internal/cache/domain" + "github.com/superseriousbusiness/gotosocial/internal/config" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/log" +) + +type GTSCaches struct { + // Account provides access to the gtsmodel Account database cache. + Account structr.Cache[*gtsmodel.Account] + + // AccountNote provides access to the gtsmodel Note database cache. + AccountNote structr.Cache[*gtsmodel.AccountNote] + + // Application provides access to the gtsmodel Application database cache. + Application structr.Cache[*gtsmodel.Application] + + // Block provides access to the gtsmodel Block (account) database cache. + Block structr.Cache[*gtsmodel.Block] + + // FollowIDs provides access to the block IDs database cache. + BlockIDs *SliceCache[string] + + // BoostOfIDs provides access to the boost of IDs list database cache. + BoostOfIDs *SliceCache[string] + + // DomainAllow provides access to the domain allow database cache. + DomainAllow *domain.Cache + + // DomainBlock provides access to the domain block database cache. + DomainBlock *domain.Cache + + // Emoji provides access to the gtsmodel Emoji database cache. + Emoji structr.Cache[*gtsmodel.Emoji] + + // EmojiCategory provides access to the gtsmodel EmojiCategory database cache. + EmojiCategory structr.Cache[*gtsmodel.EmojiCategory] + + // Follow provides access to the gtsmodel Follow database cache. + Follow structr.Cache[*gtsmodel.Follow] + + // FollowIDs provides access to the follower / following IDs database cache. + // THIS CACHE IS KEYED AS THE FOLLOWING {prefix}{accountID} WHERE PREFIX IS: + // - '>' for following IDs + // - 'l>' for local following IDs + // - '<' for follower IDs + // - 'l<' for local follower IDs + FollowIDs *SliceCache[string] + + // FollowRequest provides access to the gtsmodel FollowRequest database cache. + FollowRequest structr.Cache[*gtsmodel.FollowRequest] + + // FollowRequestIDs provides access to the follow requester / requesting IDs database + // cache. THIS CACHE IS KEYED AS THE FOLLOWING {prefix}{accountID} WHERE PREFIX IS: + // - '>' for following IDs + // - '<' for follower IDs + FollowRequestIDs *SliceCache[string] + + // Instance provides access to the gtsmodel Instance database cache. + Instance structr.Cache[*gtsmodel.Instance] + + // InReplyToIDs provides access to the status in reply to IDs list database cache. + InReplyToIDs *SliceCache[string] + + // List provides access to the gtsmodel List database cache. + List structr.Cache[*gtsmodel.List] + + // ListEntry provides access to the gtsmodel ListEntry database cache. + ListEntry structr.Cache[*gtsmodel.ListEntry] + + // Marker provides access to the gtsmodel Marker database cache. + Marker structr.Cache[*gtsmodel.Marker] + + // Media provides access to the gtsmodel Media database cache. + Media structr.Cache[*gtsmodel.MediaAttachment] + + // Mention provides access to the gtsmodel Mention database cache. + Mention structr.Cache[*gtsmodel.Mention] + + // Notification provides access to the gtsmodel Notification database cache. + Notification structr.Cache[*gtsmodel.Notification] + + // Poll provides access to the gtsmodel Poll database cache. + Poll structr.Cache[*gtsmodel.Poll] + + // PollVote provides access to the gtsmodel PollVote database cache. + PollVote structr.Cache[*gtsmodel.PollVote] + + // PollVoteIDs provides access to the poll vote IDs list database cache. + PollVoteIDs *SliceCache[string] + + // Report provides access to the gtsmodel Report database cache. + Report structr.Cache[*gtsmodel.Report] + + // Status provides access to the gtsmodel Status database cache. + Status structr.Cache[*gtsmodel.Status] + + // StatusFave provides access to the gtsmodel StatusFave database cache. + StatusFave structr.Cache[*gtsmodel.StatusFave] + + // StatusFaveIDs provides access to the status fave IDs list database cache. + StatusFaveIDs *SliceCache[string] + + // Tag provides access to the gtsmodel Tag database cache. + Tag structr.Cache[*gtsmodel.Tag] + + // Tombstone provides access to the gtsmodel Tombstone database cache. + Tombstone structr.Cache[*gtsmodel.Tombstone] + + // ThreadMute provides access to the gtsmodel ThreadMute database cache. + ThreadMute structr.Cache[*gtsmodel.ThreadMute] + + // User provides access to the gtsmodel User database cache. + User structr.Cache[*gtsmodel.User] + + // Webfinger provides access to the webfinger URL cache. + // TODO: move out of GTS caches since unrelated to DB. + Webfinger *ttl.Cache[string, string] // TTL=24hr, sweep=5min +} + +// NOTE: +// all of the below init functions +// are receivers to the main cache +// struct type, not the database cache +// struct type, in order to get access +// to the full suite of caches for +// our invalidate function hooks. + +func (c *Caches) initAccount() { + // Calculate maximum cache size. + cap := calculateResultCacheMax( + sizeofAccount(), // model in-mem size. + config.GetCacheAccountMemRatio(), + ) + + log.Infof(nil, "cache size = %d", cap) + + copyF := func(a1 *gtsmodel.Account) *gtsmodel.Account { + a2 := new(gtsmodel.Account) + *a2 = *a1 + + // Don't include ptr fields that + // will be populated separately. + // See internal/db/bundb/account.go. + a2.AvatarMediaAttachment = nil + a2.HeaderMediaAttachment = nil + a2.Emojis = nil + + return a2 + } + + c.GTS.Account.Init(structr.Config[*gtsmodel.Account]{ + Indices: []structr.IndexConfig{ + {Fields: "ID"}, + {Fields: "URI"}, + {Fields: "URL"}, + {Fields: "Username,Domain", AllowZero: true}, + {Fields: "PublicKeyURI"}, + {Fields: "InboxURI"}, + {Fields: "OutboxURI"}, + {Fields: "FollowersURI"}, + {Fields: "FollowingURI"}, + }, + MaxSize: cap, + IgnoreErr: ignoreErrors, + CopyValue: copyF, + Invalidate: c.OnInvalidateAccount, + }) +} + +func (c *Caches) initAccountNote() { + // Calculate maximum cache size. + cap := calculateResultCacheMax( + sizeofAccountNote(), // model in-mem size. + config.GetCacheAccountNoteMemRatio(), + ) + + log.Infof(nil, "cache size = %d", cap) + + copyF := func(n1 *gtsmodel.AccountNote) *gtsmodel.AccountNote { + n2 := new(gtsmodel.AccountNote) + *n2 = *n1 + + // Don't include ptr fields that + // will be populated separately. + // See internal/db/bundb/relationship_note.go. + n2.Account = nil + n2.TargetAccount = nil + + return n2 + } + + c.GTS.AccountNote.Init(structr.Config[*gtsmodel.AccountNote]{ + Indices: []structr.IndexConfig{ + {Fields: "ID"}, + {Fields: "AccountID,TargetAccountID"}, + }, + MaxSize: cap, + IgnoreErr: ignoreErrors, + CopyValue: copyF, + }) +} + +func (c *Caches) initApplication() { + // Calculate maximum cache size. + cap := calculateResultCacheMax( + sizeofApplication(), // model in-mem size. + config.GetCacheApplicationMemRatio(), + ) + + log.Infof(nil, "cache size = %d", cap) + + copyF := func(a1 *gtsmodel.Application) *gtsmodel.Application { + a2 := new(gtsmodel.Application) + *a2 = *a1 + return a2 + } + + c.GTS.Application.Init(structr.Config[*gtsmodel.Application]{ + Indices: []structr.IndexConfig{ + {Fields: "ID"}, + {Fields: "ClientID"}, + }, + MaxSize: cap, + IgnoreErr: ignoreErrors, + CopyValue: copyF, + }) +} + +func (c *Caches) initBlock() { + // Calculate maximum cache size. + cap := calculateResultCacheMax( + sizeofBlock(), // model in-mem size. + config.GetCacheBlockMemRatio(), + ) + + log.Infof(nil, "cache size = %d", cap) + + copyF := func(b1 *gtsmodel.Block) *gtsmodel.Block { + b2 := new(gtsmodel.Block) + *b2 = *b1 + + // Don't include ptr fields that + // will be populated separately. + // See internal/db/bundb/relationship_block.go. + b2.Account = nil + b2.TargetAccount = nil + + return b2 + } + + c.GTS.Block.Init(structr.Config[*gtsmodel.Block]{ + Indices: []structr.IndexConfig{ + {Fields: "ID"}, + {Fields: "URI"}, + {Fields: "AccountID,TargetAccountID"}, + {Fields: "AccountID", Multiple: true}, + {Fields: "TargetAccountID", Multiple: true}, + }, + MaxSize: cap, + IgnoreErr: ignoreErrors, + CopyValue: copyF, + Invalidate: c.OnInvalidateBlock, + }) +} + +func (c *Caches) initBlockIDs() { + // Calculate maximum cache size. + cap := calculateSliceCacheMax( + config.GetCacheBlockIDsMemRatio(), + ) + + log.Infof(nil, "cache size = %d", cap) + + c.GTS.BlockIDs = &SliceCache[string]{Cache: simple.New[string, []string]( + 0, + cap, + )} +} + +func (c *Caches) initBoostOfIDs() { + // Calculate maximum cache size. + cap := calculateSliceCacheMax( + config.GetCacheBoostOfIDsMemRatio(), + ) + + log.Infof(nil, "cache size = %d", cap) + + c.GTS.BoostOfIDs = &SliceCache[string]{Cache: simple.New[string, []string]( + 0, + cap, + )} +} + +func (c *Caches) initDomainAllow() { + c.GTS.DomainAllow = new(domain.Cache) +} + +func (c *Caches) initDomainBlock() { + c.GTS.DomainBlock = new(domain.Cache) +} + +func (c *Caches) initEmoji() { + // Calculate maximum cache size. + cap := calculateResultCacheMax( + sizeofEmoji(), // model in-mem size. + config.GetCacheEmojiMemRatio(), + ) + + log.Infof(nil, "cache size = %d", cap) + + copyF := func(e1 *gtsmodel.Emoji) *gtsmodel.Emoji { + e2 := new(gtsmodel.Emoji) + *e2 = *e1 + + // Don't include ptr fields that + // will be populated separately. + // See internal/db/bundb/emoji.go. + e2.Category = nil + + return e2 + } + + c.GTS.Emoji.Init(structr.Config[*gtsmodel.Emoji]{ + Indices: []structr.IndexConfig{ + {Fields: "ID"}, + {Fields: "URI"}, + {Fields: "Shortcode,Domain", AllowZero: true}, + {Fields: "ImageStaticURL"}, + {Fields: "CategoryID", Multiple: true}, + }, + MaxSize: cap, + IgnoreErr: ignoreErrors, + CopyValue: copyF, + }) +} + +func (c *Caches) initEmojiCategory() { + // Calculate maximum cache size. + cap := calculateResultCacheMax( + sizeofEmojiCategory(), // model in-mem size. + config.GetCacheEmojiCategoryMemRatio(), + ) + + log.Infof(nil, "cache size = %d", cap) + + copyF := func(c1 *gtsmodel.EmojiCategory) *gtsmodel.EmojiCategory { + c2 := new(gtsmodel.EmojiCategory) + *c2 = *c1 + return c2 + } + + c.GTS.EmojiCategory.Init(structr.Config[*gtsmodel.EmojiCategory]{ + Indices: []structr.IndexConfig{ + {Fields: "ID"}, + {Fields: "Name"}, + }, + MaxSize: cap, + IgnoreErr: ignoreErrors, + CopyValue: copyF, + Invalidate: c.OnInvalidateEmojiCategory, + }) +} + +func (c *Caches) initFollow() { + // Calculate maximum cache size. + cap := calculateResultCacheMax( + sizeofFollow(), // model in-mem size. + config.GetCacheFollowMemRatio(), + ) + + log.Infof(nil, "cache size = %d", cap) + + copyF := func(f1 *gtsmodel.Follow) *gtsmodel.Follow { + f2 := new(gtsmodel.Follow) + *f2 = *f1 + + // Don't include ptr fields that + // will be populated separately. + // See internal/db/bundb/relationship_follow.go. + f2.Account = nil + f2.TargetAccount = nil + + return f2 + } + + c.GTS.Follow.Init(structr.Config[*gtsmodel.Follow]{ + Indices: []structr.IndexConfig{ + {Fields: "ID"}, + {Fields: "URI"}, + {Fields: "AccountID,TargetAccountID"}, + {Fields: "AccountID", Multiple: true}, + {Fields: "TargetAccountID", Multiple: true}, + }, + MaxSize: cap, + IgnoreErr: ignoreErrors, + CopyValue: copyF, + Invalidate: c.OnInvalidateFollow, + }) +} + +func (c *Caches) initFollowIDs() { + // Calculate maximum cache size. + cap := calculateSliceCacheMax( + config.GetCacheFollowIDsMemRatio(), + ) + + log.Infof(nil, "cache size = %d", cap) + + c.GTS.FollowIDs = &SliceCache[string]{Cache: simple.New[string, []string]( + 0, + cap, + )} +} + +func (c *Caches) initFollowRequest() { + // Calculate maximum cache size. + cap := calculateResultCacheMax( + sizeofFollowRequest(), // model in-mem size. + config.GetCacheFollowRequestMemRatio(), + ) + + log.Infof(nil, "cache size = %d", cap) + + copyF := func(f1 *gtsmodel.FollowRequest) *gtsmodel.FollowRequest { + f2 := new(gtsmodel.FollowRequest) + *f2 = *f1 + + // Don't include ptr fields that + // will be populated separately. + // See internal/db/bundb/relationship_follow_req.go. + f2.Account = nil + f2.TargetAccount = nil + + return f2 + } + + c.GTS.FollowRequest.Init(structr.Config[*gtsmodel.FollowRequest]{ + Indices: []structr.IndexConfig{ + {Fields: "ID"}, + {Fields: "URI"}, + {Fields: "AccountID,TargetAccountID"}, + {Fields: "AccountID", Multiple: true}, + {Fields: "TargetAccountID", Multiple: true}, + }, + MaxSize: cap, + IgnoreErr: ignoreErrors, + CopyValue: copyF, + Invalidate: c.OnInvalidateFollowRequest, + }) +} + +func (c *Caches) initFollowRequestIDs() { + // Calculate maximum cache size. + cap := calculateSliceCacheMax( + config.GetCacheFollowRequestIDsMemRatio(), + ) + + log.Infof(nil, "cache size = %d", cap) + + c.GTS.FollowRequestIDs = &SliceCache[string]{Cache: simple.New[string, []string]( + 0, + cap, + )} +} + +func (c *Caches) initInReplyToIDs() { + // Calculate maximum cache size. + cap := calculateSliceCacheMax( + config.GetCacheInReplyToIDsMemRatio(), + ) + + log.Infof(nil, "cache size = %d", cap) + + c.GTS.InReplyToIDs = &SliceCache[string]{Cache: simple.New[string, []string]( + 0, + cap, + )} +} + +func (c *Caches) initInstance() { + // Calculate maximum cache size. + cap := calculateResultCacheMax( + sizeofInstance(), // model in-mem size. + config.GetCacheInstanceMemRatio(), + ) + + log.Infof(nil, "cache size = %d", cap) + + copyF := func(i1 *gtsmodel.Instance) *gtsmodel.Instance { + i2 := new(gtsmodel.Instance) + *i2 = *i1 + + // Don't include ptr fields that + // will be populated separately. + // See internal/db/bundb/instance.go. + i2.DomainBlock = nil + i2.ContactAccount = nil + + return i1 + } + + c.GTS.Instance.Init(structr.Config[*gtsmodel.Instance]{ + Indices: []structr.IndexConfig{ + {Fields: "ID"}, + {Fields: "Domain"}, + }, + MaxSize: cap, + IgnoreErr: ignoreErrors, + CopyValue: copyF, + }) +} + +func (c *Caches) initList() { + // Calculate maximum cache size. + cap := calculateResultCacheMax( + sizeofList(), // model in-mem size. + config.GetCacheListMemRatio(), + ) + + log.Infof(nil, "cache size = %d", cap) + + copyF := func(l1 *gtsmodel.List) *gtsmodel.List { + l2 := new(gtsmodel.List) + *l2 = *l1 + + // Don't include ptr fields that + // will be populated separately. + // See internal/db/bundb/list.go. + l2.Account = nil + l2.ListEntries = nil + + return l2 + } + + c.GTS.List.Init(structr.Config[*gtsmodel.List]{ + Indices: []structr.IndexConfig{ + {Fields: "ID"}, + }, + MaxSize: cap, + IgnoreErr: ignoreErrors, + CopyValue: copyF, + Invalidate: c.OnInvalidateList, + }) +} + +func (c *Caches) initListEntry() { + // Calculate maximum cache size. + cap := calculateResultCacheMax( + sizeofListEntry(), // model in-mem size. + config.GetCacheListEntryMemRatio(), + ) + + log.Infof(nil, "cache size = %d", cap) + + copyF := func(l1 *gtsmodel.ListEntry) *gtsmodel.ListEntry { + l2 := new(gtsmodel.ListEntry) + *l2 = *l1 + + // Don't include ptr fields that + // will be populated separately. + // See internal/db/bundb/list.go. + l2.Follow = nil + + return l2 + } + + c.GTS.ListEntry.Init(structr.Config[*gtsmodel.ListEntry]{ + Indices: []structr.IndexConfig{ + {Fields: "ID"}, + {Fields: "ListID", Multiple: true}, + {Fields: "FollowID", Multiple: true}, + }, + MaxSize: cap, + IgnoreErr: ignoreErrors, + CopyValue: copyF, + }) +} + +func (c *Caches) initMarker() { + // Calculate maximum cache size. + cap := calculateResultCacheMax( + sizeofMarker(), // model in-mem size. + config.GetCacheMarkerMemRatio(), + ) + + log.Infof(nil, "cache size = %d", cap) + + copyF := func(m1 *gtsmodel.Marker) *gtsmodel.Marker { + m2 := new(gtsmodel.Marker) + *m2 = *m1 + return m2 + } + + c.GTS.Marker.Init(structr.Config[*gtsmodel.Marker]{ + Indices: []structr.IndexConfig{ + {Fields: "AccountID,Name"}, + }, + MaxSize: cap, + IgnoreErr: ignoreErrors, + CopyValue: copyF, + }) +} + +func (c *Caches) initMedia() { + // Calculate maximum cache size. + cap := calculateResultCacheMax( + sizeofMedia(), // model in-mem size. + config.GetCacheMediaMemRatio(), + ) + + log.Infof(nil, "cache size = %d", cap) + + copyF := func(m1 *gtsmodel.MediaAttachment) *gtsmodel.MediaAttachment { + m2 := new(gtsmodel.MediaAttachment) + *m2 = *m1 + return m2 + } + + c.GTS.Media.Init(structr.Config[*gtsmodel.MediaAttachment]{ + Indices: []structr.IndexConfig{ + {Fields: "ID"}, + }, + MaxSize: cap, + IgnoreErr: ignoreErrors, + CopyValue: copyF, + Invalidate: c.OnInvalidateMedia, + }) +} + +func (c *Caches) initMention() { + // Calculate maximum cache size. + cap := calculateResultCacheMax( + sizeofMention(), // model in-mem size. + config.GetCacheMentionMemRatio(), + ) + + log.Infof(nil, "cache size = %d", cap) + + copyF := func(m1 *gtsmodel.Mention) *gtsmodel.Mention { + m2 := new(gtsmodel.Mention) + *m2 = *m1 + + // Don't include ptr fields that + // will be populated separately. + // See internal/db/bundb/mention.go. + m2.Status = nil + m2.OriginAccount = nil + m2.TargetAccount = nil + + return m2 + } + + c.GTS.Mention.Init(structr.Config[*gtsmodel.Mention]{ + Indices: []structr.IndexConfig{ + {Fields: "ID"}, + }, + MaxSize: cap, + IgnoreErr: ignoreErrors, + CopyValue: copyF, + }) +} + +func (c *Caches) initNotification() { + // Calculate maximum cache size. + cap := calculateResultCacheMax( + sizeofNotification(), // model in-mem size. + config.GetCacheNotificationMemRatio(), + ) + + log.Infof(nil, "cache size = %d", cap) + + copyF := func(n1 *gtsmodel.Notification) *gtsmodel.Notification { + n2 := new(gtsmodel.Notification) + *n2 = *n1 + + // Don't include ptr fields that + // will be populated separately. + // See internal/db/bundb/notification.go. + n2.Status = nil + n2.OriginAccount = nil + n2.TargetAccount = nil + + return n2 + } + + c.GTS.Notification.Init(structr.Config[*gtsmodel.Notification]{ + Indices: []structr.IndexConfig{ + {Fields: "ID"}, + {Fields: "NotificationType,TargetAccountID,OriginAccountID,StatusID", AllowZero: true}, + }, + MaxSize: cap, + IgnoreErr: ignoreErrors, + CopyValue: copyF, + }) +} + +func (c *Caches) initPoll() { + // Calculate maximum cache size. + cap := calculateResultCacheMax( + sizeofPoll(), // model in-mem size. + config.GetCachePollMemRatio(), + ) + + log.Infof(nil, "cache size = %d", cap) + + copyF := func(p1 *gtsmodel.Poll) *gtsmodel.Poll { + p2 := new(gtsmodel.Poll) + *p2 = *p1 + + // Don't include ptr fields that + // will be populated separately. + // See internal/db/bundb/poll.go. + p2.Status = nil + + // Don't include ephemeral fields + // which are only expected to be + // set on ONE poll instance. + p2.Closing = false + + return p2 + } + + c.GTS.Poll.Init(structr.Config[*gtsmodel.Poll]{ + Indices: []structr.IndexConfig{ + {Fields: "ID"}, + {Fields: "StatusID"}, + }, + MaxSize: cap, + IgnoreErr: ignoreErrors, + CopyValue: copyF, + Invalidate: c.OnInvalidatePoll, + }) +} + +func (c *Caches) initPollVote() { + // Calculate maximum cache size. + cap := calculateResultCacheMax( + sizeofPollVote(), // model in-mem size. + config.GetCachePollVoteMemRatio(), + ) + + log.Infof(nil, "cache size = %d", cap) + + copyF := func(v1 *gtsmodel.PollVote) *gtsmodel.PollVote { + v2 := new(gtsmodel.PollVote) + *v2 = *v1 + + // Don't include ptr fields that + // will be populated separately. + // See internal/db/bundb/poll.go. + v2.Account = nil + v2.Poll = nil + + return v2 + } + + c.GTS.PollVote.Init(structr.Config[*gtsmodel.PollVote]{ + Indices: []structr.IndexConfig{ + {Fields: "ID"}, + {Fields: "PollID", Multiple: true}, + {Fields: "PollID,AccountID"}, + }, + MaxSize: cap, + IgnoreErr: ignoreErrors, + CopyValue: copyF, + Invalidate: c.OnInvalidatePollVote, + }) +} + +func (c *Caches) initPollVoteIDs() { + // Calculate maximum cache size. + cap := calculateSliceCacheMax( + config.GetCachePollVoteIDsMemRatio(), + ) + + log.Infof(nil, "cache size = %d", cap) + + c.GTS.PollVoteIDs = &SliceCache[string]{Cache: simple.New[string, []string]( + 0, + cap, + )} +} + +func (c *Caches) initReport() { + // Calculate maximum cache size. + cap := calculateResultCacheMax( + sizeofReport(), // model in-mem size. + config.GetCacheReportMemRatio(), + ) + + log.Infof(nil, "cache size = %d", cap) + + copyF := func(r1 *gtsmodel.Report) *gtsmodel.Report { + r2 := new(gtsmodel.Report) + *r2 = *r1 + + // Don't include ptr fields that + // will be populated separately. + // See internal/db/bundb/report.go. + r2.Account = nil + r2.TargetAccount = nil + r2.Statuses = nil + r2.Rules = nil + r2.ActionTakenByAccount = nil + + return r2 + } + + c.GTS.Report.Init(structr.Config[*gtsmodel.Report]{ + Indices: []structr.IndexConfig{ + {Fields: "ID"}, + }, + MaxSize: cap, + IgnoreErr: ignoreErrors, + CopyValue: copyF, + }) +} + +func (c *Caches) initStatus() { + // Calculate maximum cache size. + cap := calculateResultCacheMax( + sizeofStatus(), // model in-mem size. + config.GetCacheStatusMemRatio(), + ) + + log.Infof(nil, "cache size = %d", cap) + + copyF := func(s1 *gtsmodel.Status) *gtsmodel.Status { + s2 := new(gtsmodel.Status) + *s2 = *s1 + + // Don't include ptr fields that + // will be populated separately. + // See internal/db/bundb/status.go. + s2.Account = nil + s2.InReplyTo = nil + s2.InReplyToAccount = nil + s2.BoostOf = nil + s2.BoostOfAccount = nil + s2.Poll = nil + s2.Attachments = nil + s2.Tags = nil + s2.Mentions = nil + s2.Emojis = nil + s2.CreatedWithApplication = nil + + return s2 + } + + c.GTS.Status.Init(structr.Config[*gtsmodel.Status]{ + Indices: []structr.IndexConfig{ + {Fields: "ID"}, + {Fields: "URI"}, + {Fields: "URL"}, + {Fields: "PollID"}, + {Fields: "BoostOfID,AccountID"}, + {Fields: "ThreadID", Multiple: true}, + }, + MaxSize: cap, + IgnoreErr: ignoreErrors, + CopyValue: copyF, + Invalidate: c.OnInvalidateStatus, + }) +} + +func (c *Caches) initStatusFave() { + // Calculate maximum cache size. + cap := calculateResultCacheMax( + sizeofStatusFave(), // model in-mem size. + config.GetCacheStatusFaveMemRatio(), + ) + + log.Infof(nil, "cache size = %d", cap) + + copyF := func(f1 *gtsmodel.StatusFave) *gtsmodel.StatusFave { + f2 := new(gtsmodel.StatusFave) + *f2 = *f1 + + // Don't include ptr fields that + // will be populated separately. + // See internal/db/bundb/statusfave.go. + f2.Account = nil + f2.TargetAccount = nil + f2.Status = nil + + return f2 + } + + c.GTS.StatusFave.Init(structr.Config[*gtsmodel.StatusFave]{ + Indices: []structr.IndexConfig{ + {Fields: "ID"}, + {Fields: "AccountID,StatusID"}, + {Fields: "StatusID", Multiple: true}, + }, + MaxSize: cap, + IgnoreErr: ignoreErrors, + CopyValue: copyF, + Invalidate: c.OnInvalidateStatusFave, + }) +} + +func (c *Caches) initStatusFaveIDs() { + // Calculate maximum cache size. + cap := calculateSliceCacheMax( + config.GetCacheStatusFaveIDsMemRatio(), + ) + + log.Infof(nil, "cache size = %d", cap) + + c.GTS.StatusFaveIDs = &SliceCache[string]{Cache: simple.New[string, []string]( + 0, + cap, + )} +} + +func (c *Caches) initTag() { + // Calculate maximum cache size. + cap := calculateResultCacheMax( + sizeofTag(), // model in-mem size. + config.GetCacheTagMemRatio(), + ) + + log.Infof(nil, "cache size = %d", cap) + + copyF := func(m1 *gtsmodel.Tag) *gtsmodel.Tag { + m2 := new(gtsmodel.Tag) + *m2 = *m1 + return m2 + } + + c.GTS.Tag.Init(structr.Config[*gtsmodel.Tag]{ + Indices: []structr.IndexConfig{ + {Fields: "ID"}, + {Fields: "Name"}, + }, + MaxSize: cap, + IgnoreErr: ignoreErrors, + CopyValue: copyF, + }) +} + +func (c *Caches) initThreadMute() { + cap := calculateResultCacheMax( + sizeOfThreadMute(), // model in-mem size. + config.GetCacheThreadMuteMemRatio(), + ) + + log.Infof(nil, "cache size = %d", cap) + + copyF := func(t1 *gtsmodel.ThreadMute) *gtsmodel.ThreadMute { + t2 := new(gtsmodel.ThreadMute) + *t2 = *t1 + return t2 + } + + c.GTS.ThreadMute.Init(structr.Config[*gtsmodel.ThreadMute]{ + Indices: []structr.IndexConfig{ + {Fields: "ID"}, + {Fields: "ThreadID", Multiple: true}, + {Fields: "AccountID", Multiple: true}, + {Fields: "ThreadID,AccountID"}, + }, + MaxSize: cap, + IgnoreErr: ignoreErrors, + CopyValue: copyF, + }) +} + +func (c *Caches) initTombstone() { + // Calculate maximum cache size. + cap := calculateResultCacheMax( + sizeofTombstone(), // model in-mem size. + config.GetCacheTombstoneMemRatio(), + ) + + log.Infof(nil, "cache size = %d", cap) + + copyF := func(t1 *gtsmodel.Tombstone) *gtsmodel.Tombstone { + t2 := new(gtsmodel.Tombstone) + *t2 = *t1 + return t2 + } + + c.GTS.Tombstone.Init(structr.Config[*gtsmodel.Tombstone]{ + Indices: []structr.IndexConfig{ + {Fields: "ID"}, + {Fields: "URI"}, + }, + MaxSize: cap, + IgnoreErr: ignoreErrors, + CopyValue: copyF, + }) +} + +func (c *Caches) initUser() { + // Calculate maximum cache size. + cap := calculateResultCacheMax( + sizeofUser(), // model in-mem size. + config.GetCacheUserMemRatio(), + ) + + log.Infof(nil, "cache size = %d", cap) + + copyF := func(u1 *gtsmodel.User) *gtsmodel.User { + u2 := new(gtsmodel.User) + *u2 = *u1 + + // Don't include ptr fields that + // will be populated separately. + // See internal/db/bundb/user.go. + u2.Account = nil + + return u2 + } + + c.GTS.User.Init(structr.Config[*gtsmodel.User]{ + Indices: []structr.IndexConfig{ + {Fields: "ID"}, + {Fields: "AccountID"}, + {Fields: "Email"}, + {Fields: "ConfirmationToken"}, + {Fields: "ExternalID"}, + }, + MaxSize: cap, + IgnoreErr: ignoreErrors, + CopyValue: copyF, + Invalidate: c.OnInvalidateUser, + }) +} + +func (c *Caches) initWebfinger() { + // Calculate maximum cache size. + cap := calculateCacheMax( + sizeofURIStr, sizeofURIStr, + config.GetCacheWebfingerMemRatio(), + ) + + log.Infof(nil, "cache size = %d", cap) + + c.GTS.Webfinger = new(ttl.Cache[string, string]) + c.GTS.Webfinger.Init( + 0, + cap, + 24*time.Hour, + ) +} diff --git a/internal/cache/gts.go b/internal/cache/gts.go deleted file mode 100644 index 507947305..000000000 --- a/internal/cache/gts.go +++ /dev/null @@ -1,1119 +0,0 @@ -// GoToSocial -// Copyright (C) GoToSocial Authors admin@gotosocial.org -// SPDX-License-Identifier: AGPL-3.0-or-later -// -// This program is free software: you can redistribute it and/or modify -// it under the terms of the GNU Affero General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// This program is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Affero General Public License for more details. -// -// You should have received a copy of the GNU Affero General Public License -// along with this program. If not, see <http://www.gnu.org/licenses/>. - -package cache - -import ( - "time" - - "codeberg.org/gruf/go-cache/v3/result" - "codeberg.org/gruf/go-cache/v3/simple" - "codeberg.org/gruf/go-cache/v3/ttl" - "github.com/superseriousbusiness/gotosocial/internal/cache/domain" - "github.com/superseriousbusiness/gotosocial/internal/config" - "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" - "github.com/superseriousbusiness/gotosocial/internal/log" -) - -type GTSCaches struct { - account *result.Cache[*gtsmodel.Account] - accountNote *result.Cache[*gtsmodel.AccountNote] - application *result.Cache[*gtsmodel.Application] - block *result.Cache[*gtsmodel.Block] - blockIDs *SliceCache[string] - boostOfIDs *SliceCache[string] - domainAllow *domain.Cache - domainBlock *domain.Cache - emoji *result.Cache[*gtsmodel.Emoji] - emojiCategory *result.Cache[*gtsmodel.EmojiCategory] - follow *result.Cache[*gtsmodel.Follow] - followIDs *SliceCache[string] - followRequest *result.Cache[*gtsmodel.FollowRequest] - followRequestIDs *SliceCache[string] - instance *result.Cache[*gtsmodel.Instance] - inReplyToIDs *SliceCache[string] - list *result.Cache[*gtsmodel.List] - listEntry *result.Cache[*gtsmodel.ListEntry] - marker *result.Cache[*gtsmodel.Marker] - media *result.Cache[*gtsmodel.MediaAttachment] - mention *result.Cache[*gtsmodel.Mention] - notification *result.Cache[*gtsmodel.Notification] - poll *result.Cache[*gtsmodel.Poll] - pollVote *result.Cache[*gtsmodel.PollVote] - pollVoteIDs *SliceCache[string] - report *result.Cache[*gtsmodel.Report] - status *result.Cache[*gtsmodel.Status] - statusFave *result.Cache[*gtsmodel.StatusFave] - statusFaveIDs *SliceCache[string] - tag *result.Cache[*gtsmodel.Tag] - threadMute *result.Cache[*gtsmodel.ThreadMute] - tombstone *result.Cache[*gtsmodel.Tombstone] - user *result.Cache[*gtsmodel.User] - - // TODO: move out of GTS caches since unrelated to DB. - webfinger *ttl.Cache[string, string] // TTL=24hr, sweep=5min -} - -// Init will initialize all the gtsmodel caches in this collection. -// NOTE: the cache MUST NOT be in use anywhere, this is not thread-safe. -func (c *GTSCaches) Init() { - c.initAccount() - c.initAccountNote() - c.initApplication() - c.initBlock() - c.initBlockIDs() - c.initBoostOfIDs() - c.initDomainAllow() - c.initDomainBlock() - c.initEmoji() - c.initEmojiCategory() - c.initFollow() - c.initFollowIDs() - c.initFollowRequest() - c.initFollowRequestIDs() - c.initInReplyToIDs() - c.initInstance() - c.initList() - c.initListEntry() - c.initMarker() - c.initMedia() - c.initMention() - c.initNotification() - c.initPoll() - c.initPollVote() - c.initPollVoteIDs() - c.initReport() - c.initStatus() - c.initStatusFave() - c.initTag() - c.initThreadMute() - c.initStatusFaveIDs() - c.initTombstone() - c.initUser() - c.initWebfinger() -} - -// Start will attempt to start all of the gtsmodel caches, or panic. -func (c *GTSCaches) Start() { - tryUntil("starting *gtsmodel.Webfinger cache", 5, func() bool { - return c.webfinger.Start(5 * time.Minute) - }) -} - -// Stop will attempt to stop all of the gtsmodel caches, or panic. -func (c *GTSCaches) Stop() { - tryUntil("stopping *gtsmodel.Webfinger cache", 5, c.webfinger.Stop) -} - -// Account provides access to the gtsmodel Account database cache. -func (c *GTSCaches) Account() *result.Cache[*gtsmodel.Account] { - return c.account -} - -// AccountNote provides access to the gtsmodel Note database cache. -func (c *GTSCaches) AccountNote() *result.Cache[*gtsmodel.AccountNote] { - return c.accountNote -} - -// Application provides access to the gtsmodel Application database cache. -func (c *GTSCaches) Application() *result.Cache[*gtsmodel.Application] { - return c.application -} - -// Block provides access to the gtsmodel Block (account) database cache. -func (c *GTSCaches) Block() *result.Cache[*gtsmodel.Block] { - return c.block -} - -// FollowIDs provides access to the block IDs database cache. -func (c *GTSCaches) BlockIDs() *SliceCache[string] { - return c.blockIDs -} - -// BoostOfIDs provides access to the boost of IDs list database cache. -func (c *GTSCaches) BoostOfIDs() *SliceCache[string] { - return c.boostOfIDs -} - -// DomainAllow provides access to the domain allow database cache. -func (c *GTSCaches) DomainAllow() *domain.Cache { - return c.domainAllow -} - -// DomainBlock provides access to the domain block database cache. -func (c *GTSCaches) DomainBlock() *domain.Cache { - return c.domainBlock -} - -// Emoji provides access to the gtsmodel Emoji database cache. -func (c *GTSCaches) Emoji() *result.Cache[*gtsmodel.Emoji] { - return c.emoji -} - -// EmojiCategory provides access to the gtsmodel EmojiCategory database cache. -func (c *GTSCaches) EmojiCategory() *result.Cache[*gtsmodel.EmojiCategory] { - return c.emojiCategory -} - -// Follow provides access to the gtsmodel Follow database cache. -func (c *GTSCaches) Follow() *result.Cache[*gtsmodel.Follow] { - return c.follow -} - -// FollowIDs provides access to the follower / following IDs database cache. -// THIS CACHE IS KEYED AS THE FOLLOWING {prefix}{accountID} WHERE PREFIX IS: -// - '>' for following IDs -// - 'l>' for local following IDs -// - '<' for follower IDs -// - 'l<' for local follower IDs -func (c *GTSCaches) FollowIDs() *SliceCache[string] { - return c.followIDs -} - -// FollowRequest provides access to the gtsmodel FollowRequest database cache. -func (c *GTSCaches) FollowRequest() *result.Cache[*gtsmodel.FollowRequest] { - return c.followRequest -} - -// FollowRequestIDs provides access to the follow requester / requesting IDs database -// cache. THIS CACHE IS KEYED AS THE FOLLOWING {prefix}{accountID} WHERE PREFIX IS: -// - '>' for following IDs -// - '<' for follower IDs -func (c *GTSCaches) FollowRequestIDs() *SliceCache[string] { - return c.followRequestIDs -} - -// Instance provides access to the gtsmodel Instance database cache. -func (c *GTSCaches) Instance() *result.Cache[*gtsmodel.Instance] { - return c.instance -} - -// InReplyToIDs provides access to the status in reply to IDs list database cache. -func (c *GTSCaches) InReplyToIDs() *SliceCache[string] { - return c.inReplyToIDs -} - -// List provides access to the gtsmodel List database cache. -func (c *GTSCaches) List() *result.Cache[*gtsmodel.List] { - return c.list -} - -// ListEntry provides access to the gtsmodel ListEntry database cache. -func (c *GTSCaches) ListEntry() *result.Cache[*gtsmodel.ListEntry] { - return c.listEntry -} - -// Marker provides access to the gtsmodel Marker database cache. -func (c *GTSCaches) Marker() *result.Cache[*gtsmodel.Marker] { - return c.marker -} - -// Media provides access to the gtsmodel Media database cache. -func (c *GTSCaches) Media() *result.Cache[*gtsmodel.MediaAttachment] { - return c.media -} - -// Mention provides access to the gtsmodel Mention database cache. -func (c *GTSCaches) Mention() *result.Cache[*gtsmodel.Mention] { - return c.mention -} - -// Notification provides access to the gtsmodel Notification database cache. -func (c *GTSCaches) Notification() *result.Cache[*gtsmodel.Notification] { - return c.notification -} - -// Poll provides access to the gtsmodel Poll database cache. -func (c *GTSCaches) Poll() *result.Cache[*gtsmodel.Poll] { - return c.poll -} - -// PollVote provides access to the gtsmodel PollVote database cache. -func (c *GTSCaches) PollVote() *result.Cache[*gtsmodel.PollVote] { - return c.pollVote -} - -// PollVoteIDs provides access to the poll vote IDs list database cache. -func (c *GTSCaches) PollVoteIDs() *SliceCache[string] { - return c.pollVoteIDs -} - -// Report provides access to the gtsmodel Report database cache. -func (c *GTSCaches) Report() *result.Cache[*gtsmodel.Report] { - return c.report -} - -// Status provides access to the gtsmodel Status database cache. -func (c *GTSCaches) Status() *result.Cache[*gtsmodel.Status] { - return c.status -} - -// StatusFave provides access to the gtsmodel StatusFave database cache. -func (c *GTSCaches) StatusFave() *result.Cache[*gtsmodel.StatusFave] { - return c.statusFave -} - -// StatusFaveIDs provides access to the status fave IDs list database cache. -func (c *GTSCaches) StatusFaveIDs() *SliceCache[string] { - return c.statusFaveIDs -} - -// Tag provides access to the gtsmodel Tag database cache. -func (c *GTSCaches) Tag() *result.Cache[*gtsmodel.Tag] { - return c.tag -} - -// Tombstone provides access to the gtsmodel Tombstone database cache. -func (c *GTSCaches) Tombstone() *result.Cache[*gtsmodel.Tombstone] { - return c.tombstone -} - -// ThreadMute provides access to the gtsmodel ThreadMute database cache. -func (c *GTSCaches) ThreadMute() *result.Cache[*gtsmodel.ThreadMute] { - return c.threadMute -} - -// User provides access to the gtsmodel User database cache. -func (c *GTSCaches) User() *result.Cache[*gtsmodel.User] { - return c.user -} - -// Webfinger provides access to the webfinger URL cache. -func (c *GTSCaches) Webfinger() *ttl.Cache[string, string] { - return c.webfinger -} - -func (c *GTSCaches) initAccount() { - // Calculate maximum cache size. - cap := calculateResultCacheMax( - sizeofAccount(), // model in-mem size. - config.GetCacheAccountMemRatio(), - ) - - log.Infof(nil, "cache size = %d", cap) - - copyF := func(a1 *gtsmodel.Account) *gtsmodel.Account { - a2 := new(gtsmodel.Account) - *a2 = *a1 - - // Don't include ptr fields that - // will be populated separately. - // See internal/db/bundb/account.go. - a2.AvatarMediaAttachment = nil - a2.HeaderMediaAttachment = nil - a2.Emojis = nil - - return a2 - } - - c.account = result.New([]result.Lookup{ - {Name: "ID"}, - {Name: "URI"}, - {Name: "URL"}, - {Name: "Username.Domain", AllowZero: true /* domain can be zero i.e. "" */}, - {Name: "PublicKeyURI"}, - {Name: "InboxURI"}, - {Name: "OutboxURI"}, - {Name: "FollowersURI"}, - {Name: "FollowingURI"}, - }, copyF, cap) - - c.account.IgnoreErrors(ignoreErrors) -} - -func (c *GTSCaches) initAccountNote() { - // Calculate maximum cache size. - cap := calculateResultCacheMax( - sizeofAccountNote(), // model in-mem size. - config.GetCacheAccountNoteMemRatio(), - ) - - log.Infof(nil, "cache size = %d", cap) - - copyF := func(n1 *gtsmodel.AccountNote) *gtsmodel.AccountNote { - n2 := new(gtsmodel.AccountNote) - *n2 = *n1 - - // Don't include ptr fields that - // will be populated separately. - // See internal/db/bundb/relationship_note.go. - n2.Account = nil - n2.TargetAccount = nil - - return n2 - } - - c.accountNote = result.New([]result.Lookup{ - {Name: "ID"}, - {Name: "AccountID.TargetAccountID"}, - }, copyF, cap) - - c.accountNote.IgnoreErrors(ignoreErrors) -} - -func (c *GTSCaches) initApplication() { - // Calculate maximum cache size. - cap := calculateResultCacheMax( - sizeofApplication(), // model in-mem size. - config.GetCacheApplicationMemRatio(), - ) - - log.Infof(nil, "cache size = %d", cap) - - c.application = result.New([]result.Lookup{ - {Name: "ID"}, - {Name: "ClientID"}, - }, func(a1 *gtsmodel.Application) *gtsmodel.Application { - a2 := new(gtsmodel.Application) - *a2 = *a1 - return a2 - }, cap) - - c.application.IgnoreErrors(ignoreErrors) -} - -func (c *GTSCaches) initBlock() { - // Calculate maximum cache size. - cap := calculateResultCacheMax( - sizeofBlock(), // model in-mem size. - config.GetCacheBlockMemRatio(), - ) - - log.Infof(nil, "cache size = %d", cap) - - copyF := func(b1 *gtsmodel.Block) *gtsmodel.Block { - b2 := new(gtsmodel.Block) - *b2 = *b1 - - // Don't include ptr fields that - // will be populated separately. - // See internal/db/bundb/relationship_block.go. - b2.Account = nil - b2.TargetAccount = nil - - return b2 - } - - c.block = result.New([]result.Lookup{ - {Name: "ID"}, - {Name: "URI"}, - {Name: "AccountID.TargetAccountID"}, - {Name: "AccountID", Multi: true}, - {Name: "TargetAccountID", Multi: true}, - }, copyF, cap) - - c.block.IgnoreErrors(ignoreErrors) -} - -func (c *GTSCaches) initBlockIDs() { - // Calculate maximum cache size. - cap := calculateSliceCacheMax( - config.GetCacheBlockIDsMemRatio(), - ) - - log.Infof(nil, "cache size = %d", cap) - - c.blockIDs = &SliceCache[string]{Cache: simple.New[string, []string]( - 0, - cap, - )} -} - -func (c *GTSCaches) initBoostOfIDs() { - // Calculate maximum cache size. - cap := calculateSliceCacheMax( - config.GetCacheBoostOfIDsMemRatio(), - ) - - log.Infof(nil, "cache size = %d", cap) - - c.boostOfIDs = &SliceCache[string]{Cache: simple.New[string, []string]( - 0, - cap, - )} -} - -func (c *GTSCaches) initDomainAllow() { - c.domainAllow = new(domain.Cache) -} - -func (c *GTSCaches) initDomainBlock() { - c.domainBlock = new(domain.Cache) -} - -func (c *GTSCaches) initEmoji() { - // Calculate maximum cache size. - cap := calculateResultCacheMax( - sizeofEmoji(), // model in-mem size. - config.GetCacheEmojiMemRatio(), - ) - - log.Infof(nil, "cache size = %d", cap) - - copyF := func(e1 *gtsmodel.Emoji) *gtsmodel.Emoji { - e2 := new(gtsmodel.Emoji) - *e2 = *e1 - - // Don't include ptr fields that - // will be populated separately. - // See internal/db/bundb/emoji.go. - e2.Category = nil - - return e2 - } - - c.emoji = result.New([]result.Lookup{ - {Name: "ID"}, - {Name: "URI"}, - {Name: "Shortcode.Domain", AllowZero: true /* domain can be zero i.e. "" */}, - {Name: "ImageStaticURL"}, - {Name: "CategoryID", Multi: true}, - }, copyF, cap) - - c.emoji.IgnoreErrors(ignoreErrors) -} - -func (c *GTSCaches) initEmojiCategory() { - // Calculate maximum cache size. - cap := calculateResultCacheMax( - sizeofEmojiCategory(), // model in-mem size. - config.GetCacheEmojiCategoryMemRatio(), - ) - - log.Infof(nil, "cache size = %d", cap) - - c.emojiCategory = result.New([]result.Lookup{ - {Name: "ID"}, - {Name: "Name"}, - }, func(c1 *gtsmodel.EmojiCategory) *gtsmodel.EmojiCategory { - c2 := new(gtsmodel.EmojiCategory) - *c2 = *c1 - return c2 - }, cap) - - c.emojiCategory.IgnoreErrors(ignoreErrors) -} - -func (c *GTSCaches) initFollow() { - // Calculate maximum cache size. - cap := calculateResultCacheMax( - sizeofFollow(), // model in-mem size. - config.GetCacheFollowMemRatio(), - ) - - log.Infof(nil, "cache size = %d", cap) - - copyF := func(f1 *gtsmodel.Follow) *gtsmodel.Follow { - f2 := new(gtsmodel.Follow) - *f2 = *f1 - - // Don't include ptr fields that - // will be populated separately. - // See internal/db/bundb/relationship_follow.go. - f2.Account = nil - f2.TargetAccount = nil - - return f2 - } - - c.follow = result.New([]result.Lookup{ - {Name: "ID"}, - {Name: "URI"}, - {Name: "AccountID.TargetAccountID"}, - {Name: "AccountID", Multi: true}, - {Name: "TargetAccountID", Multi: true}, - }, copyF, cap) - - c.follow.IgnoreErrors(ignoreErrors) -} - -func (c *GTSCaches) initFollowIDs() { - // Calculate maximum cache size. - cap := calculateSliceCacheMax( - config.GetCacheFollowIDsMemRatio(), - ) - - log.Infof(nil, "cache size = %d", cap) - - c.followIDs = &SliceCache[string]{Cache: simple.New[string, []string]( - 0, - cap, - )} -} - -func (c *GTSCaches) initFollowRequest() { - // Calculate maximum cache size. - cap := calculateResultCacheMax( - sizeofFollowRequest(), // model in-mem size. - config.GetCacheFollowRequestMemRatio(), - ) - - log.Infof(nil, "cache size = %d", cap) - - copyF := func(f1 *gtsmodel.FollowRequest) *gtsmodel.FollowRequest { - f2 := new(gtsmodel.FollowRequest) - *f2 = *f1 - - // Don't include ptr fields that - // will be populated separately. - // See internal/db/bundb/relationship_follow_req.go. - f2.Account = nil - f2.TargetAccount = nil - - return f2 - } - - c.followRequest = result.New([]result.Lookup{ - {Name: "ID"}, - {Name: "URI"}, - {Name: "AccountID.TargetAccountID"}, - {Name: "AccountID", Multi: true}, - {Name: "TargetAccountID", Multi: true}, - }, copyF, cap) - - c.followRequest.IgnoreErrors(ignoreErrors) -} - -func (c *GTSCaches) initFollowRequestIDs() { - // Calculate maximum cache size. - cap := calculateSliceCacheMax( - config.GetCacheFollowRequestIDsMemRatio(), - ) - - log.Infof(nil, "cache size = %d", cap) - - c.followRequestIDs = &SliceCache[string]{Cache: simple.New[string, []string]( - 0, - cap, - )} -} - -func (c *GTSCaches) initInReplyToIDs() { - // Calculate maximum cache size. - cap := calculateSliceCacheMax( - config.GetCacheInReplyToIDsMemRatio(), - ) - - log.Infof(nil, "cache size = %d", cap) - - c.inReplyToIDs = &SliceCache[string]{Cache: simple.New[string, []string]( - 0, - cap, - )} -} - -func (c *GTSCaches) initInstance() { - // Calculate maximum cache size. - cap := calculateResultCacheMax( - sizeofInstance(), // model in-mem size. - config.GetCacheInstanceMemRatio(), - ) - - log.Infof(nil, "cache size = %d", cap) - - copyF := func(i1 *gtsmodel.Instance) *gtsmodel.Instance { - i2 := new(gtsmodel.Instance) - *i2 = *i1 - - // Don't include ptr fields that - // will be populated separately. - // See internal/db/bundb/instance.go. - i2.DomainBlock = nil - i2.ContactAccount = nil - - return i1 - } - - c.instance = result.New([]result.Lookup{ - {Name: "ID"}, - {Name: "Domain"}, - }, copyF, cap) - - c.instance.IgnoreErrors(ignoreErrors) -} - -func (c *GTSCaches) initList() { - // Calculate maximum cache size. - cap := calculateResultCacheMax( - sizeofList(), // model in-mem size. - config.GetCacheListMemRatio(), - ) - - log.Infof(nil, "cache size = %d", cap) - - copyF := func(l1 *gtsmodel.List) *gtsmodel.List { - l2 := new(gtsmodel.List) - *l2 = *l1 - - // Don't include ptr fields that - // will be populated separately. - // See internal/db/bundb/list.go. - l2.Account = nil - l2.ListEntries = nil - - return l2 - } - - c.list = result.New([]result.Lookup{ - {Name: "ID"}, - }, copyF, cap) - - c.list.IgnoreErrors(ignoreErrors) -} - -func (c *GTSCaches) initListEntry() { - // Calculate maximum cache size. - cap := calculateResultCacheMax( - sizeofListEntry(), // model in-mem size. - config.GetCacheListEntryMemRatio(), - ) - - log.Infof(nil, "cache size = %d", cap) - - copyF := func(l1 *gtsmodel.ListEntry) *gtsmodel.ListEntry { - l2 := new(gtsmodel.ListEntry) - *l2 = *l1 - - // Don't include ptr fields that - // will be populated separately. - // See internal/db/bundb/list.go. - l2.Follow = nil - - return l2 - } - - c.listEntry = result.New([]result.Lookup{ - {Name: "ID"}, - {Name: "ListID", Multi: true}, - {Name: "FollowID", Multi: true}, - }, copyF, cap) - - c.listEntry.IgnoreErrors(ignoreErrors) -} - -func (c *GTSCaches) initMarker() { - // Calculate maximum cache size. - cap := calculateResultCacheMax( - sizeofMarker(), // model in-mem size. - config.GetCacheMarkerMemRatio(), - ) - - log.Infof(nil, "cache size = %d", cap) - - c.marker = result.New([]result.Lookup{ - {Name: "AccountID.Name"}, - }, func(m1 *gtsmodel.Marker) *gtsmodel.Marker { - m2 := new(gtsmodel.Marker) - *m2 = *m1 - return m2 - }, cap) - - c.marker.IgnoreErrors(ignoreErrors) -} - -func (c *GTSCaches) initMedia() { - // Calculate maximum cache size. - cap := calculateResultCacheMax( - sizeofMedia(), // model in-mem size. - config.GetCacheMediaMemRatio(), - ) - - log.Infof(nil, "cache size = %d", cap) - - c.media = result.New([]result.Lookup{ - {Name: "ID"}, - }, func(m1 *gtsmodel.MediaAttachment) *gtsmodel.MediaAttachment { - m2 := new(gtsmodel.MediaAttachment) - *m2 = *m1 - return m2 - }, cap) - - c.media.IgnoreErrors(ignoreErrors) -} - -func (c *GTSCaches) initMention() { - // Calculate maximum cache size. - cap := calculateResultCacheMax( - sizeofMention(), // model in-mem size. - config.GetCacheMentionMemRatio(), - ) - - log.Infof(nil, "cache size = %d", cap) - - copyF := func(m1 *gtsmodel.Mention) *gtsmodel.Mention { - m2 := new(gtsmodel.Mention) - *m2 = *m1 - - // Don't include ptr fields that - // will be populated separately. - // See internal/db/bundb/mention.go. - m2.Status = nil - m2.OriginAccount = nil - m2.TargetAccount = nil - - return m2 - } - - c.mention = result.New([]result.Lookup{ - {Name: "ID"}, - }, copyF, cap) - - c.mention.IgnoreErrors(ignoreErrors) -} - -func (c *GTSCaches) initNotification() { - // Calculate maximum cache size. - cap := calculateResultCacheMax( - sizeofNotification(), // model in-mem size. - config.GetCacheNotificationMemRatio(), - ) - - log.Infof(nil, "cache size = %d", cap) - - copyF := func(n1 *gtsmodel.Notification) *gtsmodel.Notification { - n2 := new(gtsmodel.Notification) - *n2 = *n1 - - // Don't include ptr fields that - // will be populated separately. - // See internal/db/bundb/notification.go. - n2.Status = nil - n2.OriginAccount = nil - n2.TargetAccount = nil - - return n2 - } - - c.notification = result.New([]result.Lookup{ - {Name: "ID"}, - {Name: "NotificationType.TargetAccountID.OriginAccountID.StatusID"}, - }, copyF, cap) - - c.notification.IgnoreErrors(ignoreErrors) -} - -func (c *GTSCaches) initPoll() { - // Calculate maximum cache size. - cap := calculateResultCacheMax( - sizeofPoll(), // model in-mem size. - config.GetCachePollMemRatio(), - ) - - log.Infof(nil, "cache size = %d", cap) - - copyF := func(p1 *gtsmodel.Poll) *gtsmodel.Poll { - p2 := new(gtsmodel.Poll) - *p2 = *p1 - - // Don't include ptr fields that - // will be populated separately. - // See internal/db/bundb/poll.go. - p2.Status = nil - - // Don't include ephemeral fields - // which are only expected to be - // set on ONE poll instance. - p2.Closing = false - - return p2 - } - - c.poll = result.New([]result.Lookup{ - {Name: "ID"}, - {Name: "StatusID"}, - }, copyF, cap) - - c.poll.IgnoreErrors(ignoreErrors) -} - -func (c *GTSCaches) initPollVote() { - // Calculate maximum cache size. - cap := calculateResultCacheMax( - sizeofPollVote(), // model in-mem size. - config.GetCachePollVoteMemRatio(), - ) - - log.Infof(nil, "cache size = %d", cap) - - copyF := func(v1 *gtsmodel.PollVote) *gtsmodel.PollVote { - v2 := new(gtsmodel.PollVote) - *v2 = *v1 - - // Don't include ptr fields that - // will be populated separately. - // See internal/db/bundb/poll.go. - v2.Account = nil - v2.Poll = nil - - return v2 - } - - c.pollVote = result.New([]result.Lookup{ - {Name: "ID"}, - {Name: "PollID.AccountID"}, - {Name: "PollID", Multi: true}, - }, copyF, cap) - - c.pollVote.IgnoreErrors(ignoreErrors) -} - -func (c *GTSCaches) initPollVoteIDs() { - // Calculate maximum cache size. - cap := calculateSliceCacheMax( - config.GetCachePollVoteIDsMemRatio(), - ) - - log.Infof(nil, "cache size = %d", cap) - - c.pollVoteIDs = &SliceCache[string]{Cache: simple.New[string, []string]( - 0, - cap, - )} -} - -func (c *GTSCaches) initReport() { - // Calculate maximum cache size. - cap := calculateResultCacheMax( - sizeofReport(), // model in-mem size. - config.GetCacheReportMemRatio(), - ) - - log.Infof(nil, "cache size = %d", cap) - - copyF := func(r1 *gtsmodel.Report) *gtsmodel.Report { - r2 := new(gtsmodel.Report) - *r2 = *r1 - - // Don't include ptr fields that - // will be populated separately. - // See internal/db/bundb/report.go. - r2.Account = nil - r2.TargetAccount = nil - r2.Statuses = nil - r2.Rules = nil - r2.ActionTakenByAccount = nil - - return r2 - } - - c.report = result.New([]result.Lookup{ - {Name: "ID"}, - }, copyF, cap) - - c.report.IgnoreErrors(ignoreErrors) -} - -func (c *GTSCaches) initStatus() { - // Calculate maximum cache size. - cap := calculateResultCacheMax( - sizeofStatus(), // model in-mem size. - config.GetCacheStatusMemRatio(), - ) - - log.Infof(nil, "cache size = %d", cap) - - copyF := func(s1 *gtsmodel.Status) *gtsmodel.Status { - s2 := new(gtsmodel.Status) - *s2 = *s1 - - // Don't include ptr fields that - // will be populated separately. - // See internal/db/bundb/status.go. - s2.Account = nil - s2.InReplyTo = nil - s2.InReplyToAccount = nil - s2.BoostOf = nil - s2.BoostOfAccount = nil - s2.Poll = nil - s2.Attachments = nil - s2.Tags = nil - s2.Mentions = nil - s2.Emojis = nil - s2.CreatedWithApplication = nil - - return s2 - } - - c.status = result.New([]result.Lookup{ - {Name: "ID"}, - {Name: "URI"}, - {Name: "URL"}, - {Name: "PollID"}, - {Name: "BoostOfID.AccountID"}, - {Name: "ThreadID", Multi: true}, - }, copyF, cap) - - c.status.IgnoreErrors(ignoreErrors) -} - -func (c *GTSCaches) initStatusFave() { - // Calculate maximum cache size. - cap := calculateResultCacheMax( - sizeofStatusFave(), // model in-mem size. - config.GetCacheStatusFaveMemRatio(), - ) - - log.Infof(nil, "cache size = %d", cap) - - copyF := func(f1 *gtsmodel.StatusFave) *gtsmodel.StatusFave { - f2 := new(gtsmodel.StatusFave) - *f2 = *f1 - - // Don't include ptr fields that - // will be populated separately. - // See internal/db/bundb/statusfave.go. - f2.Account = nil - f2.TargetAccount = nil - f2.Status = nil - - return f2 - } - - c.statusFave = result.New([]result.Lookup{ - {Name: "ID"}, - {Name: "AccountID.StatusID"}, - {Name: "StatusID", Multi: true}, - }, copyF, cap) - - c.statusFave.IgnoreErrors(ignoreErrors) -} - -func (c *GTSCaches) initStatusFaveIDs() { - // Calculate maximum cache size. - cap := calculateSliceCacheMax( - config.GetCacheStatusFaveIDsMemRatio(), - ) - - log.Infof(nil, "cache size = %d", cap) - - c.statusFaveIDs = &SliceCache[string]{Cache: simple.New[string, []string]( - 0, - cap, - )} -} - -func (c *GTSCaches) initTag() { - // Calculate maximum cache size. - cap := calculateResultCacheMax( - sizeofTag(), // model in-mem size. - config.GetCacheTagMemRatio(), - ) - - log.Infof(nil, "cache size = %d", cap) - - c.tag = result.New([]result.Lookup{ - {Name: "ID"}, - {Name: "Name"}, - }, func(m1 *gtsmodel.Tag) *gtsmodel.Tag { - m2 := new(gtsmodel.Tag) - *m2 = *m1 - return m2 - }, cap) - - c.tag.IgnoreErrors(ignoreErrors) -} - -func (c *GTSCaches) initThreadMute() { - cap := calculateResultCacheMax( - sizeOfThreadMute(), // model in-mem size. - config.GetCacheThreadMuteMemRatio(), - ) - - log.Infof(nil, "cache size = %d", cap) - - c.threadMute = result.New([]result.Lookup{ - {Name: "ID"}, - {Name: "ThreadID", Multi: true}, - {Name: "AccountID", Multi: true}, - {Name: "ThreadID.AccountID"}, - }, func(t1 *gtsmodel.ThreadMute) *gtsmodel.ThreadMute { - t2 := new(gtsmodel.ThreadMute) - *t2 = *t1 - return t2 - }, cap) - - c.threadMute.IgnoreErrors(ignoreErrors) -} - -func (c *GTSCaches) initTombstone() { - // Calculate maximum cache size. - cap := calculateResultCacheMax( - sizeofTombstone(), // model in-mem size. - config.GetCacheTombstoneMemRatio(), - ) - - log.Infof(nil, "cache size = %d", cap) - - c.tombstone = result.New([]result.Lookup{ - {Name: "ID"}, - {Name: "URI"}, - }, func(t1 *gtsmodel.Tombstone) *gtsmodel.Tombstone { - t2 := new(gtsmodel.Tombstone) - *t2 = *t1 - return t2 - }, cap) - - c.tombstone.IgnoreErrors(ignoreErrors) -} - -func (c *GTSCaches) initUser() { - // Calculate maximum cache size. - cap := calculateResultCacheMax( - sizeofUser(), // model in-mem size. - config.GetCacheUserMemRatio(), - ) - - log.Infof(nil, "cache size = %d", cap) - - copyF := func(u1 *gtsmodel.User) *gtsmodel.User { - u2 := new(gtsmodel.User) - *u2 = *u1 - - // Don't include ptr fields that - // will be populated separately. - // See internal/db/bundb/user.go. - u2.Account = nil - - return u2 - } - - c.user = result.New([]result.Lookup{ - {Name: "ID"}, - {Name: "AccountID"}, - {Name: "Email"}, - {Name: "ConfirmationToken"}, - {Name: "ExternalID"}, - }, copyF, cap) - - c.user.IgnoreErrors(ignoreErrors) -} - -func (c *GTSCaches) initWebfinger() { - // Calculate maximum cache size. - cap := calculateCacheMax( - sizeofURIStr, sizeofURIStr, - config.GetCacheWebfingerMemRatio(), - ) - - log.Infof(nil, "cache size = %d", cap) - - c.webfinger = ttl.New[string, string]( - 0, - cap, - 24*time.Hour, - ) -} diff --git a/internal/cache/invalidate.go b/internal/cache/invalidate.go new file mode 100644 index 000000000..d85c503da --- /dev/null +++ b/internal/cache/invalidate.go @@ -0,0 +1,192 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package cache + +import ( + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) + +// Below are cache invalidation hooks between other caches, +// as an invalidation indicates a database INSERT / UPDATE / DELETE. +// NOTE THEY ARE ONLY CALLED WHEN THE ITEM IS IN THE CACHE, SO FOR +// HOOKS TO BE CALLED ON DELETE YOU MUST FIRST POPULATE IT IN THE CACHE. + +func (c *Caches) OnInvalidateAccount(account *gtsmodel.Account) { + // Invalidate account ID cached visibility. + c.Visibility.Invalidate("ItemID", account.ID) + c.Visibility.Invalidate("RequesterID", account.ID) + + // Invalidate this account's + // following / follower lists. + // (see FollowIDs() comment for details). + c.GTS.FollowIDs.InvalidateAll( + ">"+account.ID, + "l>"+account.ID, + "<"+account.ID, + "l<"+account.ID, + ) + + // Invalidate this account's + // follow requesting / request lists. + // (see FollowRequestIDs() comment for details). + c.GTS.FollowRequestIDs.InvalidateAll( + ">"+account.ID, + "<"+account.ID, + ) + + // Invalidate this account's block lists. + c.GTS.BlockIDs.Invalidate(account.ID) +} + +func (c *Caches) OnInvalidateBlock(block *gtsmodel.Block) { + // Invalidate block origin account ID cached visibility. + c.Visibility.Invalidate("ItemID", block.AccountID) + c.Visibility.Invalidate("RequesterID", block.AccountID) + + // Invalidate block target account ID cached visibility. + c.Visibility.Invalidate("ItemID", block.TargetAccountID) + c.Visibility.Invalidate("RequesterID", block.TargetAccountID) + + // Invalidate source account's block lists. + c.GTS.BlockIDs.Invalidate(block.AccountID) +} + +func (c *Caches) OnInvalidateEmojiCategory(category *gtsmodel.EmojiCategory) { + // Invalidate any emoji in this category. + c.GTS.Emoji.Invalidate("CategoryID", category.ID) +} + +func (c *Caches) OnInvalidateFollow(follow *gtsmodel.Follow) { + // Invalidate follow request with this same ID. + c.GTS.FollowRequest.Invalidate("ID", follow.ID) + + // Invalidate any related list entries. + c.GTS.ListEntry.Invalidate("FollowID", follow.ID) + + // Invalidate follow origin account ID cached visibility. + c.Visibility.Invalidate("ItemID", follow.AccountID) + c.Visibility.Invalidate("RequesterID", follow.AccountID) + + // Invalidate follow target account ID cached visibility. + c.Visibility.Invalidate("ItemID", follow.TargetAccountID) + c.Visibility.Invalidate("RequesterID", follow.TargetAccountID) + + // Invalidate source account's following + // lists, and destination's follwer lists. + // (see FollowIDs() comment for details). + c.GTS.FollowIDs.InvalidateAll( + ">"+follow.AccountID, + "l>"+follow.AccountID, + "<"+follow.AccountID, + "l<"+follow.AccountID, + "<"+follow.TargetAccountID, + "l<"+follow.TargetAccountID, + ">"+follow.TargetAccountID, + "l>"+follow.TargetAccountID, + ) +} + +func (c *Caches) OnInvalidateFollowRequest(followReq *gtsmodel.FollowRequest) { + // Invalidate follow with this same ID. + c.GTS.Follow.Invalidate("ID", followReq.ID) + + // Invalidate source account's followreq + // lists, and destinations follow req lists. + // (see FollowRequestIDs() comment for details). + c.GTS.FollowRequestIDs.InvalidateAll( + ">"+followReq.AccountID, + "<"+followReq.AccountID, + ">"+followReq.TargetAccountID, + "<"+followReq.TargetAccountID, + ) +} + +func (c *Caches) OnInvalidateList(list *gtsmodel.List) { + // Invalidate all cached entries of this list. + c.GTS.ListEntry.Invalidate("ListID", list.ID) +} + +func (c *Caches) OnInvalidateMedia(media *gtsmodel.MediaAttachment) { + if (media.Avatar != nil && *media.Avatar) || + (media.Header != nil && *media.Header) { + // Invalidate cache of attaching account. + c.GTS.Account.Invalidate("ID", media.AccountID) + } + + if media.StatusID != "" { + // Invalidate cache of attaching status. + c.GTS.Status.Invalidate("ID", media.StatusID) + } +} + +func (c *Caches) OnInvalidatePoll(poll *gtsmodel.Poll) { + // Invalidate all cached votes of this poll. + c.GTS.PollVote.Invalidate("PollID", poll.ID) + + // Invalidate cache of poll vote IDs. + c.GTS.PollVoteIDs.Invalidate(poll.ID) +} + +func (c *Caches) OnInvalidatePollVote(vote *gtsmodel.PollVote) { + // Invalidate cached poll (contains no. votes). + c.GTS.Poll.Invalidate("ID", vote.PollID) + + // Invalidate cache of poll vote IDs. + c.GTS.PollVoteIDs.Invalidate(vote.PollID) +} + +func (c *Caches) OnInvalidateStatus(status *gtsmodel.Status) { + // Invalidate status ID cached visibility. + c.Visibility.Invalidate("ItemID", status.ID) + + for _, id := range status.AttachmentIDs { + // Invalidate each media by the IDs we're aware of. + // This must be done as the status table is aware of + // the media IDs in use before the media table is + // aware of the status ID they are linked to. + // + // c.GTS.Media().Invalidate("StatusID") will not work. + c.GTS.Media.Invalidate("ID", id) + } + + if status.BoostOfID != "" { + // Invalidate boost ID list of the original status. + c.GTS.BoostOfIDs.Invalidate(status.BoostOfID) + } + + if status.InReplyToID != "" { + // Invalidate in reply to ID list of original status. + c.GTS.InReplyToIDs.Invalidate(status.InReplyToID) + } + + if status.PollID != "" { + // Invalidate cache of attached poll ID. + c.GTS.Poll.Invalidate("ID", status.PollID) + } +} + +func (c *Caches) OnInvalidateStatusFave(fave *gtsmodel.StatusFave) { + // Invalidate status fave ID list for this status. + c.GTS.StatusFaveIDs.Invalidate(fave.StatusID) +} + +func (c *Caches) OnInvalidateUser(user *gtsmodel.User) { + // Invalidate local account ID cached visibility. + c.Visibility.Invalidate("ItemID", user.AccountID) + c.Visibility.Invalidate("RequesterID", user.AccountID) +} diff --git a/internal/cache/visibility.go b/internal/cache/visibility.go index 8c534206b..878efcdb8 100644 --- a/internal/cache/visibility.go +++ b/internal/cache/visibility.go @@ -18,18 +18,16 @@ package cache import ( - "codeberg.org/gruf/go-cache/v3/result" + "codeberg.org/gruf/go-structr" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/log" ) type VisibilityCache struct { - *result.Cache[*CachedVisibility] + structr.Cache[*CachedVisibility] } -// Init will initialize the visibility cache in this collection. -// NOTE: the cache MUST NOT be in use anywhere, this is not thread-safe. -func (c *VisibilityCache) Init() { +func (c *Caches) initVisibility() { // Calculate maximum cache size. cap := calculateResultCacheMax( sizeofVisibility(), // model in-mem size. @@ -38,25 +36,22 @@ func (c *VisibilityCache) Init() { log.Infof(nil, "Visibility cache size = %d", cap) - c.Cache = result.New([]result.Lookup{ - {Name: "ItemID", Multi: true}, - {Name: "RequesterID", Multi: true}, - {Name: "Type.RequesterID.ItemID"}, - }, func(v1 *CachedVisibility) *CachedVisibility { + copyF := func(v1 *CachedVisibility) *CachedVisibility { v2 := new(CachedVisibility) *v2 = *v1 return v2 - }, cap) + } - c.Cache.IgnoreErrors(ignoreErrors) -} - -// Start will attempt to start the visibility cache, or panic. -func (c *VisibilityCache) Start() { -} - -// Stop will attempt to stop the visibility cache, or panic. -func (c *VisibilityCache) Stop() { + c.Visibility.Init(structr.Config[*CachedVisibility]{ + Indices: []structr.IndexConfig{ + {Fields: "ItemID", Multiple: true}, + {Fields: "RequesterID", Multiple: true}, + {Fields: "Type,RequesterID,ItemID"}, + }, + MaxSize: cap, + IgnoreErr: ignoreErrors, + CopyValue: copyF, + }) } // VisibilityType represents a visibility lookup type. diff --git a/internal/db/bundb/account.go b/internal/db/bundb/account.go index fdee8cb76..cdb949efa 100644 --- a/internal/db/bundb/account.go +++ b/internal/db/bundb/account.go @@ -116,7 +116,7 @@ func (a *accountDB) GetAccountByUsernameDomain(ctx context.Context, username str return a.getAccount( ctx, - "Username.Domain", + "Username,Domain", func(account *gtsmodel.Account) error { q := a.db.NewSelect(). Model(account) @@ -224,7 +224,7 @@ func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gts func (a *accountDB) getAccount(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Account) error, keyParts ...any) (*gtsmodel.Account, error) { // Fetch account from database cache with loader callback - account, err := a.state.Caches.GTS.Account().Load(lookup, func() (*gtsmodel.Account, error) { + account, err := a.state.Caches.GTS.Account.LoadOne(lookup, func() (*gtsmodel.Account, error) { var account gtsmodel.Account // Not cached! Perform database query @@ -325,7 +325,7 @@ func (a *accountDB) PopulateAccount(ctx context.Context, account *gtsmodel.Accou } func (a *accountDB) PutAccount(ctx context.Context, account *gtsmodel.Account) error { - return a.state.Caches.GTS.Account().Store(account, func() error { + return a.state.Caches.GTS.Account.Store(account, func() error { // It is safe to run this database transaction within cache.Store // as the cache does not attempt a mutex lock until AFTER hook. // @@ -354,7 +354,7 @@ func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account columns = append(columns, "updated_at") } - return a.state.Caches.GTS.Account().Store(account, func() error { + return a.state.Caches.GTS.Account.Store(account, func() error { // It is safe to run this database transaction within cache.Store // as the cache does not attempt a mutex lock until AFTER hook. // @@ -393,7 +393,7 @@ func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account } func (a *accountDB) DeleteAccount(ctx context.Context, id string) error { - defer a.state.Caches.GTS.Account().Invalidate("ID", id) + defer a.state.Caches.GTS.Account.Invalidate("ID", id) // Load account into cache before attempting a delete, // as we need it cached in order to trigger the invalidate @@ -635,6 +635,10 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li return nil, err } + if len(statusIDs) == 0 { + return nil, db.ErrNoEntries + } + // If we're paging up, we still want statuses // to be sorted by ID desc, so reverse ids slice. // https://zchee.github.io/golang-wiki/SliceTricks/#reversing @@ -644,7 +648,7 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li } } - return a.statusesFromIDs(ctx, statusIDs) + return a.state.DB.GetStatusesByIDs(ctx, statusIDs) } func (a *accountDB) GetAccountPinnedStatuses(ctx context.Context, accountID string) ([]*gtsmodel.Status, error) { @@ -662,7 +666,11 @@ func (a *accountDB) GetAccountPinnedStatuses(ctx context.Context, accountID stri return nil, err } - return a.statusesFromIDs(ctx, statusIDs) + if len(statusIDs) == 0 { + return nil, db.ErrNoEntries + } + + return a.state.DB.GetStatusesByIDs(ctx, statusIDs) } func (a *accountDB) GetAccountWebStatuses(ctx context.Context, accountID string, limit int, maxID string) ([]*gtsmodel.Status, error) { @@ -710,29 +718,9 @@ func (a *accountDB) GetAccountWebStatuses(ctx context.Context, accountID string, return nil, err } - return a.statusesFromIDs(ctx, statusIDs) -} - -func (a *accountDB) statusesFromIDs(ctx context.Context, statusIDs []string) ([]*gtsmodel.Status, error) { - // Catch case of no statuses early if len(statusIDs) == 0 { return nil, db.ErrNoEntries } - // Allocate return slice (will be at most len statusIDS) - statuses := make([]*gtsmodel.Status, 0, len(statusIDs)) - - for _, id := range statusIDs { - // Fetch from status from database by ID - status, err := a.state.DB.GetStatusByID(ctx, id) - if err != nil { - log.Errorf(ctx, "error getting status %q: %v", id, err) - continue - } - - // Append to return slice - statuses = append(statuses, status) - } - - return statuses, nil + return a.state.DB.GetStatusesByIDs(ctx, statusIDs) } diff --git a/internal/db/bundb/application.go b/internal/db/bundb/application.go index f7328e275..2e17a0e94 100644 --- a/internal/db/bundb/application.go +++ b/internal/db/bundb/application.go @@ -53,7 +53,7 @@ func (a *applicationDB) GetApplicationByClientID(ctx context.Context, clientID s } func (a *applicationDB) getApplication(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Application) error, keyParts ...any) (*gtsmodel.Application, error) { - return a.state.Caches.GTS.Application().Load(lookup, func() (*gtsmodel.Application, error) { + return a.state.Caches.GTS.Application.LoadOne(lookup, func() (*gtsmodel.Application, error) { var app gtsmodel.Application // Not cached! Perform database query. @@ -66,7 +66,7 @@ func (a *applicationDB) getApplication(ctx context.Context, lookup string, dbQue } func (a *applicationDB) PutApplication(ctx context.Context, app *gtsmodel.Application) error { - return a.state.Caches.GTS.Application().Store(app, func() error { + return a.state.Caches.GTS.Application.Store(app, func() error { _, err := a.db.NewInsert().Model(app).Exec(ctx) return err }) @@ -91,7 +91,7 @@ func (a *applicationDB) DeleteApplicationByClientID(ctx context.Context, clientI // // Clear application from the cache. - a.state.Caches.GTS.Application().Invalidate("ClientID", clientID) + a.state.Caches.GTS.Application.Invalidate("ClientID", clientID) return nil } diff --git a/internal/db/bundb/bundb.go b/internal/db/bundb/bundb.go index d9415eff4..048474782 100644 --- a/internal/db/bundb/bundb.go +++ b/internal/db/bundb/bundb.go @@ -258,7 +258,7 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) { state: state, }, Tag: &tagDB{ - conn: db, + db: db, state: state, }, Thread: &threadDB{ diff --git a/internal/db/bundb/domain.go b/internal/db/bundb/domain.go index dd626bc0a..2398e52c2 100644 --- a/internal/db/bundb/domain.go +++ b/internal/db/bundb/domain.go @@ -51,7 +51,7 @@ func (d *domainDB) CreateDomainAllow(ctx context.Context, allow *gtsmodel.Domain } // Clear the domain allow cache (for later reload) - d.state.Caches.GTS.DomainAllow().Clear() + d.state.Caches.GTS.DomainAllow.Clear() return nil } @@ -126,7 +126,7 @@ func (d *domainDB) DeleteDomainAllow(ctx context.Context, domain string) error { } // Clear the domain allow cache (for later reload) - d.state.Caches.GTS.DomainAllow().Clear() + d.state.Caches.GTS.DomainAllow.Clear() return nil } @@ -147,7 +147,7 @@ func (d *domainDB) CreateDomainBlock(ctx context.Context, block *gtsmodel.Domain } // Clear the domain block cache (for later reload) - d.state.Caches.GTS.DomainBlock().Clear() + d.state.Caches.GTS.DomainBlock.Clear() return nil } @@ -222,7 +222,7 @@ func (d *domainDB) DeleteDomainBlock(ctx context.Context, domain string) error { } // Clear the domain block cache (for later reload) - d.state.Caches.GTS.DomainBlock().Clear() + d.state.Caches.GTS.DomainBlock.Clear() return nil } @@ -241,7 +241,7 @@ func (d *domainDB) IsDomainBlocked(ctx context.Context, domain string) (bool, er } // Check the cache for an explicit domain allow (hydrating the cache with callback if necessary). - explicitAllow, err := d.state.Caches.GTS.DomainAllow().Matches(domain, func() ([]string, error) { + explicitAllow, err := d.state.Caches.GTS.DomainAllow.Matches(domain, func() ([]string, error) { var domains []string // Scan list of all explicitly allowed domains from DB @@ -259,7 +259,7 @@ func (d *domainDB) IsDomainBlocked(ctx context.Context, domain string) (bool, er } // Check the cache for a domain block (hydrating the cache with callback if necessary) - explicitBlock, err := d.state.Caches.GTS.DomainBlock().Matches(domain, func() ([]string, error) { + explicitBlock, err := d.state.Caches.GTS.DomainBlock.Matches(domain, func() ([]string, error) { var domains []string // Scan list of all blocked domains from DB diff --git a/internal/db/bundb/emoji.go b/internal/db/bundb/emoji.go index 34a08b694..31092d0d2 100644 --- a/internal/db/bundb/emoji.go +++ b/internal/db/bundb/emoji.go @@ -21,6 +21,7 @@ import ( "context" "database/sql" "errors" + "slices" "strings" "time" @@ -30,6 +31,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/state" + "github.com/superseriousbusiness/gotosocial/internal/util" "github.com/uptrace/bun" "github.com/uptrace/bun/dialect" ) @@ -40,7 +42,7 @@ type emojiDB struct { } func (e *emojiDB) PutEmoji(ctx context.Context, emoji *gtsmodel.Emoji) error { - return e.state.Caches.GTS.Emoji().Store(emoji, func() error { + return e.state.Caches.GTS.Emoji.Store(emoji, func() error { _, err := e.db.NewInsert().Model(emoji).Exec(ctx) return err }) @@ -54,7 +56,7 @@ func (e *emojiDB) UpdateEmoji(ctx context.Context, emoji *gtsmodel.Emoji, column } // Update the emoji model in the database. - return e.state.Caches.GTS.Emoji().Store(emoji, func() error { + return e.state.Caches.GTS.Emoji.Store(emoji, func() error { _, err := e.db. NewUpdate(). Model(emoji). @@ -74,21 +76,21 @@ func (e *emojiDB) DeleteEmojiByID(ctx context.Context, id string) error { defer func() { // Invalidate cached emoji. e.state.Caches.GTS. - Emoji(). + Emoji. Invalidate("ID", id) - for _, id := range accountIDs { + for _, accountID := range accountIDs { // Invalidate cached account. e.state.Caches.GTS. - Account(). - Invalidate("ID", id) + Account. + Invalidate("ID", accountID) } - for _, id := range statusIDs { + for _, statusID := range statusIDs { // Invalidate cached account. e.state.Caches.GTS. - Status(). - Invalidate("ID", id) + Status. + Invalidate("ID", statusID) } }() @@ -129,26 +131,28 @@ func (e *emojiDB) DeleteEmojiByID(ctx context.Context, id string) error { return err } - for _, id := range statusIDs { + for _, statusID := range statusIDs { var emojiIDs []string // Select statuses with ID. if _, err := tx.NewSelect(). Table("statuses"). Column("emojis"). - Where("? = ?", bun.Ident("id"), id). + Where("? = ?", bun.Ident("id"), statusID). Exec(ctx); err != nil && err != sql.ErrNoRows { return err } - // Drop ID from account emojis. - emojiIDs = dropID(emojiIDs, id) + // Delete all instances of this emoji ID from status emojis. + emojiIDs = slices.DeleteFunc(emojiIDs, func(emojiID string) bool { + return emojiID == id + }) // Update status emoji IDs. if _, err := tx.NewUpdate(). Table("statuses"). - Where("? = ?", bun.Ident("id"), id). + Where("? = ?", bun.Ident("id"), statusID). Set("emojis = ?", emojiIDs). Exec(ctx); err != nil && err != sql.ErrNoRows { @@ -156,26 +160,28 @@ func (e *emojiDB) DeleteEmojiByID(ctx context.Context, id string) error { } } - for _, id := range accountIDs { + for _, accountID := range accountIDs { var emojiIDs []string // Select account with ID. if _, err := tx.NewSelect(). Table("accounts"). Column("emojis"). - Where("? = ?", bun.Ident("id"), id). + Where("? = ?", bun.Ident("id"), accountID). Exec(ctx); err != nil && err != sql.ErrNoRows { return err } - // Drop ID from account emojis. - emojiIDs = dropID(emojiIDs, id) + // Delete all instances of this emoji ID from account emojis. + emojiIDs = slices.DeleteFunc(emojiIDs, func(emojiID string) bool { + return emojiID == id + }) // Update account emoji IDs. if _, err := tx.NewUpdate(). Table("accounts"). - Where("? = ?", bun.Ident("id"), id). + Where("? = ?", bun.Ident("id"), accountID). Set("emojis = ?", emojiIDs). Exec(ctx); err != nil && err != sql.ErrNoRows { @@ -431,7 +437,7 @@ func (e *emojiDB) GetEmojiByURI(ctx context.Context, uri string) (*gtsmodel.Emoj func (e *emojiDB) GetEmojiByShortcodeDomain(ctx context.Context, shortcode string, domain string) (*gtsmodel.Emoji, error) { return e.getEmoji( ctx, - "Shortcode.Domain", + "Shortcode,Domain", func(emoji *gtsmodel.Emoji) error { q := e.db. NewSelect(). @@ -468,7 +474,7 @@ func (e *emojiDB) GetEmojiByStaticURL(ctx context.Context, imageStaticURL string } func (e *emojiDB) PutEmojiCategory(ctx context.Context, emojiCategory *gtsmodel.EmojiCategory) error { - return e.state.Caches.GTS.EmojiCategory().Store(emojiCategory, func() error { + return e.state.Caches.GTS.EmojiCategory.Store(emojiCategory, func() error { _, err := e.db.NewInsert().Model(emojiCategory).Exec(ctx) return err }) @@ -520,7 +526,7 @@ func (e *emojiDB) GetEmojiCategoryByName(ctx context.Context, name string) (*gts func (e *emojiDB) getEmoji(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Emoji) error, keyParts ...any) (*gtsmodel.Emoji, error) { // Fetch emoji from database cache with loader callback - emoji, err := e.state.Caches.GTS.Emoji().Load(lookup, func() (*gtsmodel.Emoji, error) { + emoji, err := e.state.Caches.GTS.Emoji.LoadOne(lookup, func() (*gtsmodel.Emoji, error) { var emoji gtsmodel.Emoji // Not cached! Perform database query @@ -568,28 +574,72 @@ func (e *emojiDB) PopulateEmoji(ctx context.Context, emoji *gtsmodel.Emoji) erro return errs.Combine() } -func (e *emojiDB) GetEmojisByIDs(ctx context.Context, emojiIDs []string) ([]*gtsmodel.Emoji, error) { - if len(emojiIDs) == 0 { +func (e *emojiDB) GetEmojisByIDs(ctx context.Context, ids []string) ([]*gtsmodel.Emoji, error) { + if len(ids) == 0 { return nil, db.ErrNoEntries } - emojis := make([]*gtsmodel.Emoji, 0, len(emojiIDs)) + // Preallocate at-worst possible length. + uncached := make([]string, 0, len(ids)) - for _, id := range emojiIDs { - emoji, err := e.GetEmojiByID(ctx, id) - if err != nil { - log.Errorf(ctx, "emojisFromIDs: error getting emoji %q: %v", id, err) - continue - } + // Load all emoji IDs via cache loader callbacks. + emojis, err := e.state.Caches.GTS.Emoji.Load("ID", + + // Load cached + check for uncached. + func(load func(keyParts ...any) bool) { + for _, id := range ids { + if !load(id) { + uncached = append(uncached, id) + } + } + }, - emojis = append(emojis, emoji) + // Uncached emoji loader function. + func() ([]*gtsmodel.Emoji, error) { + // Preallocate expected length of uncached emojis. + emojis := make([]*gtsmodel.Emoji, 0, len(uncached)) + + // Perform database query scanning + // the remaining (uncached) IDs. + if err := e.db.NewSelect(). + Model(&emojis). + Where("? IN (?)", bun.Ident("id"), bun.In(uncached)). + Scan(ctx); err != nil { + return nil, err + } + + return emojis, nil + }, + ) + if err != nil { + return nil, err + } + + // Reorder the emojis by their + // IDs to ensure in correct order. + getID := func(e *gtsmodel.Emoji) string { return e.ID } + util.OrderBy(emojis, ids, getID) + + if gtscontext.Barebones(ctx) { + // no need to fully populate. + return emojis, nil } + // Populate all loaded emojis, removing those we fail to + // populate (removes needing so many nil checks everywhere). + emojis = slices.DeleteFunc(emojis, func(emoji *gtsmodel.Emoji) bool { + if err := e.PopulateEmoji(ctx, emoji); err != nil { + log.Errorf(ctx, "error populating emoji %s: %v", emoji.ID, err) + return true + } + return false + }) + return emojis, nil } func (e *emojiDB) getEmojiCategory(ctx context.Context, lookup string, dbQuery func(*gtsmodel.EmojiCategory) error, keyParts ...any) (*gtsmodel.EmojiCategory, error) { - return e.state.Caches.GTS.EmojiCategory().Load(lookup, func() (*gtsmodel.EmojiCategory, error) { + return e.state.Caches.GTS.EmojiCategory.LoadOne(lookup, func() (*gtsmodel.EmojiCategory, error) { var category gtsmodel.EmojiCategory // Not cached! Perform database query @@ -601,36 +651,51 @@ func (e *emojiDB) getEmojiCategory(ctx context.Context, lookup string, dbQuery f }, keyParts...) } -func (e *emojiDB) GetEmojiCategoriesByIDs(ctx context.Context, emojiCategoryIDs []string) ([]*gtsmodel.EmojiCategory, error) { - if len(emojiCategoryIDs) == 0 { +func (e *emojiDB) GetEmojiCategoriesByIDs(ctx context.Context, ids []string) ([]*gtsmodel.EmojiCategory, error) { + if len(ids) == 0 { return nil, db.ErrNoEntries } - emojiCategories := make([]*gtsmodel.EmojiCategory, 0, len(emojiCategoryIDs)) + // Preallocate at-worst possible length. + uncached := make([]string, 0, len(ids)) - for _, id := range emojiCategoryIDs { - emojiCategory, err := e.GetEmojiCategory(ctx, id) - if err != nil { - log.Errorf(ctx, "error getting emoji category %q: %v", id, err) - continue - } + // Load all category IDs via cache loader callbacks. + categories, err := e.state.Caches.GTS.EmojiCategory.Load("ID", - emojiCategories = append(emojiCategories, emojiCategory) - } + // Load cached + check for uncached. + func(load func(keyParts ...any) bool) { + for _, id := range ids { + if !load(id) { + uncached = append(uncached, id) + } + } + }, - return emojiCategories, nil -} + // Uncached emoji loader function. + func() ([]*gtsmodel.EmojiCategory, error) { + // Preallocate expected length of uncached categories. + categories := make([]*gtsmodel.EmojiCategory, 0, len(uncached)) + + // Perform database query scanning + // the remaining (uncached) IDs. + if err := e.db.NewSelect(). + Model(&categories). + Where("? IN (?)", bun.Ident("id"), bun.In(uncached)). + Scan(ctx); err != nil { + return nil, err + } -// dropIDs drops given ID string from IDs slice. -func dropID(ids []string, id string) []string { - for i := 0; i < len(ids); { - if ids[i] == id { - // Remove this reference. - copy(ids[i:], ids[i+1:]) - ids = ids[:len(ids)-1] - continue - } - i++ + return categories, nil + }, + ) + if err != nil { + return nil, err } - return ids + + // Reorder the categories by their + // IDs to ensure in correct order. + getID := func(c *gtsmodel.EmojiCategory) string { return c.ID } + util.OrderBy(categories, ids, getID) + + return categories, nil } diff --git a/internal/db/bundb/instance.go b/internal/db/bundb/instance.go index 567a44ee2..d506e0a31 100644 --- a/internal/db/bundb/instance.go +++ b/internal/db/bundb/instance.go @@ -143,7 +143,7 @@ func (i *instanceDB) GetInstanceByID(ctx context.Context, id string) (*gtsmodel. func (i *instanceDB) getInstance(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Instance) error, keyParts ...any) (*gtsmodel.Instance, error) { // Fetch instance from database cache with loader callback - instance, err := i.state.Caches.GTS.Instance().Load(lookup, func() (*gtsmodel.Instance, error) { + instance, err := i.state.Caches.GTS.Instance.LoadOne(lookup, func() (*gtsmodel.Instance, error) { var instance gtsmodel.Instance // Not cached! Perform database query. @@ -219,7 +219,7 @@ func (i *instanceDB) PutInstance(ctx context.Context, instance *gtsmodel.Instanc return gtserror.Newf("error punifying domain %s: %w", instance.Domain, err) } - return i.state.Caches.GTS.Instance().Store(instance, func() error { + return i.state.Caches.GTS.Instance.Store(instance, func() error { _, err := i.db.NewInsert().Model(instance).Exec(ctx) return err }) @@ -239,7 +239,7 @@ func (i *instanceDB) UpdateInstance(ctx context.Context, instance *gtsmodel.Inst columns = append(columns, "updated_at") } - return i.state.Caches.GTS.Instance().Store(instance, func() error { + return i.state.Caches.GTS.Instance.Store(instance, func() error { _, err := i.db. NewUpdate(). Model(instance). diff --git a/internal/db/bundb/list.go b/internal/db/bundb/list.go index 7a117670a..5f95d3c24 100644 --- a/internal/db/bundb/list.go +++ b/internal/db/bundb/list.go @@ -21,6 +21,7 @@ import ( "context" "errors" "fmt" + "slices" "time" "github.com/superseriousbusiness/gotosocial/internal/db" @@ -29,6 +30,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/state" + "github.com/superseriousbusiness/gotosocial/internal/util" "github.com/uptrace/bun" ) @@ -56,7 +58,7 @@ func (l *listDB) GetListByID(ctx context.Context, id string) (*gtsmodel.List, er } func (l *listDB) getList(ctx context.Context, lookup string, dbQuery func(*gtsmodel.List) error, keyParts ...any) (*gtsmodel.List, error) { - list, err := l.state.Caches.GTS.List().Load(lookup, func() (*gtsmodel.List, error) { + list, err := l.state.Caches.GTS.List.LoadOne(lookup, func() (*gtsmodel.List, error) { var list gtsmodel.List // Not cached! Perform database query. @@ -100,18 +102,8 @@ func (l *listDB) GetListsForAccountID(ctx context.Context, accountID string) ([] return nil, nil } - // Select each list using its ID to ensure cache used. - lists := make([]*gtsmodel.List, 0, len(listIDs)) - for _, id := range listIDs { - list, err := l.state.DB.GetListByID(ctx, id) - if err != nil { - log.Errorf(ctx, "error fetching list %q: %v", id, err) - continue - } - lists = append(lists, list) - } - - return lists, nil + // Return lists by their IDs. + return l.GetListsByIDs(ctx, listIDs) } func (l *listDB) PopulateList(ctx context.Context, list *gtsmodel.List) error { @@ -147,7 +139,7 @@ func (l *listDB) PopulateList(ctx context.Context, list *gtsmodel.List) error { } func (l *listDB) PutList(ctx context.Context, list *gtsmodel.List) error { - return l.state.Caches.GTS.List().Store(list, func() error { + return l.state.Caches.GTS.List.Store(list, func() error { _, err := l.db.NewInsert().Model(list).Exec(ctx) return err }) @@ -162,7 +154,7 @@ func (l *listDB) UpdateList(ctx context.Context, list *gtsmodel.List, columns .. defer func() { // Invalidate all entries for this list ID. - l.state.Caches.GTS.ListEntry().Invalidate("ListID", list.ID) + l.state.Caches.GTS.ListEntry.Invalidate("ListID", list.ID) // Invalidate this entire list's timeline. if err := l.state.Timelines.List.RemoveTimeline(ctx, list.ID); err != nil { @@ -170,7 +162,7 @@ func (l *listDB) UpdateList(ctx context.Context, list *gtsmodel.List, columns .. } }() - return l.state.Caches.GTS.List().Store(list, func() error { + return l.state.Caches.GTS.List.Store(list, func() error { _, err := l.db.NewUpdate(). Model(list). Where("? = ?", bun.Ident("list.id"), list.ID). @@ -198,7 +190,7 @@ func (l *listDB) DeleteListByID(ctx context.Context, id string) error { defer func() { // Invalidate this list from cache. - l.state.Caches.GTS.List().Invalidate("ID", id) + l.state.Caches.GTS.List.Invalidate("ID", id) // Invalidate this entire list's timeline. if err := l.state.Timelines.List.RemoveTimeline(ctx, id); err != nil { @@ -243,7 +235,7 @@ func (l *listDB) GetListEntryByID(ctx context.Context, id string) (*gtsmodel.Lis } func (l *listDB) getListEntry(ctx context.Context, lookup string, dbQuery func(*gtsmodel.ListEntry) error, keyParts ...any) (*gtsmodel.ListEntry, error) { - listEntry, err := l.state.Caches.GTS.ListEntry().Load(lookup, func() (*gtsmodel.ListEntry, error) { + listEntry, err := l.state.Caches.GTS.ListEntry.LoadOne(lookup, func() (*gtsmodel.ListEntry, error) { var listEntry gtsmodel.ListEntry // Not cached! Perform database query. @@ -344,18 +336,128 @@ func (l *listDB) GetListEntries(ctx context.Context, } } - // Select each list entry using its ID to ensure cache used. - listEntries := make([]*gtsmodel.ListEntry, 0, len(entryIDs)) - for _, id := range entryIDs { - listEntry, err := l.state.DB.GetListEntryByID(ctx, id) - if err != nil { - log.Errorf(ctx, "error fetching list entry %q: %v", id, err) - continue + // Return list entries by their IDs. + return l.GetListEntriesByIDs(ctx, entryIDs) +} + +func (l *listDB) GetListsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.List, error) { + // Preallocate at-worst possible length. + uncached := make([]string, 0, len(ids)) + + // Load all list IDs via cache loader callbacks. + lists, err := l.state.Caches.GTS.List.Load("ID", + + // Load cached + check for uncached. + func(load func(keyParts ...any) bool) { + for _, id := range ids { + if !load(id) { + uncached = append(uncached, id) + } + } + }, + + // Uncached list loader function. + func() ([]*gtsmodel.List, error) { + // Preallocate expected length of uncached lists. + lists := make([]*gtsmodel.List, 0, len(uncached)) + + // Perform database query scanning + // the remaining (uncached) IDs. + if err := l.db.NewSelect(). + Model(&lists). + Where("? IN (?)", bun.Ident("id"), bun.In(uncached)). + Scan(ctx); err != nil { + return nil, err + } + + return lists, nil + }, + ) + if err != nil { + return nil, err + } + + // Reorder the lists by their + // IDs to ensure in correct order. + getID := func(l *gtsmodel.List) string { return l.ID } + util.OrderBy(lists, ids, getID) + + if gtscontext.Barebones(ctx) { + // no need to fully populate. + return lists, nil + } + + // Populate all loaded lists, removing those we fail to + // populate (removes needing so many nil checks everywhere). + lists = slices.DeleteFunc(lists, func(list *gtsmodel.List) bool { + if err := l.PopulateList(ctx, list); err != nil { + log.Errorf(ctx, "error populating list %s: %v", list.ID, err) + return true } - listEntries = append(listEntries, listEntry) + return false + }) + + return lists, nil +} + +func (l *listDB) GetListEntriesByIDs(ctx context.Context, ids []string) ([]*gtsmodel.ListEntry, error) { + // Preallocate at-worst possible length. + uncached := make([]string, 0, len(ids)) + + // Load all entry IDs via cache loader callbacks. + entries, err := l.state.Caches.GTS.ListEntry.Load("ID", + + // Load cached + check for uncached. + func(load func(keyParts ...any) bool) { + for _, id := range ids { + if !load(id) { + uncached = append(uncached, id) + } + } + }, + + // Uncached entry loader function. + func() ([]*gtsmodel.ListEntry, error) { + // Preallocate expected length of uncached entries. + entries := make([]*gtsmodel.ListEntry, 0, len(uncached)) + + // Perform database query scanning + // the remaining (uncached) IDs. + if err := l.db.NewSelect(). + Model(&entries). + Where("? IN (?)", bun.Ident("id"), bun.In(uncached)). + Scan(ctx); err != nil { + return nil, err + } + + return entries, nil + }, + ) + if err != nil { + return nil, err } - return listEntries, nil + // Reorder the entries by their + // IDs to ensure in correct order. + getID := func(e *gtsmodel.ListEntry) string { return e.ID } + util.OrderBy(entries, ids, getID) + + if gtscontext.Barebones(ctx) { + // no need to fully populate. + return entries, nil + } + + // Populate all loaded entries, removing those we fail to + // populate (removes needing so many nil checks everywhere). + entries = slices.DeleteFunc(entries, func(entry *gtsmodel.ListEntry) bool { + if err := l.PopulateListEntry(ctx, entry); err != nil { + log.Errorf(ctx, "error populating entry %s: %v", entry.ID, err) + return true + } + return false + }) + + return entries, nil } func (l *listDB) GetListEntriesForFollowID(ctx context.Context, followID string) ([]*gtsmodel.ListEntry, error) { @@ -376,18 +478,8 @@ func (l *listDB) GetListEntriesForFollowID(ctx context.Context, followID string) return nil, nil } - // Select each list entry using its ID to ensure cache used. - listEntries := make([]*gtsmodel.ListEntry, 0, len(entryIDs)) - for _, id := range entryIDs { - listEntry, err := l.state.DB.GetListEntryByID(ctx, id) - if err != nil { - log.Errorf(ctx, "error fetching list entry %q: %v", id, err) - continue - } - listEntries = append(listEntries, listEntry) - } - - return listEntries, nil + // Return list entries by their IDs. + return l.GetListEntriesByIDs(ctx, entryIDs) } func (l *listDB) PopulateListEntry(ctx context.Context, listEntry *gtsmodel.ListEntry) error { @@ -409,10 +501,10 @@ func (l *listDB) PopulateListEntry(ctx context.Context, listEntry *gtsmodel.List func (l *listDB) PutListEntries(ctx context.Context, entries []*gtsmodel.ListEntry) error { defer func() { - // Collect unique list IDs from the entries. - listIDs := collate(func(i int) string { - return entries[i].ListID - }, len(entries)) + // Collect unique list IDs from the provided entries. + listIDs := util.Collate(entries, func(e *gtsmodel.ListEntry) string { + return e.ListID + }) for _, id := range listIDs { // Invalidate the timeline for the list this entry belongs to. @@ -426,7 +518,7 @@ func (l *listDB) PutListEntries(ctx context.Context, entries []*gtsmodel.ListEnt return l.db.RunInTx(ctx, func(tx Tx) error { for _, entry := range entries { entry := entry // rescope - if err := l.state.Caches.GTS.ListEntry().Store(entry, func() error { + if err := l.state.Caches.GTS.ListEntry.Store(entry, func() error { _, err := tx. NewInsert(). Model(entry). @@ -459,7 +551,7 @@ func (l *listDB) DeleteListEntry(ctx context.Context, id string) error { defer func() { // Invalidate this list entry upon delete. - l.state.Caches.GTS.ListEntry().Invalidate("ID", id) + l.state.Caches.GTS.ListEntry.Invalidate("ID", id) // Invalidate the timeline for the list this entry belongs to. if err := l.state.Timelines.List.RemoveTimeline(ctx, entry.ListID); err != nil { @@ -514,24 +606,3 @@ func (l *listDB) ListIncludesAccount(ctx context.Context, listID string, account return exists, err } - -// collate will collect the values of type T from an expected slice of length 'len', -// passing the expected index to each call of 'get' and deduplicating the end result. -func collate[T comparable](get func(int) T, len int) []T { - ts := make([]T, 0, len) - tm := make(map[T]struct{}, len) - - for i := 0; i < len; i++ { - // Get next. - t := get(i) - - if _, ok := tm[t]; !ok { - // New value, add - // to map + slice. - ts = append(ts, t) - tm[t] = struct{}{} - } - } - - return ts -} diff --git a/internal/db/bundb/marker.go b/internal/db/bundb/marker.go index 5d365e08a..b1dedb4f1 100644 --- a/internal/db/bundb/marker.go +++ b/internal/db/bundb/marker.go @@ -39,8 +39,8 @@ type markerDB struct { */ func (m *markerDB) GetMarker(ctx context.Context, accountID string, name gtsmodel.MarkerName) (*gtsmodel.Marker, error) { - marker, err := m.state.Caches.GTS.Marker().Load( - "AccountID.Name", + marker, err := m.state.Caches.GTS.Marker.LoadOne( + "AccountID,Name", func() (*gtsmodel.Marker, error) { var marker gtsmodel.Marker @@ -52,9 +52,7 @@ func (m *markerDB) GetMarker(ctx context.Context, accountID string, name gtsmode } return &marker, nil - }, - accountID, - name, + }, accountID, name, ) if err != nil { return nil, err // already processed @@ -74,7 +72,7 @@ func (m *markerDB) UpdateMarker(ctx context.Context, marker *gtsmodel.Marker) er marker.Version = prevMarker.Version + 1 } - return m.state.Caches.GTS.Marker().Store(marker, func() error { + return m.state.Caches.GTS.Marker.Store(marker, func() error { if prevMarker == nil { if _, err := m.db.NewInsert(). Model(marker). diff --git a/internal/db/bundb/media.go b/internal/db/bundb/media.go index a2603eacc..ce3c90083 100644 --- a/internal/db/bundb/media.go +++ b/internal/db/bundb/media.go @@ -20,14 +20,15 @@ package bundb import ( "context" "errors" + "slices" "time" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtscontext" "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" - "github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/state" + "github.com/superseriousbusiness/gotosocial/internal/util" "github.com/uptrace/bun" ) @@ -51,25 +52,52 @@ func (m *mediaDB) GetAttachmentByID(ctx context.Context, id string) (*gtsmodel.M } func (m *mediaDB) GetAttachmentsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.MediaAttachment, error) { - attachments := make([]*gtsmodel.MediaAttachment, 0, len(ids)) - - for _, id := range ids { - // Attempt fetch from DB - attachment, err := m.GetAttachmentByID(ctx, id) - if err != nil { - log.Errorf(ctx, "error getting attachment %q: %v", id, err) - continue - } + // Preallocate at-worst possible length. + uncached := make([]string, 0, len(ids)) + + // Load all media IDs via cache loader callbacks. + media, err := m.state.Caches.GTS.Media.Load("ID", + + // Load cached + check for uncached. + func(load func(keyParts ...any) bool) { + for _, id := range ids { + if !load(id) { + uncached = append(uncached, id) + } + } + }, + + // Uncached media loader function. + func() ([]*gtsmodel.MediaAttachment, error) { + // Preallocate expected length of uncached media attachments. + media := make([]*gtsmodel.MediaAttachment, 0, len(uncached)) + + // Perform database query scanning + // the remaining (uncached) IDs. + if err := m.db.NewSelect(). + Model(&media). + Where("? IN (?)", bun.Ident("id"), bun.In(uncached)). + Scan(ctx); err != nil { + return nil, err + } - // Append attachment - attachments = append(attachments, attachment) + return media, nil + }, + ) + if err != nil { + return nil, err } - return attachments, nil + // Reorder the media by their + // IDs to ensure in correct order. + getID := func(m *gtsmodel.MediaAttachment) string { return m.ID } + util.OrderBy(media, ids, getID) + + return media, nil } func (m *mediaDB) getAttachment(ctx context.Context, lookup string, dbQuery func(*gtsmodel.MediaAttachment) error, keyParts ...any) (*gtsmodel.MediaAttachment, error) { - return m.state.Caches.GTS.Media().Load(lookup, func() (*gtsmodel.MediaAttachment, error) { + return m.state.Caches.GTS.Media.LoadOne(lookup, func() (*gtsmodel.MediaAttachment, error) { var attachment gtsmodel.MediaAttachment // Not cached! Perform database query @@ -82,7 +110,7 @@ func (m *mediaDB) getAttachment(ctx context.Context, lookup string, dbQuery func } func (m *mediaDB) PutAttachment(ctx context.Context, media *gtsmodel.MediaAttachment) error { - return m.state.Caches.GTS.Media().Store(media, func() error { + return m.state.Caches.GTS.Media.Store(media, func() error { _, err := m.db.NewInsert().Model(media).Exec(ctx) return err }) @@ -95,7 +123,7 @@ func (m *mediaDB) UpdateAttachment(ctx context.Context, media *gtsmodel.MediaAtt columns = append(columns, "updated_at") } - return m.state.Caches.GTS.Media().Store(media, func() error { + return m.state.Caches.GTS.Media.Store(media, func() error { _, err := m.db.NewUpdate(). Model(media). Where("? = ?", bun.Ident("media_attachment.id"), media.ID). @@ -119,7 +147,7 @@ func (m *mediaDB) DeleteAttachment(ctx context.Context, id string) error { } // On return, ensure that media with ID is invalidated. - defer m.state.Caches.GTS.Media().Invalidate("ID", id) + defer m.state.Caches.GTS.Media.Invalidate("ID", id) // Delete media attachment in new transaction. err = m.db.RunInTx(ctx, func(tx Tx) error { @@ -171,8 +199,12 @@ func (m *mediaDB) DeleteAttachment(ctx context.Context, id string) error { return gtserror.Newf("error selecting status: %w", err) } - if updatedIDs := dropID(status.AttachmentIDs, id); // nocollapse - len(updatedIDs) != len(status.AttachmentIDs) { + // Delete all instances of this deleted media ID from status attachments. + updatedIDs := slices.DeleteFunc(status.AttachmentIDs, func(s string) bool { + return s == id + }) + + if len(updatedIDs) != len(status.AttachmentIDs) { // Note: this handles not found. // // Attachments changed, update the status. diff --git a/internal/db/bundb/mention.go b/internal/db/bundb/mention.go index 30a20b0c1..b069423bb 100644 --- a/internal/db/bundb/mention.go +++ b/internal/db/bundb/mention.go @@ -20,6 +20,7 @@ package bundb import ( "context" "errors" + "slices" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtscontext" @@ -27,6 +28,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/state" + "github.com/superseriousbusiness/gotosocial/internal/util" "github.com/uptrace/bun" ) @@ -36,7 +38,7 @@ type mentionDB struct { } func (m *mentionDB) GetMention(ctx context.Context, id string) (*gtsmodel.Mention, error) { - mention, err := m.state.Caches.GTS.Mention().Load("ID", func() (*gtsmodel.Mention, error) { + mention, err := m.state.Caches.GTS.Mention.LoadOne("ID", func() (*gtsmodel.Mention, error) { var mention gtsmodel.Mention q := m.db. @@ -63,21 +65,64 @@ func (m *mentionDB) GetMention(ctx context.Context, id string) (*gtsmodel.Mentio } func (m *mentionDB) GetMentions(ctx context.Context, ids []string) ([]*gtsmodel.Mention, error) { - mentions := make([]*gtsmodel.Mention, 0, len(ids)) + // Preallocate at-worst possible length. + uncached := make([]string, 0, len(ids)) + + // Load all mention IDs via cache loader callbacks. + mentions, err := m.state.Caches.GTS.Mention.Load("ID", + + // Load cached + check for uncached. + func(load func(keyParts ...any) bool) { + for _, id := range ids { + if !load(id) { + uncached = append(uncached, id) + } + } + }, + + // Uncached mention loader function. + func() ([]*gtsmodel.Mention, error) { + // Preallocate expected length of uncached mentions. + mentions := make([]*gtsmodel.Mention, 0, len(uncached)) + + // Perform database query scanning + // the remaining (uncached) IDs. + if err := m.db.NewSelect(). + Model(&mentions). + Where("? IN (?)", bun.Ident("id"), bun.In(uncached)). + Scan(ctx); err != nil { + return nil, err + } + + return mentions, nil + }, + ) + if err != nil { + return nil, err + } - for _, id := range ids { - // Attempt fetch from DB - mention, err := m.GetMention(ctx, id) - if err != nil { - log.Errorf(ctx, "error getting mention %q: %v", id, err) - continue - } + // Reorder the mentions by their + // IDs to ensure in correct order. + getID := func(m *gtsmodel.Mention) string { return m.ID } + util.OrderBy(mentions, ids, getID) - // Append mention - mentions = append(mentions, mention) + if gtscontext.Barebones(ctx) { + // no need to fully populate. + return mentions, nil } + // Populate all loaded mentions, removing those we fail to + // populate (removes needing so many nil checks everywhere). + mentions = slices.DeleteFunc(mentions, func(mention *gtsmodel.Mention) bool { + if err := m.PopulateMention(ctx, mention); err != nil { + log.Errorf(ctx, "error populating mention %s: %v", mention.ID, err) + return true + } + return false + }) + return mentions, nil + } func (m *mentionDB) PopulateMention(ctx context.Context, mention *gtsmodel.Mention) (err error) { @@ -120,14 +165,14 @@ func (m *mentionDB) PopulateMention(ctx context.Context, mention *gtsmodel.Menti } func (m *mentionDB) PutMention(ctx context.Context, mention *gtsmodel.Mention) error { - return m.state.Caches.GTS.Mention().Store(mention, func() error { + return m.state.Caches.GTS.Mention.Store(mention, func() error { _, err := m.db.NewInsert().Model(mention).Exec(ctx) return err }) } func (m *mentionDB) DeleteMentionByID(ctx context.Context, id string) error { - defer m.state.Caches.GTS.Mention().Invalidate("ID", id) + defer m.state.Caches.GTS.Mention.Invalidate("ID", id) // Load mention into cache before attempting a delete, // as we need it cached in order to trigger the invalidate diff --git a/internal/db/bundb/notification.go b/internal/db/bundb/notification.go index 7532b9993..ed34222fb 100644 --- a/internal/db/bundb/notification.go +++ b/internal/db/bundb/notification.go @@ -20,6 +20,7 @@ package bundb import ( "context" "errors" + "slices" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtscontext" @@ -28,6 +29,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/id" "github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/state" + "github.com/superseriousbusiness/gotosocial/internal/util" "github.com/uptrace/bun" ) @@ -37,18 +39,17 @@ type notificationDB struct { } func (n *notificationDB) GetNotificationByID(ctx context.Context, id string) (*gtsmodel.Notification, error) { - return n.state.Caches.GTS.Notification().Load("ID", func() (*gtsmodel.Notification, error) { - var notif gtsmodel.Notification - - q := n.db.NewSelect(). - Model(¬if). - Where("? = ?", bun.Ident("notification.id"), id) - if err := q.Scan(ctx); err != nil { - return nil, err - } - - return ¬if, nil - }, id) + return n.getNotification( + ctx, + "ID", + func(notif *gtsmodel.Notification) error { + return n.db.NewSelect(). + Model(notif). + Where("? = ?", bun.Ident("id"), id). + Scan(ctx) + }, + id, + ) } func (n *notificationDB) GetNotification( @@ -58,42 +59,113 @@ func (n *notificationDB) GetNotification( originAccountID string, statusID string, ) (*gtsmodel.Notification, error) { - notif, err := n.state.Caches.GTS.Notification().Load("NotificationType.TargetAccountID.OriginAccountID.StatusID", func() (*gtsmodel.Notification, error) { - var notif gtsmodel.Notification + return n.getNotification( + ctx, + "NotificationType,TargetAccountID,OriginAccountID,StatusID", + func(notif *gtsmodel.Notification) error { + return n.db.NewSelect(). + Model(notif). + Where("? = ?", bun.Ident("notification_type"), notificationType). + Where("? = ?", bun.Ident("target_account_id"), targetAccountID). + Where("? = ?", bun.Ident("origin_account_id"), originAccountID). + Where("? = ?", bun.Ident("status_id"), statusID). + Scan(ctx) + }, + notificationType, targetAccountID, originAccountID, statusID, + ) +} - q := n.db.NewSelect(). - Model(¬if). - Where("? = ?", bun.Ident("notification_type"), notificationType). - Where("? = ?", bun.Ident("target_account_id"), targetAccountID). - Where("? = ?", bun.Ident("origin_account_id"), originAccountID). - Where("? = ?", bun.Ident("status_id"), statusID) +func (n *notificationDB) getNotification(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Notification) error, keyParts ...any) (*gtsmodel.Notification, error) { + // Fetch notification from cache with loader callback + notif, err := n.state.Caches.GTS.Notification.LoadOne(lookup, func() (*gtsmodel.Notification, error) { + var notif gtsmodel.Notification - if err := q.Scan(ctx); err != nil { + // Not cached! Perform database query + if err := dbQuery(¬if); err != nil { return nil, err } return ¬if, nil - }, notificationType, targetAccountID, originAccountID, statusID) + }, keyParts...) if err != nil { return nil, err } if gtscontext.Barebones(ctx) { - // no need to fully populate. + // Only a barebones model was requested. return notif, nil } - // Further populate the notif fields where applicable. - if err := n.PopulateNotification(ctx, notif); err != nil { + if err := n.state.DB.PopulateNotification(ctx, notif); err != nil { return nil, err } return notif, nil } +func (n *notificationDB) GetNotificationsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.Notification, error) { + // Preallocate at-worst possible length. + uncached := make([]string, 0, len(ids)) + + // Load all notif IDs via cache loader callbacks. + notifs, err := n.state.Caches.GTS.Notification.Load("ID", + + // Load cached + check for uncached. + func(load func(keyParts ...any) bool) { + for _, id := range ids { + if !load(id) { + uncached = append(uncached, id) + } + } + }, + + // Uncached notification loader function. + func() ([]*gtsmodel.Notification, error) { + // Preallocate expected length of uncached notifications. + notifs := make([]*gtsmodel.Notification, 0, len(uncached)) + + // Perform database query scanning + // the remaining (uncached) IDs. + if err := n.db.NewSelect(). + Model(¬ifs). + Where("? IN (?)", bun.Ident("id"), bun.In(uncached)). + Scan(ctx); err != nil { + return nil, err + } + + return notifs, nil + }, + ) + if err != nil { + return nil, err + } + + // Reorder the notifs by their + // IDs to ensure in correct order. + getID := func(n *gtsmodel.Notification) string { return n.ID } + util.OrderBy(notifs, ids, getID) + + if gtscontext.Barebones(ctx) { + // no need to fully populate. + return notifs, nil + } + + // Populate all loaded notifs, removing those we fail to + // populate (removes needing so many nil checks everywhere). + notifs = slices.DeleteFunc(notifs, func(notif *gtsmodel.Notification) bool { + if err := n.PopulateNotification(ctx, notif); err != nil { + log.Errorf(ctx, "error populating notif %s: %v", notif.ID, err) + return true + } + return false + }) + + return notifs, nil +} + func (n *notificationDB) PopulateNotification(ctx context.Context, notif *gtsmodel.Notification) error { var ( - errs = gtserror.NewMultiError(2) + errs gtserror.MultiError err error ) @@ -211,31 +283,19 @@ func (n *notificationDB) GetAccountNotifications( } } - notifs := make([]*gtsmodel.Notification, 0, len(notifIDs)) - for _, id := range notifIDs { - // Attempt fetch from DB - notif, err := n.GetNotificationByID(ctx, id) - if err != nil { - log.Errorf(ctx, "error fetching notification %q: %v", id, err) - continue - } - - // Append notification - notifs = append(notifs, notif) - } - - return notifs, nil + // Fetch notification models by their IDs. + return n.GetNotificationsByIDs(ctx, notifIDs) } func (n *notificationDB) PutNotification(ctx context.Context, notif *gtsmodel.Notification) error { - return n.state.Caches.GTS.Notification().Store(notif, func() error { + return n.state.Caches.GTS.Notification.Store(notif, func() error { _, err := n.db.NewInsert().Model(notif).Exec(ctx) return err }) } func (n *notificationDB) DeleteNotificationByID(ctx context.Context, id string) error { - defer n.state.Caches.GTS.Notification().Invalidate("ID", id) + defer n.state.Caches.GTS.Notification.Invalidate("ID", id) // Load notif into cache before attempting a delete, // as we need it cached in order to trigger the invalidate @@ -288,7 +348,7 @@ func (n *notificationDB) DeleteNotifications(ctx context.Context, types []string defer func() { // Invalidate all IDs on return. for _, id := range notifIDs { - n.state.Caches.GTS.Notification().Invalidate("ID", id) + n.state.Caches.GTS.Notification.Invalidate("ID", id) } }() @@ -326,7 +386,7 @@ func (n *notificationDB) DeleteNotificationsForStatus(ctx context.Context, statu defer func() { // Invalidate all IDs on return. for _, id := range notifIDs { - n.state.Caches.GTS.Notification().Invalidate("ID", id) + n.state.Caches.GTS.Notification.Invalidate("ID", id) } }() diff --git a/internal/db/bundb/poll.go b/internal/db/bundb/poll.go index 3e77fb6c5..0dfb15621 100644 --- a/internal/db/bundb/poll.go +++ b/internal/db/bundb/poll.go @@ -20,6 +20,7 @@ package bundb import ( "context" "errors" + "slices" "time" "github.com/superseriousbusiness/gotosocial/internal/db" @@ -28,6 +29,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/state" + "github.com/superseriousbusiness/gotosocial/internal/util" "github.com/uptrace/bun" ) @@ -52,7 +54,7 @@ func (p *pollDB) GetPollByID(ctx context.Context, id string) (*gtsmodel.Poll, er func (p *pollDB) getPoll(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Poll) error, keyParts ...any) (*gtsmodel.Poll, error) { // Fetch poll from database cache with loader callback - poll, err := p.state.Caches.GTS.Poll().Load(lookup, func() (*gtsmodel.Poll, error) { + poll, err := p.state.Caches.GTS.Poll.LoadOne(lookup, func() (*gtsmodel.Poll, error) { var poll gtsmodel.Poll // Not cached! Perform database query. @@ -140,7 +142,7 @@ func (p *pollDB) PutPoll(ctx context.Context, poll *gtsmodel.Poll) error { // is non nil and set. poll.CheckVotes() - return p.state.Caches.GTS.Poll().Store(poll, func() error { + return p.state.Caches.GTS.Poll.Store(poll, func() error { _, err := p.db.NewInsert().Model(poll).Exec(ctx) return err }) @@ -151,7 +153,7 @@ func (p *pollDB) UpdatePoll(ctx context.Context, poll *gtsmodel.Poll, cols ...st // is non nil and set. poll.CheckVotes() - return p.state.Caches.GTS.Poll().Store(poll, func() error { + return p.state.Caches.GTS.Poll.Store(poll, func() error { return p.db.RunInTx(ctx, func(tx Tx) error { // Update the status' "updated_at" field. if _, err := tx.NewUpdate(). @@ -184,8 +186,8 @@ func (p *pollDB) DeletePollByID(ctx context.Context, id string) error { } // Invalidate poll by ID from cache. - p.state.Caches.GTS.Poll().Invalidate("ID", id) - p.state.Caches.GTS.PollVoteIDs().Invalidate(id) + p.state.Caches.GTS.Poll.Invalidate("ID", id) + p.state.Caches.GTS.PollVoteIDs.Invalidate(id) return nil } @@ -207,7 +209,7 @@ func (p *pollDB) GetPollVoteByID(ctx context.Context, id string) (*gtsmodel.Poll func (p *pollDB) GetPollVoteBy(ctx context.Context, pollID string, accountID string) (*gtsmodel.PollVote, error) { return p.getPollVote( ctx, - "PollID.AccountID", + "PollID,AccountID", func(vote *gtsmodel.PollVote) error { return p.db.NewSelect(). Model(vote). @@ -222,7 +224,7 @@ func (p *pollDB) GetPollVoteBy(ctx context.Context, pollID string, accountID str func (p *pollDB) getPollVote(ctx context.Context, lookup string, dbQuery func(*gtsmodel.PollVote) error, keyParts ...any) (*gtsmodel.PollVote, error) { // Fetch vote from database cache with loader callback - vote, err := p.state.Caches.GTS.PollVote().Load(lookup, func() (*gtsmodel.PollVote, error) { + vote, err := p.state.Caches.GTS.PollVote.LoadOne(lookup, func() (*gtsmodel.PollVote, error) { var vote gtsmodel.PollVote // Not cached! Perform database query. @@ -250,7 +252,9 @@ func (p *pollDB) getPollVote(ctx context.Context, lookup string, dbQuery func(*g } func (p *pollDB) GetPollVotes(ctx context.Context, pollID string) ([]*gtsmodel.PollVote, error) { - voteIDs, err := p.state.Caches.GTS.PollVoteIDs().Load(pollID, func() ([]string, error) { + + // Load vote IDs known for given poll ID using loader callback. + voteIDs, err := p.state.Caches.GTS.PollVoteIDs.Load(pollID, func() ([]string, error) { var voteIDs []string // Vote IDs not in cache, perform DB query! @@ -266,21 +270,62 @@ func (p *pollDB) GetPollVotes(ctx context.Context, pollID string) ([]*gtsmodel.P return nil, err } - // Preallocate slice of expected length. - votes := make([]*gtsmodel.PollVote, 0, len(voteIDs)) + // Preallocate at-worst possible length. + uncached := make([]string, 0, len(voteIDs)) - for _, id := range voteIDs { - // Fetch poll vote model for this ID. - vote, err := p.GetPollVoteByID(ctx, id) - if err != nil { - log.Errorf(ctx, "error getting poll vote %s: %v", id, err) - continue - } + // Load all votes from IDs via cache loader callbacks. + votes, err := p.state.Caches.GTS.PollVote.Load("ID", + + // Load cached + check for uncached. + func(load func(keyParts ...any) bool) { + for _, id := range voteIDs { + if !load(id) { + uncached = append(uncached, id) + } + } + }, + + // Uncached poll vote loader function. + func() ([]*gtsmodel.PollVote, error) { + // Preallocate expected length of uncached votes. + votes := make([]*gtsmodel.PollVote, 0, len(uncached)) + + // Perform database query scanning + // the remaining (uncached) IDs. + if err := p.db.NewSelect(). + Model(&votes). + Where("? IN (?)", bun.Ident("id"), bun.In(uncached)). + Scan(ctx); err != nil { + return nil, err + } + + return votes, nil + }, + ) + if err != nil { + return nil, err + } + + // Reorder the poll votes by their + // IDs to ensure in correct order. + getID := func(v *gtsmodel.PollVote) string { return v.ID } + util.OrderBy(votes, voteIDs, getID) - // Append to return slice. - votes = append(votes, vote) + if gtscontext.Barebones(ctx) { + // no need to fully populate. + return votes, nil } + // Populate all loaded votes, removing those we fail to + // populate (removes needing so many nil checks everywhere). + votes = slices.DeleteFunc(votes, func(vote *gtsmodel.PollVote) bool { + if err := p.PopulatePollVote(ctx, vote); err != nil { + log.Errorf(ctx, "error populating vote %s: %v", vote.ID, err) + return true + } + return false + }) + return votes, nil } @@ -316,7 +361,7 @@ func (p *pollDB) PopulatePollVote(ctx context.Context, vote *gtsmodel.PollVote) } func (p *pollDB) PutPollVote(ctx context.Context, vote *gtsmodel.PollVote) error { - return p.state.Caches.GTS.PollVote().Store(vote, func() error { + return p.state.Caches.GTS.PollVote.Store(vote, func() error { return p.db.RunInTx(ctx, func(tx Tx) error { // Try insert vote into database. if _, err := tx.NewInsert(). @@ -416,9 +461,9 @@ func (p *pollDB) DeletePollVotes(ctx context.Context, pollID string) error { } // Invalidate poll vote and poll entry from caches. - p.state.Caches.GTS.Poll().Invalidate("ID", pollID) - p.state.Caches.GTS.PollVote().Invalidate("PollID", pollID) - p.state.Caches.GTS.PollVoteIDs().Invalidate(pollID) + p.state.Caches.GTS.Poll.Invalidate("ID", pollID) + p.state.Caches.GTS.PollVote.Invalidate("PollID", pollID) + p.state.Caches.GTS.PollVoteIDs.Invalidate(pollID) return nil } @@ -428,7 +473,7 @@ func (p *pollDB) DeletePollVoteBy(ctx context.Context, pollID string, accountID // Slice should only ever be of length // 0 or 1; it's a slice of slices only // because we can't LIMIT deletes to 1. - var choicesSl [][]int + var choicesSlice [][]int // Delete vote in poll by account, // returning the ID + choices of the vote. @@ -437,17 +482,19 @@ func (p *pollDB) DeletePollVoteBy(ctx context.Context, pollID string, accountID Where("? = ?", bun.Ident("poll_id"), pollID). Where("? = ?", bun.Ident("account_id"), accountID). Returning("?", bun.Ident("choices")). - Scan(ctx, &choicesSl); err != nil { + Scan(ctx, &choicesSlice); err != nil { // irrecoverable. return err } - if len(choicesSl) != 1 { + if len(choicesSlice) != 1 { // No poll votes by this // acct on this poll. return nil } - choices := choicesSl[0] + + // Extract the *actual* choices. + choices := choicesSlice[0] // Select current poll counts from DB, // taking minimal columns needed to @@ -489,9 +536,9 @@ func (p *pollDB) DeletePollVoteBy(ctx context.Context, pollID string, accountID } // Invalidate poll vote and poll entry from caches. - p.state.Caches.GTS.Poll().Invalidate("ID", pollID) - p.state.Caches.GTS.PollVote().Invalidate("PollID.AccountID", pollID, accountID) - p.state.Caches.GTS.PollVoteIDs().Invalidate(pollID) + p.state.Caches.GTS.Poll.Invalidate("ID", pollID) + p.state.Caches.GTS.PollVote.Invalidate("PollID,AccountID", pollID, accountID) + p.state.Caches.GTS.PollVoteIDs.Invalidate(pollID) return nil } diff --git a/internal/db/bundb/relationship.go b/internal/db/bundb/relationship.go index 138a5aa17..4c50862a1 100644 --- a/internal/db/bundb/relationship.go +++ b/internal/db/bundb/relationship.go @@ -194,7 +194,7 @@ func (r *relationshipDB) CountAccountBlocks(ctx context.Context, accountID strin } func (r *relationshipDB) getAccountFollowIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) { - return loadPagedIDs(r.state.Caches.GTS.FollowIDs(), ">"+accountID, page, func() ([]string, error) { + return loadPagedIDs(r.state.Caches.GTS.FollowIDs, ">"+accountID, page, func() ([]string, error) { var followIDs []string // Follow IDs not in cache, perform DB query! @@ -209,7 +209,7 @@ func (r *relationshipDB) getAccountFollowIDs(ctx context.Context, accountID stri } func (r *relationshipDB) getAccountLocalFollowIDs(ctx context.Context, accountID string) ([]string, error) { - return r.state.Caches.GTS.FollowIDs().Load("l>"+accountID, func() ([]string, error) { + return r.state.Caches.GTS.FollowIDs.Load("l>"+accountID, func() ([]string, error) { var followIDs []string // Follow IDs not in cache, perform DB query! @@ -224,7 +224,7 @@ func (r *relationshipDB) getAccountLocalFollowIDs(ctx context.Context, accountID } func (r *relationshipDB) getAccountFollowerIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) { - return loadPagedIDs(r.state.Caches.GTS.FollowIDs(), "<"+accountID, page, func() ([]string, error) { + return loadPagedIDs(r.state.Caches.GTS.FollowIDs, "<"+accountID, page, func() ([]string, error) { var followIDs []string // Follow IDs not in cache, perform DB query! @@ -239,7 +239,7 @@ func (r *relationshipDB) getAccountFollowerIDs(ctx context.Context, accountID st } func (r *relationshipDB) getAccountLocalFollowerIDs(ctx context.Context, accountID string) ([]string, error) { - return r.state.Caches.GTS.FollowIDs().Load("l<"+accountID, func() ([]string, error) { + return r.state.Caches.GTS.FollowIDs.Load("l<"+accountID, func() ([]string, error) { var followIDs []string // Follow IDs not in cache, perform DB query! @@ -254,7 +254,7 @@ func (r *relationshipDB) getAccountLocalFollowerIDs(ctx context.Context, account } func (r *relationshipDB) getAccountFollowRequestIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) { - return loadPagedIDs(r.state.Caches.GTS.FollowRequestIDs(), ">"+accountID, page, func() ([]string, error) { + return loadPagedIDs(r.state.Caches.GTS.FollowRequestIDs, ">"+accountID, page, func() ([]string, error) { var followReqIDs []string // Follow request IDs not in cache, perform DB query! @@ -269,7 +269,7 @@ func (r *relationshipDB) getAccountFollowRequestIDs(ctx context.Context, account } func (r *relationshipDB) getAccountFollowRequestingIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) { - return loadPagedIDs(r.state.Caches.GTS.FollowRequestIDs(), "<"+accountID, page, func() ([]string, error) { + return loadPagedIDs(r.state.Caches.GTS.FollowRequestIDs, "<"+accountID, page, func() ([]string, error) { var followReqIDs []string // Follow request IDs not in cache, perform DB query! @@ -284,7 +284,7 @@ func (r *relationshipDB) getAccountFollowRequestingIDs(ctx context.Context, acco } func (r *relationshipDB) getAccountBlockIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) { - return loadPagedIDs(r.state.Caches.GTS.BlockIDs(), accountID, page, func() ([]string, error) { + return loadPagedIDs(r.state.Caches.GTS.BlockIDs, accountID, page, func() ([]string, error) { var blockIDs []string // Block IDs not in cache, perform DB query! diff --git a/internal/db/bundb/relationship_block.go b/internal/db/bundb/relationship_block.go index efaa6d1a9..178de6aa7 100644 --- a/internal/db/bundb/relationship_block.go +++ b/internal/db/bundb/relationship_block.go @@ -20,12 +20,14 @@ package bundb import ( "context" "errors" + "slices" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtscontext" "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/log" + "github.com/superseriousbusiness/gotosocial/internal/util" "github.com/uptrace/bun" ) @@ -86,7 +88,7 @@ func (r *relationshipDB) GetBlockByURI(ctx context.Context, uri string) (*gtsmod func (r *relationshipDB) GetBlock(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.Block, error) { return r.getBlock( ctx, - "AccountID.TargetAccountID", + "AccountID,TargetAccountID", func(block *gtsmodel.Block) error { return r.db.NewSelect().Model(block). Where("? = ?", bun.Ident("block.account_id"), sourceAccountID). @@ -99,27 +101,68 @@ func (r *relationshipDB) GetBlock(ctx context.Context, sourceAccountID string, t } func (r *relationshipDB) GetBlocksByIDs(ctx context.Context, ids []string) ([]*gtsmodel.Block, error) { - // Preallocate slice of expected length. - blocks := make([]*gtsmodel.Block, 0, len(ids)) + // Preallocate at-worst possible length. + uncached := make([]string, 0, len(ids)) + + // Load all blocks IDs via cache loader callbacks. + blocks, err := r.state.Caches.GTS.Block.Load("ID", + + // Load cached + check for uncached. + func(load func(keyParts ...any) bool) { + for _, id := range ids { + if !load(id) { + uncached = append(uncached, id) + } + } + }, - for _, id := range ids { - // Fetch block model for this ID. - block, err := r.GetBlockByID(ctx, id) - if err != nil { - log.Errorf(ctx, "error getting block %q: %v", id, err) - continue - } + // Uncached block loader function. + func() ([]*gtsmodel.Block, error) { + // Preallocate expected length of uncached blocks. + blocks := make([]*gtsmodel.Block, 0, len(uncached)) + + // Perform database query scanning + // the remaining (uncached) IDs. + if err := r.db.NewSelect(). + Model(&blocks). + Where("? IN (?)", bun.Ident("id"), bun.In(uncached)). + Scan(ctx); err != nil { + return nil, err + } + + return blocks, nil + }, + ) + if err != nil { + return nil, err + } + + // Reorder the blocks by their + // IDs to ensure in correct order. + getID := func(b *gtsmodel.Block) string { return b.ID } + util.OrderBy(blocks, ids, getID) - // Append to return slice. - blocks = append(blocks, block) + if gtscontext.Barebones(ctx) { + // no need to fully populate. + return blocks, nil } + // Populate all loaded blocks, removing those we fail to + // populate (removes needing so many nil checks everywhere). + blocks = slices.DeleteFunc(blocks, func(block *gtsmodel.Block) bool { + if err := r.PopulateBlock(ctx, block); err != nil { + log.Errorf(ctx, "error populating block %s: %v", block.ID, err) + return true + } + return false + }) + return blocks, nil } func (r *relationshipDB) getBlock(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Block) error, keyParts ...any) (*gtsmodel.Block, error) { // Fetch block from cache with loader callback - block, err := r.state.Caches.GTS.Block().Load(lookup, func() (*gtsmodel.Block, error) { + block, err := r.state.Caches.GTS.Block.LoadOne(lookup, func() (*gtsmodel.Block, error) { var block gtsmodel.Block // Not cached! Perform database query @@ -148,8 +191,8 @@ func (r *relationshipDB) getBlock(ctx context.Context, lookup string, dbQuery fu func (r *relationshipDB) PopulateBlock(ctx context.Context, block *gtsmodel.Block) error { var ( + errs gtserror.MultiError err error - errs = gtserror.NewMultiError(2) ) if block.Account == nil { @@ -178,7 +221,7 @@ func (r *relationshipDB) PopulateBlock(ctx context.Context, block *gtsmodel.Bloc } func (r *relationshipDB) PutBlock(ctx context.Context, block *gtsmodel.Block) error { - return r.state.Caches.GTS.Block().Store(block, func() error { + return r.state.Caches.GTS.Block.Store(block, func() error { _, err := r.db.NewInsert().Model(block).Exec(ctx) return err }) @@ -198,7 +241,7 @@ func (r *relationshipDB) DeleteBlockByID(ctx context.Context, id string) error { } // Drop this now-cached block on return after delete. - defer r.state.Caches.GTS.Block().Invalidate("ID", id) + defer r.state.Caches.GTS.Block.Invalidate("ID", id) // Finally delete block from DB. _, err = r.db.NewDelete(). @@ -222,7 +265,7 @@ func (r *relationshipDB) DeleteBlockByURI(ctx context.Context, uri string) error } // Drop this now-cached block on return after delete. - defer r.state.Caches.GTS.Block().Invalidate("URI", uri) + defer r.state.Caches.GTS.Block.Invalidate("URI", uri) // Finally delete block from DB. _, err = r.db.NewDelete(). @@ -251,22 +294,20 @@ func (r *relationshipDB) DeleteAccountBlocks(ctx context.Context, accountID stri defer func() { // Invalidate all account's incoming / outoing blocks on return. - r.state.Caches.GTS.Block().Invalidate("AccountID", accountID) - r.state.Caches.GTS.Block().Invalidate("TargetAccountID", accountID) + r.state.Caches.GTS.Block.Invalidate("AccountID", accountID) + r.state.Caches.GTS.Block.Invalidate("TargetAccountID", accountID) }() // Load all blocks into cache, this *really* isn't great // but it is the only way we can ensure we invalidate all // related caches correctly (e.g. visibility). - for _, id := range blockIDs { - _, err := r.GetBlockByID(ctx, id) - if err != nil && !errors.Is(err, db.ErrNoEntries) { - return err - } + _, err := r.GetAccountBlocks(ctx, accountID, nil) + if err != nil && !errors.Is(err, db.ErrNoEntries) { + return err } // Finally delete all from DB. - _, err := r.db.NewDelete(). + _, err = r.db.NewDelete(). Table("blocks"). Where("? IN (?)", bun.Ident("id"), bun.In(blockIDs)). Exec(ctx) diff --git a/internal/db/bundb/relationship_follow.go b/internal/db/bundb/relationship_follow.go index 6c5a75e4c..93ee69bd7 100644 --- a/internal/db/bundb/relationship_follow.go +++ b/internal/db/bundb/relationship_follow.go @@ -21,6 +21,7 @@ import ( "context" "errors" "fmt" + "slices" "time" "github.com/superseriousbusiness/gotosocial/internal/db" @@ -28,6 +29,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/log" + "github.com/superseriousbusiness/gotosocial/internal/util" "github.com/uptrace/bun" ) @@ -62,7 +64,7 @@ func (r *relationshipDB) GetFollowByURI(ctx context.Context, uri string) (*gtsmo func (r *relationshipDB) GetFollow(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.Follow, error) { return r.getFollow( ctx, - "AccountID.TargetAccountID", + "AccountID,TargetAccountID", func(follow *gtsmodel.Follow) error { return r.db.NewSelect(). Model(follow). @@ -76,21 +78,62 @@ func (r *relationshipDB) GetFollow(ctx context.Context, sourceAccountID string, } func (r *relationshipDB) GetFollowsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.Follow, error) { - // Preallocate slice of expected length. - follows := make([]*gtsmodel.Follow, 0, len(ids)) + // Preallocate at-worst possible length. + uncached := make([]string, 0, len(ids)) + + // Load all follow IDs via cache loader callbacks. + follows, err := r.state.Caches.GTS.Follow.Load("ID", + + // Load cached + check for uncached. + func(load func(keyParts ...any) bool) { + for _, id := range ids { + if !load(id) { + uncached = append(uncached, id) + } + } + }, - for _, id := range ids { - // Fetch follow model for this ID. - follow, err := r.GetFollowByID(ctx, id) - if err != nil { - log.Errorf(ctx, "error getting follow %q: %v", id, err) - continue - } + // Uncached follow loader function. + func() ([]*gtsmodel.Follow, error) { + // Preallocate expected length of uncached follows. + follows := make([]*gtsmodel.Follow, 0, len(uncached)) + + // Perform database query scanning + // the remaining (uncached) IDs. + if err := r.db.NewSelect(). + Model(&follows). + Where("? IN (?)", bun.Ident("id"), bun.In(uncached)). + Scan(ctx); err != nil { + return nil, err + } + + return follows, nil + }, + ) + if err != nil { + return nil, err + } + + // Reorder the follows by their + // IDs to ensure in correct order. + getID := func(f *gtsmodel.Follow) string { return f.ID } + util.OrderBy(follows, ids, getID) - // Append to return slice. - follows = append(follows, follow) + if gtscontext.Barebones(ctx) { + // no need to fully populate. + return follows, nil } + // Populate all loaded follows, removing those we fail to + // populate (removes needing so many nil checks everywhere). + follows = slices.DeleteFunc(follows, func(follow *gtsmodel.Follow) bool { + if err := r.PopulateFollow(ctx, follow); err != nil { + log.Errorf(ctx, "error populating follow %s: %v", follow.ID, err) + return true + } + return false + }) + return follows, nil } @@ -130,7 +173,7 @@ func (r *relationshipDB) IsMutualFollowing(ctx context.Context, accountID1 strin func (r *relationshipDB) getFollow(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Follow) error, keyParts ...any) (*gtsmodel.Follow, error) { // Fetch follow from database cache with loader callback - follow, err := r.state.Caches.GTS.Follow().Load(lookup, func() (*gtsmodel.Follow, error) { + follow, err := r.state.Caches.GTS.Follow.LoadOne(lookup, func() (*gtsmodel.Follow, error) { var follow gtsmodel.Follow // Not cached! Perform database query @@ -189,7 +232,7 @@ func (r *relationshipDB) PopulateFollow(ctx context.Context, follow *gtsmodel.Fo } func (r *relationshipDB) PutFollow(ctx context.Context, follow *gtsmodel.Follow) error { - return r.state.Caches.GTS.Follow().Store(follow, func() error { + return r.state.Caches.GTS.Follow.Store(follow, func() error { _, err := r.db.NewInsert().Model(follow).Exec(ctx) return err }) @@ -202,7 +245,7 @@ func (r *relationshipDB) UpdateFollow(ctx context.Context, follow *gtsmodel.Foll columns = append(columns, "updated_at") } - return r.state.Caches.GTS.Follow().Store(follow, func() error { + return r.state.Caches.GTS.Follow.Store(follow, func() error { if _, err := r.db.NewUpdate(). Model(follow). Where("? = ?", bun.Ident("follow.id"), follow.ID). @@ -250,7 +293,7 @@ func (r *relationshipDB) DeleteFollow(ctx context.Context, sourceAccountID strin } // Drop this now-cached follow on return after delete. - defer r.state.Caches.GTS.Follow().Invalidate("AccountID.TargetAccountID", sourceAccountID, targetAccountID) + defer r.state.Caches.GTS.Follow.Invalidate("AccountID,TargetAccountID", sourceAccountID, targetAccountID) // Finally delete follow from DB. return r.deleteFollow(ctx, follow.ID) @@ -270,7 +313,7 @@ func (r *relationshipDB) DeleteFollowByID(ctx context.Context, id string) error } // Drop this now-cached follow on return after delete. - defer r.state.Caches.GTS.Follow().Invalidate("ID", id) + defer r.state.Caches.GTS.Follow.Invalidate("ID", id) // Finally delete follow from DB. return r.deleteFollow(ctx, follow.ID) @@ -290,7 +333,7 @@ func (r *relationshipDB) DeleteFollowByURI(ctx context.Context, uri string) erro } // Drop this now-cached follow on return after delete. - defer r.state.Caches.GTS.Follow().Invalidate("URI", uri) + defer r.state.Caches.GTS.Follow.Invalidate("URI", uri) // Finally delete follow from DB. return r.deleteFollow(ctx, follow.ID) @@ -316,22 +359,30 @@ func (r *relationshipDB) DeleteAccountFollows(ctx context.Context, accountID str defer func() { // Invalidate all account's incoming / outoing follows on return. - r.state.Caches.GTS.Follow().Invalidate("AccountID", accountID) - r.state.Caches.GTS.Follow().Invalidate("TargetAccountID", accountID) + r.state.Caches.GTS.Follow.Invalidate("AccountID", accountID) + r.state.Caches.GTS.Follow.Invalidate("TargetAccountID", accountID) }() // Load all follows into cache, this *really* isn't great // but it is the only way we can ensure we invalidate all // related caches correctly (e.g. visibility). - for _, id := range followIDs { - follow, err := r.GetFollowByID(ctx, id) - if err != nil && !errors.Is(err, db.ErrNoEntries) { - return err - } + _, err := r.GetAccountFollows(ctx, accountID, nil) + if err != nil && !errors.Is(err, db.ErrNoEntries) { + return err + } - // Delete each follow from DB. - if err := r.deleteFollow(ctx, follow.ID); err != nil && - !errors.Is(err, db.ErrNoEntries) { + // Delete all follows from DB. + _, err = r.db.NewDelete(). + Table("follows"). + Where("? IN (?)", bun.Ident("id"), bun.In(followIDs)). + Exec(ctx) + if err != nil { + return err + } + + for _, id := range followIDs { + // Finally, delete all list entries associated with each follow ID. + if err := r.state.DB.DeleteListEntriesForFollowID(ctx, id); err != nil { return err } } diff --git a/internal/db/bundb/relationship_follow_req.go b/internal/db/bundb/relationship_follow_req.go index 51aceafe1..690b97cf0 100644 --- a/internal/db/bundb/relationship_follow_req.go +++ b/internal/db/bundb/relationship_follow_req.go @@ -20,6 +20,7 @@ package bundb import ( "context" "errors" + "slices" "time" "github.com/superseriousbusiness/gotosocial/internal/db" @@ -27,6 +28,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/log" + "github.com/superseriousbusiness/gotosocial/internal/util" "github.com/uptrace/bun" ) @@ -61,7 +63,7 @@ func (r *relationshipDB) GetFollowRequestByURI(ctx context.Context, uri string) func (r *relationshipDB) GetFollowRequest(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.FollowRequest, error) { return r.getFollowRequest( ctx, - "AccountID.TargetAccountID", + "AccountID,TargetAccountID", func(followReq *gtsmodel.FollowRequest) error { return r.db.NewSelect(). Model(followReq). @@ -75,22 +77,63 @@ func (r *relationshipDB) GetFollowRequest(ctx context.Context, sourceAccountID s } func (r *relationshipDB) GetFollowRequestsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.FollowRequest, error) { - // Preallocate slice of expected length. - followReqs := make([]*gtsmodel.FollowRequest, 0, len(ids)) + // Preallocate at-worst possible length. + uncached := make([]string, 0, len(ids)) + + // Load all follow IDs via cache loader callbacks. + follows, err := r.state.Caches.GTS.FollowRequest.Load("ID", + + // Load cached + check for uncached. + func(load func(keyParts ...any) bool) { + for _, id := range ids { + if !load(id) { + uncached = append(uncached, id) + } + } + }, - for _, id := range ids { - // Fetch follow request model for this ID. - followReq, err := r.GetFollowRequestByID(ctx, id) - if err != nil { - log.Errorf(ctx, "error getting follow request %q: %v", id, err) - continue - } + // Uncached follow req loader function. + func() ([]*gtsmodel.FollowRequest, error) { + // Preallocate expected length of uncached followReqs. + follows := make([]*gtsmodel.FollowRequest, 0, len(uncached)) + + // Perform database query scanning + // the remaining (uncached) IDs. + if err := r.db.NewSelect(). + Model(&follows). + Where("? IN (?)", bun.Ident("id"), bun.In(uncached)). + Scan(ctx); err != nil { + return nil, err + } + + return follows, nil + }, + ) + if err != nil { + return nil, err + } + + // Reorder the requests by their + // IDs to ensure in correct order. + getID := func(f *gtsmodel.FollowRequest) string { return f.ID } + util.OrderBy(follows, ids, getID) - // Append to return slice. - followReqs = append(followReqs, followReq) + if gtscontext.Barebones(ctx) { + // no need to fully populate. + return follows, nil } - return followReqs, nil + // Populate all loaded followreqs, removing those we fail to + // populate (removes needing so many nil checks everywhere). + follows = slices.DeleteFunc(follows, func(follow *gtsmodel.FollowRequest) bool { + if err := r.PopulateFollowRequest(ctx, follow); err != nil { + log.Errorf(ctx, "error populating follow request %s: %v", follow.ID, err) + return true + } + return false + }) + + return follows, nil } func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, error) { @@ -107,7 +150,7 @@ func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccountID func (r *relationshipDB) getFollowRequest(ctx context.Context, lookup string, dbQuery func(*gtsmodel.FollowRequest) error, keyParts ...any) (*gtsmodel.FollowRequest, error) { // Fetch follow request from database cache with loader callback - followReq, err := r.state.Caches.GTS.FollowRequest().Load(lookup, func() (*gtsmodel.FollowRequest, error) { + followReq, err := r.state.Caches.GTS.FollowRequest.LoadOne(lookup, func() (*gtsmodel.FollowRequest, error) { var followReq gtsmodel.FollowRequest // Not cached! Perform database query @@ -166,7 +209,7 @@ func (r *relationshipDB) PopulateFollowRequest(ctx context.Context, follow *gtsm } func (r *relationshipDB) PutFollowRequest(ctx context.Context, follow *gtsmodel.FollowRequest) error { - return r.state.Caches.GTS.FollowRequest().Store(follow, func() error { + return r.state.Caches.GTS.FollowRequest.Store(follow, func() error { _, err := r.db.NewInsert().Model(follow).Exec(ctx) return err }) @@ -179,7 +222,7 @@ func (r *relationshipDB) UpdateFollowRequest(ctx context.Context, followRequest columns = append(columns, "updated_at") } - return r.state.Caches.GTS.FollowRequest().Store(followRequest, func() error { + return r.state.Caches.GTS.FollowRequest.Store(followRequest, func() error { if _, err := r.db.NewUpdate(). Model(followRequest). Where("? = ?", bun.Ident("follow_request.id"), followRequest.ID). @@ -212,7 +255,7 @@ func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, sourceAccountI Notify: followReq.Notify, } - if err := r.state.Caches.GTS.Follow().Store(follow, func() error { + if err := r.state.Caches.GTS.Follow.Store(follow, func() error { // If the follow already exists, just // replace the URI with the new one. _, err := r.db. @@ -274,7 +317,7 @@ func (r *relationshipDB) DeleteFollowRequest(ctx context.Context, sourceAccountI } // Drop this now-cached follow request on return after delete. - defer r.state.Caches.GTS.FollowRequest().Invalidate("AccountID.TargetAccountID", sourceAccountID, targetAccountID) + defer r.state.Caches.GTS.FollowRequest.Invalidate("AccountID,TargetAccountID", sourceAccountID, targetAccountID) // Finally delete followreq from DB. _, err = r.db.NewDelete(). @@ -298,7 +341,7 @@ func (r *relationshipDB) DeleteFollowRequestByID(ctx context.Context, id string) } // Drop this now-cached follow request on return after delete. - defer r.state.Caches.GTS.FollowRequest().Invalidate("ID", id) + defer r.state.Caches.GTS.FollowRequest.Invalidate("ID", id) // Finally delete followreq from DB. _, err = r.db.NewDelete(). @@ -322,7 +365,7 @@ func (r *relationshipDB) DeleteFollowRequestByURI(ctx context.Context, uri strin } // Drop this now-cached follow request on return after delete. - defer r.state.Caches.GTS.FollowRequest().Invalidate("URI", uri) + defer r.state.Caches.GTS.FollowRequest.Invalidate("URI", uri) // Finally delete followreq from DB. _, err = r.db.NewDelete(). @@ -352,22 +395,20 @@ func (r *relationshipDB) DeleteAccountFollowRequests(ctx context.Context, accoun defer func() { // Invalidate all account's incoming / outoing follow requests on return. - r.state.Caches.GTS.FollowRequest().Invalidate("AccountID", accountID) - r.state.Caches.GTS.FollowRequest().Invalidate("TargetAccountID", accountID) + r.state.Caches.GTS.FollowRequest.Invalidate("AccountID", accountID) + r.state.Caches.GTS.FollowRequest.Invalidate("TargetAccountID", accountID) }() // Load all followreqs into cache, this *really* isn't // great but it is the only way we can ensure we invalidate // all related caches correctly (e.g. visibility). - for _, id := range followReqIDs { - _, err := r.GetFollowRequestByID(ctx, id) - if err != nil && !errors.Is(err, db.ErrNoEntries) { - return err - } + _, err := r.GetAccountFollowRequests(ctx, accountID, nil) + if err != nil && !errors.Is(err, db.ErrNoEntries) { + return err } // Finally delete all from DB. - _, err := r.db.NewDelete(). + _, err = r.db.NewDelete(). Table("follow_requests"). Where("? IN (?)", bun.Ident("id"), bun.In(followReqIDs)). Exec(ctx) diff --git a/internal/db/bundb/relationship_note.go b/internal/db/bundb/relationship_note.go index f7d15f8b7..126ea0cd1 100644 --- a/internal/db/bundb/relationship_note.go +++ b/internal/db/bundb/relationship_note.go @@ -30,7 +30,7 @@ import ( func (r *relationshipDB) GetNote(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.AccountNote, error) { return r.getNote( ctx, - "AccountID.TargetAccountID", + "AccountID,TargetAccountID", func(note *gtsmodel.AccountNote) error { return r.db.NewSelect().Model(note). Where("? = ?", bun.Ident("account_id"), sourceAccountID). @@ -44,7 +44,7 @@ func (r *relationshipDB) GetNote(ctx context.Context, sourceAccountID string, ta func (r *relationshipDB) getNote(ctx context.Context, lookup string, dbQuery func(*gtsmodel.AccountNote) error, keyParts ...any) (*gtsmodel.AccountNote, error) { // Fetch note from cache with loader callback - note, err := r.state.Caches.GTS.AccountNote().Load(lookup, func() (*gtsmodel.AccountNote, error) { + note, err := r.state.Caches.GTS.AccountNote.LoadOne(lookup, func() (*gtsmodel.AccountNote, error) { var note gtsmodel.AccountNote // Not cached! Perform database query @@ -105,7 +105,7 @@ func (r *relationshipDB) PopulateNote(ctx context.Context, note *gtsmodel.Accoun func (r *relationshipDB) PutNote(ctx context.Context, note *gtsmodel.AccountNote) error { note.UpdatedAt = time.Now() - return r.state.Caches.GTS.AccountNote().Store(note, func() error { + return r.state.Caches.GTS.AccountNote.Store(note, func() error { _, err := r.db. NewInsert(). Model(note). diff --git a/internal/db/bundb/report.go b/internal/db/bundb/report.go index 9e4ba5b29..5b0ae17f3 100644 --- a/internal/db/bundb/report.go +++ b/internal/db/bundb/report.go @@ -120,7 +120,7 @@ func (r *reportDB) GetReports(ctx context.Context, resolved *bool, accountID str func (r *reportDB) getReport(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Report) error, keyParts ...any) (*gtsmodel.Report, error) { // Fetch report from database cache with loader callback - report, err := r.state.Caches.GTS.Report().Load(lookup, func() (*gtsmodel.Report, error) { + report, err := r.state.Caches.GTS.Report.LoadOne(lookup, func() (*gtsmodel.Report, error) { var report gtsmodel.Report // Not cached! Perform database query @@ -215,7 +215,7 @@ func (r *reportDB) PopulateReport(ctx context.Context, report *gtsmodel.Report) } func (r *reportDB) PutReport(ctx context.Context, report *gtsmodel.Report) error { - return r.state.Caches.GTS.Report().Store(report, func() error { + return r.state.Caches.GTS.Report.Store(report, func() error { _, err := r.db.NewInsert().Model(report).Exec(ctx) return err }) @@ -237,12 +237,12 @@ func (r *reportDB) UpdateReport(ctx context.Context, report *gtsmodel.Report, co return nil, err } - r.state.Caches.GTS.Report().Invalidate("ID", report.ID) + r.state.Caches.GTS.Report.Invalidate("ID", report.ID) return report, nil } func (r *reportDB) DeleteReportByID(ctx context.Context, id string) error { - defer r.state.Caches.GTS.Report().Invalidate("ID", id) + defer r.state.Caches.GTS.Report.Invalidate("ID", id) // Load status into cache before attempting a delete, // as we need it cached in order to trigger the invalidate diff --git a/internal/db/bundb/rule.go b/internal/db/bundb/rule.go index 79825923b..ebfa89d15 100644 --- a/internal/db/bundb/rule.go +++ b/internal/db/bundb/rule.go @@ -125,7 +125,7 @@ func (r *ruleDB) PutRule(ctx context.Context, rule *gtsmodel.Rule) error { } // invalidate cached local instance response, so it gets updated with the new rules - r.state.Caches.GTS.Instance().Invalidate("Domain", config.GetHost()) + r.state.Caches.GTS.Instance.Invalidate("Domain", config.GetHost()) return nil } @@ -143,7 +143,7 @@ func (r *ruleDB) UpdateRule(ctx context.Context, rule *gtsmodel.Rule) (*gtsmodel } // invalidate cached local instance response, so it gets updated with the new rules - r.state.Caches.GTS.Instance().Invalidate("Domain", config.GetHost()) + r.state.Caches.GTS.Instance.Invalidate("Domain", config.GetHost()) return rule, nil } diff --git a/internal/db/bundb/status.go b/internal/db/bundb/status.go index da252c7f7..07a09050a 100644 --- a/internal/db/bundb/status.go +++ b/internal/db/bundb/status.go @@ -20,6 +20,7 @@ package bundb import ( "context" "errors" + "slices" "time" "github.com/superseriousbusiness/gotosocial/internal/db" @@ -28,6 +29,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/state" + "github.com/superseriousbusiness/gotosocial/internal/util" "github.com/uptrace/bun" ) @@ -48,20 +50,62 @@ func (s *statusDB) GetStatusByID(ctx context.Context, id string) (*gtsmodel.Stat } func (s *statusDB) GetStatusesByIDs(ctx context.Context, ids []string) ([]*gtsmodel.Status, error) { - statuses := make([]*gtsmodel.Status, 0, len(ids)) + // Preallocate at-worst possible length. + uncached := make([]string, 0, len(ids)) - for _, id := range ids { - // Attempt to fetch status from DB. - status, err := s.GetStatusByID(ctx, id) - if err != nil { - log.Errorf(ctx, "error getting status %q: %v", id, err) - continue - } + // Load all status IDs via cache loader callbacks. + statuses, err := s.state.Caches.GTS.Status.Load("ID", + + // Load cached + check for uncached. + func(load func(keyParts ...any) bool) { + for _, id := range ids { + if !load(id) { + uncached = append(uncached, id) + } + } + }, + + // Uncached statuses loader function. + func() ([]*gtsmodel.Status, error) { + // Preallocate expected length of uncached statuses. + statuses := make([]*gtsmodel.Status, 0, len(uncached)) + + // Perform database query scanning + // the remaining (uncached) status IDs. + if err := s.db.NewSelect(). + Model(&statuses). + Where("? IN (?)", bun.Ident("id"), bun.In(uncached)). + Scan(ctx); err != nil { + return nil, err + } + + return statuses, nil + }, + ) + if err != nil { + return nil, err + } + + // Reorder the statuses by their + // IDs to ensure in correct order. + getID := func(s *gtsmodel.Status) string { return s.ID } + util.OrderBy(statuses, ids, getID) - // Append status to return slice. - statuses = append(statuses, status) + if gtscontext.Barebones(ctx) { + // no need to fully populate. + return statuses, nil } + // Populate all loaded statuses, removing those we fail to + // populate (removes needing so many nil checks everywhere). + statuses = slices.DeleteFunc(statuses, func(status *gtsmodel.Status) bool { + if err := s.PopulateStatus(ctx, status); err != nil { + log.Errorf(ctx, "error populating status %s: %v", status.ID, err) + return true + } + return false + }) + return statuses, nil } @@ -101,7 +145,7 @@ func (s *statusDB) GetStatusByPollID(ctx context.Context, pollID string) (*gtsmo func (s *statusDB) GetStatusBoost(ctx context.Context, boostOfID string, byAccountID string) (*gtsmodel.Status, error) { return s.getStatus( ctx, - "BoostOfID.AccountID", + "BoostOfID,AccountID", func(status *gtsmodel.Status) error { return s.db.NewSelect().Model(status). Where("status.boost_of_id = ?", boostOfID). @@ -120,7 +164,7 @@ func (s *statusDB) GetStatusBoost(ctx context.Context, boostOfID string, byAccou func (s *statusDB) getStatus(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Status) error, keyParts ...any) (*gtsmodel.Status, error) { // Fetch status from database cache with loader callback - status, err := s.state.Caches.GTS.Status().Load(lookup, func() (*gtsmodel.Status, error) { + status, err := s.state.Caches.GTS.Status.LoadOne(lookup, func() (*gtsmodel.Status, error) { var status gtsmodel.Status // Not cached! Perform database query. @@ -282,7 +326,7 @@ func (s *statusDB) PopulateStatus(ctx context.Context, status *gtsmodel.Status) } func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) error { - return s.state.Caches.GTS.Status().Store(status, func() error { + return s.state.Caches.GTS.Status.Store(status, func() error { // It is safe to run this database transaction within cache.Store // as the cache does not attempt a mutex lock until AFTER hook. // @@ -366,7 +410,7 @@ func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status, co columns = append(columns, "updated_at") } - return s.state.Caches.GTS.Status().Store(status, func() error { + return s.state.Caches.GTS.Status.Store(status, func() error { // It is safe to run this database transaction within cache.Store // as the cache does not attempt a mutex lock until AFTER hook. // @@ -463,7 +507,7 @@ func (s *statusDB) DeleteStatusByID(ctx context.Context, id string) error { } // On return ensure status invalidated from cache. - defer s.state.Caches.GTS.Status().Invalidate("ID", id) + defer s.state.Caches.GTS.Status.Invalidate("ID", id) return s.db.RunInTx(ctx, func(tx Tx) error { // delete links between this status and any emojis it uses @@ -585,7 +629,7 @@ func (s *statusDB) CountStatusReplies(ctx context.Context, statusID string) (int } func (s *statusDB) getStatusReplyIDs(ctx context.Context, statusID string) ([]string, error) { - return s.state.Caches.GTS.InReplyToIDs().Load(statusID, func() ([]string, error) { + return s.state.Caches.GTS.InReplyToIDs.Load(statusID, func() ([]string, error) { var statusIDs []string // Status reply IDs not in cache, perform DB query! @@ -629,7 +673,7 @@ func (s *statusDB) CountStatusBoosts(ctx context.Context, statusID string) (int, } func (s *statusDB) getStatusBoostIDs(ctx context.Context, statusID string) ([]string, error) { - return s.state.Caches.GTS.BoostOfIDs().Load(statusID, func() ([]string, error) { + return s.state.Caches.GTS.BoostOfIDs.Load(statusID, func() ([]string, error) { var statusIDs []string // Status boost IDs not in cache, perform DB query! diff --git a/internal/db/bundb/statusfave.go b/internal/db/bundb/statusfave.go index 73ac62fe7..e0f018b68 100644 --- a/internal/db/bundb/statusfave.go +++ b/internal/db/bundb/statusfave.go @@ -22,6 +22,7 @@ import ( "database/sql" "errors" "fmt" + "slices" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtscontext" @@ -29,6 +30,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/state" + "github.com/superseriousbusiness/gotosocial/internal/util" "github.com/uptrace/bun" ) @@ -40,7 +42,7 @@ type statusFaveDB struct { func (s *statusFaveDB) GetStatusFave(ctx context.Context, accountID string, statusID string) (*gtsmodel.StatusFave, error) { return s.getStatusFave( ctx, - "AccountID.StatusID", + "AccountID,StatusID", func(fave *gtsmodel.StatusFave) error { return s.db. NewSelect(). @@ -77,7 +79,7 @@ func (s *statusFaveDB) GetStatusFaveByID(ctx context.Context, id string) (*gtsmo func (s *statusFaveDB) getStatusFave(ctx context.Context, lookup string, dbQuery func(*gtsmodel.StatusFave) error, keyParts ...any) (*gtsmodel.StatusFave, error) { // Fetch status fave from database cache with loader callback - fave, err := s.state.Caches.GTS.StatusFave().Load(lookup, func() (*gtsmodel.StatusFave, error) { + fave, err := s.state.Caches.GTS.StatusFave.LoadOne(lookup, func() (*gtsmodel.StatusFave, error) { var fave gtsmodel.StatusFave // Not cached! Perform database query. @@ -111,19 +113,62 @@ func (s *statusFaveDB) GetStatusFaves(ctx context.Context, statusID string) ([]* return nil, err } - // Preallocate a slice of expected status fave capacity. - faves := make([]*gtsmodel.StatusFave, 0, len(faveIDs)) + // Preallocate at-worst possible length. + uncached := make([]string, 0, len(faveIDs)) - for _, id := range faveIDs { - // Fetch status fave model for each ID. - fave, err := s.GetStatusFaveByID(ctx, id) - if err != nil { - log.Errorf(ctx, "error getting status fave %q: %v", id, err) - continue - } - faves = append(faves, fave) + // Load all fave IDs via cache loader callbacks. + faves, err := s.state.Caches.GTS.StatusFave.Load("ID", + + // Load cached + check for uncached. + func(load func(keyParts ...any) bool) { + for _, id := range faveIDs { + if !load(id) { + uncached = append(uncached, id) + } + } + }, + + // Uncached status faves loader function. + func() ([]*gtsmodel.StatusFave, error) { + // Preallocate expected length of uncached faves. + faves := make([]*gtsmodel.StatusFave, 0, len(uncached)) + + // Perform database query scanning + // the remaining (uncached) fave IDs. + if err := s.db.NewSelect(). + Model(&faves). + Where("? IN (?)", bun.Ident("id"), bun.In(uncached)). + Scan(ctx); err != nil { + return nil, err + } + + return faves, nil + }, + ) + if err != nil { + return nil, err } + // Reorder the statuses by their + // IDs to ensure in correct order. + getID := func(f *gtsmodel.StatusFave) string { return f.ID } + util.OrderBy(faves, faveIDs, getID) + + if gtscontext.Barebones(ctx) { + // no need to fully populate. + return faves, nil + } + + // Populate all loaded faves, removing those we fail to + // populate (removes needing so many nil checks everywhere). + faves = slices.DeleteFunc(faves, func(fave *gtsmodel.StatusFave) bool { + if err := s.PopulateStatusFave(ctx, fave); err != nil { + log.Errorf(ctx, "error populating fave %s: %v", fave.ID, err) + return true + } + return false + }) + return faves, nil } @@ -141,7 +186,7 @@ func (s *statusFaveDB) CountStatusFaves(ctx context.Context, statusID string) (i } func (s *statusFaveDB) getStatusFaveIDs(ctx context.Context, statusID string) ([]string, error) { - return s.state.Caches.GTS.StatusFaveIDs().Load(statusID, func() ([]string, error) { + return s.state.Caches.GTS.StatusFaveIDs.Load(statusID, func() ([]string, error) { var faveIDs []string // Status fave IDs not in cache, perform DB query! @@ -201,7 +246,7 @@ func (s *statusFaveDB) PopulateStatusFave(ctx context.Context, statusFave *gtsmo } func (s *statusFaveDB) PutStatusFave(ctx context.Context, fave *gtsmodel.StatusFave) error { - return s.state.Caches.GTS.StatusFave().Store(fave, func() error { + return s.state.Caches.GTS.StatusFave.Store(fave, func() error { _, err := s.db. NewInsert(). Model(fave). @@ -230,10 +275,10 @@ func (s *statusFaveDB) DeleteStatusFaveByID(ctx context.Context, id string) erro if statusID != "" { // Invalidate any cached status faves for this status. - s.state.Caches.GTS.StatusFave().Invalidate("ID", id) + s.state.Caches.GTS.StatusFave.Invalidate("ID", id) // Invalidate any cached status fave IDs for this status. - s.state.Caches.GTS.StatusFaveIDs().Invalidate(statusID) + s.state.Caches.GTS.StatusFaveIDs.Invalidate(statusID) } return nil @@ -270,17 +315,15 @@ func (s *statusFaveDB) DeleteStatusFaves(ctx context.Context, targetAccountID st return err } - // Collate (deduplicating) status IDs. - statusIDs = collate(func(i int) string { - return statusIDs[i] - }, len(statusIDs)) + // Deduplicate determined status IDs. + statusIDs = util.Deduplicate(statusIDs) for _, id := range statusIDs { // Invalidate any cached status faves for this status. - s.state.Caches.GTS.StatusFave().Invalidate("ID", id) + s.state.Caches.GTS.StatusFave.Invalidate("ID", id) // Invalidate any cached status fave IDs for this status. - s.state.Caches.GTS.StatusFaveIDs().Invalidate(id) + s.state.Caches.GTS.StatusFaveIDs.Invalidate(id) } return nil @@ -296,10 +339,10 @@ func (s *statusFaveDB) DeleteStatusFavesForStatus(ctx context.Context, statusID } // Invalidate any cached status faves for this status. - s.state.Caches.GTS.StatusFave().Invalidate("ID", statusID) + s.state.Caches.GTS.StatusFave.Invalidate("ID", statusID) // Invalidate any cached status fave IDs for this status. - s.state.Caches.GTS.StatusFaveIDs().Invalidate(statusID) + s.state.Caches.GTS.StatusFaveIDs.Invalidate(statusID) return nil } diff --git a/internal/db/bundb/tag.go b/internal/db/bundb/tag.go index fac621f0a..66ee8cb3a 100644 --- a/internal/db/bundb/tag.go +++ b/internal/db/bundb/tag.go @@ -22,21 +22,21 @@ import ( "strings" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" - "github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/state" + "github.com/superseriousbusiness/gotosocial/internal/util" "github.com/uptrace/bun" ) type tagDB struct { - conn *DB + db *DB state *state.State } -func (m *tagDB) GetTag(ctx context.Context, id string) (*gtsmodel.Tag, error) { - return m.state.Caches.GTS.Tag().Load("ID", func() (*gtsmodel.Tag, error) { +func (t *tagDB) GetTag(ctx context.Context, id string) (*gtsmodel.Tag, error) { + return t.state.Caches.GTS.Tag.LoadOne("ID", func() (*gtsmodel.Tag, error) { var tag gtsmodel.Tag - q := m.conn. + q := t.db. NewSelect(). Model(&tag). Where("? = ?", bun.Ident("tag.id"), id) @@ -49,15 +49,15 @@ func (m *tagDB) GetTag(ctx context.Context, id string) (*gtsmodel.Tag, error) { }, id) } -func (m *tagDB) GetTagByName(ctx context.Context, name string) (*gtsmodel.Tag, error) { +func (t *tagDB) GetTagByName(ctx context.Context, name string) (*gtsmodel.Tag, error) { // Normalize 'name' string. name = strings.TrimSpace(name) name = strings.ToLower(name) - return m.state.Caches.GTS.Tag().Load("Name", func() (*gtsmodel.Tag, error) { + return t.state.Caches.GTS.Tag.LoadOne("Name", func() (*gtsmodel.Tag, error) { var tag gtsmodel.Tag - q := m.conn. + q := t.db. NewSelect(). Model(&tag). Where("? = ?", bun.Ident("tag.name"), name) @@ -70,25 +70,52 @@ func (m *tagDB) GetTagByName(ctx context.Context, name string) (*gtsmodel.Tag, e }, name) } -func (m *tagDB) GetTags(ctx context.Context, ids []string) ([]*gtsmodel.Tag, error) { - tags := make([]*gtsmodel.Tag, 0, len(ids)) - - for _, id := range ids { - // Attempt fetch from DB - tag, err := m.GetTag(ctx, id) - if err != nil { - log.Errorf(ctx, "error getting tag %q: %v", id, err) - continue - } - - // Append tag - tags = append(tags, tag) +func (t *tagDB) GetTags(ctx context.Context, ids []string) ([]*gtsmodel.Tag, error) { + // Preallocate at-worst possible length. + uncached := make([]string, 0, len(ids)) + + // Load all tag IDs via cache loader callbacks. + tags, err := t.state.Caches.GTS.Tag.Load("ID", + + // Load cached + check for uncached. + func(load func(keyParts ...any) bool) { + for _, id := range ids { + if !load(id) { + uncached = append(uncached, id) + } + } + }, + + // Uncached tag loader function. + func() ([]*gtsmodel.Tag, error) { + // Preallocate expected length of uncached tags. + tags := make([]*gtsmodel.Tag, 0, len(uncached)) + + // Perform database query scanning + // the remaining (uncached) IDs. + if err := t.db.NewSelect(). + Model(&tags). + Where("? IN (?)", bun.Ident("id"), bun.In(uncached)). + Scan(ctx); err != nil { + return nil, err + } + + return tags, nil + }, + ) + if err != nil { + return nil, err } + // Reorder the tags by their + // IDs to ensure in correct order. + getID := func(t *gtsmodel.Tag) string { return t.ID } + util.OrderBy(tags, ids, getID) + return tags, nil } -func (m *tagDB) PutTag(ctx context.Context, tag *gtsmodel.Tag) error { +func (t *tagDB) PutTag(ctx context.Context, tag *gtsmodel.Tag) error { // Normalize 'name' string before it enters // the db, without changing tag we were given. // @@ -101,8 +128,8 @@ func (m *tagDB) PutTag(ctx context.Context, tag *gtsmodel.Tag) error { t2.Name = strings.ToLower(t2.Name) // Insert the copy. - if err := m.state.Caches.GTS.Tag().Store(t2, func() error { - _, err := m.conn.NewInsert().Model(t2).Exec(ctx) + if err := t.state.Caches.GTS.Tag.Store(t2, func() error { + _, err := t.db.NewInsert().Model(t2).Exec(ctx) return err }); err != nil { return err // err already processed diff --git a/internal/db/bundb/thread.go b/internal/db/bundb/thread.go index e6d6154d4..34c5f783a 100644 --- a/internal/db/bundb/thread.go +++ b/internal/db/bundb/thread.go @@ -42,7 +42,7 @@ func (t *threadDB) PutThread(ctx context.Context, thread *gtsmodel.Thread) error } func (t *threadDB) GetThreadMute(ctx context.Context, id string) (*gtsmodel.ThreadMute, error) { - return t.state.Caches.GTS.ThreadMute().Load("ID", func() (*gtsmodel.ThreadMute, error) { + return t.state.Caches.GTS.ThreadMute.LoadOne("ID", func() (*gtsmodel.ThreadMute, error) { var threadMute gtsmodel.ThreadMute q := t.db. @@ -63,7 +63,7 @@ func (t *threadDB) GetThreadMutedByAccount( threadID string, accountID string, ) (*gtsmodel.ThreadMute, error) { - return t.state.Caches.GTS.ThreadMute().Load("ThreadID.AccountID", func() (*gtsmodel.ThreadMute, error) { + return t.state.Caches.GTS.ThreadMute.LoadOne("ThreadID,AccountID", func() (*gtsmodel.ThreadMute, error) { var threadMute gtsmodel.ThreadMute q := t.db. @@ -98,7 +98,7 @@ func (t *threadDB) IsThreadMutedByAccount( } func (t *threadDB) PutThreadMute(ctx context.Context, threadMute *gtsmodel.ThreadMute) error { - return t.state.Caches.GTS.ThreadMute().Store(threadMute, func() error { + return t.state.Caches.GTS.ThreadMute.Store(threadMute, func() error { _, err := t.db.NewInsert().Model(threadMute).Exec(ctx) return err }) @@ -112,6 +112,6 @@ func (t *threadDB) DeleteThreadMute(ctx context.Context, id string) error { return err } - t.state.Caches.GTS.ThreadMute().Invalidate("ID", id) + t.state.Caches.GTS.ThreadMute.Invalidate("ID", id) return nil } diff --git a/internal/db/bundb/timeline.go b/internal/db/bundb/timeline.go index 4af17fb7f..f2ba2a9d1 100644 --- a/internal/db/bundb/timeline.go +++ b/internal/db/bundb/timeline.go @@ -29,7 +29,6 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/id" - "github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/uptrace/bun" ) @@ -155,20 +154,8 @@ func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxI } } - statuses := make([]*gtsmodel.Status, 0, len(statusIDs)) - for _, id := range statusIDs { - // Fetch status from db for ID - status, err := t.state.DB.GetStatusByID(ctx, id) - if err != nil { - log.Errorf(ctx, "error fetching status %q: %v", id, err) - continue - } - - // Append status to slice - statuses = append(statuses, status) - } - - return statuses, nil + // Return status IDs loaded from cache + db. + return t.state.DB.GetStatusesByIDs(ctx, statusIDs) } func (t *timelineDB) GetPublicTimeline(ctx context.Context, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, error) { @@ -256,20 +243,8 @@ func (t *timelineDB) GetPublicTimeline(ctx context.Context, maxID string, sinceI } } - statuses := make([]*gtsmodel.Status, 0, len(statusIDs)) - for _, id := range statusIDs { - // Fetch status from db for ID - status, err := t.state.DB.GetStatusByID(ctx, id) - if err != nil { - log.Errorf(ctx, "error fetching status %q: %v", id, err) - continue - } - - // Append status to slice - statuses = append(statuses, status) - } - - return statuses, nil + // Return status IDs loaded from cache + db. + return t.state.DB.GetStatusesByIDs(ctx, statusIDs) } // TODO optimize this query and the logic here, because it's slow as balls -- it takes like a literal second to return with a limit of 20! @@ -323,18 +298,15 @@ func (t *timelineDB) GetFavedTimeline(ctx context.Context, accountID string, max } }) - statuses := make([]*gtsmodel.Status, 0, len(faves)) - - for _, fave := range faves { - // Fetch status from db for corresponding favourite - status, err := t.state.DB.GetStatusByID(ctx, fave.StatusID) - if err != nil { - log.Errorf(ctx, "error fetching status for fave %q: %v", fave.ID, err) - continue - } + // Convert fave IDs to status IDs. + statusIDs := make([]string, len(faves)) + for i, fave := range faves { + statusIDs[i] = fave.StatusID + } - // Append status to slice - statuses = append(statuses, status) + statuses, err := t.state.DB.GetStatusesByIDs(ctx, statusIDs) + if err != nil { + return nil, "", "", err } nextMaxID := faves[len(faves)-1].ID @@ -453,20 +425,8 @@ func (t *timelineDB) GetListTimeline( } } - statuses := make([]*gtsmodel.Status, 0, len(statusIDs)) - for _, id := range statusIDs { - // Fetch status from db for ID - status, err := t.state.DB.GetStatusByID(ctx, id) - if err != nil { - log.Errorf(ctx, "error fetching status %q: %v", id, err) - continue - } - - // Append status to slice - statuses = append(statuses, status) - } - - return statuses, nil + // Return status IDs loaded from cache + db. + return t.state.DB.GetStatusesByIDs(ctx, statusIDs) } func (t *timelineDB) GetTagTimeline( @@ -561,18 +521,6 @@ func (t *timelineDB) GetTagTimeline( } } - statuses := make([]*gtsmodel.Status, 0, len(statusIDs)) - for _, id := range statusIDs { - // Fetch status from db for ID - status, err := t.state.DB.GetStatusByID(ctx, id) - if err != nil { - log.Errorf(ctx, "error fetching status %q: %v", id, err) - continue - } - - // Append status to slice - statuses = append(statuses, status) - } - - return statuses, nil + // Return status IDs loaded from cache + db. + return t.state.DB.GetStatusesByIDs(ctx, statusIDs) } diff --git a/internal/db/bundb/tombstone.go b/internal/db/bundb/tombstone.go index f9882d1c6..c0e439720 100644 --- a/internal/db/bundb/tombstone.go +++ b/internal/db/bundb/tombstone.go @@ -32,7 +32,7 @@ type tombstoneDB struct { } func (t *tombstoneDB) GetTombstoneByURI(ctx context.Context, uri string) (*gtsmodel.Tombstone, error) { - return t.state.Caches.GTS.Tombstone().Load("URI", func() (*gtsmodel.Tombstone, error) { + return t.state.Caches.GTS.Tombstone.LoadOne("URI", func() (*gtsmodel.Tombstone, error) { var tomb gtsmodel.Tombstone q := t.db. @@ -57,7 +57,7 @@ func (t *tombstoneDB) TombstoneExistsWithURI(ctx context.Context, uri string) (b } func (t *tombstoneDB) PutTombstone(ctx context.Context, tombstone *gtsmodel.Tombstone) error { - return t.state.Caches.GTS.Tombstone().Store(tombstone, func() error { + return t.state.Caches.GTS.Tombstone.Store(tombstone, func() error { _, err := t.db. NewInsert(). Model(tombstone). @@ -67,7 +67,7 @@ func (t *tombstoneDB) PutTombstone(ctx context.Context, tombstone *gtsmodel.Tomb } func (t *tombstoneDB) DeleteTombstone(ctx context.Context, id string) error { - defer t.state.Caches.GTS.Tombstone().Invalidate("ID", id) + defer t.state.Caches.GTS.Tombstone.Invalidate("ID", id) // Delete tombstone from DB. _, err := t.db.NewDelete(). diff --git a/internal/db/bundb/user.go b/internal/db/bundb/user.go index 46b3c568f..a6fa142f2 100644 --- a/internal/db/bundb/user.go +++ b/internal/db/bundb/user.go @@ -116,7 +116,7 @@ func (u *userDB) GetUserByConfirmationToken(ctx context.Context, token string) ( func (u *userDB) getUser(ctx context.Context, lookup string, dbQuery func(*gtsmodel.User) error, keyParts ...any) (*gtsmodel.User, error) { // Fetch user from database cache with loader callback. - user, err := u.state.Caches.GTS.User().Load(lookup, func() (*gtsmodel.User, error) { + user, err := u.state.Caches.GTS.User.LoadOne(lookup, func() (*gtsmodel.User, error) { var user gtsmodel.User // Not cached! perform database query. @@ -179,7 +179,7 @@ func (u *userDB) GetAllUsers(ctx context.Context) ([]*gtsmodel.User, error) { } func (u *userDB) PutUser(ctx context.Context, user *gtsmodel.User) error { - return u.state.Caches.GTS.User().Store(user, func() error { + return u.state.Caches.GTS.User.Store(user, func() error { _, err := u.db. NewInsert(). Model(user). @@ -197,7 +197,7 @@ func (u *userDB) UpdateUser(ctx context.Context, user *gtsmodel.User, columns .. columns = append(columns, "updated_at") } - return u.state.Caches.GTS.User().Store(user, func() error { + return u.state.Caches.GTS.User.Store(user, func() error { _, err := u.db. NewUpdate(). Model(user). @@ -209,7 +209,7 @@ func (u *userDB) UpdateUser(ctx context.Context, user *gtsmodel.User, columns .. } func (u *userDB) DeleteUserByID(ctx context.Context, userID string) error { - defer u.state.Caches.GTS.User().Invalidate("ID", userID) + defer u.state.Caches.GTS.User.Invalidate("ID", userID) // Load user into cache before attempting a delete, // as we need it cached in order to trigger the invalidate diff --git a/internal/db/list.go b/internal/db/list.go index 91a540486..16a0207de 100644 --- a/internal/db/list.go +++ b/internal/db/list.go @@ -27,6 +27,9 @@ type List interface { // GetListByID gets one list with the given id. GetListByID(ctx context.Context, id string) (*gtsmodel.List, error) + // GetListsByIDs fetches all lists with the provided IDs. + GetListsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.List, error) + // GetListsForAccountID gets all lists owned by the given accountID. GetListsForAccountID(ctx context.Context, accountID string) ([]*gtsmodel.List, error) @@ -46,6 +49,9 @@ type List interface { // GetListEntryByID gets one list entry with the given ID. GetListEntryByID(ctx context.Context, id string) (*gtsmodel.ListEntry, error) + // GetListEntriesyIDs fetches all list entries with the provided IDs. + GetListEntriesByIDs(ctx context.Context, ids []string) ([]*gtsmodel.ListEntry, error) + // GetListEntries gets list entries from the given listID, using the given parameters. GetListEntries(ctx context.Context, listID string, maxID string, sinceID string, minID string, limit int) ([]*gtsmodel.ListEntry, error) diff --git a/internal/db/notification.go b/internal/db/notification.go index ab8b5cc6d..9ff459b9c 100644 --- a/internal/db/notification.go +++ b/internal/db/notification.go @@ -33,6 +33,9 @@ type Notification interface { // GetNotification returns one notification according to its id. GetNotificationByID(ctx context.Context, id string) (*gtsmodel.Notification, error) + // GetNotificationsByIDs returns a slice of notifications of the the provided IDs. + GetNotificationsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.Notification, error) + // GetNotification gets one notification according to the provided parameters, if it exists. // Since not all notifications are about a status, statusID can be an empty string. GetNotification(ctx context.Context, notificationType gtsmodel.NotificationType, targetAccountID string, originAccountID string, statusID string) (*gtsmodel.Notification, error) diff --git a/internal/federation/dereferencing/announce.go b/internal/federation/dereferencing/announce.go index 8e880dad5..8d082105b 100644 --- a/internal/federation/dereferencing/announce.go +++ b/internal/federation/dereferencing/announce.go @@ -107,19 +107,21 @@ func (d *Dereferencer) EnrichAnnounce( // All good baby. case errors.Is(err, db.ErrAlreadyExists): + uri := boost.URI + // DATA RACE! We likely lost out to another goroutine // in a call to db.Put(Status). Look again in DB by URI. - boost, err = d.state.DB.GetStatusByURI(ctx, boost.URI) + boost, err = d.state.DB.GetStatusByURI(ctx, uri) if err != nil { - err = gtserror.Newf( + return nil, gtserror.Newf( "error getting boost wrapper status %s from database after race: %w", - boost.URI, err, + uri, err, ) } default: // Proper database error. - err = gtserror.Newf("db error inserting status: %w", err) + return nil, gtserror.Newf("db error inserting status: %w", err) } return boost, err diff --git a/internal/federation/federatingdb/announce_test.go b/internal/federation/federatingdb/announce_test.go index d8de2e49c..8dd5ce9da 100644 --- a/internal/federation/federatingdb/announce_test.go +++ b/internal/federation/federatingdb/announce_test.go @@ -79,9 +79,7 @@ func (suite *AnnounceTestSuite) TestAnnounceTwice() { // Insert the boost-of status into the // DB cache to emulate processor handling boost.ID, _ = id.NewULIDFromTime(boost.CreatedAt) - suite.state.Caches.GTS.Status().Store(boost, func() error { - return nil - }) + suite.state.Caches.GTS.Status.Put(boost) // only the URI will be set for the boosted status // because it still needs to be dereferenced diff --git a/internal/media/refetch.go b/internal/media/refetch.go index 03f0fbf34..8e5188178 100644 --- a/internal/media/refetch.go +++ b/internal/media/refetch.go @@ -55,7 +55,6 @@ func (m *Manager) RefetchEmojis(ctx context.Context, domain string, dereferenceM emojis, err := m.state.DB.GetEmojisBy(ctx, domain, false, true, "", maxShortcodeDomain, "", 20) if err != nil { if !errors.Is(err, db.ErrNoEntries) { - // an actual error has occurred log.Errorf(ctx, "error fetching emojis from database: %s", err) } break diff --git a/internal/processing/status/create.go b/internal/processing/status/create.go index fbe1fbd64..b450e4bdd 100644 --- a/internal/processing/status/create.go +++ b/internal/processing/status/create.go @@ -229,6 +229,7 @@ func (p *Processor) processMediaIDs(ctx context.Context, form *apimodel.Advanced attachments := []*gtsmodel.MediaAttachment{} attachmentIDs := []string{} + for _, mediaID := range form.MediaIDs { attachment, err := p.state.DB.GetAttachmentByID(ctx, mediaID) if err != nil && !errors.Is(err, db.ErrNoEntries) { diff --git a/internal/transport/deliver.go b/internal/transport/deliver.go index e2dbc4829..71b065719 100644 --- a/internal/transport/deliver.go +++ b/internal/transport/deliver.go @@ -82,7 +82,7 @@ func (t *transport) BatchDeliver(ctx context.Context, b []byte, recipients []*ur // Attempt to deliver data to recipient. if err := t.deliver(ctx, b, to); err != nil { mutex.Lock() // safely append err to accumulator. - errs.Appendf("error delivering to %s: %v", to, err) + errs.Appendf("error delivering to %s: %w", to, err) mutex.Unlock() } } diff --git a/internal/transport/finger.go b/internal/transport/finger.go index 49648c7e9..385af5e1c 100644 --- a/internal/transport/finger.go +++ b/internal/transport/finger.go @@ -36,7 +36,8 @@ import ( func (t *transport) webfingerURLFor(targetDomain string) (string, bool) { url := "https://" + targetDomain + "/.well-known/webfinger" - wc := t.controller.state.Caches.GTS.Webfinger() + wc := t.controller.state.Caches.GTS.Webfinger + // We're doing the manual locking/unlocking here to be able to // safely call Cache.Get instead of Get, as the latter updates the // item expiry which we don't want to do here @@ -95,7 +96,7 @@ func (t *transport) Finger(ctx context.Context, targetUsername string, targetDom // If we got a response we consider successful on a cached URL, i.e one set // by us later on when a host-meta based webfinger request succeeded, set it // again here to renew the TTL - t.controller.state.Caches.GTS.Webfinger().Set(targetDomain, url) + t.controller.state.Caches.GTS.Webfinger.Set(targetDomain, url) } if rsp.StatusCode == http.StatusGone { return nil, fmt.Errorf("account has been deleted/is gone") @@ -151,7 +152,7 @@ func (t *transport) Finger(ctx context.Context, targetUsername string, targetDom // we asked for is gone. This means the endpoint itself is valid and we should // cache it for future queries to the same domain if rsp.StatusCode == http.StatusGone { - t.controller.state.Caches.GTS.Webfinger().Set(targetDomain, host) + t.controller.state.Caches.GTS.Webfinger.Set(targetDomain, host) return nil, fmt.Errorf("account has been deleted/is gone") } // We've reached the end of the line here, both the original request @@ -162,7 +163,7 @@ func (t *transport) Finger(ctx context.Context, targetUsername string, targetDom // Set the URL in cache here, since host-meta told us this should be the // valid one, it's different from the default and our request to it did // not fail in any manner - t.controller.state.Caches.GTS.Webfinger().Set(targetDomain, host) + t.controller.state.Caches.GTS.Webfinger.Set(targetDomain, host) return io.ReadAll(rsp.Body) } diff --git a/internal/transport/finger_test.go b/internal/transport/finger_test.go index c012af62c..380a4aff9 100644 --- a/internal/transport/finger_test.go +++ b/internal/transport/finger_test.go @@ -31,7 +31,7 @@ type FingerTestSuite struct { } func (suite *FingerTestSuite) TestFinger() { - wc := suite.state.Caches.GTS.Webfinger() + wc := suite.state.Caches.GTS.Webfinger suite.Equal(0, wc.Len(), "expect webfinger cache to be empty") _, err := suite.transport.Finger(context.TODO(), "brand_new_person", "unknown-instance.com") @@ -43,7 +43,7 @@ func (suite *FingerTestSuite) TestFinger() { } func (suite *FingerTestSuite) TestFingerWithHostMeta() { - wc := suite.state.Caches.GTS.Webfinger() + wc := suite.state.Caches.GTS.Webfinger suite.Equal(0, wc.Len(), "expect webfinger cache to be empty") _, err := suite.transport.Finger(context.TODO(), "someone", "misconfigured-instance.com") @@ -60,7 +60,7 @@ func (suite *FingerTestSuite) TestFingerWithHostMetaCacheStrategy() { suite.T().Skip("this test is flaky on CI for as of yet unknown reasons") } - wc := suite.state.Caches.GTS.Webfinger() + wc := suite.state.Caches.GTS.Webfinger // Reset the sweep frequency so nothing interferes with the test wc.Stop() diff --git a/internal/typeutils/astointernal.go b/internal/typeutils/astointernal.go index 5bcb35d75..ec17527c4 100644 --- a/internal/typeutils/astointernal.go +++ b/internal/typeutils/astointernal.go @@ -794,7 +794,6 @@ func (c *Converter) getASAttributedToAccount(ctx context.Context, id string, wit } return account, nil - } func (c *Converter) getASObjectAccount(ctx context.Context, id string, with ap.WithObject) (*gtsmodel.Account, error) { diff --git a/internal/typeutils/internaltoas.go b/internal/typeutils/internaltoas.go index c88fd2e11..dc25babaa 100644 --- a/internal/typeutils/internaltoas.go +++ b/internal/typeutils/internaltoas.go @@ -491,7 +491,7 @@ func (c *Converter) StatusToAS(ctx context.Context, s *gtsmodel.Status) (ap.Stat // tag -- mentions mentions := s.Mentions - if len(s.MentionIDs) > len(mentions) { + if len(s.MentionIDs) != len(mentions) { mentions, err = c.state.DB.GetMentions(ctx, s.MentionIDs) if err != nil { return nil, gtserror.Newf("error getting mentions: %w", err) @@ -507,14 +507,10 @@ func (c *Converter) StatusToAS(ctx context.Context, s *gtsmodel.Status) (ap.Stat // tag -- emojis emojis := s.Emojis - if len(s.EmojiIDs) > len(emojis) { - emojis = []*gtsmodel.Emoji{} - for _, emojiID := range s.EmojiIDs { - emoji, err := c.state.DB.GetEmojiByID(ctx, emojiID) - if err != nil { - return nil, gtserror.Newf("error getting emoji %s from database: %w", emojiID, err) - } - emojis = append(emojis, emoji) + if len(s.EmojiIDs) != len(emojis) { + emojis, err = c.state.DB.GetEmojisByIDs(ctx, s.EmojiIDs) + if err != nil { + return nil, gtserror.Newf("error getting emojis from database: %w", err) } } for _, emoji := range emojis { @@ -527,7 +523,7 @@ func (c *Converter) StatusToAS(ctx context.Context, s *gtsmodel.Status) (ap.Stat // tag -- hashtags hashtags := s.Tags - if len(s.TagIDs) > len(hashtags) { + if len(s.TagIDs) != len(hashtags) { hashtags, err = c.state.DB.GetTags(ctx, s.TagIDs) if err != nil { return nil, gtserror.Newf("error getting tags: %w", err) @@ -623,14 +619,10 @@ func (c *Converter) StatusToAS(ctx context.Context, s *gtsmodel.Status) (ap.Stat // attachments attachmentProp := streams.NewActivityStreamsAttachmentProperty() attachments := s.Attachments - if len(s.AttachmentIDs) > len(attachments) { - attachments = []*gtsmodel.MediaAttachment{} - for _, attachmentID := range s.AttachmentIDs { - attachment, err := c.state.DB.GetAttachmentByID(ctx, attachmentID) - if err != nil { - return nil, gtserror.Newf("error getting attachment %s from database: %w", attachmentID, err) - } - attachments = append(attachments, attachment) + if len(s.AttachmentIDs) != len(attachments) { + attachments, err = c.state.DB.GetAttachmentsByIDs(ctx, s.AttachmentIDs) + if err != nil { + return nil, gtserror.Newf("error getting attachments from database: %w", err) } } for _, a := range attachments { diff --git a/internal/typeutils/internaltofrontend.go b/internal/typeutils/internaltofrontend.go index 75247f411..941b9e866 100644 --- a/internal/typeutils/internaltofrontend.go +++ b/internal/typeutils/internaltofrontend.go @@ -1563,20 +1563,15 @@ func (c *Converter) PollToAPIPoll(ctx context.Context, requester *gtsmodel.Accou func (c *Converter) convertAttachmentsToAPIAttachments(ctx context.Context, attachments []*gtsmodel.MediaAttachment, attachmentIDs []string) ([]*apimodel.Attachment, error) { var errs gtserror.MultiError - if len(attachments) == 0 { + if len(attachments) == 0 && len(attachmentIDs) > 0 { // GTS model attachments were not populated - // Preallocate expected GTS slice - attachments = make([]*gtsmodel.MediaAttachment, 0, len(attachmentIDs)) + var err error // Fetch GTS models for attachment IDs - for _, id := range attachmentIDs { - attachment, err := c.state.DB.GetAttachmentByID(ctx, id) - if err != nil { - errs.Appendf("error fetching attachment %s from database: %v", id, err) - continue - } - attachments = append(attachments, attachment) + attachments, err = c.state.DB.GetAttachmentsByIDs(ctx, attachmentIDs) + if err != nil { + errs.Appendf("error fetching attachments from database: %w", err) } } @@ -1587,7 +1582,7 @@ func (c *Converter) convertAttachmentsToAPIAttachments(ctx context.Context, atta for _, attachment := range attachments { apiAttachment, err := c.AttachmentToAPIAttachment(ctx, attachment) if err != nil { - errs.Appendf("error converting attchment %s to api attachment: %v", attachment.ID, err) + errs.Appendf("error converting attchment %s to api attachment: %w", attachment.ID, err) continue } apiAttachments = append(apiAttachments, &apiAttachment) @@ -1600,20 +1595,15 @@ func (c *Converter) convertAttachmentsToAPIAttachments(ctx context.Context, atta func (c *Converter) convertEmojisToAPIEmojis(ctx context.Context, emojis []*gtsmodel.Emoji, emojiIDs []string) ([]apimodel.Emoji, error) { var errs gtserror.MultiError - if len(emojis) == 0 { + if len(emojis) == 0 && len(emojiIDs) > 0 { // GTS model attachments were not populated - // Preallocate expected GTS slice - emojis = make([]*gtsmodel.Emoji, 0, len(emojiIDs)) + var err error // Fetch GTS models for emoji IDs - for _, id := range emojiIDs { - emoji, err := c.state.DB.GetEmojiByID(ctx, id) - if err != nil { - errs.Appendf("error fetching emoji %s from database: %v", id, err) - continue - } - emojis = append(emojis, emoji) + emojis, err = c.state.DB.GetEmojisByIDs(ctx, emojiIDs) + if err != nil { + errs.Appendf("error fetching emojis from database: %w", err) } } @@ -1624,7 +1614,7 @@ func (c *Converter) convertEmojisToAPIEmojis(ctx context.Context, emojis []*gtsm for _, emoji := range emojis { apiEmoji, err := c.EmojiToAPIEmoji(ctx, emoji) if err != nil { - errs.Appendf("error converting emoji %s to api emoji: %v", emoji.ID, err) + errs.Appendf("error converting emoji %s to api emoji: %w", emoji.ID, err) continue } apiEmojis = append(apiEmojis, apiEmoji) @@ -1637,7 +1627,7 @@ func (c *Converter) convertEmojisToAPIEmojis(ctx context.Context, emojis []*gtsm func (c *Converter) convertMentionsToAPIMentions(ctx context.Context, mentions []*gtsmodel.Mention, mentionIDs []string) ([]apimodel.Mention, error) { var errs gtserror.MultiError - if len(mentions) == 0 { + if len(mentions) == 0 && len(mentionIDs) > 0 { var err error // GTS model mentions were not populated @@ -1645,7 +1635,7 @@ func (c *Converter) convertMentionsToAPIMentions(ctx context.Context, mentions [ // Fetch GTS models for mention IDs mentions, err = c.state.DB.GetMentions(ctx, mentionIDs) if err != nil { - errs.Appendf("error fetching mentions from database: %v", err) + errs.Appendf("error fetching mentions from database: %w", err) } } @@ -1656,7 +1646,7 @@ func (c *Converter) convertMentionsToAPIMentions(ctx context.Context, mentions [ for _, mention := range mentions { apiMention, err := c.MentionToAPIMention(ctx, mention) if err != nil { - errs.Appendf("error converting mention %s to api mention: %v", mention.ID, err) + errs.Appendf("error converting mention %s to api mention: %w", mention.ID, err) continue } apiMentions = append(apiMentions, apiMention) @@ -1669,12 +1659,12 @@ func (c *Converter) convertMentionsToAPIMentions(ctx context.Context, mentions [ func (c *Converter) convertTagsToAPITags(ctx context.Context, tags []*gtsmodel.Tag, tagIDs []string) ([]apimodel.Tag, error) { var errs gtserror.MultiError - if len(tags) == 0 { + if len(tags) == 0 && len(tagIDs) > 0 { var err error tags, err = c.state.DB.GetTags(ctx, tagIDs) if err != nil { - errs.Appendf("error fetching tags from database: %v", err) + errs.Appendf("error fetching tags from database: %w", err) } } @@ -1685,7 +1675,7 @@ func (c *Converter) convertTagsToAPITags(ctx context.Context, tags []*gtsmodel.T for _, tag := range tags { apiTag, err := c.TagToAPITag(ctx, tag, false) if err != nil { - errs.Appendf("error converting tag %s to api tag: %v", tag.ID, err) + errs.Appendf("error converting tag %s to api tag: %w", tag.ID, err) continue } apiTags = append(apiTags, apiTag) diff --git a/internal/util/deduplicate.go b/internal/util/slices.go index 099ec96b5..51d560dbd 100644 --- a/internal/util/deduplicate.go +++ b/internal/util/slices.go @@ -61,3 +61,75 @@ func DeduplicateFunc[T any, C comparable](in []T, key func(v T) C) []T { return deduped } + +// Collate will collect the values of type K from input type []T, +// passing each item to 'get' and deduplicating the end result. +// Compared to Deduplicate() this returns []K, NOT input type []T. +func Collate[T any, K comparable](in []T, get func(T) K) []K { + ks := make([]K, 0, len(in)) + km := make(map[K]struct{}, len(in)) + + for i := 0; i < len(in); i++ { + // Get next k. + k := get(in[i]) + + if _, ok := km[k]; !ok { + // New value, add + // to map + slice. + ks = append(ks, k) + km[k] = struct{}{} + } + } + + return ks +} + +// OrderBy orders a slice of given type by the provided alternative slice of comparable type. +func OrderBy[T any, K comparable](in []T, keys []K, key func(T) K) { + var ( + start int + offset int + ) + + for i := 0; i < len(keys); i++ { + var ( + // key at index. + k = keys[i] + + // sentinel + // idx value. + idx = -1 + ) + + // Look for model with key in slice. + for j := start; j < len(in); j++ { + if key(in[j]) == k { + idx = j + break + } + } + + if idx == -1 { + // model with key + // was not found. + offset++ + continue + } + + // Update + // start + start++ + + // Expected ID index. + exp := i - offset + + if idx == exp { + // Model is in expected + // location, keep going. + continue + } + + // Swap models at current and expected. + in[idx], in[exp] = in[exp], in[idx] + } +} diff --git a/internal/visibility/account.go b/internal/visibility/account.go index 4d42b5973..410daa1ce 100644 --- a/internal/visibility/account.go +++ b/internal/visibility/account.go @@ -39,7 +39,7 @@ func (f *Filter) AccountVisible(ctx context.Context, requester *gtsmodel.Account requesterID = requester.ID } - visibility, err := f.state.Caches.Visibility.Load("Type.RequesterID.ItemID", func() (*cache.CachedVisibility, error) { + visibility, err := f.state.Caches.Visibility.LoadOne("Type,RequesterID,ItemID", func() (*cache.CachedVisibility, error) { // Visibility not yet cached, perform visibility lookup. visible, err := f.isAccountVisibleTo(ctx, requester, account) if err != nil { diff --git a/internal/visibility/home_timeline.go b/internal/visibility/home_timeline.go index 273ca8457..ab7b83d55 100644 --- a/internal/visibility/home_timeline.go +++ b/internal/visibility/home_timeline.go @@ -42,7 +42,7 @@ func (f *Filter) StatusHomeTimelineable(ctx context.Context, owner *gtsmodel.Acc requesterID = owner.ID } - visibility, err := f.state.Caches.Visibility.Load("Type.RequesterID.ItemID", func() (*cache.CachedVisibility, error) { + visibility, err := f.state.Caches.Visibility.LoadOne("Type,RequesterID,ItemID", func() (*cache.CachedVisibility, error) { // Visibility not yet cached, perform timeline visibility lookup. visible, err := f.isStatusHomeTimelineable(ctx, owner, status) if err != nil { diff --git a/internal/visibility/public_timeline.go b/internal/visibility/public_timeline.go index 63e802614..b2c05d51f 100644 --- a/internal/visibility/public_timeline.go +++ b/internal/visibility/public_timeline.go @@ -40,7 +40,7 @@ func (f *Filter) StatusPublicTimelineable(ctx context.Context, requester *gtsmod requesterID = requester.ID } - visibility, err := f.state.Caches.Visibility.Load("Type.RequesterID.ItemID", func() (*cache.CachedVisibility, error) { + visibility, err := f.state.Caches.Visibility.LoadOne("Type,RequesterID,ItemID", func() (*cache.CachedVisibility, error) { // Visibility not yet cached, perform timeline visibility lookup. visible, err := f.isStatusPublicTimelineable(ctx, requester, status) if err != nil { diff --git a/internal/visibility/status.go b/internal/visibility/status.go index 3684bae4f..5e2052ae4 100644 --- a/internal/visibility/status.go +++ b/internal/visibility/status.go @@ -53,7 +53,7 @@ func (f *Filter) StatusVisible(ctx context.Context, requester *gtsmodel.Account, requesterID = requester.ID } - visibility, err := f.state.Caches.Visibility.Load("Type.RequesterID.ItemID", func() (*cache.CachedVisibility, error) { + visibility, err := f.state.Caches.Visibility.LoadOne("Type,RequesterID,ItemID", func() (*cache.CachedVisibility, error) { // Visibility not yet cached, perform visibility lookup. visible, err := f.isStatusVisible(ctx, requester, status) if err != nil { |