diff options
Diffstat (limited to 'internal/db/bundb')
26 files changed, 1018 insertions, 517 deletions
| diff --git a/internal/db/bundb/account.go b/internal/db/bundb/account.go index fdee8cb76..cdb949efa 100644 --- a/internal/db/bundb/account.go +++ b/internal/db/bundb/account.go @@ -116,7 +116,7 @@ func (a *accountDB) GetAccountByUsernameDomain(ctx context.Context, username str  	return a.getAccount(  		ctx, -		"Username.Domain", +		"Username,Domain",  		func(account *gtsmodel.Account) error {  			q := a.db.NewSelect().  				Model(account) @@ -224,7 +224,7 @@ func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gts  func (a *accountDB) getAccount(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Account) error, keyParts ...any) (*gtsmodel.Account, error) {  	// Fetch account from database cache with loader callback -	account, err := a.state.Caches.GTS.Account().Load(lookup, func() (*gtsmodel.Account, error) { +	account, err := a.state.Caches.GTS.Account.LoadOne(lookup, func() (*gtsmodel.Account, error) {  		var account gtsmodel.Account  		// Not cached! Perform database query @@ -325,7 +325,7 @@ func (a *accountDB) PopulateAccount(ctx context.Context, account *gtsmodel.Accou  }  func (a *accountDB) PutAccount(ctx context.Context, account *gtsmodel.Account) error { -	return a.state.Caches.GTS.Account().Store(account, func() error { +	return a.state.Caches.GTS.Account.Store(account, func() error {  		// It is safe to run this database transaction within cache.Store  		// as the cache does not attempt a mutex lock until AFTER hook.  		// @@ -354,7 +354,7 @@ func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account  		columns = append(columns, "updated_at")  	} -	return a.state.Caches.GTS.Account().Store(account, func() error { +	return a.state.Caches.GTS.Account.Store(account, func() error {  		// It is safe to run this database transaction within cache.Store  		// as the cache does not attempt a mutex lock until AFTER hook.  		// @@ -393,7 +393,7 @@ func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account  }  func (a *accountDB) DeleteAccount(ctx context.Context, id string) error { -	defer a.state.Caches.GTS.Account().Invalidate("ID", id) +	defer a.state.Caches.GTS.Account.Invalidate("ID", id)  	// Load account into cache before attempting a delete,  	// as we need it cached in order to trigger the invalidate @@ -635,6 +635,10 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li  		return nil, err  	} +	if len(statusIDs) == 0 { +		return nil, db.ErrNoEntries +	} +  	// If we're paging up, we still want statuses  	// to be sorted by ID desc, so reverse ids slice.  	// https://zchee.github.io/golang-wiki/SliceTricks/#reversing @@ -644,7 +648,7 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li  		}  	} -	return a.statusesFromIDs(ctx, statusIDs) +	return a.state.DB.GetStatusesByIDs(ctx, statusIDs)  }  func (a *accountDB) GetAccountPinnedStatuses(ctx context.Context, accountID string) ([]*gtsmodel.Status, error) { @@ -662,7 +666,11 @@ func (a *accountDB) GetAccountPinnedStatuses(ctx context.Context, accountID stri  		return nil, err  	} -	return a.statusesFromIDs(ctx, statusIDs) +	if len(statusIDs) == 0 { +		return nil, db.ErrNoEntries +	} + +	return a.state.DB.GetStatusesByIDs(ctx, statusIDs)  }  func (a *accountDB) GetAccountWebStatuses(ctx context.Context, accountID string, limit int, maxID string) ([]*gtsmodel.Status, error) { @@ -710,29 +718,9 @@ func (a *accountDB) GetAccountWebStatuses(ctx context.Context, accountID string,  		return nil, err  	} -	return a.statusesFromIDs(ctx, statusIDs) -} - -func (a *accountDB) statusesFromIDs(ctx context.Context, statusIDs []string) ([]*gtsmodel.Status, error) { -	// Catch case of no statuses early  	if len(statusIDs) == 0 {  		return nil, db.ErrNoEntries  	} -	// Allocate return slice (will be at most len statusIDS) -	statuses := make([]*gtsmodel.Status, 0, len(statusIDs)) - -	for _, id := range statusIDs { -		// Fetch from status from database by ID -		status, err := a.state.DB.GetStatusByID(ctx, id) -		if err != nil { -			log.Errorf(ctx, "error getting status %q: %v", id, err) -			continue -		} - -		// Append to return slice -		statuses = append(statuses, status) -	} - -	return statuses, nil +	return a.state.DB.GetStatusesByIDs(ctx, statusIDs)  } diff --git a/internal/db/bundb/application.go b/internal/db/bundb/application.go index f7328e275..2e17a0e94 100644 --- a/internal/db/bundb/application.go +++ b/internal/db/bundb/application.go @@ -53,7 +53,7 @@ func (a *applicationDB) GetApplicationByClientID(ctx context.Context, clientID s  }  func (a *applicationDB) getApplication(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Application) error, keyParts ...any) (*gtsmodel.Application, error) { -	return a.state.Caches.GTS.Application().Load(lookup, func() (*gtsmodel.Application, error) { +	return a.state.Caches.GTS.Application.LoadOne(lookup, func() (*gtsmodel.Application, error) {  		var app gtsmodel.Application  		// Not cached! Perform database query. @@ -66,7 +66,7 @@ func (a *applicationDB) getApplication(ctx context.Context, lookup string, dbQue  }  func (a *applicationDB) PutApplication(ctx context.Context, app *gtsmodel.Application) error { -	return a.state.Caches.GTS.Application().Store(app, func() error { +	return a.state.Caches.GTS.Application.Store(app, func() error {  		_, err := a.db.NewInsert().Model(app).Exec(ctx)  		return err  	}) @@ -91,7 +91,7 @@ func (a *applicationDB) DeleteApplicationByClientID(ctx context.Context, clientI  	//  	// Clear application from the cache. -	a.state.Caches.GTS.Application().Invalidate("ClientID", clientID) +	a.state.Caches.GTS.Application.Invalidate("ClientID", clientID)  	return nil  } diff --git a/internal/db/bundb/bundb.go b/internal/db/bundb/bundb.go index d9415eff4..048474782 100644 --- a/internal/db/bundb/bundb.go +++ b/internal/db/bundb/bundb.go @@ -258,7 +258,7 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) {  			state: state,  		},  		Tag: &tagDB{ -			conn:  db, +			db:    db,  			state: state,  		},  		Thread: &threadDB{ diff --git a/internal/db/bundb/domain.go b/internal/db/bundb/domain.go index dd626bc0a..2398e52c2 100644 --- a/internal/db/bundb/domain.go +++ b/internal/db/bundb/domain.go @@ -51,7 +51,7 @@ func (d *domainDB) CreateDomainAllow(ctx context.Context, allow *gtsmodel.Domain  	}  	// Clear the domain allow cache (for later reload) -	d.state.Caches.GTS.DomainAllow().Clear() +	d.state.Caches.GTS.DomainAllow.Clear()  	return nil  } @@ -126,7 +126,7 @@ func (d *domainDB) DeleteDomainAllow(ctx context.Context, domain string) error {  	}  	// Clear the domain allow cache (for later reload) -	d.state.Caches.GTS.DomainAllow().Clear() +	d.state.Caches.GTS.DomainAllow.Clear()  	return nil  } @@ -147,7 +147,7 @@ func (d *domainDB) CreateDomainBlock(ctx context.Context, block *gtsmodel.Domain  	}  	// Clear the domain block cache (for later reload) -	d.state.Caches.GTS.DomainBlock().Clear() +	d.state.Caches.GTS.DomainBlock.Clear()  	return nil  } @@ -222,7 +222,7 @@ func (d *domainDB) DeleteDomainBlock(ctx context.Context, domain string) error {  	}  	// Clear the domain block cache (for later reload) -	d.state.Caches.GTS.DomainBlock().Clear() +	d.state.Caches.GTS.DomainBlock.Clear()  	return nil  } @@ -241,7 +241,7 @@ func (d *domainDB) IsDomainBlocked(ctx context.Context, domain string) (bool, er  	}  	// Check the cache for an explicit domain allow (hydrating the cache with callback if necessary). -	explicitAllow, err := d.state.Caches.GTS.DomainAllow().Matches(domain, func() ([]string, error) { +	explicitAllow, err := d.state.Caches.GTS.DomainAllow.Matches(domain, func() ([]string, error) {  		var domains []string  		// Scan list of all explicitly allowed domains from DB @@ -259,7 +259,7 @@ func (d *domainDB) IsDomainBlocked(ctx context.Context, domain string) (bool, er  	}  	// Check the cache for a domain block (hydrating the cache with callback if necessary) -	explicitBlock, err := d.state.Caches.GTS.DomainBlock().Matches(domain, func() ([]string, error) { +	explicitBlock, err := d.state.Caches.GTS.DomainBlock.Matches(domain, func() ([]string, error) {  		var domains []string  		// Scan list of all blocked domains from DB diff --git a/internal/db/bundb/emoji.go b/internal/db/bundb/emoji.go index 34a08b694..31092d0d2 100644 --- a/internal/db/bundb/emoji.go +++ b/internal/db/bundb/emoji.go @@ -21,6 +21,7 @@ import (  	"context"  	"database/sql"  	"errors" +	"slices"  	"strings"  	"time" @@ -30,6 +31,7 @@ import (  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"  	"github.com/superseriousbusiness/gotosocial/internal/log"  	"github.com/superseriousbusiness/gotosocial/internal/state" +	"github.com/superseriousbusiness/gotosocial/internal/util"  	"github.com/uptrace/bun"  	"github.com/uptrace/bun/dialect"  ) @@ -40,7 +42,7 @@ type emojiDB struct {  }  func (e *emojiDB) PutEmoji(ctx context.Context, emoji *gtsmodel.Emoji) error { -	return e.state.Caches.GTS.Emoji().Store(emoji, func() error { +	return e.state.Caches.GTS.Emoji.Store(emoji, func() error {  		_, err := e.db.NewInsert().Model(emoji).Exec(ctx)  		return err  	}) @@ -54,7 +56,7 @@ func (e *emojiDB) UpdateEmoji(ctx context.Context, emoji *gtsmodel.Emoji, column  	}  	// Update the emoji model in the database. -	return e.state.Caches.GTS.Emoji().Store(emoji, func() error { +	return e.state.Caches.GTS.Emoji.Store(emoji, func() error {  		_, err := e.db.  			NewUpdate().  			Model(emoji). @@ -74,21 +76,21 @@ func (e *emojiDB) DeleteEmojiByID(ctx context.Context, id string) error {  	defer func() {  		// Invalidate cached emoji.  		e.state.Caches.GTS. -			Emoji(). +			Emoji.  			Invalidate("ID", id) -		for _, id := range accountIDs { +		for _, accountID := range accountIDs {  			// Invalidate cached account.  			e.state.Caches.GTS. -				Account(). -				Invalidate("ID", id) +				Account. +				Invalidate("ID", accountID)  		} -		for _, id := range statusIDs { +		for _, statusID := range statusIDs {  			// Invalidate cached account.  			e.state.Caches.GTS. -				Status(). -				Invalidate("ID", id) +				Status. +				Invalidate("ID", statusID)  		}  	}() @@ -129,26 +131,28 @@ func (e *emojiDB) DeleteEmojiByID(ctx context.Context, id string) error {  			return err  		} -		for _, id := range statusIDs { +		for _, statusID := range statusIDs {  			var emojiIDs []string  			// Select statuses with ID.  			if _, err := tx.NewSelect().  				Table("statuses").  				Column("emojis"). -				Where("? = ?", bun.Ident("id"), id). +				Where("? = ?", bun.Ident("id"), statusID).  				Exec(ctx); err != nil &&  				err != sql.ErrNoRows {  				return err  			} -			// Drop ID from account emojis. -			emojiIDs = dropID(emojiIDs, id) +			// Delete all instances of this emoji ID from status emojis. +			emojiIDs = slices.DeleteFunc(emojiIDs, func(emojiID string) bool { +				return emojiID == id +			})  			// Update status emoji IDs.  			if _, err := tx.NewUpdate().  				Table("statuses"). -				Where("? = ?", bun.Ident("id"), id). +				Where("? = ?", bun.Ident("id"), statusID).  				Set("emojis = ?", emojiIDs).  				Exec(ctx); err != nil &&  				err != sql.ErrNoRows { @@ -156,26 +160,28 @@ func (e *emojiDB) DeleteEmojiByID(ctx context.Context, id string) error {  			}  		} -		for _, id := range accountIDs { +		for _, accountID := range accountIDs {  			var emojiIDs []string  			// Select account with ID.  			if _, err := tx.NewSelect().  				Table("accounts").  				Column("emojis"). -				Where("? = ?", bun.Ident("id"), id). +				Where("? = ?", bun.Ident("id"), accountID).  				Exec(ctx); err != nil &&  				err != sql.ErrNoRows {  				return err  			} -			// Drop ID from account emojis. -			emojiIDs = dropID(emojiIDs, id) +			// Delete all instances of this emoji ID from account emojis. +			emojiIDs = slices.DeleteFunc(emojiIDs, func(emojiID string) bool { +				return emojiID == id +			})  			// Update account emoji IDs.  			if _, err := tx.NewUpdate().  				Table("accounts"). -				Where("? = ?", bun.Ident("id"), id). +				Where("? = ?", bun.Ident("id"), accountID).  				Set("emojis = ?", emojiIDs).  				Exec(ctx); err != nil &&  				err != sql.ErrNoRows { @@ -431,7 +437,7 @@ func (e *emojiDB) GetEmojiByURI(ctx context.Context, uri string) (*gtsmodel.Emoj  func (e *emojiDB) GetEmojiByShortcodeDomain(ctx context.Context, shortcode string, domain string) (*gtsmodel.Emoji, error) {  	return e.getEmoji(  		ctx, -		"Shortcode.Domain", +		"Shortcode,Domain",  		func(emoji *gtsmodel.Emoji) error {  			q := e.db.  				NewSelect(). @@ -468,7 +474,7 @@ func (e *emojiDB) GetEmojiByStaticURL(ctx context.Context, imageStaticURL string  }  func (e *emojiDB) PutEmojiCategory(ctx context.Context, emojiCategory *gtsmodel.EmojiCategory) error { -	return e.state.Caches.GTS.EmojiCategory().Store(emojiCategory, func() error { +	return e.state.Caches.GTS.EmojiCategory.Store(emojiCategory, func() error {  		_, err := e.db.NewInsert().Model(emojiCategory).Exec(ctx)  		return err  	}) @@ -520,7 +526,7 @@ func (e *emojiDB) GetEmojiCategoryByName(ctx context.Context, name string) (*gts  func (e *emojiDB) getEmoji(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Emoji) error, keyParts ...any) (*gtsmodel.Emoji, error) {  	// Fetch emoji from database cache with loader callback -	emoji, err := e.state.Caches.GTS.Emoji().Load(lookup, func() (*gtsmodel.Emoji, error) { +	emoji, err := e.state.Caches.GTS.Emoji.LoadOne(lookup, func() (*gtsmodel.Emoji, error) {  		var emoji gtsmodel.Emoji  		// Not cached! Perform database query @@ -568,28 +574,72 @@ func (e *emojiDB) PopulateEmoji(ctx context.Context, emoji *gtsmodel.Emoji) erro  	return errs.Combine()  } -func (e *emojiDB) GetEmojisByIDs(ctx context.Context, emojiIDs []string) ([]*gtsmodel.Emoji, error) { -	if len(emojiIDs) == 0 { +func (e *emojiDB) GetEmojisByIDs(ctx context.Context, ids []string) ([]*gtsmodel.Emoji, error) { +	if len(ids) == 0 {  		return nil, db.ErrNoEntries  	} -	emojis := make([]*gtsmodel.Emoji, 0, len(emojiIDs)) +	// Preallocate at-worst possible length. +	uncached := make([]string, 0, len(ids)) -	for _, id := range emojiIDs { -		emoji, err := e.GetEmojiByID(ctx, id) -		if err != nil { -			log.Errorf(ctx, "emojisFromIDs: error getting emoji %q: %v", id, err) -			continue -		} +	// Load all emoji IDs via cache loader callbacks. +	emojis, err := e.state.Caches.GTS.Emoji.Load("ID", + +		// Load cached + check for uncached. +		func(load func(keyParts ...any) bool) { +			for _, id := range ids { +				if !load(id) { +					uncached = append(uncached, id) +				} +			} +		}, -		emojis = append(emojis, emoji) +		// Uncached emoji loader function. +		func() ([]*gtsmodel.Emoji, error) { +			// Preallocate expected length of uncached emojis. +			emojis := make([]*gtsmodel.Emoji, 0, len(uncached)) + +			// Perform database query scanning +			// the remaining (uncached) IDs. +			if err := e.db.NewSelect(). +				Model(&emojis). +				Where("? IN (?)", bun.Ident("id"), bun.In(uncached)). +				Scan(ctx); err != nil { +				return nil, err +			} + +			return emojis, nil +		}, +	) +	if err != nil { +		return nil, err +	} + +	// Reorder the emojis by their +	// IDs to ensure in correct order. +	getID := func(e *gtsmodel.Emoji) string { return e.ID } +	util.OrderBy(emojis, ids, getID) + +	if gtscontext.Barebones(ctx) { +		// no need to fully populate. +		return emojis, nil  	} +	// Populate all loaded emojis, removing those we fail to +	// populate (removes needing so many nil checks everywhere). +	emojis = slices.DeleteFunc(emojis, func(emoji *gtsmodel.Emoji) bool { +		if err := e.PopulateEmoji(ctx, emoji); err != nil { +			log.Errorf(ctx, "error populating emoji %s: %v", emoji.ID, err) +			return true +		} +		return false +	}) +  	return emojis, nil  }  func (e *emojiDB) getEmojiCategory(ctx context.Context, lookup string, dbQuery func(*gtsmodel.EmojiCategory) error, keyParts ...any) (*gtsmodel.EmojiCategory, error) { -	return e.state.Caches.GTS.EmojiCategory().Load(lookup, func() (*gtsmodel.EmojiCategory, error) { +	return e.state.Caches.GTS.EmojiCategory.LoadOne(lookup, func() (*gtsmodel.EmojiCategory, error) {  		var category gtsmodel.EmojiCategory  		// Not cached! Perform database query @@ -601,36 +651,51 @@ func (e *emojiDB) getEmojiCategory(ctx context.Context, lookup string, dbQuery f  	}, keyParts...)  } -func (e *emojiDB) GetEmojiCategoriesByIDs(ctx context.Context, emojiCategoryIDs []string) ([]*gtsmodel.EmojiCategory, error) { -	if len(emojiCategoryIDs) == 0 { +func (e *emojiDB) GetEmojiCategoriesByIDs(ctx context.Context, ids []string) ([]*gtsmodel.EmojiCategory, error) { +	if len(ids) == 0 {  		return nil, db.ErrNoEntries  	} -	emojiCategories := make([]*gtsmodel.EmojiCategory, 0, len(emojiCategoryIDs)) +	// Preallocate at-worst possible length. +	uncached := make([]string, 0, len(ids)) -	for _, id := range emojiCategoryIDs { -		emojiCategory, err := e.GetEmojiCategory(ctx, id) -		if err != nil { -			log.Errorf(ctx, "error getting emoji category %q: %v", id, err) -			continue -		} +	// Load all category IDs via cache loader callbacks. +	categories, err := e.state.Caches.GTS.EmojiCategory.Load("ID", -		emojiCategories = append(emojiCategories, emojiCategory) -	} +		// Load cached + check for uncached. +		func(load func(keyParts ...any) bool) { +			for _, id := range ids { +				if !load(id) { +					uncached = append(uncached, id) +				} +			} +		}, -	return emojiCategories, nil -} +		// Uncached emoji loader function. +		func() ([]*gtsmodel.EmojiCategory, error) { +			// Preallocate expected length of uncached categories. +			categories := make([]*gtsmodel.EmojiCategory, 0, len(uncached)) + +			// Perform database query scanning +			// the remaining (uncached) IDs. +			if err := e.db.NewSelect(). +				Model(&categories). +				Where("? IN (?)", bun.Ident("id"), bun.In(uncached)). +				Scan(ctx); err != nil { +				return nil, err +			} -// dropIDs drops given ID string from IDs slice. -func dropID(ids []string, id string) []string { -	for i := 0; i < len(ids); { -		if ids[i] == id { -			// Remove this reference. -			copy(ids[i:], ids[i+1:]) -			ids = ids[:len(ids)-1] -			continue -		} -		i++ +			return categories, nil +		}, +	) +	if err != nil { +		return nil, err  	} -	return ids + +	// Reorder the categories by their +	// IDs to ensure in correct order. +	getID := func(c *gtsmodel.EmojiCategory) string { return c.ID } +	util.OrderBy(categories, ids, getID) + +	return categories, nil  } diff --git a/internal/db/bundb/instance.go b/internal/db/bundb/instance.go index 567a44ee2..d506e0a31 100644 --- a/internal/db/bundb/instance.go +++ b/internal/db/bundb/instance.go @@ -143,7 +143,7 @@ func (i *instanceDB) GetInstanceByID(ctx context.Context, id string) (*gtsmodel.  func (i *instanceDB) getInstance(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Instance) error, keyParts ...any) (*gtsmodel.Instance, error) {  	// Fetch instance from database cache with loader callback -	instance, err := i.state.Caches.GTS.Instance().Load(lookup, func() (*gtsmodel.Instance, error) { +	instance, err := i.state.Caches.GTS.Instance.LoadOne(lookup, func() (*gtsmodel.Instance, error) {  		var instance gtsmodel.Instance  		// Not cached! Perform database query. @@ -219,7 +219,7 @@ func (i *instanceDB) PutInstance(ctx context.Context, instance *gtsmodel.Instanc  		return gtserror.Newf("error punifying domain %s: %w", instance.Domain, err)  	} -	return i.state.Caches.GTS.Instance().Store(instance, func() error { +	return i.state.Caches.GTS.Instance.Store(instance, func() error {  		_, err := i.db.NewInsert().Model(instance).Exec(ctx)  		return err  	}) @@ -239,7 +239,7 @@ func (i *instanceDB) UpdateInstance(ctx context.Context, instance *gtsmodel.Inst  		columns = append(columns, "updated_at")  	} -	return i.state.Caches.GTS.Instance().Store(instance, func() error { +	return i.state.Caches.GTS.Instance.Store(instance, func() error {  		_, err := i.db.  			NewUpdate().  			Model(instance). diff --git a/internal/db/bundb/list.go b/internal/db/bundb/list.go index 7a117670a..5f95d3c24 100644 --- a/internal/db/bundb/list.go +++ b/internal/db/bundb/list.go @@ -21,6 +21,7 @@ import (  	"context"  	"errors"  	"fmt" +	"slices"  	"time"  	"github.com/superseriousbusiness/gotosocial/internal/db" @@ -29,6 +30,7 @@ import (  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"  	"github.com/superseriousbusiness/gotosocial/internal/log"  	"github.com/superseriousbusiness/gotosocial/internal/state" +	"github.com/superseriousbusiness/gotosocial/internal/util"  	"github.com/uptrace/bun"  ) @@ -56,7 +58,7 @@ func (l *listDB) GetListByID(ctx context.Context, id string) (*gtsmodel.List, er  }  func (l *listDB) getList(ctx context.Context, lookup string, dbQuery func(*gtsmodel.List) error, keyParts ...any) (*gtsmodel.List, error) { -	list, err := l.state.Caches.GTS.List().Load(lookup, func() (*gtsmodel.List, error) { +	list, err := l.state.Caches.GTS.List.LoadOne(lookup, func() (*gtsmodel.List, error) {  		var list gtsmodel.List  		// Not cached! Perform database query. @@ -100,18 +102,8 @@ func (l *listDB) GetListsForAccountID(ctx context.Context, accountID string) ([]  		return nil, nil  	} -	// Select each list using its ID to ensure cache used. -	lists := make([]*gtsmodel.List, 0, len(listIDs)) -	for _, id := range listIDs { -		list, err := l.state.DB.GetListByID(ctx, id) -		if err != nil { -			log.Errorf(ctx, "error fetching list %q: %v", id, err) -			continue -		} -		lists = append(lists, list) -	} - -	return lists, nil +	// Return lists by their IDs. +	return l.GetListsByIDs(ctx, listIDs)  }  func (l *listDB) PopulateList(ctx context.Context, list *gtsmodel.List) error { @@ -147,7 +139,7 @@ func (l *listDB) PopulateList(ctx context.Context, list *gtsmodel.List) error {  }  func (l *listDB) PutList(ctx context.Context, list *gtsmodel.List) error { -	return l.state.Caches.GTS.List().Store(list, func() error { +	return l.state.Caches.GTS.List.Store(list, func() error {  		_, err := l.db.NewInsert().Model(list).Exec(ctx)  		return err  	}) @@ -162,7 +154,7 @@ func (l *listDB) UpdateList(ctx context.Context, list *gtsmodel.List, columns ..  	defer func() {  		// Invalidate all entries for this list ID. -		l.state.Caches.GTS.ListEntry().Invalidate("ListID", list.ID) +		l.state.Caches.GTS.ListEntry.Invalidate("ListID", list.ID)  		// Invalidate this entire list's timeline.  		if err := l.state.Timelines.List.RemoveTimeline(ctx, list.ID); err != nil { @@ -170,7 +162,7 @@ func (l *listDB) UpdateList(ctx context.Context, list *gtsmodel.List, columns ..  		}  	}() -	return l.state.Caches.GTS.List().Store(list, func() error { +	return l.state.Caches.GTS.List.Store(list, func() error {  		_, err := l.db.NewUpdate().  			Model(list).  			Where("? = ?", bun.Ident("list.id"), list.ID). @@ -198,7 +190,7 @@ func (l *listDB) DeleteListByID(ctx context.Context, id string) error {  	defer func() {  		// Invalidate this list from cache. -		l.state.Caches.GTS.List().Invalidate("ID", id) +		l.state.Caches.GTS.List.Invalidate("ID", id)  		// Invalidate this entire list's timeline.  		if err := l.state.Timelines.List.RemoveTimeline(ctx, id); err != nil { @@ -243,7 +235,7 @@ func (l *listDB) GetListEntryByID(ctx context.Context, id string) (*gtsmodel.Lis  }  func (l *listDB) getListEntry(ctx context.Context, lookup string, dbQuery func(*gtsmodel.ListEntry) error, keyParts ...any) (*gtsmodel.ListEntry, error) { -	listEntry, err := l.state.Caches.GTS.ListEntry().Load(lookup, func() (*gtsmodel.ListEntry, error) { +	listEntry, err := l.state.Caches.GTS.ListEntry.LoadOne(lookup, func() (*gtsmodel.ListEntry, error) {  		var listEntry gtsmodel.ListEntry  		// Not cached! Perform database query. @@ -344,18 +336,128 @@ func (l *listDB) GetListEntries(ctx context.Context,  		}  	} -	// Select each list entry using its ID to ensure cache used. -	listEntries := make([]*gtsmodel.ListEntry, 0, len(entryIDs)) -	for _, id := range entryIDs { -		listEntry, err := l.state.DB.GetListEntryByID(ctx, id) -		if err != nil { -			log.Errorf(ctx, "error fetching list entry %q: %v", id, err) -			continue +	// Return list entries by their IDs. +	return l.GetListEntriesByIDs(ctx, entryIDs) +} + +func (l *listDB) GetListsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.List, error) { +	// Preallocate at-worst possible length. +	uncached := make([]string, 0, len(ids)) + +	// Load all list IDs via cache loader callbacks. +	lists, err := l.state.Caches.GTS.List.Load("ID", + +		// Load cached + check for uncached. +		func(load func(keyParts ...any) bool) { +			for _, id := range ids { +				if !load(id) { +					uncached = append(uncached, id) +				} +			} +		}, + +		// Uncached list loader function. +		func() ([]*gtsmodel.List, error) { +			// Preallocate expected length of uncached lists. +			lists := make([]*gtsmodel.List, 0, len(uncached)) + +			// Perform database query scanning +			// the remaining (uncached) IDs. +			if err := l.db.NewSelect(). +				Model(&lists). +				Where("? IN (?)", bun.Ident("id"), bun.In(uncached)). +				Scan(ctx); err != nil { +				return nil, err +			} + +			return lists, nil +		}, +	) +	if err != nil { +		return nil, err +	} + +	// Reorder the lists by their +	// IDs to ensure in correct order. +	getID := func(l *gtsmodel.List) string { return l.ID } +	util.OrderBy(lists, ids, getID) + +	if gtscontext.Barebones(ctx) { +		// no need to fully populate. +		return lists, nil +	} + +	// Populate all loaded lists, removing those we fail to +	// populate (removes needing so many nil checks everywhere). +	lists = slices.DeleteFunc(lists, func(list *gtsmodel.List) bool { +		if err := l.PopulateList(ctx, list); err != nil { +			log.Errorf(ctx, "error populating list %s: %v", list.ID, err) +			return true  		} -		listEntries = append(listEntries, listEntry) +		return false +	}) + +	return lists, nil +} + +func (l *listDB) GetListEntriesByIDs(ctx context.Context, ids []string) ([]*gtsmodel.ListEntry, error) { +	// Preallocate at-worst possible length. +	uncached := make([]string, 0, len(ids)) + +	// Load all entry IDs via cache loader callbacks. +	entries, err := l.state.Caches.GTS.ListEntry.Load("ID", + +		// Load cached + check for uncached. +		func(load func(keyParts ...any) bool) { +			for _, id := range ids { +				if !load(id) { +					uncached = append(uncached, id) +				} +			} +		}, + +		// Uncached entry loader function. +		func() ([]*gtsmodel.ListEntry, error) { +			// Preallocate expected length of uncached entries. +			entries := make([]*gtsmodel.ListEntry, 0, len(uncached)) + +			// Perform database query scanning +			// the remaining (uncached) IDs. +			if err := l.db.NewSelect(). +				Model(&entries). +				Where("? IN (?)", bun.Ident("id"), bun.In(uncached)). +				Scan(ctx); err != nil { +				return nil, err +			} + +			return entries, nil +		}, +	) +	if err != nil { +		return nil, err  	} -	return listEntries, nil +	// Reorder the entries by their +	// IDs to ensure in correct order. +	getID := func(e *gtsmodel.ListEntry) string { return e.ID } +	util.OrderBy(entries, ids, getID) + +	if gtscontext.Barebones(ctx) { +		// no need to fully populate. +		return entries, nil +	} + +	// Populate all loaded entries, removing those we fail to +	// populate (removes needing so many nil checks everywhere). +	entries = slices.DeleteFunc(entries, func(entry *gtsmodel.ListEntry) bool { +		if err := l.PopulateListEntry(ctx, entry); err != nil { +			log.Errorf(ctx, "error populating entry %s: %v", entry.ID, err) +			return true +		} +		return false +	}) + +	return entries, nil  }  func (l *listDB) GetListEntriesForFollowID(ctx context.Context, followID string) ([]*gtsmodel.ListEntry, error) { @@ -376,18 +478,8 @@ func (l *listDB) GetListEntriesForFollowID(ctx context.Context, followID string)  		return nil, nil  	} -	// Select each list entry using its ID to ensure cache used. -	listEntries := make([]*gtsmodel.ListEntry, 0, len(entryIDs)) -	for _, id := range entryIDs { -		listEntry, err := l.state.DB.GetListEntryByID(ctx, id) -		if err != nil { -			log.Errorf(ctx, "error fetching list entry %q: %v", id, err) -			continue -		} -		listEntries = append(listEntries, listEntry) -	} - -	return listEntries, nil +	// Return list entries by their IDs. +	return l.GetListEntriesByIDs(ctx, entryIDs)  }  func (l *listDB) PopulateListEntry(ctx context.Context, listEntry *gtsmodel.ListEntry) error { @@ -409,10 +501,10 @@ func (l *listDB) PopulateListEntry(ctx context.Context, listEntry *gtsmodel.List  func (l *listDB) PutListEntries(ctx context.Context, entries []*gtsmodel.ListEntry) error {  	defer func() { -		// Collect unique list IDs from the entries. -		listIDs := collate(func(i int) string { -			return entries[i].ListID -		}, len(entries)) +		// Collect unique list IDs from the provided entries. +		listIDs := util.Collate(entries, func(e *gtsmodel.ListEntry) string { +			return e.ListID +		})  		for _, id := range listIDs {  			// Invalidate the timeline for the list this entry belongs to. @@ -426,7 +518,7 @@ func (l *listDB) PutListEntries(ctx context.Context, entries []*gtsmodel.ListEnt  	return l.db.RunInTx(ctx, func(tx Tx) error {  		for _, entry := range entries {  			entry := entry // rescope -			if err := l.state.Caches.GTS.ListEntry().Store(entry, func() error { +			if err := l.state.Caches.GTS.ListEntry.Store(entry, func() error {  				_, err := tx.  					NewInsert().  					Model(entry). @@ -459,7 +551,7 @@ func (l *listDB) DeleteListEntry(ctx context.Context, id string) error {  	defer func() {  		// Invalidate this list entry upon delete. -		l.state.Caches.GTS.ListEntry().Invalidate("ID", id) +		l.state.Caches.GTS.ListEntry.Invalidate("ID", id)  		// Invalidate the timeline for the list this entry belongs to.  		if err := l.state.Timelines.List.RemoveTimeline(ctx, entry.ListID); err != nil { @@ -514,24 +606,3 @@ func (l *listDB) ListIncludesAccount(ctx context.Context, listID string, account  	return exists, err  } - -// collate will collect the values of type T from an expected slice of length 'len', -// passing the expected index to each call of 'get' and deduplicating the end result. -func collate[T comparable](get func(int) T, len int) []T { -	ts := make([]T, 0, len) -	tm := make(map[T]struct{}, len) - -	for i := 0; i < len; i++ { -		// Get next. -		t := get(i) - -		if _, ok := tm[t]; !ok { -			// New value, add -			// to map + slice. -			ts = append(ts, t) -			tm[t] = struct{}{} -		} -	} - -	return ts -} diff --git a/internal/db/bundb/marker.go b/internal/db/bundb/marker.go index 5d365e08a..b1dedb4f1 100644 --- a/internal/db/bundb/marker.go +++ b/internal/db/bundb/marker.go @@ -39,8 +39,8 @@ type markerDB struct {  */  func (m *markerDB) GetMarker(ctx context.Context, accountID string, name gtsmodel.MarkerName) (*gtsmodel.Marker, error) { -	marker, err := m.state.Caches.GTS.Marker().Load( -		"AccountID.Name", +	marker, err := m.state.Caches.GTS.Marker.LoadOne( +		"AccountID,Name",  		func() (*gtsmodel.Marker, error) {  			var marker gtsmodel.Marker @@ -52,9 +52,7 @@ func (m *markerDB) GetMarker(ctx context.Context, accountID string, name gtsmode  			}  			return &marker, nil -		}, -		accountID, -		name, +		}, accountID, name,  	)  	if err != nil {  		return nil, err // already processed @@ -74,7 +72,7 @@ func (m *markerDB) UpdateMarker(ctx context.Context, marker *gtsmodel.Marker) er  		marker.Version = prevMarker.Version + 1  	} -	return m.state.Caches.GTS.Marker().Store(marker, func() error { +	return m.state.Caches.GTS.Marker.Store(marker, func() error {  		if prevMarker == nil {  			if _, err := m.db.NewInsert().  				Model(marker). diff --git a/internal/db/bundb/media.go b/internal/db/bundb/media.go index a2603eacc..ce3c90083 100644 --- a/internal/db/bundb/media.go +++ b/internal/db/bundb/media.go @@ -20,14 +20,15 @@ package bundb  import (  	"context"  	"errors" +	"slices"  	"time"  	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/gtscontext"  	"github.com/superseriousbusiness/gotosocial/internal/gtserror"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" -	"github.com/superseriousbusiness/gotosocial/internal/log"  	"github.com/superseriousbusiness/gotosocial/internal/state" +	"github.com/superseriousbusiness/gotosocial/internal/util"  	"github.com/uptrace/bun"  ) @@ -51,25 +52,52 @@ func (m *mediaDB) GetAttachmentByID(ctx context.Context, id string) (*gtsmodel.M  }  func (m *mediaDB) GetAttachmentsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.MediaAttachment, error) { -	attachments := make([]*gtsmodel.MediaAttachment, 0, len(ids)) - -	for _, id := range ids { -		// Attempt fetch from DB -		attachment, err := m.GetAttachmentByID(ctx, id) -		if err != nil { -			log.Errorf(ctx, "error getting attachment %q: %v", id, err) -			continue -		} +	// Preallocate at-worst possible length. +	uncached := make([]string, 0, len(ids)) + +	// Load all media IDs via cache loader callbacks. +	media, err := m.state.Caches.GTS.Media.Load("ID", + +		// Load cached + check for uncached. +		func(load func(keyParts ...any) bool) { +			for _, id := range ids { +				if !load(id) { +					uncached = append(uncached, id) +				} +			} +		}, + +		// Uncached media loader function. +		func() ([]*gtsmodel.MediaAttachment, error) { +			// Preallocate expected length of uncached media attachments. +			media := make([]*gtsmodel.MediaAttachment, 0, len(uncached)) + +			// Perform database query scanning +			// the remaining (uncached) IDs. +			if err := m.db.NewSelect(). +				Model(&media). +				Where("? IN (?)", bun.Ident("id"), bun.In(uncached)). +				Scan(ctx); err != nil { +				return nil, err +			} -		// Append attachment -		attachments = append(attachments, attachment) +			return media, nil +		}, +	) +	if err != nil { +		return nil, err  	} -	return attachments, nil +	// Reorder the media by their +	// IDs to ensure in correct order. +	getID := func(m *gtsmodel.MediaAttachment) string { return m.ID } +	util.OrderBy(media, ids, getID) + +	return media, nil  }  func (m *mediaDB) getAttachment(ctx context.Context, lookup string, dbQuery func(*gtsmodel.MediaAttachment) error, keyParts ...any) (*gtsmodel.MediaAttachment, error) { -	return m.state.Caches.GTS.Media().Load(lookup, func() (*gtsmodel.MediaAttachment, error) { +	return m.state.Caches.GTS.Media.LoadOne(lookup, func() (*gtsmodel.MediaAttachment, error) {  		var attachment gtsmodel.MediaAttachment  		// Not cached! Perform database query @@ -82,7 +110,7 @@ func (m *mediaDB) getAttachment(ctx context.Context, lookup string, dbQuery func  }  func (m *mediaDB) PutAttachment(ctx context.Context, media *gtsmodel.MediaAttachment) error { -	return m.state.Caches.GTS.Media().Store(media, func() error { +	return m.state.Caches.GTS.Media.Store(media, func() error {  		_, err := m.db.NewInsert().Model(media).Exec(ctx)  		return err  	}) @@ -95,7 +123,7 @@ func (m *mediaDB) UpdateAttachment(ctx context.Context, media *gtsmodel.MediaAtt  		columns = append(columns, "updated_at")  	} -	return m.state.Caches.GTS.Media().Store(media, func() error { +	return m.state.Caches.GTS.Media.Store(media, func() error {  		_, err := m.db.NewUpdate().  			Model(media).  			Where("? = ?", bun.Ident("media_attachment.id"), media.ID). @@ -119,7 +147,7 @@ func (m *mediaDB) DeleteAttachment(ctx context.Context, id string) error {  	}  	// On return, ensure that media with ID is invalidated. -	defer m.state.Caches.GTS.Media().Invalidate("ID", id) +	defer m.state.Caches.GTS.Media.Invalidate("ID", id)  	// Delete media attachment in new transaction.  	err = m.db.RunInTx(ctx, func(tx Tx) error { @@ -171,8 +199,12 @@ func (m *mediaDB) DeleteAttachment(ctx context.Context, id string) error {  				return gtserror.Newf("error selecting status: %w", err)  			} -			if updatedIDs := dropID(status.AttachmentIDs, id); // nocollapse -			len(updatedIDs) != len(status.AttachmentIDs) { +			// Delete all instances of this deleted media ID from status attachments. +			updatedIDs := slices.DeleteFunc(status.AttachmentIDs, func(s string) bool { +				return s == id +			}) + +			if len(updatedIDs) != len(status.AttachmentIDs) {  				// Note: this handles not found.  				//  				// Attachments changed, update the status. diff --git a/internal/db/bundb/mention.go b/internal/db/bundb/mention.go index 30a20b0c1..b069423bb 100644 --- a/internal/db/bundb/mention.go +++ b/internal/db/bundb/mention.go @@ -20,6 +20,7 @@ package bundb  import (  	"context"  	"errors" +	"slices"  	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/gtscontext" @@ -27,6 +28,7 @@ import (  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"  	"github.com/superseriousbusiness/gotosocial/internal/log"  	"github.com/superseriousbusiness/gotosocial/internal/state" +	"github.com/superseriousbusiness/gotosocial/internal/util"  	"github.com/uptrace/bun"  ) @@ -36,7 +38,7 @@ type mentionDB struct {  }  func (m *mentionDB) GetMention(ctx context.Context, id string) (*gtsmodel.Mention, error) { -	mention, err := m.state.Caches.GTS.Mention().Load("ID", func() (*gtsmodel.Mention, error) { +	mention, err := m.state.Caches.GTS.Mention.LoadOne("ID", func() (*gtsmodel.Mention, error) {  		var mention gtsmodel.Mention  		q := m.db. @@ -63,21 +65,64 @@ func (m *mentionDB) GetMention(ctx context.Context, id string) (*gtsmodel.Mentio  }  func (m *mentionDB) GetMentions(ctx context.Context, ids []string) ([]*gtsmodel.Mention, error) { -	mentions := make([]*gtsmodel.Mention, 0, len(ids)) +	// Preallocate at-worst possible length. +	uncached := make([]string, 0, len(ids)) + +	// Load all mention IDs via cache loader callbacks. +	mentions, err := m.state.Caches.GTS.Mention.Load("ID", + +		// Load cached + check for uncached. +		func(load func(keyParts ...any) bool) { +			for _, id := range ids { +				if !load(id) { +					uncached = append(uncached, id) +				} +			} +		}, + +		// Uncached mention loader function. +		func() ([]*gtsmodel.Mention, error) { +			// Preallocate expected length of uncached mentions. +			mentions := make([]*gtsmodel.Mention, 0, len(uncached)) + +			// Perform database query scanning +			// the remaining (uncached) IDs. +			if err := m.db.NewSelect(). +				Model(&mentions). +				Where("? IN (?)", bun.Ident("id"), bun.In(uncached)). +				Scan(ctx); err != nil { +				return nil, err +			} + +			return mentions, nil +		}, +	) +	if err != nil { +		return nil, err +	} -	for _, id := range ids { -		// Attempt fetch from DB -		mention, err := m.GetMention(ctx, id) -		if err != nil { -			log.Errorf(ctx, "error getting mention %q: %v", id, err) -			continue -		} +	// Reorder the mentions by their +	// IDs to ensure in correct order. +	getID := func(m *gtsmodel.Mention) string { return m.ID } +	util.OrderBy(mentions, ids, getID) -		// Append mention -		mentions = append(mentions, mention) +	if gtscontext.Barebones(ctx) { +		// no need to fully populate. +		return mentions, nil  	} +	// Populate all loaded mentions, removing those we fail to +	// populate (removes needing so many nil checks everywhere). +	mentions = slices.DeleteFunc(mentions, func(mention *gtsmodel.Mention) bool { +		if err := m.PopulateMention(ctx, mention); err != nil { +			log.Errorf(ctx, "error populating mention %s: %v", mention.ID, err) +			return true +		} +		return false +	}) +  	return mentions, nil +  }  func (m *mentionDB) PopulateMention(ctx context.Context, mention *gtsmodel.Mention) (err error) { @@ -120,14 +165,14 @@ func (m *mentionDB) PopulateMention(ctx context.Context, mention *gtsmodel.Menti  }  func (m *mentionDB) PutMention(ctx context.Context, mention *gtsmodel.Mention) error { -	return m.state.Caches.GTS.Mention().Store(mention, func() error { +	return m.state.Caches.GTS.Mention.Store(mention, func() error {  		_, err := m.db.NewInsert().Model(mention).Exec(ctx)  		return err  	})  }  func (m *mentionDB) DeleteMentionByID(ctx context.Context, id string) error { -	defer m.state.Caches.GTS.Mention().Invalidate("ID", id) +	defer m.state.Caches.GTS.Mention.Invalidate("ID", id)  	// Load mention into cache before attempting a delete,  	// as we need it cached in order to trigger the invalidate diff --git a/internal/db/bundb/notification.go b/internal/db/bundb/notification.go index 7532b9993..ed34222fb 100644 --- a/internal/db/bundb/notification.go +++ b/internal/db/bundb/notification.go @@ -20,6 +20,7 @@ package bundb  import (  	"context"  	"errors" +	"slices"  	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/gtscontext" @@ -28,6 +29,7 @@ import (  	"github.com/superseriousbusiness/gotosocial/internal/id"  	"github.com/superseriousbusiness/gotosocial/internal/log"  	"github.com/superseriousbusiness/gotosocial/internal/state" +	"github.com/superseriousbusiness/gotosocial/internal/util"  	"github.com/uptrace/bun"  ) @@ -37,18 +39,17 @@ type notificationDB struct {  }  func (n *notificationDB) GetNotificationByID(ctx context.Context, id string) (*gtsmodel.Notification, error) { -	return n.state.Caches.GTS.Notification().Load("ID", func() (*gtsmodel.Notification, error) { -		var notif gtsmodel.Notification - -		q := n.db.NewSelect(). -			Model(¬if). -			Where("? = ?", bun.Ident("notification.id"), id) -		if err := q.Scan(ctx); err != nil { -			return nil, err -		} - -		return ¬if, nil -	}, id) +	return n.getNotification( +		ctx, +		"ID", +		func(notif *gtsmodel.Notification) error { +			return n.db.NewSelect(). +				Model(notif). +				Where("? = ?", bun.Ident("id"), id). +				Scan(ctx) +		}, +		id, +	)  }  func (n *notificationDB) GetNotification( @@ -58,42 +59,113 @@ func (n *notificationDB) GetNotification(  	originAccountID string,  	statusID string,  ) (*gtsmodel.Notification, error) { -	notif, err := n.state.Caches.GTS.Notification().Load("NotificationType.TargetAccountID.OriginAccountID.StatusID", func() (*gtsmodel.Notification, error) { -		var notif gtsmodel.Notification +	return n.getNotification( +		ctx, +		"NotificationType,TargetAccountID,OriginAccountID,StatusID", +		func(notif *gtsmodel.Notification) error { +			return n.db.NewSelect(). +				Model(notif). +				Where("? = ?", bun.Ident("notification_type"), notificationType). +				Where("? = ?", bun.Ident("target_account_id"), targetAccountID). +				Where("? = ?", bun.Ident("origin_account_id"), originAccountID). +				Where("? = ?", bun.Ident("status_id"), statusID). +				Scan(ctx) +		}, +		notificationType, targetAccountID, originAccountID, statusID, +	) +} -		q := n.db.NewSelect(). -			Model(¬if). -			Where("? = ?", bun.Ident("notification_type"), notificationType). -			Where("? = ?", bun.Ident("target_account_id"), targetAccountID). -			Where("? = ?", bun.Ident("origin_account_id"), originAccountID). -			Where("? = ?", bun.Ident("status_id"), statusID) +func (n *notificationDB) getNotification(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Notification) error, keyParts ...any) (*gtsmodel.Notification, error) { +	// Fetch notification from cache with loader callback +	notif, err := n.state.Caches.GTS.Notification.LoadOne(lookup, func() (*gtsmodel.Notification, error) { +		var notif gtsmodel.Notification -		if err := q.Scan(ctx); err != nil { +		// Not cached! Perform database query +		if err := dbQuery(¬if); err != nil {  			return nil, err  		}  		return ¬if, nil -	}, notificationType, targetAccountID, originAccountID, statusID) +	}, keyParts...)  	if err != nil {  		return nil, err  	}  	if gtscontext.Barebones(ctx) { -		// no need to fully populate. +		// Only a barebones model was requested.  		return notif, nil  	} -	// Further populate the notif fields where applicable. -	if err := n.PopulateNotification(ctx, notif); err != nil { +	if err := n.state.DB.PopulateNotification(ctx, notif); err != nil {  		return nil, err  	}  	return notif, nil  } +func (n *notificationDB) GetNotificationsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.Notification, error) { +	// Preallocate at-worst possible length. +	uncached := make([]string, 0, len(ids)) + +	// Load all notif IDs via cache loader callbacks. +	notifs, err := n.state.Caches.GTS.Notification.Load("ID", + +		// Load cached + check for uncached. +		func(load func(keyParts ...any) bool) { +			for _, id := range ids { +				if !load(id) { +					uncached = append(uncached, id) +				} +			} +		}, + +		// Uncached notification loader function. +		func() ([]*gtsmodel.Notification, error) { +			// Preallocate expected length of uncached notifications. +			notifs := make([]*gtsmodel.Notification, 0, len(uncached)) + +			// Perform database query scanning +			// the remaining (uncached) IDs. +			if err := n.db.NewSelect(). +				Model(¬ifs). +				Where("? IN (?)", bun.Ident("id"), bun.In(uncached)). +				Scan(ctx); err != nil { +				return nil, err +			} + +			return notifs, nil +		}, +	) +	if err != nil { +		return nil, err +	} + +	// Reorder the notifs by their +	// IDs to ensure in correct order. +	getID := func(n *gtsmodel.Notification) string { return n.ID } +	util.OrderBy(notifs, ids, getID) + +	if gtscontext.Barebones(ctx) { +		// no need to fully populate. +		return notifs, nil +	} + +	// Populate all loaded notifs, removing those we fail to +	// populate (removes needing so many nil checks everywhere). +	notifs = slices.DeleteFunc(notifs, func(notif *gtsmodel.Notification) bool { +		if err := n.PopulateNotification(ctx, notif); err != nil { +			log.Errorf(ctx, "error populating notif %s: %v", notif.ID, err) +			return true +		} +		return false +	}) + +	return notifs, nil +} +  func (n *notificationDB) PopulateNotification(ctx context.Context, notif *gtsmodel.Notification) error {  	var ( -		errs = gtserror.NewMultiError(2) +		errs gtserror.MultiError  		err  error  	) @@ -211,31 +283,19 @@ func (n *notificationDB) GetAccountNotifications(  		}  	} -	notifs := make([]*gtsmodel.Notification, 0, len(notifIDs)) -	for _, id := range notifIDs { -		// Attempt fetch from DB -		notif, err := n.GetNotificationByID(ctx, id) -		if err != nil { -			log.Errorf(ctx, "error fetching notification %q: %v", id, err) -			continue -		} - -		// Append notification -		notifs = append(notifs, notif) -	} - -	return notifs, nil +	// Fetch notification models by their IDs. +	return n.GetNotificationsByIDs(ctx, notifIDs)  }  func (n *notificationDB) PutNotification(ctx context.Context, notif *gtsmodel.Notification) error { -	return n.state.Caches.GTS.Notification().Store(notif, func() error { +	return n.state.Caches.GTS.Notification.Store(notif, func() error {  		_, err := n.db.NewInsert().Model(notif).Exec(ctx)  		return err  	})  }  func (n *notificationDB) DeleteNotificationByID(ctx context.Context, id string) error { -	defer n.state.Caches.GTS.Notification().Invalidate("ID", id) +	defer n.state.Caches.GTS.Notification.Invalidate("ID", id)  	// Load notif into cache before attempting a delete,  	// as we need it cached in order to trigger the invalidate @@ -288,7 +348,7 @@ func (n *notificationDB) DeleteNotifications(ctx context.Context, types []string  	defer func() {  		// Invalidate all IDs on return.  		for _, id := range notifIDs { -			n.state.Caches.GTS.Notification().Invalidate("ID", id) +			n.state.Caches.GTS.Notification.Invalidate("ID", id)  		}  	}() @@ -326,7 +386,7 @@ func (n *notificationDB) DeleteNotificationsForStatus(ctx context.Context, statu  	defer func() {  		// Invalidate all IDs on return.  		for _, id := range notifIDs { -			n.state.Caches.GTS.Notification().Invalidate("ID", id) +			n.state.Caches.GTS.Notification.Invalidate("ID", id)  		}  	}() diff --git a/internal/db/bundb/poll.go b/internal/db/bundb/poll.go index 3e77fb6c5..0dfb15621 100644 --- a/internal/db/bundb/poll.go +++ b/internal/db/bundb/poll.go @@ -20,6 +20,7 @@ package bundb  import (  	"context"  	"errors" +	"slices"  	"time"  	"github.com/superseriousbusiness/gotosocial/internal/db" @@ -28,6 +29,7 @@ import (  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"  	"github.com/superseriousbusiness/gotosocial/internal/log"  	"github.com/superseriousbusiness/gotosocial/internal/state" +	"github.com/superseriousbusiness/gotosocial/internal/util"  	"github.com/uptrace/bun"  ) @@ -52,7 +54,7 @@ func (p *pollDB) GetPollByID(ctx context.Context, id string) (*gtsmodel.Poll, er  func (p *pollDB) getPoll(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Poll) error, keyParts ...any) (*gtsmodel.Poll, error) {  	// Fetch poll from database cache with loader callback -	poll, err := p.state.Caches.GTS.Poll().Load(lookup, func() (*gtsmodel.Poll, error) { +	poll, err := p.state.Caches.GTS.Poll.LoadOne(lookup, func() (*gtsmodel.Poll, error) {  		var poll gtsmodel.Poll  		// Not cached! Perform database query. @@ -140,7 +142,7 @@ func (p *pollDB) PutPoll(ctx context.Context, poll *gtsmodel.Poll) error {  	// is non nil and set.  	poll.CheckVotes() -	return p.state.Caches.GTS.Poll().Store(poll, func() error { +	return p.state.Caches.GTS.Poll.Store(poll, func() error {  		_, err := p.db.NewInsert().Model(poll).Exec(ctx)  		return err  	}) @@ -151,7 +153,7 @@ func (p *pollDB) UpdatePoll(ctx context.Context, poll *gtsmodel.Poll, cols ...st  	// is non nil and set.  	poll.CheckVotes() -	return p.state.Caches.GTS.Poll().Store(poll, func() error { +	return p.state.Caches.GTS.Poll.Store(poll, func() error {  		return p.db.RunInTx(ctx, func(tx Tx) error {  			// Update the status' "updated_at" field.  			if _, err := tx.NewUpdate(). @@ -184,8 +186,8 @@ func (p *pollDB) DeletePollByID(ctx context.Context, id string) error {  	}  	// Invalidate poll by ID from cache. -	p.state.Caches.GTS.Poll().Invalidate("ID", id) -	p.state.Caches.GTS.PollVoteIDs().Invalidate(id) +	p.state.Caches.GTS.Poll.Invalidate("ID", id) +	p.state.Caches.GTS.PollVoteIDs.Invalidate(id)  	return nil  } @@ -207,7 +209,7 @@ func (p *pollDB) GetPollVoteByID(ctx context.Context, id string) (*gtsmodel.Poll  func (p *pollDB) GetPollVoteBy(ctx context.Context, pollID string, accountID string) (*gtsmodel.PollVote, error) {  	return p.getPollVote(  		ctx, -		"PollID.AccountID", +		"PollID,AccountID",  		func(vote *gtsmodel.PollVote) error {  			return p.db.NewSelect().  				Model(vote). @@ -222,7 +224,7 @@ func (p *pollDB) GetPollVoteBy(ctx context.Context, pollID string, accountID str  func (p *pollDB) getPollVote(ctx context.Context, lookup string, dbQuery func(*gtsmodel.PollVote) error, keyParts ...any) (*gtsmodel.PollVote, error) {  	// Fetch vote from database cache with loader callback -	vote, err := p.state.Caches.GTS.PollVote().Load(lookup, func() (*gtsmodel.PollVote, error) { +	vote, err := p.state.Caches.GTS.PollVote.LoadOne(lookup, func() (*gtsmodel.PollVote, error) {  		var vote gtsmodel.PollVote  		// Not cached! Perform database query. @@ -250,7 +252,9 @@ func (p *pollDB) getPollVote(ctx context.Context, lookup string, dbQuery func(*g  }  func (p *pollDB) GetPollVotes(ctx context.Context, pollID string) ([]*gtsmodel.PollVote, error) { -	voteIDs, err := p.state.Caches.GTS.PollVoteIDs().Load(pollID, func() ([]string, error) { + +	// Load vote IDs known for given poll ID using loader callback. +	voteIDs, err := p.state.Caches.GTS.PollVoteIDs.Load(pollID, func() ([]string, error) {  		var voteIDs []string  		// Vote IDs not in cache, perform DB query! @@ -266,21 +270,62 @@ func (p *pollDB) GetPollVotes(ctx context.Context, pollID string) ([]*gtsmodel.P  		return nil, err  	} -	// Preallocate slice of expected length. -	votes := make([]*gtsmodel.PollVote, 0, len(voteIDs)) +	// Preallocate at-worst possible length. +	uncached := make([]string, 0, len(voteIDs)) -	for _, id := range voteIDs { -		// Fetch poll vote model for this ID. -		vote, err := p.GetPollVoteByID(ctx, id) -		if err != nil { -			log.Errorf(ctx, "error getting poll vote %s: %v", id, err) -			continue -		} +	// Load all votes from IDs via cache loader callbacks. +	votes, err := p.state.Caches.GTS.PollVote.Load("ID", + +		// Load cached + check for uncached. +		func(load func(keyParts ...any) bool) { +			for _, id := range voteIDs { +				if !load(id) { +					uncached = append(uncached, id) +				} +			} +		}, + +		// Uncached poll vote loader function. +		func() ([]*gtsmodel.PollVote, error) { +			// Preallocate expected length of uncached votes. +			votes := make([]*gtsmodel.PollVote, 0, len(uncached)) + +			// Perform database query scanning +			// the remaining (uncached) IDs. +			if err := p.db.NewSelect(). +				Model(&votes). +				Where("? IN (?)", bun.Ident("id"), bun.In(uncached)). +				Scan(ctx); err != nil { +				return nil, err +			} + +			return votes, nil +		}, +	) +	if err != nil { +		return nil, err +	} + +	// Reorder the poll votes by their +	// IDs to ensure in correct order. +	getID := func(v *gtsmodel.PollVote) string { return v.ID } +	util.OrderBy(votes, voteIDs, getID) -		// Append to return slice. -		votes = append(votes, vote) +	if gtscontext.Barebones(ctx) { +		// no need to fully populate. +		return votes, nil  	} +	// Populate all loaded votes, removing those we fail to +	// populate (removes needing so many nil checks everywhere). +	votes = slices.DeleteFunc(votes, func(vote *gtsmodel.PollVote) bool { +		if err := p.PopulatePollVote(ctx, vote); err != nil { +			log.Errorf(ctx, "error populating vote %s: %v", vote.ID, err) +			return true +		} +		return false +	}) +  	return votes, nil  } @@ -316,7 +361,7 @@ func (p *pollDB) PopulatePollVote(ctx context.Context, vote *gtsmodel.PollVote)  }  func (p *pollDB) PutPollVote(ctx context.Context, vote *gtsmodel.PollVote) error { -	return p.state.Caches.GTS.PollVote().Store(vote, func() error { +	return p.state.Caches.GTS.PollVote.Store(vote, func() error {  		return p.db.RunInTx(ctx, func(tx Tx) error {  			// Try insert vote into database.  			if _, err := tx.NewInsert(). @@ -416,9 +461,9 @@ func (p *pollDB) DeletePollVotes(ctx context.Context, pollID string) error {  	}  	// Invalidate poll vote and poll entry from caches. -	p.state.Caches.GTS.Poll().Invalidate("ID", pollID) -	p.state.Caches.GTS.PollVote().Invalidate("PollID", pollID) -	p.state.Caches.GTS.PollVoteIDs().Invalidate(pollID) +	p.state.Caches.GTS.Poll.Invalidate("ID", pollID) +	p.state.Caches.GTS.PollVote.Invalidate("PollID", pollID) +	p.state.Caches.GTS.PollVoteIDs.Invalidate(pollID)  	return nil  } @@ -428,7 +473,7 @@ func (p *pollDB) DeletePollVoteBy(ctx context.Context, pollID string, accountID  		// Slice should only ever be of length  		// 0 or 1; it's a slice of slices only  		// because we can't LIMIT deletes to 1. -		var choicesSl [][]int +		var choicesSlice [][]int  		// Delete vote in poll by account,  		// returning the ID + choices of the vote. @@ -437,17 +482,19 @@ func (p *pollDB) DeletePollVoteBy(ctx context.Context, pollID string, accountID  			Where("? = ?", bun.Ident("poll_id"), pollID).  			Where("? = ?", bun.Ident("account_id"), accountID).  			Returning("?", bun.Ident("choices")). -			Scan(ctx, &choicesSl); err != nil { +			Scan(ctx, &choicesSlice); err != nil {  			// irrecoverable.  			return err  		} -		if len(choicesSl) != 1 { +		if len(choicesSlice) != 1 {  			// No poll votes by this  			// acct on this poll.  			return nil  		} -		choices := choicesSl[0] + +		// Extract the *actual* choices. +		choices := choicesSlice[0]  		// Select current poll counts from DB,  		// taking minimal columns needed to @@ -489,9 +536,9 @@ func (p *pollDB) DeletePollVoteBy(ctx context.Context, pollID string, accountID  	}  	// Invalidate poll vote and poll entry from caches. -	p.state.Caches.GTS.Poll().Invalidate("ID", pollID) -	p.state.Caches.GTS.PollVote().Invalidate("PollID.AccountID", pollID, accountID) -	p.state.Caches.GTS.PollVoteIDs().Invalidate(pollID) +	p.state.Caches.GTS.Poll.Invalidate("ID", pollID) +	p.state.Caches.GTS.PollVote.Invalidate("PollID,AccountID", pollID, accountID) +	p.state.Caches.GTS.PollVoteIDs.Invalidate(pollID)  	return nil  } diff --git a/internal/db/bundb/relationship.go b/internal/db/bundb/relationship.go index 138a5aa17..4c50862a1 100644 --- a/internal/db/bundb/relationship.go +++ b/internal/db/bundb/relationship.go @@ -194,7 +194,7 @@ func (r *relationshipDB) CountAccountBlocks(ctx context.Context, accountID strin  }  func (r *relationshipDB) getAccountFollowIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) { -	return loadPagedIDs(r.state.Caches.GTS.FollowIDs(), ">"+accountID, page, func() ([]string, error) { +	return loadPagedIDs(r.state.Caches.GTS.FollowIDs, ">"+accountID, page, func() ([]string, error) {  		var followIDs []string  		// Follow IDs not in cache, perform DB query! @@ -209,7 +209,7 @@ func (r *relationshipDB) getAccountFollowIDs(ctx context.Context, accountID stri  }  func (r *relationshipDB) getAccountLocalFollowIDs(ctx context.Context, accountID string) ([]string, error) { -	return r.state.Caches.GTS.FollowIDs().Load("l>"+accountID, func() ([]string, error) { +	return r.state.Caches.GTS.FollowIDs.Load("l>"+accountID, func() ([]string, error) {  		var followIDs []string  		// Follow IDs not in cache, perform DB query! @@ -224,7 +224,7 @@ func (r *relationshipDB) getAccountLocalFollowIDs(ctx context.Context, accountID  }  func (r *relationshipDB) getAccountFollowerIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) { -	return loadPagedIDs(r.state.Caches.GTS.FollowIDs(), "<"+accountID, page, func() ([]string, error) { +	return loadPagedIDs(r.state.Caches.GTS.FollowIDs, "<"+accountID, page, func() ([]string, error) {  		var followIDs []string  		// Follow IDs not in cache, perform DB query! @@ -239,7 +239,7 @@ func (r *relationshipDB) getAccountFollowerIDs(ctx context.Context, accountID st  }  func (r *relationshipDB) getAccountLocalFollowerIDs(ctx context.Context, accountID string) ([]string, error) { -	return r.state.Caches.GTS.FollowIDs().Load("l<"+accountID, func() ([]string, error) { +	return r.state.Caches.GTS.FollowIDs.Load("l<"+accountID, func() ([]string, error) {  		var followIDs []string  		// Follow IDs not in cache, perform DB query! @@ -254,7 +254,7 @@ func (r *relationshipDB) getAccountLocalFollowerIDs(ctx context.Context, account  }  func (r *relationshipDB) getAccountFollowRequestIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) { -	return loadPagedIDs(r.state.Caches.GTS.FollowRequestIDs(), ">"+accountID, page, func() ([]string, error) { +	return loadPagedIDs(r.state.Caches.GTS.FollowRequestIDs, ">"+accountID, page, func() ([]string, error) {  		var followReqIDs []string  		// Follow request IDs not in cache, perform DB query! @@ -269,7 +269,7 @@ func (r *relationshipDB) getAccountFollowRequestIDs(ctx context.Context, account  }  func (r *relationshipDB) getAccountFollowRequestingIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) { -	return loadPagedIDs(r.state.Caches.GTS.FollowRequestIDs(), "<"+accountID, page, func() ([]string, error) { +	return loadPagedIDs(r.state.Caches.GTS.FollowRequestIDs, "<"+accountID, page, func() ([]string, error) {  		var followReqIDs []string  		// Follow request IDs not in cache, perform DB query! @@ -284,7 +284,7 @@ func (r *relationshipDB) getAccountFollowRequestingIDs(ctx context.Context, acco  }  func (r *relationshipDB) getAccountBlockIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) { -	return loadPagedIDs(r.state.Caches.GTS.BlockIDs(), accountID, page, func() ([]string, error) { +	return loadPagedIDs(r.state.Caches.GTS.BlockIDs, accountID, page, func() ([]string, error) {  		var blockIDs []string  		// Block IDs not in cache, perform DB query! diff --git a/internal/db/bundb/relationship_block.go b/internal/db/bundb/relationship_block.go index efaa6d1a9..178de6aa7 100644 --- a/internal/db/bundb/relationship_block.go +++ b/internal/db/bundb/relationship_block.go @@ -20,12 +20,14 @@ package bundb  import (  	"context"  	"errors" +	"slices"  	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/gtscontext"  	"github.com/superseriousbusiness/gotosocial/internal/gtserror"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"  	"github.com/superseriousbusiness/gotosocial/internal/log" +	"github.com/superseriousbusiness/gotosocial/internal/util"  	"github.com/uptrace/bun"  ) @@ -86,7 +88,7 @@ func (r *relationshipDB) GetBlockByURI(ctx context.Context, uri string) (*gtsmod  func (r *relationshipDB) GetBlock(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.Block, error) {  	return r.getBlock(  		ctx, -		"AccountID.TargetAccountID", +		"AccountID,TargetAccountID",  		func(block *gtsmodel.Block) error {  			return r.db.NewSelect().Model(block).  				Where("? = ?", bun.Ident("block.account_id"), sourceAccountID). @@ -99,27 +101,68 @@ func (r *relationshipDB) GetBlock(ctx context.Context, sourceAccountID string, t  }  func (r *relationshipDB) GetBlocksByIDs(ctx context.Context, ids []string) ([]*gtsmodel.Block, error) { -	// Preallocate slice of expected length. -	blocks := make([]*gtsmodel.Block, 0, len(ids)) +	// Preallocate at-worst possible length. +	uncached := make([]string, 0, len(ids)) + +	// Load all blocks IDs via cache loader callbacks. +	blocks, err := r.state.Caches.GTS.Block.Load("ID", + +		// Load cached + check for uncached. +		func(load func(keyParts ...any) bool) { +			for _, id := range ids { +				if !load(id) { +					uncached = append(uncached, id) +				} +			} +		}, -	for _, id := range ids { -		// Fetch block model for this ID. -		block, err := r.GetBlockByID(ctx, id) -		if err != nil { -			log.Errorf(ctx, "error getting block %q: %v", id, err) -			continue -		} +		// Uncached block loader function. +		func() ([]*gtsmodel.Block, error) { +			// Preallocate expected length of uncached blocks. +			blocks := make([]*gtsmodel.Block, 0, len(uncached)) + +			// Perform database query scanning +			// the remaining (uncached) IDs. +			if err := r.db.NewSelect(). +				Model(&blocks). +				Where("? IN (?)", bun.Ident("id"), bun.In(uncached)). +				Scan(ctx); err != nil { +				return nil, err +			} + +			return blocks, nil +		}, +	) +	if err != nil { +		return nil, err +	} + +	// Reorder the blocks by their +	// IDs to ensure in correct order. +	getID := func(b *gtsmodel.Block) string { return b.ID } +	util.OrderBy(blocks, ids, getID) -		// Append to return slice. -		blocks = append(blocks, block) +	if gtscontext.Barebones(ctx) { +		// no need to fully populate. +		return blocks, nil  	} +	// Populate all loaded blocks, removing those we fail to +	// populate (removes needing so many nil checks everywhere). +	blocks = slices.DeleteFunc(blocks, func(block *gtsmodel.Block) bool { +		if err := r.PopulateBlock(ctx, block); err != nil { +			log.Errorf(ctx, "error populating block %s: %v", block.ID, err) +			return true +		} +		return false +	}) +  	return blocks, nil  }  func (r *relationshipDB) getBlock(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Block) error, keyParts ...any) (*gtsmodel.Block, error) {  	// Fetch block from cache with loader callback -	block, err := r.state.Caches.GTS.Block().Load(lookup, func() (*gtsmodel.Block, error) { +	block, err := r.state.Caches.GTS.Block.LoadOne(lookup, func() (*gtsmodel.Block, error) {  		var block gtsmodel.Block  		// Not cached! Perform database query @@ -148,8 +191,8 @@ func (r *relationshipDB) getBlock(ctx context.Context, lookup string, dbQuery fu  func (r *relationshipDB) PopulateBlock(ctx context.Context, block *gtsmodel.Block) error {  	var ( +		errs gtserror.MultiError  		err  error -		errs = gtserror.NewMultiError(2)  	)  	if block.Account == nil { @@ -178,7 +221,7 @@ func (r *relationshipDB) PopulateBlock(ctx context.Context, block *gtsmodel.Bloc  }  func (r *relationshipDB) PutBlock(ctx context.Context, block *gtsmodel.Block) error { -	return r.state.Caches.GTS.Block().Store(block, func() error { +	return r.state.Caches.GTS.Block.Store(block, func() error {  		_, err := r.db.NewInsert().Model(block).Exec(ctx)  		return err  	}) @@ -198,7 +241,7 @@ func (r *relationshipDB) DeleteBlockByID(ctx context.Context, id string) error {  	}  	// Drop this now-cached block on return after delete. -	defer r.state.Caches.GTS.Block().Invalidate("ID", id) +	defer r.state.Caches.GTS.Block.Invalidate("ID", id)  	// Finally delete block from DB.  	_, err = r.db.NewDelete(). @@ -222,7 +265,7 @@ func (r *relationshipDB) DeleteBlockByURI(ctx context.Context, uri string) error  	}  	// Drop this now-cached block on return after delete. -	defer r.state.Caches.GTS.Block().Invalidate("URI", uri) +	defer r.state.Caches.GTS.Block.Invalidate("URI", uri)  	// Finally delete block from DB.  	_, err = r.db.NewDelete(). @@ -251,22 +294,20 @@ func (r *relationshipDB) DeleteAccountBlocks(ctx context.Context, accountID stri  	defer func() {  		// Invalidate all account's incoming / outoing blocks on return. -		r.state.Caches.GTS.Block().Invalidate("AccountID", accountID) -		r.state.Caches.GTS.Block().Invalidate("TargetAccountID", accountID) +		r.state.Caches.GTS.Block.Invalidate("AccountID", accountID) +		r.state.Caches.GTS.Block.Invalidate("TargetAccountID", accountID)  	}()  	// Load all blocks into cache, this *really* isn't great  	// but it is the only way we can ensure we invalidate all  	// related caches correctly (e.g. visibility). -	for _, id := range blockIDs { -		_, err := r.GetBlockByID(ctx, id) -		if err != nil && !errors.Is(err, db.ErrNoEntries) { -			return err -		} +	_, err := r.GetAccountBlocks(ctx, accountID, nil) +	if err != nil && !errors.Is(err, db.ErrNoEntries) { +		return err  	}  	// Finally delete all from DB. -	_, err := r.db.NewDelete(). +	_, err = r.db.NewDelete().  		Table("blocks").  		Where("? IN (?)", bun.Ident("id"), bun.In(blockIDs)).  		Exec(ctx) diff --git a/internal/db/bundb/relationship_follow.go b/internal/db/bundb/relationship_follow.go index 6c5a75e4c..93ee69bd7 100644 --- a/internal/db/bundb/relationship_follow.go +++ b/internal/db/bundb/relationship_follow.go @@ -21,6 +21,7 @@ import (  	"context"  	"errors"  	"fmt" +	"slices"  	"time"  	"github.com/superseriousbusiness/gotosocial/internal/db" @@ -28,6 +29,7 @@ import (  	"github.com/superseriousbusiness/gotosocial/internal/gtserror"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"  	"github.com/superseriousbusiness/gotosocial/internal/log" +	"github.com/superseriousbusiness/gotosocial/internal/util"  	"github.com/uptrace/bun"  ) @@ -62,7 +64,7 @@ func (r *relationshipDB) GetFollowByURI(ctx context.Context, uri string) (*gtsmo  func (r *relationshipDB) GetFollow(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.Follow, error) {  	return r.getFollow(  		ctx, -		"AccountID.TargetAccountID", +		"AccountID,TargetAccountID",  		func(follow *gtsmodel.Follow) error {  			return r.db.NewSelect().  				Model(follow). @@ -76,21 +78,62 @@ func (r *relationshipDB) GetFollow(ctx context.Context, sourceAccountID string,  }  func (r *relationshipDB) GetFollowsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.Follow, error) { -	// Preallocate slice of expected length. -	follows := make([]*gtsmodel.Follow, 0, len(ids)) +	// Preallocate at-worst possible length. +	uncached := make([]string, 0, len(ids)) + +	// Load all follow IDs via cache loader callbacks. +	follows, err := r.state.Caches.GTS.Follow.Load("ID", + +		// Load cached + check for uncached. +		func(load func(keyParts ...any) bool) { +			for _, id := range ids { +				if !load(id) { +					uncached = append(uncached, id) +				} +			} +		}, -	for _, id := range ids { -		// Fetch follow model for this ID. -		follow, err := r.GetFollowByID(ctx, id) -		if err != nil { -			log.Errorf(ctx, "error getting follow %q: %v", id, err) -			continue -		} +		// Uncached follow loader function. +		func() ([]*gtsmodel.Follow, error) { +			// Preallocate expected length of uncached follows. +			follows := make([]*gtsmodel.Follow, 0, len(uncached)) + +			// Perform database query scanning +			// the remaining (uncached) IDs. +			if err := r.db.NewSelect(). +				Model(&follows). +				Where("? IN (?)", bun.Ident("id"), bun.In(uncached)). +				Scan(ctx); err != nil { +				return nil, err +			} + +			return follows, nil +		}, +	) +	if err != nil { +		return nil, err +	} + +	// Reorder the follows by their +	// IDs to ensure in correct order. +	getID := func(f *gtsmodel.Follow) string { return f.ID } +	util.OrderBy(follows, ids, getID) -		// Append to return slice. -		follows = append(follows, follow) +	if gtscontext.Barebones(ctx) { +		// no need to fully populate. +		return follows, nil  	} +	// Populate all loaded follows, removing those we fail to +	// populate (removes needing so many nil checks everywhere). +	follows = slices.DeleteFunc(follows, func(follow *gtsmodel.Follow) bool { +		if err := r.PopulateFollow(ctx, follow); err != nil { +			log.Errorf(ctx, "error populating follow %s: %v", follow.ID, err) +			return true +		} +		return false +	}) +  	return follows, nil  } @@ -130,7 +173,7 @@ func (r *relationshipDB) IsMutualFollowing(ctx context.Context, accountID1 strin  func (r *relationshipDB) getFollow(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Follow) error, keyParts ...any) (*gtsmodel.Follow, error) {  	// Fetch follow from database cache with loader callback -	follow, err := r.state.Caches.GTS.Follow().Load(lookup, func() (*gtsmodel.Follow, error) { +	follow, err := r.state.Caches.GTS.Follow.LoadOne(lookup, func() (*gtsmodel.Follow, error) {  		var follow gtsmodel.Follow  		// Not cached! Perform database query @@ -189,7 +232,7 @@ func (r *relationshipDB) PopulateFollow(ctx context.Context, follow *gtsmodel.Fo  }  func (r *relationshipDB) PutFollow(ctx context.Context, follow *gtsmodel.Follow) error { -	return r.state.Caches.GTS.Follow().Store(follow, func() error { +	return r.state.Caches.GTS.Follow.Store(follow, func() error {  		_, err := r.db.NewInsert().Model(follow).Exec(ctx)  		return err  	}) @@ -202,7 +245,7 @@ func (r *relationshipDB) UpdateFollow(ctx context.Context, follow *gtsmodel.Foll  		columns = append(columns, "updated_at")  	} -	return r.state.Caches.GTS.Follow().Store(follow, func() error { +	return r.state.Caches.GTS.Follow.Store(follow, func() error {  		if _, err := r.db.NewUpdate().  			Model(follow).  			Where("? = ?", bun.Ident("follow.id"), follow.ID). @@ -250,7 +293,7 @@ func (r *relationshipDB) DeleteFollow(ctx context.Context, sourceAccountID strin  	}  	// Drop this now-cached follow on return after delete. -	defer r.state.Caches.GTS.Follow().Invalidate("AccountID.TargetAccountID", sourceAccountID, targetAccountID) +	defer r.state.Caches.GTS.Follow.Invalidate("AccountID,TargetAccountID", sourceAccountID, targetAccountID)  	// Finally delete follow from DB.  	return r.deleteFollow(ctx, follow.ID) @@ -270,7 +313,7 @@ func (r *relationshipDB) DeleteFollowByID(ctx context.Context, id string) error  	}  	// Drop this now-cached follow on return after delete. -	defer r.state.Caches.GTS.Follow().Invalidate("ID", id) +	defer r.state.Caches.GTS.Follow.Invalidate("ID", id)  	// Finally delete follow from DB.  	return r.deleteFollow(ctx, follow.ID) @@ -290,7 +333,7 @@ func (r *relationshipDB) DeleteFollowByURI(ctx context.Context, uri string) erro  	}  	// Drop this now-cached follow on return after delete. -	defer r.state.Caches.GTS.Follow().Invalidate("URI", uri) +	defer r.state.Caches.GTS.Follow.Invalidate("URI", uri)  	// Finally delete follow from DB.  	return r.deleteFollow(ctx, follow.ID) @@ -316,22 +359,30 @@ func (r *relationshipDB) DeleteAccountFollows(ctx context.Context, accountID str  	defer func() {  		// Invalidate all account's incoming / outoing follows on return. -		r.state.Caches.GTS.Follow().Invalidate("AccountID", accountID) -		r.state.Caches.GTS.Follow().Invalidate("TargetAccountID", accountID) +		r.state.Caches.GTS.Follow.Invalidate("AccountID", accountID) +		r.state.Caches.GTS.Follow.Invalidate("TargetAccountID", accountID)  	}()  	// Load all follows into cache, this *really* isn't great  	// but it is the only way we can ensure we invalidate all  	// related caches correctly (e.g. visibility). -	for _, id := range followIDs { -		follow, err := r.GetFollowByID(ctx, id) -		if err != nil && !errors.Is(err, db.ErrNoEntries) { -			return err -		} +	_, err := r.GetAccountFollows(ctx, accountID, nil) +	if err != nil && !errors.Is(err, db.ErrNoEntries) { +		return err +	} -		// Delete each follow from DB. -		if err := r.deleteFollow(ctx, follow.ID); err != nil && -			!errors.Is(err, db.ErrNoEntries) { +	// Delete all follows from DB. +	_, err = r.db.NewDelete(). +		Table("follows"). +		Where("? IN (?)", bun.Ident("id"), bun.In(followIDs)). +		Exec(ctx) +	if err != nil { +		return err +	} + +	for _, id := range followIDs { +		// Finally, delete all list entries associated with each follow ID. +		if err := r.state.DB.DeleteListEntriesForFollowID(ctx, id); err != nil {  			return err  		}  	} diff --git a/internal/db/bundb/relationship_follow_req.go b/internal/db/bundb/relationship_follow_req.go index 51aceafe1..690b97cf0 100644 --- a/internal/db/bundb/relationship_follow_req.go +++ b/internal/db/bundb/relationship_follow_req.go @@ -20,6 +20,7 @@ package bundb  import (  	"context"  	"errors" +	"slices"  	"time"  	"github.com/superseriousbusiness/gotosocial/internal/db" @@ -27,6 +28,7 @@ import (  	"github.com/superseriousbusiness/gotosocial/internal/gtserror"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"  	"github.com/superseriousbusiness/gotosocial/internal/log" +	"github.com/superseriousbusiness/gotosocial/internal/util"  	"github.com/uptrace/bun"  ) @@ -61,7 +63,7 @@ func (r *relationshipDB) GetFollowRequestByURI(ctx context.Context, uri string)  func (r *relationshipDB) GetFollowRequest(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.FollowRequest, error) {  	return r.getFollowRequest(  		ctx, -		"AccountID.TargetAccountID", +		"AccountID,TargetAccountID",  		func(followReq *gtsmodel.FollowRequest) error {  			return r.db.NewSelect().  				Model(followReq). @@ -75,22 +77,63 @@ func (r *relationshipDB) GetFollowRequest(ctx context.Context, sourceAccountID s  }  func (r *relationshipDB) GetFollowRequestsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.FollowRequest, error) { -	// Preallocate slice of expected length. -	followReqs := make([]*gtsmodel.FollowRequest, 0, len(ids)) +	// Preallocate at-worst possible length. +	uncached := make([]string, 0, len(ids)) + +	// Load all follow IDs via cache loader callbacks. +	follows, err := r.state.Caches.GTS.FollowRequest.Load("ID", + +		// Load cached + check for uncached. +		func(load func(keyParts ...any) bool) { +			for _, id := range ids { +				if !load(id) { +					uncached = append(uncached, id) +				} +			} +		}, -	for _, id := range ids { -		// Fetch follow request model for this ID. -		followReq, err := r.GetFollowRequestByID(ctx, id) -		if err != nil { -			log.Errorf(ctx, "error getting follow request %q: %v", id, err) -			continue -		} +		// Uncached follow req loader function. +		func() ([]*gtsmodel.FollowRequest, error) { +			// Preallocate expected length of uncached followReqs. +			follows := make([]*gtsmodel.FollowRequest, 0, len(uncached)) + +			// Perform database query scanning +			// the remaining (uncached) IDs. +			if err := r.db.NewSelect(). +				Model(&follows). +				Where("? IN (?)", bun.Ident("id"), bun.In(uncached)). +				Scan(ctx); err != nil { +				return nil, err +			} + +			return follows, nil +		}, +	) +	if err != nil { +		return nil, err +	} + +	// Reorder the requests by their +	// IDs to ensure in correct order. +	getID := func(f *gtsmodel.FollowRequest) string { return f.ID } +	util.OrderBy(follows, ids, getID) -		// Append to return slice. -		followReqs = append(followReqs, followReq) +	if gtscontext.Barebones(ctx) { +		// no need to fully populate. +		return follows, nil  	} -	return followReqs, nil +	// Populate all loaded followreqs, removing those we fail to +	// populate (removes needing so many nil checks everywhere). +	follows = slices.DeleteFunc(follows, func(follow *gtsmodel.FollowRequest) bool { +		if err := r.PopulateFollowRequest(ctx, follow); err != nil { +			log.Errorf(ctx, "error populating follow request %s: %v", follow.ID, err) +			return true +		} +		return false +	}) + +	return follows, nil  }  func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, error) { @@ -107,7 +150,7 @@ func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccountID  func (r *relationshipDB) getFollowRequest(ctx context.Context, lookup string, dbQuery func(*gtsmodel.FollowRequest) error, keyParts ...any) (*gtsmodel.FollowRequest, error) {  	// Fetch follow request from database cache with loader callback -	followReq, err := r.state.Caches.GTS.FollowRequest().Load(lookup, func() (*gtsmodel.FollowRequest, error) { +	followReq, err := r.state.Caches.GTS.FollowRequest.LoadOne(lookup, func() (*gtsmodel.FollowRequest, error) {  		var followReq gtsmodel.FollowRequest  		// Not cached! Perform database query @@ -166,7 +209,7 @@ func (r *relationshipDB) PopulateFollowRequest(ctx context.Context, follow *gtsm  }  func (r *relationshipDB) PutFollowRequest(ctx context.Context, follow *gtsmodel.FollowRequest) error { -	return r.state.Caches.GTS.FollowRequest().Store(follow, func() error { +	return r.state.Caches.GTS.FollowRequest.Store(follow, func() error {  		_, err := r.db.NewInsert().Model(follow).Exec(ctx)  		return err  	}) @@ -179,7 +222,7 @@ func (r *relationshipDB) UpdateFollowRequest(ctx context.Context, followRequest  		columns = append(columns, "updated_at")  	} -	return r.state.Caches.GTS.FollowRequest().Store(followRequest, func() error { +	return r.state.Caches.GTS.FollowRequest.Store(followRequest, func() error {  		if _, err := r.db.NewUpdate().  			Model(followRequest).  			Where("? = ?", bun.Ident("follow_request.id"), followRequest.ID). @@ -212,7 +255,7 @@ func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, sourceAccountI  		Notify:          followReq.Notify,  	} -	if err := r.state.Caches.GTS.Follow().Store(follow, func() error { +	if err := r.state.Caches.GTS.Follow.Store(follow, func() error {  		// If the follow already exists, just  		// replace the URI with the new one.  		_, err := r.db. @@ -274,7 +317,7 @@ func (r *relationshipDB) DeleteFollowRequest(ctx context.Context, sourceAccountI  	}  	// Drop this now-cached follow request on return after delete. -	defer r.state.Caches.GTS.FollowRequest().Invalidate("AccountID.TargetAccountID", sourceAccountID, targetAccountID) +	defer r.state.Caches.GTS.FollowRequest.Invalidate("AccountID,TargetAccountID", sourceAccountID, targetAccountID)  	// Finally delete followreq from DB.  	_, err = r.db.NewDelete(). @@ -298,7 +341,7 @@ func (r *relationshipDB) DeleteFollowRequestByID(ctx context.Context, id string)  	}  	// Drop this now-cached follow request on return after delete. -	defer r.state.Caches.GTS.FollowRequest().Invalidate("ID", id) +	defer r.state.Caches.GTS.FollowRequest.Invalidate("ID", id)  	// Finally delete followreq from DB.  	_, err = r.db.NewDelete(). @@ -322,7 +365,7 @@ func (r *relationshipDB) DeleteFollowRequestByURI(ctx context.Context, uri strin  	}  	// Drop this now-cached follow request on return after delete. -	defer r.state.Caches.GTS.FollowRequest().Invalidate("URI", uri) +	defer r.state.Caches.GTS.FollowRequest.Invalidate("URI", uri)  	// Finally delete followreq from DB.  	_, err = r.db.NewDelete(). @@ -352,22 +395,20 @@ func (r *relationshipDB) DeleteAccountFollowRequests(ctx context.Context, accoun  	defer func() {  		// Invalidate all account's incoming / outoing follow requests on return. -		r.state.Caches.GTS.FollowRequest().Invalidate("AccountID", accountID) -		r.state.Caches.GTS.FollowRequest().Invalidate("TargetAccountID", accountID) +		r.state.Caches.GTS.FollowRequest.Invalidate("AccountID", accountID) +		r.state.Caches.GTS.FollowRequest.Invalidate("TargetAccountID", accountID)  	}()  	// Load all followreqs into cache, this *really* isn't  	// great but it is the only way we can ensure we invalidate  	// all related caches correctly (e.g. visibility). -	for _, id := range followReqIDs { -		_, err := r.GetFollowRequestByID(ctx, id) -		if err != nil && !errors.Is(err, db.ErrNoEntries) { -			return err -		} +	_, err := r.GetAccountFollowRequests(ctx, accountID, nil) +	if err != nil && !errors.Is(err, db.ErrNoEntries) { +		return err  	}  	// Finally delete all from DB. -	_, err := r.db.NewDelete(). +	_, err = r.db.NewDelete().  		Table("follow_requests").  		Where("? IN (?)", bun.Ident("id"), bun.In(followReqIDs)).  		Exec(ctx) diff --git a/internal/db/bundb/relationship_note.go b/internal/db/bundb/relationship_note.go index f7d15f8b7..126ea0cd1 100644 --- a/internal/db/bundb/relationship_note.go +++ b/internal/db/bundb/relationship_note.go @@ -30,7 +30,7 @@ import (  func (r *relationshipDB) GetNote(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.AccountNote, error) {  	return r.getNote(  		ctx, -		"AccountID.TargetAccountID", +		"AccountID,TargetAccountID",  		func(note *gtsmodel.AccountNote) error {  			return r.db.NewSelect().Model(note).  				Where("? = ?", bun.Ident("account_id"), sourceAccountID). @@ -44,7 +44,7 @@ func (r *relationshipDB) GetNote(ctx context.Context, sourceAccountID string, ta  func (r *relationshipDB) getNote(ctx context.Context, lookup string, dbQuery func(*gtsmodel.AccountNote) error, keyParts ...any) (*gtsmodel.AccountNote, error) {  	// Fetch note from cache with loader callback -	note, err := r.state.Caches.GTS.AccountNote().Load(lookup, func() (*gtsmodel.AccountNote, error) { +	note, err := r.state.Caches.GTS.AccountNote.LoadOne(lookup, func() (*gtsmodel.AccountNote, error) {  		var note gtsmodel.AccountNote  		// Not cached! Perform database query @@ -105,7 +105,7 @@ func (r *relationshipDB) PopulateNote(ctx context.Context, note *gtsmodel.Accoun  func (r *relationshipDB) PutNote(ctx context.Context, note *gtsmodel.AccountNote) error {  	note.UpdatedAt = time.Now() -	return r.state.Caches.GTS.AccountNote().Store(note, func() error { +	return r.state.Caches.GTS.AccountNote.Store(note, func() error {  		_, err := r.db.  			NewInsert().  			Model(note). diff --git a/internal/db/bundb/report.go b/internal/db/bundb/report.go index 9e4ba5b29..5b0ae17f3 100644 --- a/internal/db/bundb/report.go +++ b/internal/db/bundb/report.go @@ -120,7 +120,7 @@ func (r *reportDB) GetReports(ctx context.Context, resolved *bool, accountID str  func (r *reportDB) getReport(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Report) error, keyParts ...any) (*gtsmodel.Report, error) {  	// Fetch report from database cache with loader callback -	report, err := r.state.Caches.GTS.Report().Load(lookup, func() (*gtsmodel.Report, error) { +	report, err := r.state.Caches.GTS.Report.LoadOne(lookup, func() (*gtsmodel.Report, error) {  		var report gtsmodel.Report  		// Not cached! Perform database query @@ -215,7 +215,7 @@ func (r *reportDB) PopulateReport(ctx context.Context, report *gtsmodel.Report)  }  func (r *reportDB) PutReport(ctx context.Context, report *gtsmodel.Report) error { -	return r.state.Caches.GTS.Report().Store(report, func() error { +	return r.state.Caches.GTS.Report.Store(report, func() error {  		_, err := r.db.NewInsert().Model(report).Exec(ctx)  		return err  	}) @@ -237,12 +237,12 @@ func (r *reportDB) UpdateReport(ctx context.Context, report *gtsmodel.Report, co  		return nil, err  	} -	r.state.Caches.GTS.Report().Invalidate("ID", report.ID) +	r.state.Caches.GTS.Report.Invalidate("ID", report.ID)  	return report, nil  }  func (r *reportDB) DeleteReportByID(ctx context.Context, id string) error { -	defer r.state.Caches.GTS.Report().Invalidate("ID", id) +	defer r.state.Caches.GTS.Report.Invalidate("ID", id)  	// Load status into cache before attempting a delete,  	// as we need it cached in order to trigger the invalidate diff --git a/internal/db/bundb/rule.go b/internal/db/bundb/rule.go index 79825923b..ebfa89d15 100644 --- a/internal/db/bundb/rule.go +++ b/internal/db/bundb/rule.go @@ -125,7 +125,7 @@ func (r *ruleDB) PutRule(ctx context.Context, rule *gtsmodel.Rule) error {  	}  	// invalidate cached local instance response, so it gets updated with the new rules -	r.state.Caches.GTS.Instance().Invalidate("Domain", config.GetHost()) +	r.state.Caches.GTS.Instance.Invalidate("Domain", config.GetHost())  	return nil  } @@ -143,7 +143,7 @@ func (r *ruleDB) UpdateRule(ctx context.Context, rule *gtsmodel.Rule) (*gtsmodel  	}  	// invalidate cached local instance response, so it gets updated with the new rules -	r.state.Caches.GTS.Instance().Invalidate("Domain", config.GetHost()) +	r.state.Caches.GTS.Instance.Invalidate("Domain", config.GetHost())  	return rule, nil  } diff --git a/internal/db/bundb/status.go b/internal/db/bundb/status.go index da252c7f7..07a09050a 100644 --- a/internal/db/bundb/status.go +++ b/internal/db/bundb/status.go @@ -20,6 +20,7 @@ package bundb  import (  	"context"  	"errors" +	"slices"  	"time"  	"github.com/superseriousbusiness/gotosocial/internal/db" @@ -28,6 +29,7 @@ import (  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"  	"github.com/superseriousbusiness/gotosocial/internal/log"  	"github.com/superseriousbusiness/gotosocial/internal/state" +	"github.com/superseriousbusiness/gotosocial/internal/util"  	"github.com/uptrace/bun"  ) @@ -48,20 +50,62 @@ func (s *statusDB) GetStatusByID(ctx context.Context, id string) (*gtsmodel.Stat  }  func (s *statusDB) GetStatusesByIDs(ctx context.Context, ids []string) ([]*gtsmodel.Status, error) { -	statuses := make([]*gtsmodel.Status, 0, len(ids)) +	// Preallocate at-worst possible length. +	uncached := make([]string, 0, len(ids)) -	for _, id := range ids { -		// Attempt to fetch status from DB. -		status, err := s.GetStatusByID(ctx, id) -		if err != nil { -			log.Errorf(ctx, "error getting status %q: %v", id, err) -			continue -		} +	// Load all status IDs via cache loader callbacks. +	statuses, err := s.state.Caches.GTS.Status.Load("ID", + +		// Load cached + check for uncached. +		func(load func(keyParts ...any) bool) { +			for _, id := range ids { +				if !load(id) { +					uncached = append(uncached, id) +				} +			} +		}, + +		// Uncached statuses loader function. +		func() ([]*gtsmodel.Status, error) { +			// Preallocate expected length of uncached statuses. +			statuses := make([]*gtsmodel.Status, 0, len(uncached)) + +			// Perform database query scanning +			// the remaining (uncached) status IDs. +			if err := s.db.NewSelect(). +				Model(&statuses). +				Where("? IN (?)", bun.Ident("id"), bun.In(uncached)). +				Scan(ctx); err != nil { +				return nil, err +			} + +			return statuses, nil +		}, +	) +	if err != nil { +		return nil, err +	} + +	// Reorder the statuses by their +	// IDs to ensure in correct order. +	getID := func(s *gtsmodel.Status) string { return s.ID } +	util.OrderBy(statuses, ids, getID) -		// Append status to return slice. -		statuses = append(statuses, status) +	if gtscontext.Barebones(ctx) { +		// no need to fully populate. +		return statuses, nil  	} +	// Populate all loaded statuses, removing those we fail to +	// populate (removes needing so many nil checks everywhere). +	statuses = slices.DeleteFunc(statuses, func(status *gtsmodel.Status) bool { +		if err := s.PopulateStatus(ctx, status); err != nil { +			log.Errorf(ctx, "error populating status %s: %v", status.ID, err) +			return true +		} +		return false +	}) +  	return statuses, nil  } @@ -101,7 +145,7 @@ func (s *statusDB) GetStatusByPollID(ctx context.Context, pollID string) (*gtsmo  func (s *statusDB) GetStatusBoost(ctx context.Context, boostOfID string, byAccountID string) (*gtsmodel.Status, error) {  	return s.getStatus(  		ctx, -		"BoostOfID.AccountID", +		"BoostOfID,AccountID",  		func(status *gtsmodel.Status) error {  			return s.db.NewSelect().Model(status).  				Where("status.boost_of_id = ?", boostOfID). @@ -120,7 +164,7 @@ func (s *statusDB) GetStatusBoost(ctx context.Context, boostOfID string, byAccou  func (s *statusDB) getStatus(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Status) error, keyParts ...any) (*gtsmodel.Status, error) {  	// Fetch status from database cache with loader callback -	status, err := s.state.Caches.GTS.Status().Load(lookup, func() (*gtsmodel.Status, error) { +	status, err := s.state.Caches.GTS.Status.LoadOne(lookup, func() (*gtsmodel.Status, error) {  		var status gtsmodel.Status  		// Not cached! Perform database query. @@ -282,7 +326,7 @@ func (s *statusDB) PopulateStatus(ctx context.Context, status *gtsmodel.Status)  }  func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) error { -	return s.state.Caches.GTS.Status().Store(status, func() error { +	return s.state.Caches.GTS.Status.Store(status, func() error {  		// It is safe to run this database transaction within cache.Store  		// as the cache does not attempt a mutex lock until AFTER hook.  		// @@ -366,7 +410,7 @@ func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status, co  		columns = append(columns, "updated_at")  	} -	return s.state.Caches.GTS.Status().Store(status, func() error { +	return s.state.Caches.GTS.Status.Store(status, func() error {  		// It is safe to run this database transaction within cache.Store  		// as the cache does not attempt a mutex lock until AFTER hook.  		// @@ -463,7 +507,7 @@ func (s *statusDB) DeleteStatusByID(ctx context.Context, id string) error {  	}  	// On return ensure status invalidated from cache. -	defer s.state.Caches.GTS.Status().Invalidate("ID", id) +	defer s.state.Caches.GTS.Status.Invalidate("ID", id)  	return s.db.RunInTx(ctx, func(tx Tx) error {  		// delete links between this status and any emojis it uses @@ -585,7 +629,7 @@ func (s *statusDB) CountStatusReplies(ctx context.Context, statusID string) (int  }  func (s *statusDB) getStatusReplyIDs(ctx context.Context, statusID string) ([]string, error) { -	return s.state.Caches.GTS.InReplyToIDs().Load(statusID, func() ([]string, error) { +	return s.state.Caches.GTS.InReplyToIDs.Load(statusID, func() ([]string, error) {  		var statusIDs []string  		// Status reply IDs not in cache, perform DB query! @@ -629,7 +673,7 @@ func (s *statusDB) CountStatusBoosts(ctx context.Context, statusID string) (int,  }  func (s *statusDB) getStatusBoostIDs(ctx context.Context, statusID string) ([]string, error) { -	return s.state.Caches.GTS.BoostOfIDs().Load(statusID, func() ([]string, error) { +	return s.state.Caches.GTS.BoostOfIDs.Load(statusID, func() ([]string, error) {  		var statusIDs []string  		// Status boost IDs not in cache, perform DB query! diff --git a/internal/db/bundb/statusfave.go b/internal/db/bundb/statusfave.go index 73ac62fe7..e0f018b68 100644 --- a/internal/db/bundb/statusfave.go +++ b/internal/db/bundb/statusfave.go @@ -22,6 +22,7 @@ import (  	"database/sql"  	"errors"  	"fmt" +	"slices"  	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/gtscontext" @@ -29,6 +30,7 @@ import (  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"  	"github.com/superseriousbusiness/gotosocial/internal/log"  	"github.com/superseriousbusiness/gotosocial/internal/state" +	"github.com/superseriousbusiness/gotosocial/internal/util"  	"github.com/uptrace/bun"  ) @@ -40,7 +42,7 @@ type statusFaveDB struct {  func (s *statusFaveDB) GetStatusFave(ctx context.Context, accountID string, statusID string) (*gtsmodel.StatusFave, error) {  	return s.getStatusFave(  		ctx, -		"AccountID.StatusID", +		"AccountID,StatusID",  		func(fave *gtsmodel.StatusFave) error {  			return s.db.  				NewSelect(). @@ -77,7 +79,7 @@ func (s *statusFaveDB) GetStatusFaveByID(ctx context.Context, id string) (*gtsmo  func (s *statusFaveDB) getStatusFave(ctx context.Context, lookup string, dbQuery func(*gtsmodel.StatusFave) error, keyParts ...any) (*gtsmodel.StatusFave, error) {  	// Fetch status fave from database cache with loader callback -	fave, err := s.state.Caches.GTS.StatusFave().Load(lookup, func() (*gtsmodel.StatusFave, error) { +	fave, err := s.state.Caches.GTS.StatusFave.LoadOne(lookup, func() (*gtsmodel.StatusFave, error) {  		var fave gtsmodel.StatusFave  		// Not cached! Perform database query. @@ -111,19 +113,62 @@ func (s *statusFaveDB) GetStatusFaves(ctx context.Context, statusID string) ([]*  		return nil, err  	} -	// Preallocate a slice of expected status fave capacity. -	faves := make([]*gtsmodel.StatusFave, 0, len(faveIDs)) +	// Preallocate at-worst possible length. +	uncached := make([]string, 0, len(faveIDs)) -	for _, id := range faveIDs { -		// Fetch status fave model for each ID. -		fave, err := s.GetStatusFaveByID(ctx, id) -		if err != nil { -			log.Errorf(ctx, "error getting status fave %q: %v", id, err) -			continue -		} -		faves = append(faves, fave) +	// Load all fave IDs via cache loader callbacks. +	faves, err := s.state.Caches.GTS.StatusFave.Load("ID", + +		// Load cached + check for uncached. +		func(load func(keyParts ...any) bool) { +			for _, id := range faveIDs { +				if !load(id) { +					uncached = append(uncached, id) +				} +			} +		}, + +		// Uncached status faves loader function. +		func() ([]*gtsmodel.StatusFave, error) { +			// Preallocate expected length of uncached faves. +			faves := make([]*gtsmodel.StatusFave, 0, len(uncached)) + +			// Perform database query scanning +			// the remaining (uncached) fave IDs. +			if err := s.db.NewSelect(). +				Model(&faves). +				Where("? IN (?)", bun.Ident("id"), bun.In(uncached)). +				Scan(ctx); err != nil { +				return nil, err +			} + +			return faves, nil +		}, +	) +	if err != nil { +		return nil, err  	} +	// Reorder the statuses by their +	// IDs to ensure in correct order. +	getID := func(f *gtsmodel.StatusFave) string { return f.ID } +	util.OrderBy(faves, faveIDs, getID) + +	if gtscontext.Barebones(ctx) { +		// no need to fully populate. +		return faves, nil +	} + +	// Populate all loaded faves, removing those we fail to +	// populate (removes needing so many nil checks everywhere). +	faves = slices.DeleteFunc(faves, func(fave *gtsmodel.StatusFave) bool { +		if err := s.PopulateStatusFave(ctx, fave); err != nil { +			log.Errorf(ctx, "error populating fave %s: %v", fave.ID, err) +			return true +		} +		return false +	}) +  	return faves, nil  } @@ -141,7 +186,7 @@ func (s *statusFaveDB) CountStatusFaves(ctx context.Context, statusID string) (i  }  func (s *statusFaveDB) getStatusFaveIDs(ctx context.Context, statusID string) ([]string, error) { -	return s.state.Caches.GTS.StatusFaveIDs().Load(statusID, func() ([]string, error) { +	return s.state.Caches.GTS.StatusFaveIDs.Load(statusID, func() ([]string, error) {  		var faveIDs []string  		// Status fave IDs not in cache, perform DB query! @@ -201,7 +246,7 @@ func (s *statusFaveDB) PopulateStatusFave(ctx context.Context, statusFave *gtsmo  }  func (s *statusFaveDB) PutStatusFave(ctx context.Context, fave *gtsmodel.StatusFave) error { -	return s.state.Caches.GTS.StatusFave().Store(fave, func() error { +	return s.state.Caches.GTS.StatusFave.Store(fave, func() error {  		_, err := s.db.  			NewInsert().  			Model(fave). @@ -230,10 +275,10 @@ func (s *statusFaveDB) DeleteStatusFaveByID(ctx context.Context, id string) erro  	if statusID != "" {  		// Invalidate any cached status faves for this status. -		s.state.Caches.GTS.StatusFave().Invalidate("ID", id) +		s.state.Caches.GTS.StatusFave.Invalidate("ID", id)  		// Invalidate any cached status fave IDs for this status. -		s.state.Caches.GTS.StatusFaveIDs().Invalidate(statusID) +		s.state.Caches.GTS.StatusFaveIDs.Invalidate(statusID)  	}  	return nil @@ -270,17 +315,15 @@ func (s *statusFaveDB) DeleteStatusFaves(ctx context.Context, targetAccountID st  		return err  	} -	// Collate (deduplicating) status IDs. -	statusIDs = collate(func(i int) string { -		return statusIDs[i] -	}, len(statusIDs)) +	// Deduplicate determined status IDs. +	statusIDs = util.Deduplicate(statusIDs)  	for _, id := range statusIDs {  		// Invalidate any cached status faves for this status. -		s.state.Caches.GTS.StatusFave().Invalidate("ID", id) +		s.state.Caches.GTS.StatusFave.Invalidate("ID", id)  		// Invalidate any cached status fave IDs for this status. -		s.state.Caches.GTS.StatusFaveIDs().Invalidate(id) +		s.state.Caches.GTS.StatusFaveIDs.Invalidate(id)  	}  	return nil @@ -296,10 +339,10 @@ func (s *statusFaveDB) DeleteStatusFavesForStatus(ctx context.Context, statusID  	}  	// Invalidate any cached status faves for this status. -	s.state.Caches.GTS.StatusFave().Invalidate("ID", statusID) +	s.state.Caches.GTS.StatusFave.Invalidate("ID", statusID)  	// Invalidate any cached status fave IDs for this status. -	s.state.Caches.GTS.StatusFaveIDs().Invalidate(statusID) +	s.state.Caches.GTS.StatusFaveIDs.Invalidate(statusID)  	return nil  } diff --git a/internal/db/bundb/tag.go b/internal/db/bundb/tag.go index fac621f0a..66ee8cb3a 100644 --- a/internal/db/bundb/tag.go +++ b/internal/db/bundb/tag.go @@ -22,21 +22,21 @@ import (  	"strings"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" -	"github.com/superseriousbusiness/gotosocial/internal/log"  	"github.com/superseriousbusiness/gotosocial/internal/state" +	"github.com/superseriousbusiness/gotosocial/internal/util"  	"github.com/uptrace/bun"  )  type tagDB struct { -	conn  *DB +	db    *DB  	state *state.State  } -func (m *tagDB) GetTag(ctx context.Context, id string) (*gtsmodel.Tag, error) { -	return m.state.Caches.GTS.Tag().Load("ID", func() (*gtsmodel.Tag, error) { +func (t *tagDB) GetTag(ctx context.Context, id string) (*gtsmodel.Tag, error) { +	return t.state.Caches.GTS.Tag.LoadOne("ID", func() (*gtsmodel.Tag, error) {  		var tag gtsmodel.Tag -		q := m.conn. +		q := t.db.  			NewSelect().  			Model(&tag).  			Where("? = ?", bun.Ident("tag.id"), id) @@ -49,15 +49,15 @@ func (m *tagDB) GetTag(ctx context.Context, id string) (*gtsmodel.Tag, error) {  	}, id)  } -func (m *tagDB) GetTagByName(ctx context.Context, name string) (*gtsmodel.Tag, error) { +func (t *tagDB) GetTagByName(ctx context.Context, name string) (*gtsmodel.Tag, error) {  	// Normalize 'name' string.  	name = strings.TrimSpace(name)  	name = strings.ToLower(name) -	return m.state.Caches.GTS.Tag().Load("Name", func() (*gtsmodel.Tag, error) { +	return t.state.Caches.GTS.Tag.LoadOne("Name", func() (*gtsmodel.Tag, error) {  		var tag gtsmodel.Tag -		q := m.conn. +		q := t.db.  			NewSelect().  			Model(&tag).  			Where("? = ?", bun.Ident("tag.name"), name) @@ -70,25 +70,52 @@ func (m *tagDB) GetTagByName(ctx context.Context, name string) (*gtsmodel.Tag, e  	}, name)  } -func (m *tagDB) GetTags(ctx context.Context, ids []string) ([]*gtsmodel.Tag, error) { -	tags := make([]*gtsmodel.Tag, 0, len(ids)) - -	for _, id := range ids { -		// Attempt fetch from DB -		tag, err := m.GetTag(ctx, id) -		if err != nil { -			log.Errorf(ctx, "error getting tag %q: %v", id, err) -			continue -		} - -		// Append tag -		tags = append(tags, tag) +func (t *tagDB) GetTags(ctx context.Context, ids []string) ([]*gtsmodel.Tag, error) { +	// Preallocate at-worst possible length. +	uncached := make([]string, 0, len(ids)) + +	// Load all tag IDs via cache loader callbacks. +	tags, err := t.state.Caches.GTS.Tag.Load("ID", + +		// Load cached + check for uncached. +		func(load func(keyParts ...any) bool) { +			for _, id := range ids { +				if !load(id) { +					uncached = append(uncached, id) +				} +			} +		}, + +		// Uncached tag loader function. +		func() ([]*gtsmodel.Tag, error) { +			// Preallocate expected length of uncached tags. +			tags := make([]*gtsmodel.Tag, 0, len(uncached)) + +			// Perform database query scanning +			// the remaining (uncached) IDs. +			if err := t.db.NewSelect(). +				Model(&tags). +				Where("? IN (?)", bun.Ident("id"), bun.In(uncached)). +				Scan(ctx); err != nil { +				return nil, err +			} + +			return tags, nil +		}, +	) +	if err != nil { +		return nil, err  	} +	// Reorder the tags by their +	// IDs to ensure in correct order. +	getID := func(t *gtsmodel.Tag) string { return t.ID } +	util.OrderBy(tags, ids, getID) +  	return tags, nil  } -func (m *tagDB) PutTag(ctx context.Context, tag *gtsmodel.Tag) error { +func (t *tagDB) PutTag(ctx context.Context, tag *gtsmodel.Tag) error {  	// Normalize 'name' string before it enters  	// the db, without changing tag we were given.  	// @@ -101,8 +128,8 @@ func (m *tagDB) PutTag(ctx context.Context, tag *gtsmodel.Tag) error {  	t2.Name = strings.ToLower(t2.Name)  	// Insert the copy. -	if err := m.state.Caches.GTS.Tag().Store(t2, func() error { -		_, err := m.conn.NewInsert().Model(t2).Exec(ctx) +	if err := t.state.Caches.GTS.Tag.Store(t2, func() error { +		_, err := t.db.NewInsert().Model(t2).Exec(ctx)  		return err  	}); err != nil {  		return err // err already processed diff --git a/internal/db/bundb/thread.go b/internal/db/bundb/thread.go index e6d6154d4..34c5f783a 100644 --- a/internal/db/bundb/thread.go +++ b/internal/db/bundb/thread.go @@ -42,7 +42,7 @@ func (t *threadDB) PutThread(ctx context.Context, thread *gtsmodel.Thread) error  }  func (t *threadDB) GetThreadMute(ctx context.Context, id string) (*gtsmodel.ThreadMute, error) { -	return t.state.Caches.GTS.ThreadMute().Load("ID", func() (*gtsmodel.ThreadMute, error) { +	return t.state.Caches.GTS.ThreadMute.LoadOne("ID", func() (*gtsmodel.ThreadMute, error) {  		var threadMute gtsmodel.ThreadMute  		q := t.db. @@ -63,7 +63,7 @@ func (t *threadDB) GetThreadMutedByAccount(  	threadID string,  	accountID string,  ) (*gtsmodel.ThreadMute, error) { -	return t.state.Caches.GTS.ThreadMute().Load("ThreadID.AccountID", func() (*gtsmodel.ThreadMute, error) { +	return t.state.Caches.GTS.ThreadMute.LoadOne("ThreadID,AccountID", func() (*gtsmodel.ThreadMute, error) {  		var threadMute gtsmodel.ThreadMute  		q := t.db. @@ -98,7 +98,7 @@ func (t *threadDB) IsThreadMutedByAccount(  }  func (t *threadDB) PutThreadMute(ctx context.Context, threadMute *gtsmodel.ThreadMute) error { -	return t.state.Caches.GTS.ThreadMute().Store(threadMute, func() error { +	return t.state.Caches.GTS.ThreadMute.Store(threadMute, func() error {  		_, err := t.db.NewInsert().Model(threadMute).Exec(ctx)  		return err  	}) @@ -112,6 +112,6 @@ func (t *threadDB) DeleteThreadMute(ctx context.Context, id string) error {  		return err  	} -	t.state.Caches.GTS.ThreadMute().Invalidate("ID", id) +	t.state.Caches.GTS.ThreadMute.Invalidate("ID", id)  	return nil  } diff --git a/internal/db/bundb/timeline.go b/internal/db/bundb/timeline.go index 4af17fb7f..f2ba2a9d1 100644 --- a/internal/db/bundb/timeline.go +++ b/internal/db/bundb/timeline.go @@ -29,7 +29,6 @@ import (  	"github.com/superseriousbusiness/gotosocial/internal/gtserror"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"  	"github.com/superseriousbusiness/gotosocial/internal/id" -	"github.com/superseriousbusiness/gotosocial/internal/log"  	"github.com/superseriousbusiness/gotosocial/internal/state"  	"github.com/uptrace/bun"  ) @@ -155,20 +154,8 @@ func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxI  		}  	} -	statuses := make([]*gtsmodel.Status, 0, len(statusIDs)) -	for _, id := range statusIDs { -		// Fetch status from db for ID -		status, err := t.state.DB.GetStatusByID(ctx, id) -		if err != nil { -			log.Errorf(ctx, "error fetching status %q: %v", id, err) -			continue -		} - -		// Append status to slice -		statuses = append(statuses, status) -	} - -	return statuses, nil +	// Return status IDs loaded from cache + db. +	return t.state.DB.GetStatusesByIDs(ctx, statusIDs)  }  func (t *timelineDB) GetPublicTimeline(ctx context.Context, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, error) { @@ -256,20 +243,8 @@ func (t *timelineDB) GetPublicTimeline(ctx context.Context, maxID string, sinceI  		}  	} -	statuses := make([]*gtsmodel.Status, 0, len(statusIDs)) -	for _, id := range statusIDs { -		// Fetch status from db for ID -		status, err := t.state.DB.GetStatusByID(ctx, id) -		if err != nil { -			log.Errorf(ctx, "error fetching status %q: %v", id, err) -			continue -		} - -		// Append status to slice -		statuses = append(statuses, status) -	} - -	return statuses, nil +	// Return status IDs loaded from cache + db. +	return t.state.DB.GetStatusesByIDs(ctx, statusIDs)  }  // TODO optimize this query and the logic here, because it's slow as balls -- it takes like a literal second to return with a limit of 20! @@ -323,18 +298,15 @@ func (t *timelineDB) GetFavedTimeline(ctx context.Context, accountID string, max  		}  	}) -	statuses := make([]*gtsmodel.Status, 0, len(faves)) - -	for _, fave := range faves { -		// Fetch status from db for corresponding favourite -		status, err := t.state.DB.GetStatusByID(ctx, fave.StatusID) -		if err != nil { -			log.Errorf(ctx, "error fetching status for fave %q: %v", fave.ID, err) -			continue -		} +	// Convert fave IDs to status IDs. +	statusIDs := make([]string, len(faves)) +	for i, fave := range faves { +		statusIDs[i] = fave.StatusID +	} -		// Append status to slice -		statuses = append(statuses, status) +	statuses, err := t.state.DB.GetStatusesByIDs(ctx, statusIDs) +	if err != nil { +		return nil, "", "", err  	}  	nextMaxID := faves[len(faves)-1].ID @@ -453,20 +425,8 @@ func (t *timelineDB) GetListTimeline(  		}  	} -	statuses := make([]*gtsmodel.Status, 0, len(statusIDs)) -	for _, id := range statusIDs { -		// Fetch status from db for ID -		status, err := t.state.DB.GetStatusByID(ctx, id) -		if err != nil { -			log.Errorf(ctx, "error fetching status %q: %v", id, err) -			continue -		} - -		// Append status to slice -		statuses = append(statuses, status) -	} - -	return statuses, nil +	// Return status IDs loaded from cache + db. +	return t.state.DB.GetStatusesByIDs(ctx, statusIDs)  }  func (t *timelineDB) GetTagTimeline( @@ -561,18 +521,6 @@ func (t *timelineDB) GetTagTimeline(  		}  	} -	statuses := make([]*gtsmodel.Status, 0, len(statusIDs)) -	for _, id := range statusIDs { -		// Fetch status from db for ID -		status, err := t.state.DB.GetStatusByID(ctx, id) -		if err != nil { -			log.Errorf(ctx, "error fetching status %q: %v", id, err) -			continue -		} - -		// Append status to slice -		statuses = append(statuses, status) -	} - -	return statuses, nil +	// Return status IDs loaded from cache + db. +	return t.state.DB.GetStatusesByIDs(ctx, statusIDs)  } diff --git a/internal/db/bundb/tombstone.go b/internal/db/bundb/tombstone.go index f9882d1c6..c0e439720 100644 --- a/internal/db/bundb/tombstone.go +++ b/internal/db/bundb/tombstone.go @@ -32,7 +32,7 @@ type tombstoneDB struct {  }  func (t *tombstoneDB) GetTombstoneByURI(ctx context.Context, uri string) (*gtsmodel.Tombstone, error) { -	return t.state.Caches.GTS.Tombstone().Load("URI", func() (*gtsmodel.Tombstone, error) { +	return t.state.Caches.GTS.Tombstone.LoadOne("URI", func() (*gtsmodel.Tombstone, error) {  		var tomb gtsmodel.Tombstone  		q := t.db. @@ -57,7 +57,7 @@ func (t *tombstoneDB) TombstoneExistsWithURI(ctx context.Context, uri string) (b  }  func (t *tombstoneDB) PutTombstone(ctx context.Context, tombstone *gtsmodel.Tombstone) error { -	return t.state.Caches.GTS.Tombstone().Store(tombstone, func() error { +	return t.state.Caches.GTS.Tombstone.Store(tombstone, func() error {  		_, err := t.db.  			NewInsert().  			Model(tombstone). @@ -67,7 +67,7 @@ func (t *tombstoneDB) PutTombstone(ctx context.Context, tombstone *gtsmodel.Tomb  }  func (t *tombstoneDB) DeleteTombstone(ctx context.Context, id string) error { -	defer t.state.Caches.GTS.Tombstone().Invalidate("ID", id) +	defer t.state.Caches.GTS.Tombstone.Invalidate("ID", id)  	// Delete tombstone from DB.  	_, err := t.db.NewDelete(). diff --git a/internal/db/bundb/user.go b/internal/db/bundb/user.go index 46b3c568f..a6fa142f2 100644 --- a/internal/db/bundb/user.go +++ b/internal/db/bundb/user.go @@ -116,7 +116,7 @@ func (u *userDB) GetUserByConfirmationToken(ctx context.Context, token string) (  func (u *userDB) getUser(ctx context.Context, lookup string, dbQuery func(*gtsmodel.User) error, keyParts ...any) (*gtsmodel.User, error) {  	// Fetch user from database cache with loader callback. -	user, err := u.state.Caches.GTS.User().Load(lookup, func() (*gtsmodel.User, error) { +	user, err := u.state.Caches.GTS.User.LoadOne(lookup, func() (*gtsmodel.User, error) {  		var user gtsmodel.User  		// Not cached! perform database query. @@ -179,7 +179,7 @@ func (u *userDB) GetAllUsers(ctx context.Context) ([]*gtsmodel.User, error) {  }  func (u *userDB) PutUser(ctx context.Context, user *gtsmodel.User) error { -	return u.state.Caches.GTS.User().Store(user, func() error { +	return u.state.Caches.GTS.User.Store(user, func() error {  		_, err := u.db.  			NewInsert().  			Model(user). @@ -197,7 +197,7 @@ func (u *userDB) UpdateUser(ctx context.Context, user *gtsmodel.User, columns ..  		columns = append(columns, "updated_at")  	} -	return u.state.Caches.GTS.User().Store(user, func() error { +	return u.state.Caches.GTS.User.Store(user, func() error {  		_, err := u.db.  			NewUpdate().  			Model(user). @@ -209,7 +209,7 @@ func (u *userDB) UpdateUser(ctx context.Context, user *gtsmodel.User, columns ..  }  func (u *userDB) DeleteUserByID(ctx context.Context, userID string) error { -	defer u.state.Caches.GTS.User().Invalidate("ID", userID) +	defer u.state.Caches.GTS.User.Invalidate("ID", userID)  	// Load user into cache before attempting a delete,  	// as we need it cached in order to trigger the invalidate | 
