diff options
Diffstat (limited to 'internal/db/bundb/list.go')
-rw-r--r-- | internal/db/bundb/list.go | 203 |
1 files changed, 137 insertions, 66 deletions
diff --git a/internal/db/bundb/list.go b/internal/db/bundb/list.go index 7a117670a..5f95d3c24 100644 --- a/internal/db/bundb/list.go +++ b/internal/db/bundb/list.go @@ -21,6 +21,7 @@ import ( "context" "errors" "fmt" + "slices" "time" "github.com/superseriousbusiness/gotosocial/internal/db" @@ -29,6 +30,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/state" + "github.com/superseriousbusiness/gotosocial/internal/util" "github.com/uptrace/bun" ) @@ -56,7 +58,7 @@ func (l *listDB) GetListByID(ctx context.Context, id string) (*gtsmodel.List, er } func (l *listDB) getList(ctx context.Context, lookup string, dbQuery func(*gtsmodel.List) error, keyParts ...any) (*gtsmodel.List, error) { - list, err := l.state.Caches.GTS.List().Load(lookup, func() (*gtsmodel.List, error) { + list, err := l.state.Caches.GTS.List.LoadOne(lookup, func() (*gtsmodel.List, error) { var list gtsmodel.List // Not cached! Perform database query. @@ -100,18 +102,8 @@ func (l *listDB) GetListsForAccountID(ctx context.Context, accountID string) ([] return nil, nil } - // Select each list using its ID to ensure cache used. - lists := make([]*gtsmodel.List, 0, len(listIDs)) - for _, id := range listIDs { - list, err := l.state.DB.GetListByID(ctx, id) - if err != nil { - log.Errorf(ctx, "error fetching list %q: %v", id, err) - continue - } - lists = append(lists, list) - } - - return lists, nil + // Return lists by their IDs. + return l.GetListsByIDs(ctx, listIDs) } func (l *listDB) PopulateList(ctx context.Context, list *gtsmodel.List) error { @@ -147,7 +139,7 @@ func (l *listDB) PopulateList(ctx context.Context, list *gtsmodel.List) error { } func (l *listDB) PutList(ctx context.Context, list *gtsmodel.List) error { - return l.state.Caches.GTS.List().Store(list, func() error { + return l.state.Caches.GTS.List.Store(list, func() error { _, err := l.db.NewInsert().Model(list).Exec(ctx) return err }) @@ -162,7 +154,7 @@ func (l *listDB) UpdateList(ctx context.Context, list *gtsmodel.List, columns .. defer func() { // Invalidate all entries for this list ID. - l.state.Caches.GTS.ListEntry().Invalidate("ListID", list.ID) + l.state.Caches.GTS.ListEntry.Invalidate("ListID", list.ID) // Invalidate this entire list's timeline. if err := l.state.Timelines.List.RemoveTimeline(ctx, list.ID); err != nil { @@ -170,7 +162,7 @@ func (l *listDB) UpdateList(ctx context.Context, list *gtsmodel.List, columns .. } }() - return l.state.Caches.GTS.List().Store(list, func() error { + return l.state.Caches.GTS.List.Store(list, func() error { _, err := l.db.NewUpdate(). Model(list). Where("? = ?", bun.Ident("list.id"), list.ID). @@ -198,7 +190,7 @@ func (l *listDB) DeleteListByID(ctx context.Context, id string) error { defer func() { // Invalidate this list from cache. - l.state.Caches.GTS.List().Invalidate("ID", id) + l.state.Caches.GTS.List.Invalidate("ID", id) // Invalidate this entire list's timeline. if err := l.state.Timelines.List.RemoveTimeline(ctx, id); err != nil { @@ -243,7 +235,7 @@ func (l *listDB) GetListEntryByID(ctx context.Context, id string) (*gtsmodel.Lis } func (l *listDB) getListEntry(ctx context.Context, lookup string, dbQuery func(*gtsmodel.ListEntry) error, keyParts ...any) (*gtsmodel.ListEntry, error) { - listEntry, err := l.state.Caches.GTS.ListEntry().Load(lookup, func() (*gtsmodel.ListEntry, error) { + listEntry, err := l.state.Caches.GTS.ListEntry.LoadOne(lookup, func() (*gtsmodel.ListEntry, error) { var listEntry gtsmodel.ListEntry // Not cached! Perform database query. @@ -344,18 +336,128 @@ func (l *listDB) GetListEntries(ctx context.Context, } } - // Select each list entry using its ID to ensure cache used. - listEntries := make([]*gtsmodel.ListEntry, 0, len(entryIDs)) - for _, id := range entryIDs { - listEntry, err := l.state.DB.GetListEntryByID(ctx, id) - if err != nil { - log.Errorf(ctx, "error fetching list entry %q: %v", id, err) - continue + // Return list entries by their IDs. + return l.GetListEntriesByIDs(ctx, entryIDs) +} + +func (l *listDB) GetListsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.List, error) { + // Preallocate at-worst possible length. + uncached := make([]string, 0, len(ids)) + + // Load all list IDs via cache loader callbacks. + lists, err := l.state.Caches.GTS.List.Load("ID", + + // Load cached + check for uncached. + func(load func(keyParts ...any) bool) { + for _, id := range ids { + if !load(id) { + uncached = append(uncached, id) + } + } + }, + + // Uncached list loader function. + func() ([]*gtsmodel.List, error) { + // Preallocate expected length of uncached lists. + lists := make([]*gtsmodel.List, 0, len(uncached)) + + // Perform database query scanning + // the remaining (uncached) IDs. + if err := l.db.NewSelect(). + Model(&lists). + Where("? IN (?)", bun.Ident("id"), bun.In(uncached)). + Scan(ctx); err != nil { + return nil, err + } + + return lists, nil + }, + ) + if err != nil { + return nil, err + } + + // Reorder the lists by their + // IDs to ensure in correct order. + getID := func(l *gtsmodel.List) string { return l.ID } + util.OrderBy(lists, ids, getID) + + if gtscontext.Barebones(ctx) { + // no need to fully populate. + return lists, nil + } + + // Populate all loaded lists, removing those we fail to + // populate (removes needing so many nil checks everywhere). + lists = slices.DeleteFunc(lists, func(list *gtsmodel.List) bool { + if err := l.PopulateList(ctx, list); err != nil { + log.Errorf(ctx, "error populating list %s: %v", list.ID, err) + return true } - listEntries = append(listEntries, listEntry) + return false + }) + + return lists, nil +} + +func (l *listDB) GetListEntriesByIDs(ctx context.Context, ids []string) ([]*gtsmodel.ListEntry, error) { + // Preallocate at-worst possible length. + uncached := make([]string, 0, len(ids)) + + // Load all entry IDs via cache loader callbacks. + entries, err := l.state.Caches.GTS.ListEntry.Load("ID", + + // Load cached + check for uncached. + func(load func(keyParts ...any) bool) { + for _, id := range ids { + if !load(id) { + uncached = append(uncached, id) + } + } + }, + + // Uncached entry loader function. + func() ([]*gtsmodel.ListEntry, error) { + // Preallocate expected length of uncached entries. + entries := make([]*gtsmodel.ListEntry, 0, len(uncached)) + + // Perform database query scanning + // the remaining (uncached) IDs. + if err := l.db.NewSelect(). + Model(&entries). + Where("? IN (?)", bun.Ident("id"), bun.In(uncached)). + Scan(ctx); err != nil { + return nil, err + } + + return entries, nil + }, + ) + if err != nil { + return nil, err } - return listEntries, nil + // Reorder the entries by their + // IDs to ensure in correct order. + getID := func(e *gtsmodel.ListEntry) string { return e.ID } + util.OrderBy(entries, ids, getID) + + if gtscontext.Barebones(ctx) { + // no need to fully populate. + return entries, nil + } + + // Populate all loaded entries, removing those we fail to + // populate (removes needing so many nil checks everywhere). + entries = slices.DeleteFunc(entries, func(entry *gtsmodel.ListEntry) bool { + if err := l.PopulateListEntry(ctx, entry); err != nil { + log.Errorf(ctx, "error populating entry %s: %v", entry.ID, err) + return true + } + return false + }) + + return entries, nil } func (l *listDB) GetListEntriesForFollowID(ctx context.Context, followID string) ([]*gtsmodel.ListEntry, error) { @@ -376,18 +478,8 @@ func (l *listDB) GetListEntriesForFollowID(ctx context.Context, followID string) return nil, nil } - // Select each list entry using its ID to ensure cache used. - listEntries := make([]*gtsmodel.ListEntry, 0, len(entryIDs)) - for _, id := range entryIDs { - listEntry, err := l.state.DB.GetListEntryByID(ctx, id) - if err != nil { - log.Errorf(ctx, "error fetching list entry %q: %v", id, err) - continue - } - listEntries = append(listEntries, listEntry) - } - - return listEntries, nil + // Return list entries by their IDs. + return l.GetListEntriesByIDs(ctx, entryIDs) } func (l *listDB) PopulateListEntry(ctx context.Context, listEntry *gtsmodel.ListEntry) error { @@ -409,10 +501,10 @@ func (l *listDB) PopulateListEntry(ctx context.Context, listEntry *gtsmodel.List func (l *listDB) PutListEntries(ctx context.Context, entries []*gtsmodel.ListEntry) error { defer func() { - // Collect unique list IDs from the entries. - listIDs := collate(func(i int) string { - return entries[i].ListID - }, len(entries)) + // Collect unique list IDs from the provided entries. + listIDs := util.Collate(entries, func(e *gtsmodel.ListEntry) string { + return e.ListID + }) for _, id := range listIDs { // Invalidate the timeline for the list this entry belongs to. @@ -426,7 +518,7 @@ func (l *listDB) PutListEntries(ctx context.Context, entries []*gtsmodel.ListEnt return l.db.RunInTx(ctx, func(tx Tx) error { for _, entry := range entries { entry := entry // rescope - if err := l.state.Caches.GTS.ListEntry().Store(entry, func() error { + if err := l.state.Caches.GTS.ListEntry.Store(entry, func() error { _, err := tx. NewInsert(). Model(entry). @@ -459,7 +551,7 @@ func (l *listDB) DeleteListEntry(ctx context.Context, id string) error { defer func() { // Invalidate this list entry upon delete. - l.state.Caches.GTS.ListEntry().Invalidate("ID", id) + l.state.Caches.GTS.ListEntry.Invalidate("ID", id) // Invalidate the timeline for the list this entry belongs to. if err := l.state.Timelines.List.RemoveTimeline(ctx, entry.ListID); err != nil { @@ -514,24 +606,3 @@ func (l *listDB) ListIncludesAccount(ctx context.Context, listID string, account return exists, err } - -// collate will collect the values of type T from an expected slice of length 'len', -// passing the expected index to each call of 'get' and deduplicating the end result. -func collate[T comparable](get func(int) T, len int) []T { - ts := make([]T, 0, len) - tm := make(map[T]struct{}, len) - - for i := 0; i < len; i++ { - // Get next. - t := get(i) - - if _, ok := tm[t]; !ok { - // New value, add - // to map + slice. - ts = append(ts, t) - tm[t] = struct{}{} - } - } - - return ts -} |