summaryrefslogtreecommitdiff
path: root/internal/db/bundb/list.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/db/bundb/list.go')
-rw-r--r--internal/db/bundb/list.go203
1 files changed, 137 insertions, 66 deletions
diff --git a/internal/db/bundb/list.go b/internal/db/bundb/list.go
index 7a117670a..5f95d3c24 100644
--- a/internal/db/bundb/list.go
+++ b/internal/db/bundb/list.go
@@ -21,6 +21,7 @@ import (
"context"
"errors"
"fmt"
+ "slices"
"time"
"github.com/superseriousbusiness/gotosocial/internal/db"
@@ -29,6 +30,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state"
+ "github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun"
)
@@ -56,7 +58,7 @@ func (l *listDB) GetListByID(ctx context.Context, id string) (*gtsmodel.List, er
}
func (l *listDB) getList(ctx context.Context, lookup string, dbQuery func(*gtsmodel.List) error, keyParts ...any) (*gtsmodel.List, error) {
- list, err := l.state.Caches.GTS.List().Load(lookup, func() (*gtsmodel.List, error) {
+ list, err := l.state.Caches.GTS.List.LoadOne(lookup, func() (*gtsmodel.List, error) {
var list gtsmodel.List
// Not cached! Perform database query.
@@ -100,18 +102,8 @@ func (l *listDB) GetListsForAccountID(ctx context.Context, accountID string) ([]
return nil, nil
}
- // Select each list using its ID to ensure cache used.
- lists := make([]*gtsmodel.List, 0, len(listIDs))
- for _, id := range listIDs {
- list, err := l.state.DB.GetListByID(ctx, id)
- if err != nil {
- log.Errorf(ctx, "error fetching list %q: %v", id, err)
- continue
- }
- lists = append(lists, list)
- }
-
- return lists, nil
+ // Return lists by their IDs.
+ return l.GetListsByIDs(ctx, listIDs)
}
func (l *listDB) PopulateList(ctx context.Context, list *gtsmodel.List) error {
@@ -147,7 +139,7 @@ func (l *listDB) PopulateList(ctx context.Context, list *gtsmodel.List) error {
}
func (l *listDB) PutList(ctx context.Context, list *gtsmodel.List) error {
- return l.state.Caches.GTS.List().Store(list, func() error {
+ return l.state.Caches.GTS.List.Store(list, func() error {
_, err := l.db.NewInsert().Model(list).Exec(ctx)
return err
})
@@ -162,7 +154,7 @@ func (l *listDB) UpdateList(ctx context.Context, list *gtsmodel.List, columns ..
defer func() {
// Invalidate all entries for this list ID.
- l.state.Caches.GTS.ListEntry().Invalidate("ListID", list.ID)
+ l.state.Caches.GTS.ListEntry.Invalidate("ListID", list.ID)
// Invalidate this entire list's timeline.
if err := l.state.Timelines.List.RemoveTimeline(ctx, list.ID); err != nil {
@@ -170,7 +162,7 @@ func (l *listDB) UpdateList(ctx context.Context, list *gtsmodel.List, columns ..
}
}()
- return l.state.Caches.GTS.List().Store(list, func() error {
+ return l.state.Caches.GTS.List.Store(list, func() error {
_, err := l.db.NewUpdate().
Model(list).
Where("? = ?", bun.Ident("list.id"), list.ID).
@@ -198,7 +190,7 @@ func (l *listDB) DeleteListByID(ctx context.Context, id string) error {
defer func() {
// Invalidate this list from cache.
- l.state.Caches.GTS.List().Invalidate("ID", id)
+ l.state.Caches.GTS.List.Invalidate("ID", id)
// Invalidate this entire list's timeline.
if err := l.state.Timelines.List.RemoveTimeline(ctx, id); err != nil {
@@ -243,7 +235,7 @@ func (l *listDB) GetListEntryByID(ctx context.Context, id string) (*gtsmodel.Lis
}
func (l *listDB) getListEntry(ctx context.Context, lookup string, dbQuery func(*gtsmodel.ListEntry) error, keyParts ...any) (*gtsmodel.ListEntry, error) {
- listEntry, err := l.state.Caches.GTS.ListEntry().Load(lookup, func() (*gtsmodel.ListEntry, error) {
+ listEntry, err := l.state.Caches.GTS.ListEntry.LoadOne(lookup, func() (*gtsmodel.ListEntry, error) {
var listEntry gtsmodel.ListEntry
// Not cached! Perform database query.
@@ -344,18 +336,128 @@ func (l *listDB) GetListEntries(ctx context.Context,
}
}
- // Select each list entry using its ID to ensure cache used.
- listEntries := make([]*gtsmodel.ListEntry, 0, len(entryIDs))
- for _, id := range entryIDs {
- listEntry, err := l.state.DB.GetListEntryByID(ctx, id)
- if err != nil {
- log.Errorf(ctx, "error fetching list entry %q: %v", id, err)
- continue
+ // Return list entries by their IDs.
+ return l.GetListEntriesByIDs(ctx, entryIDs)
+}
+
+func (l *listDB) GetListsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.List, error) {
+ // Preallocate at-worst possible length.
+ uncached := make([]string, 0, len(ids))
+
+ // Load all list IDs via cache loader callbacks.
+ lists, err := l.state.Caches.GTS.List.Load("ID",
+
+ // Load cached + check for uncached.
+ func(load func(keyParts ...any) bool) {
+ for _, id := range ids {
+ if !load(id) {
+ uncached = append(uncached, id)
+ }
+ }
+ },
+
+ // Uncached list loader function.
+ func() ([]*gtsmodel.List, error) {
+ // Preallocate expected length of uncached lists.
+ lists := make([]*gtsmodel.List, 0, len(uncached))
+
+ // Perform database query scanning
+ // the remaining (uncached) IDs.
+ if err := l.db.NewSelect().
+ Model(&lists).
+ Where("? IN (?)", bun.Ident("id"), bun.In(uncached)).
+ Scan(ctx); err != nil {
+ return nil, err
+ }
+
+ return lists, nil
+ },
+ )
+ if err != nil {
+ return nil, err
+ }
+
+ // Reorder the lists by their
+ // IDs to ensure in correct order.
+ getID := func(l *gtsmodel.List) string { return l.ID }
+ util.OrderBy(lists, ids, getID)
+
+ if gtscontext.Barebones(ctx) {
+ // no need to fully populate.
+ return lists, nil
+ }
+
+ // Populate all loaded lists, removing those we fail to
+ // populate (removes needing so many nil checks everywhere).
+ lists = slices.DeleteFunc(lists, func(list *gtsmodel.List) bool {
+ if err := l.PopulateList(ctx, list); err != nil {
+ log.Errorf(ctx, "error populating list %s: %v", list.ID, err)
+ return true
}
- listEntries = append(listEntries, listEntry)
+ return false
+ })
+
+ return lists, nil
+}
+
+func (l *listDB) GetListEntriesByIDs(ctx context.Context, ids []string) ([]*gtsmodel.ListEntry, error) {
+ // Preallocate at-worst possible length.
+ uncached := make([]string, 0, len(ids))
+
+ // Load all entry IDs via cache loader callbacks.
+ entries, err := l.state.Caches.GTS.ListEntry.Load("ID",
+
+ // Load cached + check for uncached.
+ func(load func(keyParts ...any) bool) {
+ for _, id := range ids {
+ if !load(id) {
+ uncached = append(uncached, id)
+ }
+ }
+ },
+
+ // Uncached entry loader function.
+ func() ([]*gtsmodel.ListEntry, error) {
+ // Preallocate expected length of uncached entries.
+ entries := make([]*gtsmodel.ListEntry, 0, len(uncached))
+
+ // Perform database query scanning
+ // the remaining (uncached) IDs.
+ if err := l.db.NewSelect().
+ Model(&entries).
+ Where("? IN (?)", bun.Ident("id"), bun.In(uncached)).
+ Scan(ctx); err != nil {
+ return nil, err
+ }
+
+ return entries, nil
+ },
+ )
+ if err != nil {
+ return nil, err
}
- return listEntries, nil
+ // Reorder the entries by their
+ // IDs to ensure in correct order.
+ getID := func(e *gtsmodel.ListEntry) string { return e.ID }
+ util.OrderBy(entries, ids, getID)
+
+ if gtscontext.Barebones(ctx) {
+ // no need to fully populate.
+ return entries, nil
+ }
+
+ // Populate all loaded entries, removing those we fail to
+ // populate (removes needing so many nil checks everywhere).
+ entries = slices.DeleteFunc(entries, func(entry *gtsmodel.ListEntry) bool {
+ if err := l.PopulateListEntry(ctx, entry); err != nil {
+ log.Errorf(ctx, "error populating entry %s: %v", entry.ID, err)
+ return true
+ }
+ return false
+ })
+
+ return entries, nil
}
func (l *listDB) GetListEntriesForFollowID(ctx context.Context, followID string) ([]*gtsmodel.ListEntry, error) {
@@ -376,18 +478,8 @@ func (l *listDB) GetListEntriesForFollowID(ctx context.Context, followID string)
return nil, nil
}
- // Select each list entry using its ID to ensure cache used.
- listEntries := make([]*gtsmodel.ListEntry, 0, len(entryIDs))
- for _, id := range entryIDs {
- listEntry, err := l.state.DB.GetListEntryByID(ctx, id)
- if err != nil {
- log.Errorf(ctx, "error fetching list entry %q: %v", id, err)
- continue
- }
- listEntries = append(listEntries, listEntry)
- }
-
- return listEntries, nil
+ // Return list entries by their IDs.
+ return l.GetListEntriesByIDs(ctx, entryIDs)
}
func (l *listDB) PopulateListEntry(ctx context.Context, listEntry *gtsmodel.ListEntry) error {
@@ -409,10 +501,10 @@ func (l *listDB) PopulateListEntry(ctx context.Context, listEntry *gtsmodel.List
func (l *listDB) PutListEntries(ctx context.Context, entries []*gtsmodel.ListEntry) error {
defer func() {
- // Collect unique list IDs from the entries.
- listIDs := collate(func(i int) string {
- return entries[i].ListID
- }, len(entries))
+ // Collect unique list IDs from the provided entries.
+ listIDs := util.Collate(entries, func(e *gtsmodel.ListEntry) string {
+ return e.ListID
+ })
for _, id := range listIDs {
// Invalidate the timeline for the list this entry belongs to.
@@ -426,7 +518,7 @@ func (l *listDB) PutListEntries(ctx context.Context, entries []*gtsmodel.ListEnt
return l.db.RunInTx(ctx, func(tx Tx) error {
for _, entry := range entries {
entry := entry // rescope
- if err := l.state.Caches.GTS.ListEntry().Store(entry, func() error {
+ if err := l.state.Caches.GTS.ListEntry.Store(entry, func() error {
_, err := tx.
NewInsert().
Model(entry).
@@ -459,7 +551,7 @@ func (l *listDB) DeleteListEntry(ctx context.Context, id string) error {
defer func() {
// Invalidate this list entry upon delete.
- l.state.Caches.GTS.ListEntry().Invalidate("ID", id)
+ l.state.Caches.GTS.ListEntry.Invalidate("ID", id)
// Invalidate the timeline for the list this entry belongs to.
if err := l.state.Timelines.List.RemoveTimeline(ctx, entry.ListID); err != nil {
@@ -514,24 +606,3 @@ func (l *listDB) ListIncludesAccount(ctx context.Context, listID string, account
return exists, err
}
-
-// collate will collect the values of type T from an expected slice of length 'len',
-// passing the expected index to each call of 'get' and deduplicating the end result.
-func collate[T comparable](get func(int) T, len int) []T {
- ts := make([]T, 0, len)
- tm := make(map[T]struct{}, len)
-
- for i := 0; i < len; i++ {
- // Get next.
- t := get(i)
-
- if _, ok := tm[t]; !ok {
- // New value, add
- // to map + slice.
- ts = append(ts, t)
- tm[t] = struct{}{}
- }
- }
-
- return ts
-}