diff options
Diffstat (limited to 'internal/db')
| -rw-r--r-- | internal/db/bundb/list.go | 240 | ||||
| -rw-r--r-- | internal/db/bundb/relationship_follow.go | 3 | 
2 files changed, 149 insertions, 94 deletions
| diff --git a/internal/db/bundb/list.go b/internal/db/bundb/list.go index 38701cc07..837dfac27 100644 --- a/internal/db/bundb/list.go +++ b/internal/db/bundb/list.go @@ -41,6 +41,20 @@ type listDB struct {  	LIST FUNCTIONS  */ +func (l *listDB) GetListByID(ctx context.Context, id string) (*gtsmodel.List, error) { +	return l.getList( +		ctx, +		"ID", +		func(list *gtsmodel.List) error { +			return l.conn.NewSelect(). +				Model(list). +				Where("? = ?", bun.Ident("list.id"), id). +				Scan(ctx) +		}, +		id, +	) +} +  func (l *listDB) getList(ctx context.Context, lookup string, dbQuery func(*gtsmodel.List) error, keyParts ...any) (*gtsmodel.List, error) {  	list, err := l.state.Caches.GTS.List().Load(lookup, func() (*gtsmodel.List, error) {  		var list gtsmodel.List @@ -53,7 +67,8 @@ func (l *listDB) getList(ctx context.Context, lookup string, dbQuery func(*gtsmo  		return &list, nil  	}, keyParts...)  	if err != nil { -		return nil, err // already processed +		// already processed +		return nil, err  	}  	if gtscontext.Barebones(ctx) { @@ -68,20 +83,6 @@ func (l *listDB) getList(ctx context.Context, lookup string, dbQuery func(*gtsmo  	return list, nil  } -func (l *listDB) GetListByID(ctx context.Context, id string) (*gtsmodel.List, error) { -	return l.getList( -		ctx, -		"ID", -		func(list *gtsmodel.List) error { -			return l.conn.NewSelect(). -				Model(list). -				Where("? = ?", bun.Ident("list.id"), id). -				Scan(ctx) -		}, -		id, -	) -} -  func (l *listDB) GetListsForAccountID(ctx context.Context, accountID string) ([]*gtsmodel.List, error) {  	// Fetch IDs of all lists owned by this account.  	var listIDs []string @@ -107,8 +108,6 @@ func (l *listDB) GetListsForAccountID(ctx context.Context, accountID string) ([]  			log.Errorf(ctx, "error fetching list %q: %v", id, err)  			continue  		} - -		// Append list.  		lists = append(lists, list)  	} @@ -161,49 +160,89 @@ func (l *listDB) UpdateList(ctx context.Context, list *gtsmodel.List, columns ..  		columns = append(columns, "updated_at")  	} +	defer func() { +		// Invalidate all entries for this list ID. +		l.state.Caches.GTS.ListEntry().Invalidate("ListID", list.ID) + +		// Invalidate this entire list's timeline. +		if err := l.state.Timelines.List.RemoveTimeline(ctx, list.ID); err != nil { +			log.Errorf(ctx, "error invalidating list timeline: %q", err) +		} +	}() +  	return l.state.Caches.GTS.List().Store(list, func() error { -		if _, err := l.conn.NewUpdate(). +		_, err := l.conn.NewUpdate().  			Model(list).  			Where("? = ?", bun.Ident("list.id"), list.ID).  			Column(columns...). -			Exec(ctx); err != nil { -			return l.conn.ProcessError(err) -		} - -		return nil +			Exec(ctx) +		return l.conn.ProcessError(err)  	})  }  func (l *listDB) DeleteListByID(ctx context.Context, id string) error { -	defer l.state.Caches.GTS.List().Invalidate("ID", id) - -	// Select all entries that belong to this list. -	listEntries, err := l.state.DB.GetListEntries(ctx, id, "", "", "", 0) +	// Load list by ID into cache to ensure we can perform +	// all necessary cache invalidation hooks on removal. +	_, err := l.GetListByID( +		// Don't populate the entry; +		// we only want the list ID. +		gtscontext.SetBarebones(ctx), +		id, +	)  	if err != nil { -		return fmt.Errorf("error selecting entries from list %q: %w", id, err) +		if errors.Is(err, db.ErrNoEntries) { +			// Already gone. +			return nil +		} +		return err  	} -	// Delete each list entry. This will -	// invalidate the list timeline too. -	for _, listEntry := range listEntries { -		err := l.state.DB.DeleteListEntry(ctx, listEntry.ID) -		if err != nil && !errors.Is(err, db.ErrNoEntries) { +	defer func() { +		// Invalidate this list from cache. +		l.state.Caches.GTS.List().Invalidate("ID", id) + +		// Invalidate this entire list's timeline. +		if err := l.state.Timelines.List.RemoveTimeline(ctx, id); err != nil { +			log.Errorf(ctx, "error invalidating list timeline: %q", err) +		} +	}() + +	return l.conn.RunInTx(ctx, func(tx bun.Tx) error { +		// Delete all entries attached to list. +		if _, err := tx.NewDelete(). +			Table("list_entries"). +			Where("? = ?", bun.Ident("list_id"), id). +			Exec(ctx); err != nil {  			return err  		} -	} -	// Finally delete list itself from DB. -	_, err = l.conn.NewDelete(). -		Table("lists"). -		Where("? = ?", bun.Ident("id"), id). -		Exec(ctx) -	return l.conn.ProcessError(err) +		// Delete the list itself. +		_, err := tx.NewDelete(). +			Table("lists"). +			Where("? = ?", bun.Ident("id"), id). +			Exec(ctx) +		return err +	})  }  /*  	LIST ENTRY functions  */ +func (l *listDB) GetListEntryByID(ctx context.Context, id string) (*gtsmodel.ListEntry, error) { +	return l.getListEntry( +		ctx, +		"ID", +		func(listEntry *gtsmodel.ListEntry) error { +			return l.conn.NewSelect(). +				Model(listEntry). +				Where("? = ?", bun.Ident("list_entry.id"), id). +				Scan(ctx) +		}, +		id, +	) +} +  func (l *listDB) getListEntry(ctx context.Context, lookup string, dbQuery func(*gtsmodel.ListEntry) error, keyParts ...any) (*gtsmodel.ListEntry, error) {  	listEntry, err := l.state.Caches.GTS.ListEntry().Load(lookup, func() (*gtsmodel.ListEntry, error) {  		var listEntry gtsmodel.ListEntry @@ -232,20 +271,6 @@ func (l *listDB) getListEntry(ctx context.Context, lookup string, dbQuery func(*  	return listEntry, nil  } -func (l *listDB) GetListEntryByID(ctx context.Context, id string) (*gtsmodel.ListEntry, error) { -	return l.getListEntry( -		ctx, -		"ID", -		func(listEntry *gtsmodel.ListEntry) error { -			return l.conn.NewSelect(). -				Model(listEntry). -				Where("? = ?", bun.Ident("list_entry.id"), id). -				Scan(ctx) -		}, -		id, -	) -} -  func (l *listDB) GetListEntries(ctx context.Context,  	listID string,  	maxID string, @@ -328,8 +353,6 @@ func (l *listDB) GetListEntries(ctx context.Context,  			log.Errorf(ctx, "error fetching list entry %q: %v", id, err)  			continue  		} - -		// Append list entries.  		listEntries = append(listEntries, listEntry)  	} @@ -337,7 +360,7 @@ func (l *listDB) GetListEntries(ctx context.Context,  }  func (l *listDB) GetListEntriesForFollowID(ctx context.Context, followID string) ([]*gtsmodel.ListEntry, error) { -	entryIDs := []string{} +	var entryIDs []string  	if err := l.conn.  		NewSelect(). @@ -362,8 +385,6 @@ func (l *listDB) GetListEntriesForFollowID(ctx context.Context, followID string)  			log.Errorf(ctx, "error fetching list entry %q: %v", id, err)  			continue  		} - -		// Append list entries.  		listEntries = append(listEntries, listEntry)  	} @@ -387,33 +408,42 @@ func (l *listDB) PopulateListEntry(ctx context.Context, listEntry *gtsmodel.List  	return nil  } -func (l *listDB) PutListEntries(ctx context.Context, listEntries []*gtsmodel.ListEntry) error { -	return l.conn.RunInTx(ctx, func(tx bun.Tx) error { -		for _, listEntry := range listEntries { -			if _, err := tx. -				NewInsert(). -				Model(listEntry). -				Exec(ctx); err != nil { -				return err -			} +func (l *listDB) PutListEntries(ctx context.Context, entries []*gtsmodel.ListEntry) error { +	defer func() { +		// Collect unique list IDs from the entries. +		listIDs := collate(func(i int) string { +			return entries[i].ListID +		}, len(entries)) +		for _, id := range listIDs {  			// Invalidate the timeline for the list this entry belongs to. -			if err := l.state.Timelines.List.RemoveTimeline(ctx, listEntry.ListID); err != nil { -				log.Errorf(ctx, "PutListEntries: error invalidating list timeline: %q", err) +			if err := l.state.Timelines.List.RemoveTimeline(ctx, id); err != nil { +				log.Errorf(ctx, "error invalidating list timeline: %q", err)  			}  		} +	}() +	// Finally, insert each list entry into the database. +	return l.conn.RunInTx(ctx, func(tx bun.Tx) error { +		for _, entry := range entries { +			if err := l.state.Caches.GTS.ListEntry().Store(entry, func() error { +				_, err := tx. +					NewInsert(). +					Model(entry). +					Exec(ctx) +				return err +			}); err != nil { +				return err +			} +		}  		return nil  	})  }  func (l *listDB) DeleteListEntry(ctx context.Context, id string) error { -	defer l.state.Caches.GTS.ListEntry().Invalidate("ID", id) - -	// Load list entry into cache before attempting a delete, -	// as we need the followID from it in order to trigger -	// timeline invalidation. -	listEntry, err := l.GetListEntryByID( +	// Load list entry into cache to ensure we can perform +	// all necessary cache invalidation hooks on removal. +	entry, err := l.GetListEntryByID(  		// Don't populate the entry;  		// we only want the list ID.  		gtscontext.SetBarebones(ctx), @@ -428,36 +458,39 @@ func (l *listDB) DeleteListEntry(ctx context.Context, id string) error {  	}  	defer func() { +		// Invalidate this list entry upon delete. +		l.state.Caches.GTS.ListEntry().Invalidate("ID", id) +  		// Invalidate the timeline for the list this entry belongs to. -		if err := l.state.Timelines.List.RemoveTimeline(ctx, listEntry.ListID); err != nil { -			log.Errorf(ctx, "DeleteListEntry: error invalidating list timeline: %q", err) +		if err := l.state.Timelines.List.RemoveTimeline(ctx, entry.ListID); err != nil { +			log.Errorf(ctx, "error invalidating list timeline: %q", err)  		}  	}() -	if _, err := l.conn.NewDelete(). +	// Finally delete the list entry. +	_, err = l.conn.NewDelete().  		Table("list_entries"). -		Where("? = ?", bun.Ident("id"), listEntry.ID). -		Exec(ctx); err != nil { -		return l.conn.ProcessError(err) -	} - -	return nil +		Where("? = ?", bun.Ident("id"), id). +		Exec(ctx) +	return err  }  func (l *listDB) DeleteListEntriesForFollowID(ctx context.Context, followID string) error { -	// Fetch IDs of all entries that pertain to this follow. -	var listEntryIDs []string +	var entryIDs []string + +	// Fetch entry IDs for follow ID.  	if err := l.conn.  		NewSelect(). -		TableExpr("? AS ?", bun.Ident("list_entries"), bun.Ident("list_entry")). -		Column("list_entry.id"). -		Where("? = ?", bun.Ident("list_entry.follow_id"), followID). -		Order("list_entry.id DESC"). -		Scan(ctx, &listEntryIDs); err != nil { +		Table("list_entries"). +		Column("id"). +		Where("? = ?", bun.Ident("follow_id"), followID). +		Order("id DESC"). +		Scan(ctx, &entryIDs); err != nil {  		return l.conn.ProcessError(err)  	} -	for _, id := range listEntryIDs { +	for _, id := range entryIDs { +		// Delete each separately to trigger cache invalidations.  		if err := l.DeleteListEntry(ctx, id); err != nil {  			return err  		} @@ -465,3 +498,24 @@ func (l *listDB) DeleteListEntriesForFollowID(ctx context.Context, followID stri  	return nil  } + +// collate will collect the values of type T from an expected slice of length 'len', +// passing the expected index to each call of 'get' and deduplicating the end result. +func collate[T comparable](get func(int) T, len int) []T { +	ts := make([]T, 0, len) +	tm := make(map[T]struct{}, len) + +	for i := 0; i < len; i++ { +		// Get next. +		t := get(i) + +		if _, ok := tm[t]; !ok { +			// New value, add +			// to map + slice. +			ts = append(ts, t) +			tm[t] = struct{}{} +		} +	} + +	return ts +} diff --git a/internal/db/bundb/relationship_follow.go b/internal/db/bundb/relationship_follow.go index 88850e72a..349c1ef43 100644 --- a/internal/db/bundb/relationship_follow.go +++ b/internal/db/bundb/relationship_follow.go @@ -328,7 +328,8 @@ func (r *relationshipDB) DeleteAccountFollows(ctx context.Context, accountID str  		}  		// Delete each follow from DB. -		if err := r.deleteFollow(ctx, follow.ID); err != nil && !errors.Is(err, db.ErrNoEntries) { +		if err := r.deleteFollow(ctx, follow.ID); err != nil && +			!errors.Is(err, db.ErrNoEntries) {  			return err  		}  	} | 
