diff options
Diffstat (limited to 'internal/db')
-rw-r--r-- | internal/db/bundb/list.go | 240 | ||||
-rw-r--r-- | internal/db/bundb/relationship_follow.go | 3 |
2 files changed, 149 insertions, 94 deletions
diff --git a/internal/db/bundb/list.go b/internal/db/bundb/list.go index 38701cc07..837dfac27 100644 --- a/internal/db/bundb/list.go +++ b/internal/db/bundb/list.go @@ -41,6 +41,20 @@ type listDB struct { LIST FUNCTIONS */ +func (l *listDB) GetListByID(ctx context.Context, id string) (*gtsmodel.List, error) { + return l.getList( + ctx, + "ID", + func(list *gtsmodel.List) error { + return l.conn.NewSelect(). + Model(list). + Where("? = ?", bun.Ident("list.id"), id). + Scan(ctx) + }, + id, + ) +} + func (l *listDB) getList(ctx context.Context, lookup string, dbQuery func(*gtsmodel.List) error, keyParts ...any) (*gtsmodel.List, error) { list, err := l.state.Caches.GTS.List().Load(lookup, func() (*gtsmodel.List, error) { var list gtsmodel.List @@ -53,7 +67,8 @@ func (l *listDB) getList(ctx context.Context, lookup string, dbQuery func(*gtsmo return &list, nil }, keyParts...) if err != nil { - return nil, err // already processed + // already processed + return nil, err } if gtscontext.Barebones(ctx) { @@ -68,20 +83,6 @@ func (l *listDB) getList(ctx context.Context, lookup string, dbQuery func(*gtsmo return list, nil } -func (l *listDB) GetListByID(ctx context.Context, id string) (*gtsmodel.List, error) { - return l.getList( - ctx, - "ID", - func(list *gtsmodel.List) error { - return l.conn.NewSelect(). - Model(list). - Where("? = ?", bun.Ident("list.id"), id). - Scan(ctx) - }, - id, - ) -} - func (l *listDB) GetListsForAccountID(ctx context.Context, accountID string) ([]*gtsmodel.List, error) { // Fetch IDs of all lists owned by this account. var listIDs []string @@ -107,8 +108,6 @@ func (l *listDB) GetListsForAccountID(ctx context.Context, accountID string) ([] log.Errorf(ctx, "error fetching list %q: %v", id, err) continue } - - // Append list. lists = append(lists, list) } @@ -161,49 +160,89 @@ func (l *listDB) UpdateList(ctx context.Context, list *gtsmodel.List, columns .. columns = append(columns, "updated_at") } + defer func() { + // Invalidate all entries for this list ID. + l.state.Caches.GTS.ListEntry().Invalidate("ListID", list.ID) + + // Invalidate this entire list's timeline. + if err := l.state.Timelines.List.RemoveTimeline(ctx, list.ID); err != nil { + log.Errorf(ctx, "error invalidating list timeline: %q", err) + } + }() + return l.state.Caches.GTS.List().Store(list, func() error { - if _, err := l.conn.NewUpdate(). + _, err := l.conn.NewUpdate(). Model(list). Where("? = ?", bun.Ident("list.id"), list.ID). Column(columns...). - Exec(ctx); err != nil { - return l.conn.ProcessError(err) - } - - return nil + Exec(ctx) + return l.conn.ProcessError(err) }) } func (l *listDB) DeleteListByID(ctx context.Context, id string) error { - defer l.state.Caches.GTS.List().Invalidate("ID", id) - - // Select all entries that belong to this list. - listEntries, err := l.state.DB.GetListEntries(ctx, id, "", "", "", 0) + // Load list by ID into cache to ensure we can perform + // all necessary cache invalidation hooks on removal. + _, err := l.GetListByID( + // Don't populate the entry; + // we only want the list ID. + gtscontext.SetBarebones(ctx), + id, + ) if err != nil { - return fmt.Errorf("error selecting entries from list %q: %w", id, err) + if errors.Is(err, db.ErrNoEntries) { + // Already gone. + return nil + } + return err } - // Delete each list entry. This will - // invalidate the list timeline too. - for _, listEntry := range listEntries { - err := l.state.DB.DeleteListEntry(ctx, listEntry.ID) - if err != nil && !errors.Is(err, db.ErrNoEntries) { + defer func() { + // Invalidate this list from cache. + l.state.Caches.GTS.List().Invalidate("ID", id) + + // Invalidate this entire list's timeline. + if err := l.state.Timelines.List.RemoveTimeline(ctx, id); err != nil { + log.Errorf(ctx, "error invalidating list timeline: %q", err) + } + }() + + return l.conn.RunInTx(ctx, func(tx bun.Tx) error { + // Delete all entries attached to list. + if _, err := tx.NewDelete(). + Table("list_entries"). + Where("? = ?", bun.Ident("list_id"), id). + Exec(ctx); err != nil { return err } - } - // Finally delete list itself from DB. - _, err = l.conn.NewDelete(). - Table("lists"). - Where("? = ?", bun.Ident("id"), id). - Exec(ctx) - return l.conn.ProcessError(err) + // Delete the list itself. + _, err := tx.NewDelete(). + Table("lists"). + Where("? = ?", bun.Ident("id"), id). + Exec(ctx) + return err + }) } /* LIST ENTRY functions */ +func (l *listDB) GetListEntryByID(ctx context.Context, id string) (*gtsmodel.ListEntry, error) { + return l.getListEntry( + ctx, + "ID", + func(listEntry *gtsmodel.ListEntry) error { + return l.conn.NewSelect(). + Model(listEntry). + Where("? = ?", bun.Ident("list_entry.id"), id). + Scan(ctx) + }, + id, + ) +} + func (l *listDB) getListEntry(ctx context.Context, lookup string, dbQuery func(*gtsmodel.ListEntry) error, keyParts ...any) (*gtsmodel.ListEntry, error) { listEntry, err := l.state.Caches.GTS.ListEntry().Load(lookup, func() (*gtsmodel.ListEntry, error) { var listEntry gtsmodel.ListEntry @@ -232,20 +271,6 @@ func (l *listDB) getListEntry(ctx context.Context, lookup string, dbQuery func(* return listEntry, nil } -func (l *listDB) GetListEntryByID(ctx context.Context, id string) (*gtsmodel.ListEntry, error) { - return l.getListEntry( - ctx, - "ID", - func(listEntry *gtsmodel.ListEntry) error { - return l.conn.NewSelect(). - Model(listEntry). - Where("? = ?", bun.Ident("list_entry.id"), id). - Scan(ctx) - }, - id, - ) -} - func (l *listDB) GetListEntries(ctx context.Context, listID string, maxID string, @@ -328,8 +353,6 @@ func (l *listDB) GetListEntries(ctx context.Context, log.Errorf(ctx, "error fetching list entry %q: %v", id, err) continue } - - // Append list entries. listEntries = append(listEntries, listEntry) } @@ -337,7 +360,7 @@ func (l *listDB) GetListEntries(ctx context.Context, } func (l *listDB) GetListEntriesForFollowID(ctx context.Context, followID string) ([]*gtsmodel.ListEntry, error) { - entryIDs := []string{} + var entryIDs []string if err := l.conn. NewSelect(). @@ -362,8 +385,6 @@ func (l *listDB) GetListEntriesForFollowID(ctx context.Context, followID string) log.Errorf(ctx, "error fetching list entry %q: %v", id, err) continue } - - // Append list entries. listEntries = append(listEntries, listEntry) } @@ -387,33 +408,42 @@ func (l *listDB) PopulateListEntry(ctx context.Context, listEntry *gtsmodel.List return nil } -func (l *listDB) PutListEntries(ctx context.Context, listEntries []*gtsmodel.ListEntry) error { - return l.conn.RunInTx(ctx, func(tx bun.Tx) error { - for _, listEntry := range listEntries { - if _, err := tx. - NewInsert(). - Model(listEntry). - Exec(ctx); err != nil { - return err - } +func (l *listDB) PutListEntries(ctx context.Context, entries []*gtsmodel.ListEntry) error { + defer func() { + // Collect unique list IDs from the entries. + listIDs := collate(func(i int) string { + return entries[i].ListID + }, len(entries)) + for _, id := range listIDs { // Invalidate the timeline for the list this entry belongs to. - if err := l.state.Timelines.List.RemoveTimeline(ctx, listEntry.ListID); err != nil { - log.Errorf(ctx, "PutListEntries: error invalidating list timeline: %q", err) + if err := l.state.Timelines.List.RemoveTimeline(ctx, id); err != nil { + log.Errorf(ctx, "error invalidating list timeline: %q", err) } } + }() + // Finally, insert each list entry into the database. + return l.conn.RunInTx(ctx, func(tx bun.Tx) error { + for _, entry := range entries { + if err := l.state.Caches.GTS.ListEntry().Store(entry, func() error { + _, err := tx. + NewInsert(). + Model(entry). + Exec(ctx) + return err + }); err != nil { + return err + } + } return nil }) } func (l *listDB) DeleteListEntry(ctx context.Context, id string) error { - defer l.state.Caches.GTS.ListEntry().Invalidate("ID", id) - - // Load list entry into cache before attempting a delete, - // as we need the followID from it in order to trigger - // timeline invalidation. - listEntry, err := l.GetListEntryByID( + // Load list entry into cache to ensure we can perform + // all necessary cache invalidation hooks on removal. + entry, err := l.GetListEntryByID( // Don't populate the entry; // we only want the list ID. gtscontext.SetBarebones(ctx), @@ -428,36 +458,39 @@ func (l *listDB) DeleteListEntry(ctx context.Context, id string) error { } defer func() { + // Invalidate this list entry upon delete. + l.state.Caches.GTS.ListEntry().Invalidate("ID", id) + // Invalidate the timeline for the list this entry belongs to. - if err := l.state.Timelines.List.RemoveTimeline(ctx, listEntry.ListID); err != nil { - log.Errorf(ctx, "DeleteListEntry: error invalidating list timeline: %q", err) + if err := l.state.Timelines.List.RemoveTimeline(ctx, entry.ListID); err != nil { + log.Errorf(ctx, "error invalidating list timeline: %q", err) } }() - if _, err := l.conn.NewDelete(). + // Finally delete the list entry. + _, err = l.conn.NewDelete(). Table("list_entries"). - Where("? = ?", bun.Ident("id"), listEntry.ID). - Exec(ctx); err != nil { - return l.conn.ProcessError(err) - } - - return nil + Where("? = ?", bun.Ident("id"), id). + Exec(ctx) + return err } func (l *listDB) DeleteListEntriesForFollowID(ctx context.Context, followID string) error { - // Fetch IDs of all entries that pertain to this follow. - var listEntryIDs []string + var entryIDs []string + + // Fetch entry IDs for follow ID. if err := l.conn. NewSelect(). - TableExpr("? AS ?", bun.Ident("list_entries"), bun.Ident("list_entry")). - Column("list_entry.id"). - Where("? = ?", bun.Ident("list_entry.follow_id"), followID). - Order("list_entry.id DESC"). - Scan(ctx, &listEntryIDs); err != nil { + Table("list_entries"). + Column("id"). + Where("? = ?", bun.Ident("follow_id"), followID). + Order("id DESC"). + Scan(ctx, &entryIDs); err != nil { return l.conn.ProcessError(err) } - for _, id := range listEntryIDs { + for _, id := range entryIDs { + // Delete each separately to trigger cache invalidations. if err := l.DeleteListEntry(ctx, id); err != nil { return err } @@ -465,3 +498,24 @@ func (l *listDB) DeleteListEntriesForFollowID(ctx context.Context, followID stri return nil } + +// collate will collect the values of type T from an expected slice of length 'len', +// passing the expected index to each call of 'get' and deduplicating the end result. +func collate[T comparable](get func(int) T, len int) []T { + ts := make([]T, 0, len) + tm := make(map[T]struct{}, len) + + for i := 0; i < len; i++ { + // Get next. + t := get(i) + + if _, ok := tm[t]; !ok { + // New value, add + // to map + slice. + ts = append(ts, t) + tm[t] = struct{}{} + } + } + + return ts +} diff --git a/internal/db/bundb/relationship_follow.go b/internal/db/bundb/relationship_follow.go index 88850e72a..349c1ef43 100644 --- a/internal/db/bundb/relationship_follow.go +++ b/internal/db/bundb/relationship_follow.go @@ -328,7 +328,8 @@ func (r *relationshipDB) DeleteAccountFollows(ctx context.Context, accountID str } // Delete each follow from DB. - if err := r.deleteFollow(ctx, follow.ID); err != nil && !errors.Is(err, db.ErrNoEntries) { + if err := r.deleteFollow(ctx, follow.ID); err != nil && + !errors.Is(err, db.ErrNoEntries) { return err } } |