diff options
Diffstat (limited to 'internal/db/bundb/list.go')
| -rw-r--r-- | internal/db/bundb/list.go | 565 |
1 files changed, 221 insertions, 344 deletions
diff --git a/internal/db/bundb/list.go b/internal/db/bundb/list.go index 0ed0f1b15..03dff95e3 100644 --- a/internal/db/bundb/list.go +++ b/internal/db/bundb/list.go @@ -29,6 +29,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/log" + "github.com/superseriousbusiness/gotosocial/internal/paging" "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/util" "github.com/uptrace/bun" @@ -85,39 +86,52 @@ func (l *listDB) getList(ctx context.Context, lookup string, dbQuery func(*gtsmo return list, nil } -func (l *listDB) GetListsForAccountID(ctx context.Context, accountID string) ([]*gtsmodel.List, error) { - // Fetch IDs of all lists owned by this account. - var listIDs []string - if err := l.db. - NewSelect(). - TableExpr("? AS ?", bun.Ident("lists"), bun.Ident("list")). - Column("list.id"). - Where("? = ?", bun.Ident("list.account_id"), accountID). - Order("list.id DESC"). - Scan(ctx, &listIDs); err != nil { +func (l *listDB) GetListsByAccountID(ctx context.Context, accountID string) ([]*gtsmodel.List, error) { + listIDs, err := l.getListIDsByAccountID(ctx, accountID) + if err != nil { return nil, err } + return l.GetListsByIDs(ctx, listIDs) +} - if len(listIDs) == 0 { - return nil, nil - } +func (l *listDB) CountListsByAccountID(ctx context.Context, accountID string) (int, error) { + listIDs, err := l.getListIDsByAccountID(ctx, accountID) + return len(listIDs), err +} - // Return lists by their IDs. +func (l *listDB) GetListsContainingFollowID(ctx context.Context, followID string) ([]*gtsmodel.List, error) { + listIDs, err := l.getListIDsWithFollowID(ctx, followID) + if err != nil { + return nil, err + } return l.GetListsByIDs(ctx, listIDs) } -func (l *listDB) CountListsForAccountID(ctx context.Context, accountID string) (int, error) { - return l.db. - NewSelect(). - Table("lists"). - Where("? = ?", bun.Ident("account_id"), accountID). - Count(ctx) +func (l *listDB) GetFollowsInList(ctx context.Context, listID string, page *paging.Page) ([]*gtsmodel.Follow, error) { + followIDs, err := l.GetFollowIDsInList(ctx, listID, page) + if err != nil { + return nil, err + } + return l.state.DB.GetFollowsByIDs(ctx, followIDs) +} + +func (l *listDB) GetAccountsInList(ctx context.Context, listID string, page *paging.Page) ([]*gtsmodel.Account, error) { + accountIDs, err := l.GetAccountIDsInList(ctx, listID, page) + if err != nil { + return nil, err + } + return l.state.DB.GetAccountsByIDs(ctx, accountIDs) +} + +func (l *listDB) IsAccountInList(ctx context.Context, listID string, accountID string) (bool, error) { + accountIDs, err := l.GetAccountIDsInList(ctx, listID, nil) + return slices.Contains(accountIDs, accountID), err } func (l *listDB) PopulateList(ctx context.Context, list *gtsmodel.List) error { var ( err error - errs = gtserror.NewMultiError(2) + errs gtserror.MultiError ) if list.Account == nil { @@ -131,22 +145,12 @@ func (l *listDB) PopulateList(ctx context.Context, list *gtsmodel.List) error { } } - if list.ListEntries == nil { - // List entries are not set, fetch from the database. - list.ListEntries, err = l.state.DB.GetListEntries( - gtscontext.SetBarebones(ctx), - list.ID, - "", "", "", 0, - ) - if err != nil { - errs.Appendf("error populating list entries: %w", err) - } - } - return errs.Combine() } func (l *listDB) PutList(ctx context.Context, list *gtsmodel.List) error { + // note that inserting list will call OnInvalidateList() + // which will handle clearing caches other than List cache. return l.state.Caches.DB.List.Store(list, func() error { _, err := l.db.NewInsert().Model(list).Exec(ctx) return err @@ -160,192 +164,146 @@ func (l *listDB) UpdateList(ctx context.Context, list *gtsmodel.List, columns .. columns = append(columns, "updated_at") } - defer func() { - // Invalidate all entries for this list ID. - l.state.Caches.DB.ListEntry.Invalidate("ListID", list.ID) - - // Invalidate this entire list's timeline. - if err := l.state.Timelines.List.RemoveTimeline(ctx, list.ID); err != nil { - log.Errorf(ctx, "error invalidating list timeline: %q", err) - } - }() - - return l.state.Caches.DB.List.Store(list, func() error { + // Update list in the database, invalidating main list cache. + if err := l.state.Caches.DB.List.Store(list, func() error { _, err := l.db.NewUpdate(). Model(list). Where("? = ?", bun.Ident("list.id"), list.ID). Column(columns...). Exec(ctx) return err - }) -} - -func (l *listDB) DeleteListByID(ctx context.Context, id string) error { - // Load list by ID into cache to ensure we can perform - // all necessary cache invalidation hooks on removal. - _, err := l.GetListByID( - // Don't populate the entry; - // we only want the list ID. - gtscontext.SetBarebones(ctx), - id, - ) - if err != nil && !errors.Is(err, db.ErrNoEntries) { - // NOTE: even if db.ErrNoEntries is returned, we - // still run the below transaction to ensure related - // objects are appropriately deleted. + }); err != nil { return err } - defer func() { - // Invalidate this list from cache. - l.state.Caches.DB.List.Invalidate("ID", id) + // Invalidate this entire list's timeline. + if err := l.state.Timelines.List.RemoveTimeline(ctx, list.ID); err != nil { + log.Errorf(ctx, "error invalidating list timeline: %q", err) + } - // Invalidate this entire list's timeline. - if err := l.state.Timelines.List.RemoveTimeline(ctx, id); err != nil { - log.Errorf(ctx, "error invalidating list timeline: %q", err) - } - }() + return nil +} + +func (l *listDB) DeleteListByID(ctx context.Context, id string) error { + // Acquire list owner ID. + var accountID string + + // Gather follow IDs of all + // entries contained in list. + var followIDs []string - return l.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { - // Delete all entries attached to list. + // Delete all list entries associated with list, and list itself in transaction. + if err := l.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { if _, err := tx.NewDelete(). Table("list_entries"). Where("? = ?", bun.Ident("list_id"), id). - Exec(ctx); err != nil { + Returning("?", bun.Ident("follow_id")). + Exec(ctx, &followIDs); err != nil { return err } - // Delete the list itself. _, err := tx.NewDelete(). Table("lists"). Where("? = ?", bun.Ident("id"), id). - Exec(ctx) + Returning("?", bun.Ident("account_id")). + Exec(ctx, &accountID) return err - }) -} + }); err != nil { + return err + } -/* - LIST ENTRY functions -*/ + // Invalidate the main list database cache. + l.state.Caches.DB.List.Invalidate("ID", id) -func (l *listDB) GetListEntryByID(ctx context.Context, id string) (*gtsmodel.ListEntry, error) { - return l.getListEntry( - ctx, - "ID", - func(listEntry *gtsmodel.ListEntry) error { - return l.db.NewSelect(). - Model(listEntry). - Where("? = ?", bun.Ident("list_entry.id"), id). - Scan(ctx) - }, - id, - ) + // Invalidate cache of list IDs owned by account. + l.state.Caches.DB.ListIDs.Invalidate("a" + accountID) + + // Invalidate all related entry caches for this list. + l.invalidateEntryCaches(ctx, []string{id}, followIDs) + + return nil } -func (l *listDB) getListEntry(ctx context.Context, lookup string, dbQuery func(*gtsmodel.ListEntry) error, keyParts ...any) (*gtsmodel.ListEntry, error) { - listEntry, err := l.state.Caches.DB.ListEntry.LoadOne(lookup, func() (*gtsmodel.ListEntry, error) { - var listEntry gtsmodel.ListEntry +func (l *listDB) getListIDsByAccountID(ctx context.Context, accountID string) ([]string, error) { + return l.state.Caches.DB.ListIDs.Load("a"+accountID, func() ([]string, error) { + var listIDs []string - // Not cached! Perform database query. - if err := dbQuery(&listEntry); err != nil { + // List IDs not in cache. + // Perform the DB query. + if _, err := l.db.NewSelect(). + Table("lists"). + Column("id"). + Where("? = ?", bun.Ident("account_id"), accountID). + OrderExpr("? DESC", bun.Ident("created_at")). + Exec(ctx, &listIDs); err != nil && + !errors.Is(err, db.ErrNoEntries) { return nil, err } - return &listEntry, nil - }, keyParts...) - if err != nil { - return nil, err // already processed - } - - if gtscontext.Barebones(ctx) { - // Only a barebones model was requested. - return listEntry, nil - } - - // Further populate the list entry fields where applicable. - if err := l.state.DB.PopulateListEntry(ctx, listEntry); err != nil { - return nil, err - } - - return listEntry, nil + return listIDs, nil + }) } -func (l *listDB) GetListEntries(ctx context.Context, - listID string, - maxID string, - sinceID string, - minID string, - limit int, -) ([]*gtsmodel.ListEntry, error) { - // Ensure reasonable - if limit < 0 { - limit = 0 - } - - // Make educated guess for slice size - var ( - entryIDs = make([]string, 0, limit) - frontToBack = true - ) - - q := l.db. - NewSelect(). - TableExpr("? AS ?", bun.Ident("list_entries"), bun.Ident("entry")). - // Select only IDs from table - Column("entry.id"). - // Select only entries belonging to listID. - Where("? = ?", bun.Ident("entry.list_id"), listID) - - if maxID != "" { - // return only entries LOWER (ie., older) than maxID - q = q.Where("? < ?", bun.Ident("entry.id"), maxID) - } - - if sinceID != "" { - // return only entries HIGHER (ie., newer) than sinceID - q = q.Where("? > ?", bun.Ident("entry.id"), sinceID) - } +func (l *listDB) getListIDsWithFollowID(ctx context.Context, followID string) ([]string, error) { + return l.state.Caches.DB.ListIDs.Load("f"+followID, func() ([]string, error) { + var listIDs []string - if minID != "" { - // return only entries HIGHER (ie., newer) than minID - q = q.Where("? > ?", bun.Ident("entry.id"), minID) + // List IDs not in cache. + // Perform the DB query. + if _, err := l.db.NewSelect(). + Table("list_entries"). + Column("list_id"). + Where("? = ?", bun.Ident("follow_id"), followID). + OrderExpr("? DESC", bun.Ident("created_at")). + Exec(ctx, &listIDs); err != nil && + !errors.Is(err, db.ErrNoEntries) { + return nil, err + } - // page up - frontToBack = false - } + return listIDs, nil + }) +} - if limit > 0 { - // limit amount of entries returned - q = q.Limit(limit) - } +func (l *listDB) GetFollowIDsInList(ctx context.Context, listID string, page *paging.Page) ([]string, error) { + return loadPagedIDs(&l.state.Caches.DB.ListedIDs, "f"+listID, page, func() ([]string, error) { + var followIDs []string - if frontToBack { - // Page down. - q = q.Order("entry.id DESC") - } else { - // Page up. - q = q.Order("entry.id ASC") - } + // Follow IDs not in cache. + // Perform the DB query. + _, err := l.db.NewSelect(). + Table("list_entries"). + Column("follow_id"). + Where("? = ?", bun.Ident("list_id"), listID). + OrderExpr("? DESC", bun.Ident("created_at")). + Exec(ctx, &followIDs) + if err != nil && !errors.Is(err, db.ErrNoEntries) { + return nil, err + } - if err := q.Scan(ctx, &entryIDs); err != nil { - return nil, err - } + return followIDs, nil + }) +} - if len(entryIDs) == 0 { - return nil, nil - } +func (l *listDB) GetAccountIDsInList(ctx context.Context, listID string, page *paging.Page) ([]string, error) { + return loadPagedIDs(&l.state.Caches.DB.ListedIDs, "a"+listID, page, func() ([]string, error) { + var accountIDs []string - // If we're paging up, we still want entries - // to be sorted by ID desc, so reverse ids slice. - // https://zchee.github.io/golang-wiki/SliceTricks/#reversing - if !frontToBack { - for l, r := 0, len(entryIDs)-1; l < r; l, r = l+1, r-1 { - entryIDs[l], entryIDs[r] = entryIDs[r], entryIDs[l] + // Account IDs not in cache. + // Perform the DB query. + _, err := l.db.NewSelect(). + Table("follows"). + Column("follows.target_account_id"). + Join("INNER JOIN ?", bun.Ident("list_entries")). + JoinOn("? = ?", bun.Ident("follows.id"), bun.Ident("list_entries.follow_id")). + Where("? = ?", bun.Ident("list_entries.list_id"), listID). + OrderExpr("? DESC", bun.Ident("list_entries.id")). + Exec(ctx, &accountIDs) + if err != nil && !errors.Is(err, db.ErrNoEntries) { + return nil, err } - } - // Return list entries by their IDs. - return l.GetListEntriesByIDs(ctx, entryIDs) + return accountIDs, nil + }) } func (l *listDB) GetListsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.List, error) { @@ -353,15 +311,8 @@ func (l *listDB) GetListsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.L lists, err := l.state.Caches.DB.List.LoadIDs("ID", ids, func(uncached []string) ([]*gtsmodel.List, error) { - // Avoid querying - // if none uncached. - count := len(uncached) - if count == 0 { - return nil, nil - } - // Preallocate expected length of uncached lists. - lists := make([]*gtsmodel.List, 0, count) + lists := make([]*gtsmodel.List, 0, len(uncached)) // Perform database query scanning // the remaining (uncached) IDs. @@ -402,82 +353,6 @@ func (l *listDB) GetListsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.L return lists, nil } -func (l *listDB) GetListEntriesByIDs(ctx context.Context, ids []string) ([]*gtsmodel.ListEntry, error) { - // Load all entry IDs via cache loader callbacks. - entries, err := l.state.Caches.DB.ListEntry.LoadIDs("ID", - ids, - func(uncached []string) ([]*gtsmodel.ListEntry, error) { - // Avoid querying - // if none uncached. - count := len(uncached) - if count == 0 { - return nil, nil - } - - // Preallocate expected length of uncached entries. - entries := make([]*gtsmodel.ListEntry, 0, count) - - // Perform database query scanning - // the remaining (uncached) IDs. - if err := l.db.NewSelect(). - Model(&entries). - Where("? IN (?)", bun.Ident("id"), bun.In(uncached)). - Scan(ctx); err != nil { - return nil, err - } - - return entries, nil - }, - ) - if err != nil { - return nil, err - } - - // Reorder the entries by their - // IDs to ensure in correct order. - getID := func(e *gtsmodel.ListEntry) string { return e.ID } - util.OrderBy(entries, ids, getID) - - if gtscontext.Barebones(ctx) { - // no need to fully populate. - return entries, nil - } - - // Populate all loaded entries, removing those we fail to - // populate (removes needing so many nil checks everywhere). - entries = slices.DeleteFunc(entries, func(entry *gtsmodel.ListEntry) bool { - if err := l.PopulateListEntry(ctx, entry); err != nil { - log.Errorf(ctx, "error populating entry %s: %v", entry.ID, err) - return true - } - return false - }) - - return entries, nil -} - -func (l *listDB) GetListEntriesForFollowID(ctx context.Context, followID string) ([]*gtsmodel.ListEntry, error) { - var entryIDs []string - - if err := l.db. - NewSelect(). - TableExpr("? AS ?", bun.Ident("list_entries"), bun.Ident("entry")). - // Select only IDs from table - Column("entry.id"). - // Select only entries belonging with given followID. - Where("? = ?", bun.Ident("entry.follow_id"), followID). - Scan(ctx, &entryIDs); err != nil { - return nil, err - } - - if len(entryIDs) == 0 { - return nil, nil - } - - // Return list entries by their IDs. - return l.GetListEntriesByIDs(ctx, entryIDs) -} - func (l *listDB) PopulateListEntry(ctx context.Context, listEntry *gtsmodel.ListEntry) error { var err error @@ -496,109 +371,111 @@ func (l *listDB) PopulateListEntry(ctx context.Context, listEntry *gtsmodel.List } func (l *listDB) PutListEntries(ctx context.Context, entries []*gtsmodel.ListEntry) error { - defer func() { - // Collect unique list IDs from the provided entries. - listIDs := util.Collate(entries, func(e *gtsmodel.ListEntry) string { - return e.ListID - }) - - for _, id := range listIDs { - // Invalidate the timeline for the list this entry belongs to. - if err := l.state.Timelines.List.RemoveTimeline(ctx, id); err != nil { - log.Errorf(ctx, "error invalidating list timeline: %q", err) - } - } - }() - - // Finally, insert each list entry into the database. - return l.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + // Insert all entries into the database in a single transaction (all or nothing!). + if err := l.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { for _, entry := range entries { - entry := entry // rescope - if err := l.state.Caches.DB.ListEntry.Store(entry, func() error { - _, err := tx. - NewInsert(). - Model(entry). - Exec(ctx) - return err - }); err != nil { + if _, err := tx. + NewInsert(). + Model(entry). + Exec(ctx); err != nil { return err } } return nil + }); err != nil { + return err + } + + // Collect unique list IDs from the provided list entries. + listIDs := util.Collate(entries, func(e *gtsmodel.ListEntry) string { + return e.ListID + }) + + // Collect unique follow IDs from the provided list entries. + followIDs := util.Collate(entries, func(e *gtsmodel.ListEntry) string { + return e.FollowID }) + + // Invalidate all related list entry caches. + l.invalidateEntryCaches(ctx, listIDs, followIDs) + + return nil } -func (l *listDB) DeleteListEntry(ctx context.Context, id string) error { - // Load list entry into cache to ensure we can perform - // all necessary cache invalidation hooks on removal. - entry, err := l.GetListEntryByID( - // Don't populate the entry; - // we only want the list ID. - gtscontext.SetBarebones(ctx), - id, - ) - if err != nil { - if errors.Is(err, db.ErrNoEntries) { - // Already gone. - return nil - } +func (l *listDB) DeleteListEntry(ctx context.Context, listID string, followID string) error { + // Delete list entry with given + // ID, returning its list ID. + if _, err := l.db.NewDelete(). + Table("list_entries"). + Where("? = ?", bun.Ident("list_id"), listID). + Where("? = ?", bun.Ident("follow_id"), followID). + Exec(ctx, &listID); err != nil && + !errors.Is(err, db.ErrNoEntries) { return err } - defer func() { - // Invalidate this list entry upon delete. - l.state.Caches.DB.ListEntry.Invalidate("ID", id) + // Invalidate all related list entry caches. + l.invalidateEntryCaches(ctx, []string{listID}, + []string{followID}) - // Invalidate the timeline for the list this entry belongs to. - if err := l.state.Timelines.List.RemoveTimeline(ctx, entry.ListID); err != nil { - log.Errorf(ctx, "error invalidating list timeline: %q", err) - } - }() - - // Finally delete the list entry. - _, err = l.db.NewDelete(). - Table("list_entries"). - Where("? = ?", bun.Ident("id"), id). - Exec(ctx) - return err + return nil } -func (l *listDB) DeleteListEntriesForFollowID(ctx context.Context, followID string) error { - var entryIDs []string +func (l *listDB) DeleteAllListEntriesByFollows(ctx context.Context, followIDs ...string) error { + var listIDs []string + + // Check for empty list. + if len(followIDs) == 0 { + return nil + } - // Fetch entry IDs for follow ID. - if err := l.db. - NewSelect(). + // Delete all entries with follow + // ID, returning IDs and list IDs. + if _, err := l.db.NewDelete(). Table("list_entries"). - Column("id"). - Where("? = ?", bun.Ident("follow_id"), followID). - Order("id DESC"). - Scan(ctx, &entryIDs); err != nil { + Where("? IN (?)", bun.Ident("follow_id"), bun.In(followIDs)). + Returning("?", bun.Ident("list_id")). + Exec(ctx, &listIDs); err != nil && + !errors.Is(err, db.ErrNoEntries) { return err } - for _, id := range entryIDs { - // Delete each separately to trigger cache invalidations. - if err := l.DeleteListEntry(ctx, id); err != nil { - return err - } - } + // Deduplicate IDs before invalidate. + listIDs = util.Deduplicate(listIDs) + + // Invalidate all related list entry caches. + l.invalidateEntryCaches(ctx, listIDs, followIDs) return nil } -func (l *listDB) ListIncludesAccount(ctx context.Context, listID string, accountID string) (bool, error) { - exists, err := l.db. - NewSelect(). - TableExpr("? AS ?", bun.Ident("list_entries"), bun.Ident("list_entry")). - Join( - "JOIN ? AS ? ON ? = ?", - bun.Ident("follows"), bun.Ident("follow"), - bun.Ident("list_entry.follow_id"), bun.Ident("follow.id"), - ). - Where("? = ?", bun.Ident("list_entry.list_id"), listID). - Where("? = ?", bun.Ident("follow.target_account_id"), accountID). - Exists(ctx) +// invalidateEntryCaches will invalidate all related ListEntry caches for given list IDs and follow IDs, including timelines. +func (l *listDB) invalidateEntryCaches(ctx context.Context, listIDs, followIDs []string) { + var keys []string + + // Generate ListedID keys to invalidate. + keys = slices.Grow(keys[:0], 2*len(listIDs)) + for _, listID := range listIDs { + keys = append(keys, + "a"+listID, + "f"+listID, + ) + + // Invalidate the timeline for the list this entry belongs to. + if err := l.state.Timelines.List.RemoveTimeline(ctx, listID); err != nil { + log.Errorf(ctx, "error invalidating list timeline: %q", err) + } + } + + // Invalidate ListedID slice cache entries. + l.state.Caches.DB.ListedIDs.Invalidate(keys...) + + // Generate ListID keys to invalidate. + keys = slices.Grow(keys[:0], len(followIDs)) + for _, followID := range followIDs { + keys = append(keys, "f"+followID) + } - return exists, err + // Invalidate ListID slice cache entries. + l.state.Caches.DB.ListIDs.Invalidate(keys...) } |
