summaryrefslogtreecommitdiff
path: root/internal/db
diff options
context:
space:
mode:
Diffstat (limited to 'internal/db')
-rw-r--r--internal/db/bundb/account.go44
-rw-r--r--internal/db/bundb/application.go6
-rw-r--r--internal/db/bundb/bundb.go2
-rw-r--r--internal/db/bundb/domain.go12
-rw-r--r--internal/db/bundb/emoji.go179
-rw-r--r--internal/db/bundb/instance.go6
-rw-r--r--internal/db/bundb/list.go203
-rw-r--r--internal/db/bundb/marker.go10
-rw-r--r--internal/db/bundb/media.go70
-rw-r--r--internal/db/bundb/mention.go71
-rw-r--r--internal/db/bundb/notification.go148
-rw-r--r--internal/db/bundb/poll.go107
-rw-r--r--internal/db/bundb/relationship.go14
-rw-r--r--internal/db/bundb/relationship_block.go91
-rw-r--r--internal/db/bundb/relationship_follow.go107
-rw-r--r--internal/db/bundb/relationship_follow_req.go97
-rw-r--r--internal/db/bundb/relationship_note.go6
-rw-r--r--internal/db/bundb/report.go8
-rw-r--r--internal/db/bundb/rule.go4
-rw-r--r--internal/db/bundb/status.go78
-rw-r--r--internal/db/bundb/statusfave.go91
-rw-r--r--internal/db/bundb/tag.go75
-rw-r--r--internal/db/bundb/thread.go8
-rw-r--r--internal/db/bundb/timeline.go84
-rw-r--r--internal/db/bundb/tombstone.go6
-rw-r--r--internal/db/bundb/user.go8
-rw-r--r--internal/db/list.go6
-rw-r--r--internal/db/notification.go3
28 files changed, 1027 insertions, 517 deletions
diff --git a/internal/db/bundb/account.go b/internal/db/bundb/account.go
index fdee8cb76..cdb949efa 100644
--- a/internal/db/bundb/account.go
+++ b/internal/db/bundb/account.go
@@ -116,7 +116,7 @@ func (a *accountDB) GetAccountByUsernameDomain(ctx context.Context, username str
return a.getAccount(
ctx,
- "Username.Domain",
+ "Username,Domain",
func(account *gtsmodel.Account) error {
q := a.db.NewSelect().
Model(account)
@@ -224,7 +224,7 @@ func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gts
func (a *accountDB) getAccount(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Account) error, keyParts ...any) (*gtsmodel.Account, error) {
// Fetch account from database cache with loader callback
- account, err := a.state.Caches.GTS.Account().Load(lookup, func() (*gtsmodel.Account, error) {
+ account, err := a.state.Caches.GTS.Account.LoadOne(lookup, func() (*gtsmodel.Account, error) {
var account gtsmodel.Account
// Not cached! Perform database query
@@ -325,7 +325,7 @@ func (a *accountDB) PopulateAccount(ctx context.Context, account *gtsmodel.Accou
}
func (a *accountDB) PutAccount(ctx context.Context, account *gtsmodel.Account) error {
- return a.state.Caches.GTS.Account().Store(account, func() error {
+ return a.state.Caches.GTS.Account.Store(account, func() error {
// It is safe to run this database transaction within cache.Store
// as the cache does not attempt a mutex lock until AFTER hook.
//
@@ -354,7 +354,7 @@ func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account
columns = append(columns, "updated_at")
}
- return a.state.Caches.GTS.Account().Store(account, func() error {
+ return a.state.Caches.GTS.Account.Store(account, func() error {
// It is safe to run this database transaction within cache.Store
// as the cache does not attempt a mutex lock until AFTER hook.
//
@@ -393,7 +393,7 @@ func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account
}
func (a *accountDB) DeleteAccount(ctx context.Context, id string) error {
- defer a.state.Caches.GTS.Account().Invalidate("ID", id)
+ defer a.state.Caches.GTS.Account.Invalidate("ID", id)
// Load account into cache before attempting a delete,
// as we need it cached in order to trigger the invalidate
@@ -635,6 +635,10 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li
return nil, err
}
+ if len(statusIDs) == 0 {
+ return nil, db.ErrNoEntries
+ }
+
// If we're paging up, we still want statuses
// to be sorted by ID desc, so reverse ids slice.
// https://zchee.github.io/golang-wiki/SliceTricks/#reversing
@@ -644,7 +648,7 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li
}
}
- return a.statusesFromIDs(ctx, statusIDs)
+ return a.state.DB.GetStatusesByIDs(ctx, statusIDs)
}
func (a *accountDB) GetAccountPinnedStatuses(ctx context.Context, accountID string) ([]*gtsmodel.Status, error) {
@@ -662,7 +666,11 @@ func (a *accountDB) GetAccountPinnedStatuses(ctx context.Context, accountID stri
return nil, err
}
- return a.statusesFromIDs(ctx, statusIDs)
+ if len(statusIDs) == 0 {
+ return nil, db.ErrNoEntries
+ }
+
+ return a.state.DB.GetStatusesByIDs(ctx, statusIDs)
}
func (a *accountDB) GetAccountWebStatuses(ctx context.Context, accountID string, limit int, maxID string) ([]*gtsmodel.Status, error) {
@@ -710,29 +718,9 @@ func (a *accountDB) GetAccountWebStatuses(ctx context.Context, accountID string,
return nil, err
}
- return a.statusesFromIDs(ctx, statusIDs)
-}
-
-func (a *accountDB) statusesFromIDs(ctx context.Context, statusIDs []string) ([]*gtsmodel.Status, error) {
- // Catch case of no statuses early
if len(statusIDs) == 0 {
return nil, db.ErrNoEntries
}
- // Allocate return slice (will be at most len statusIDS)
- statuses := make([]*gtsmodel.Status, 0, len(statusIDs))
-
- for _, id := range statusIDs {
- // Fetch from status from database by ID
- status, err := a.state.DB.GetStatusByID(ctx, id)
- if err != nil {
- log.Errorf(ctx, "error getting status %q: %v", id, err)
- continue
- }
-
- // Append to return slice
- statuses = append(statuses, status)
- }
-
- return statuses, nil
+ return a.state.DB.GetStatusesByIDs(ctx, statusIDs)
}
diff --git a/internal/db/bundb/application.go b/internal/db/bundb/application.go
index f7328e275..2e17a0e94 100644
--- a/internal/db/bundb/application.go
+++ b/internal/db/bundb/application.go
@@ -53,7 +53,7 @@ func (a *applicationDB) GetApplicationByClientID(ctx context.Context, clientID s
}
func (a *applicationDB) getApplication(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Application) error, keyParts ...any) (*gtsmodel.Application, error) {
- return a.state.Caches.GTS.Application().Load(lookup, func() (*gtsmodel.Application, error) {
+ return a.state.Caches.GTS.Application.LoadOne(lookup, func() (*gtsmodel.Application, error) {
var app gtsmodel.Application
// Not cached! Perform database query.
@@ -66,7 +66,7 @@ func (a *applicationDB) getApplication(ctx context.Context, lookup string, dbQue
}
func (a *applicationDB) PutApplication(ctx context.Context, app *gtsmodel.Application) error {
- return a.state.Caches.GTS.Application().Store(app, func() error {
+ return a.state.Caches.GTS.Application.Store(app, func() error {
_, err := a.db.NewInsert().Model(app).Exec(ctx)
return err
})
@@ -91,7 +91,7 @@ func (a *applicationDB) DeleteApplicationByClientID(ctx context.Context, clientI
//
// Clear application from the cache.
- a.state.Caches.GTS.Application().Invalidate("ClientID", clientID)
+ a.state.Caches.GTS.Application.Invalidate("ClientID", clientID)
return nil
}
diff --git a/internal/db/bundb/bundb.go b/internal/db/bundb/bundb.go
index d9415eff4..048474782 100644
--- a/internal/db/bundb/bundb.go
+++ b/internal/db/bundb/bundb.go
@@ -258,7 +258,7 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) {
state: state,
},
Tag: &tagDB{
- conn: db,
+ db: db,
state: state,
},
Thread: &threadDB{
diff --git a/internal/db/bundb/domain.go b/internal/db/bundb/domain.go
index dd626bc0a..2398e52c2 100644
--- a/internal/db/bundb/domain.go
+++ b/internal/db/bundb/domain.go
@@ -51,7 +51,7 @@ func (d *domainDB) CreateDomainAllow(ctx context.Context, allow *gtsmodel.Domain
}
// Clear the domain allow cache (for later reload)
- d.state.Caches.GTS.DomainAllow().Clear()
+ d.state.Caches.GTS.DomainAllow.Clear()
return nil
}
@@ -126,7 +126,7 @@ func (d *domainDB) DeleteDomainAllow(ctx context.Context, domain string) error {
}
// Clear the domain allow cache (for later reload)
- d.state.Caches.GTS.DomainAllow().Clear()
+ d.state.Caches.GTS.DomainAllow.Clear()
return nil
}
@@ -147,7 +147,7 @@ func (d *domainDB) CreateDomainBlock(ctx context.Context, block *gtsmodel.Domain
}
// Clear the domain block cache (for later reload)
- d.state.Caches.GTS.DomainBlock().Clear()
+ d.state.Caches.GTS.DomainBlock.Clear()
return nil
}
@@ -222,7 +222,7 @@ func (d *domainDB) DeleteDomainBlock(ctx context.Context, domain string) error {
}
// Clear the domain block cache (for later reload)
- d.state.Caches.GTS.DomainBlock().Clear()
+ d.state.Caches.GTS.DomainBlock.Clear()
return nil
}
@@ -241,7 +241,7 @@ func (d *domainDB) IsDomainBlocked(ctx context.Context, domain string) (bool, er
}
// Check the cache for an explicit domain allow (hydrating the cache with callback if necessary).
- explicitAllow, err := d.state.Caches.GTS.DomainAllow().Matches(domain, func() ([]string, error) {
+ explicitAllow, err := d.state.Caches.GTS.DomainAllow.Matches(domain, func() ([]string, error) {
var domains []string
// Scan list of all explicitly allowed domains from DB
@@ -259,7 +259,7 @@ func (d *domainDB) IsDomainBlocked(ctx context.Context, domain string) (bool, er
}
// Check the cache for a domain block (hydrating the cache with callback if necessary)
- explicitBlock, err := d.state.Caches.GTS.DomainBlock().Matches(domain, func() ([]string, error) {
+ explicitBlock, err := d.state.Caches.GTS.DomainBlock.Matches(domain, func() ([]string, error) {
var domains []string
// Scan list of all blocked domains from DB
diff --git a/internal/db/bundb/emoji.go b/internal/db/bundb/emoji.go
index 34a08b694..31092d0d2 100644
--- a/internal/db/bundb/emoji.go
+++ b/internal/db/bundb/emoji.go
@@ -21,6 +21,7 @@ import (
"context"
"database/sql"
"errors"
+ "slices"
"strings"
"time"
@@ -30,6 +31,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state"
+ "github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun"
"github.com/uptrace/bun/dialect"
)
@@ -40,7 +42,7 @@ type emojiDB struct {
}
func (e *emojiDB) PutEmoji(ctx context.Context, emoji *gtsmodel.Emoji) error {
- return e.state.Caches.GTS.Emoji().Store(emoji, func() error {
+ return e.state.Caches.GTS.Emoji.Store(emoji, func() error {
_, err := e.db.NewInsert().Model(emoji).Exec(ctx)
return err
})
@@ -54,7 +56,7 @@ func (e *emojiDB) UpdateEmoji(ctx context.Context, emoji *gtsmodel.Emoji, column
}
// Update the emoji model in the database.
- return e.state.Caches.GTS.Emoji().Store(emoji, func() error {
+ return e.state.Caches.GTS.Emoji.Store(emoji, func() error {
_, err := e.db.
NewUpdate().
Model(emoji).
@@ -74,21 +76,21 @@ func (e *emojiDB) DeleteEmojiByID(ctx context.Context, id string) error {
defer func() {
// Invalidate cached emoji.
e.state.Caches.GTS.
- Emoji().
+ Emoji.
Invalidate("ID", id)
- for _, id := range accountIDs {
+ for _, accountID := range accountIDs {
// Invalidate cached account.
e.state.Caches.GTS.
- Account().
- Invalidate("ID", id)
+ Account.
+ Invalidate("ID", accountID)
}
- for _, id := range statusIDs {
+ for _, statusID := range statusIDs {
// Invalidate cached account.
e.state.Caches.GTS.
- Status().
- Invalidate("ID", id)
+ Status.
+ Invalidate("ID", statusID)
}
}()
@@ -129,26 +131,28 @@ func (e *emojiDB) DeleteEmojiByID(ctx context.Context, id string) error {
return err
}
- for _, id := range statusIDs {
+ for _, statusID := range statusIDs {
var emojiIDs []string
// Select statuses with ID.
if _, err := tx.NewSelect().
Table("statuses").
Column("emojis").
- Where("? = ?", bun.Ident("id"), id).
+ Where("? = ?", bun.Ident("id"), statusID).
Exec(ctx); err != nil &&
err != sql.ErrNoRows {
return err
}
- // Drop ID from account emojis.
- emojiIDs = dropID(emojiIDs, id)
+ // Delete all instances of this emoji ID from status emojis.
+ emojiIDs = slices.DeleteFunc(emojiIDs, func(emojiID string) bool {
+ return emojiID == id
+ })
// Update status emoji IDs.
if _, err := tx.NewUpdate().
Table("statuses").
- Where("? = ?", bun.Ident("id"), id).
+ Where("? = ?", bun.Ident("id"), statusID).
Set("emojis = ?", emojiIDs).
Exec(ctx); err != nil &&
err != sql.ErrNoRows {
@@ -156,26 +160,28 @@ func (e *emojiDB) DeleteEmojiByID(ctx context.Context, id string) error {
}
}
- for _, id := range accountIDs {
+ for _, accountID := range accountIDs {
var emojiIDs []string
// Select account with ID.
if _, err := tx.NewSelect().
Table("accounts").
Column("emojis").
- Where("? = ?", bun.Ident("id"), id).
+ Where("? = ?", bun.Ident("id"), accountID).
Exec(ctx); err != nil &&
err != sql.ErrNoRows {
return err
}
- // Drop ID from account emojis.
- emojiIDs = dropID(emojiIDs, id)
+ // Delete all instances of this emoji ID from account emojis.
+ emojiIDs = slices.DeleteFunc(emojiIDs, func(emojiID string) bool {
+ return emojiID == id
+ })
// Update account emoji IDs.
if _, err := tx.NewUpdate().
Table("accounts").
- Where("? = ?", bun.Ident("id"), id).
+ Where("? = ?", bun.Ident("id"), accountID).
Set("emojis = ?", emojiIDs).
Exec(ctx); err != nil &&
err != sql.ErrNoRows {
@@ -431,7 +437,7 @@ func (e *emojiDB) GetEmojiByURI(ctx context.Context, uri string) (*gtsmodel.Emoj
func (e *emojiDB) GetEmojiByShortcodeDomain(ctx context.Context, shortcode string, domain string) (*gtsmodel.Emoji, error) {
return e.getEmoji(
ctx,
- "Shortcode.Domain",
+ "Shortcode,Domain",
func(emoji *gtsmodel.Emoji) error {
q := e.db.
NewSelect().
@@ -468,7 +474,7 @@ func (e *emojiDB) GetEmojiByStaticURL(ctx context.Context, imageStaticURL string
}
func (e *emojiDB) PutEmojiCategory(ctx context.Context, emojiCategory *gtsmodel.EmojiCategory) error {
- return e.state.Caches.GTS.EmojiCategory().Store(emojiCategory, func() error {
+ return e.state.Caches.GTS.EmojiCategory.Store(emojiCategory, func() error {
_, err := e.db.NewInsert().Model(emojiCategory).Exec(ctx)
return err
})
@@ -520,7 +526,7 @@ func (e *emojiDB) GetEmojiCategoryByName(ctx context.Context, name string) (*gts
func (e *emojiDB) getEmoji(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Emoji) error, keyParts ...any) (*gtsmodel.Emoji, error) {
// Fetch emoji from database cache with loader callback
- emoji, err := e.state.Caches.GTS.Emoji().Load(lookup, func() (*gtsmodel.Emoji, error) {
+ emoji, err := e.state.Caches.GTS.Emoji.LoadOne(lookup, func() (*gtsmodel.Emoji, error) {
var emoji gtsmodel.Emoji
// Not cached! Perform database query
@@ -568,28 +574,72 @@ func (e *emojiDB) PopulateEmoji(ctx context.Context, emoji *gtsmodel.Emoji) erro
return errs.Combine()
}
-func (e *emojiDB) GetEmojisByIDs(ctx context.Context, emojiIDs []string) ([]*gtsmodel.Emoji, error) {
- if len(emojiIDs) == 0 {
+func (e *emojiDB) GetEmojisByIDs(ctx context.Context, ids []string) ([]*gtsmodel.Emoji, error) {
+ if len(ids) == 0 {
return nil, db.ErrNoEntries
}
- emojis := make([]*gtsmodel.Emoji, 0, len(emojiIDs))
+ // Preallocate at-worst possible length.
+ uncached := make([]string, 0, len(ids))
- for _, id := range emojiIDs {
- emoji, err := e.GetEmojiByID(ctx, id)
- if err != nil {
- log.Errorf(ctx, "emojisFromIDs: error getting emoji %q: %v", id, err)
- continue
- }
+ // Load all emoji IDs via cache loader callbacks.
+ emojis, err := e.state.Caches.GTS.Emoji.Load("ID",
+
+ // Load cached + check for uncached.
+ func(load func(keyParts ...any) bool) {
+ for _, id := range ids {
+ if !load(id) {
+ uncached = append(uncached, id)
+ }
+ }
+ },
- emojis = append(emojis, emoji)
+ // Uncached emoji loader function.
+ func() ([]*gtsmodel.Emoji, error) {
+ // Preallocate expected length of uncached emojis.
+ emojis := make([]*gtsmodel.Emoji, 0, len(uncached))
+
+ // Perform database query scanning
+ // the remaining (uncached) IDs.
+ if err := e.db.NewSelect().
+ Model(&emojis).
+ Where("? IN (?)", bun.Ident("id"), bun.In(uncached)).
+ Scan(ctx); err != nil {
+ return nil, err
+ }
+
+ return emojis, nil
+ },
+ )
+ if err != nil {
+ return nil, err
+ }
+
+ // Reorder the emojis by their
+ // IDs to ensure in correct order.
+ getID := func(e *gtsmodel.Emoji) string { return e.ID }
+ util.OrderBy(emojis, ids, getID)
+
+ if gtscontext.Barebones(ctx) {
+ // no need to fully populate.
+ return emojis, nil
}
+ // Populate all loaded emojis, removing those we fail to
+ // populate (removes needing so many nil checks everywhere).
+ emojis = slices.DeleteFunc(emojis, func(emoji *gtsmodel.Emoji) bool {
+ if err := e.PopulateEmoji(ctx, emoji); err != nil {
+ log.Errorf(ctx, "error populating emoji %s: %v", emoji.ID, err)
+ return true
+ }
+ return false
+ })
+
return emojis, nil
}
func (e *emojiDB) getEmojiCategory(ctx context.Context, lookup string, dbQuery func(*gtsmodel.EmojiCategory) error, keyParts ...any) (*gtsmodel.EmojiCategory, error) {
- return e.state.Caches.GTS.EmojiCategory().Load(lookup, func() (*gtsmodel.EmojiCategory, error) {
+ return e.state.Caches.GTS.EmojiCategory.LoadOne(lookup, func() (*gtsmodel.EmojiCategory, error) {
var category gtsmodel.EmojiCategory
// Not cached! Perform database query
@@ -601,36 +651,51 @@ func (e *emojiDB) getEmojiCategory(ctx context.Context, lookup string, dbQuery f
}, keyParts...)
}
-func (e *emojiDB) GetEmojiCategoriesByIDs(ctx context.Context, emojiCategoryIDs []string) ([]*gtsmodel.EmojiCategory, error) {
- if len(emojiCategoryIDs) == 0 {
+func (e *emojiDB) GetEmojiCategoriesByIDs(ctx context.Context, ids []string) ([]*gtsmodel.EmojiCategory, error) {
+ if len(ids) == 0 {
return nil, db.ErrNoEntries
}
- emojiCategories := make([]*gtsmodel.EmojiCategory, 0, len(emojiCategoryIDs))
+ // Preallocate at-worst possible length.
+ uncached := make([]string, 0, len(ids))
- for _, id := range emojiCategoryIDs {
- emojiCategory, err := e.GetEmojiCategory(ctx, id)
- if err != nil {
- log.Errorf(ctx, "error getting emoji category %q: %v", id, err)
- continue
- }
+ // Load all category IDs via cache loader callbacks.
+ categories, err := e.state.Caches.GTS.EmojiCategory.Load("ID",
- emojiCategories = append(emojiCategories, emojiCategory)
- }
+ // Load cached + check for uncached.
+ func(load func(keyParts ...any) bool) {
+ for _, id := range ids {
+ if !load(id) {
+ uncached = append(uncached, id)
+ }
+ }
+ },
- return emojiCategories, nil
-}
+ // Uncached emoji loader function.
+ func() ([]*gtsmodel.EmojiCategory, error) {
+ // Preallocate expected length of uncached categories.
+ categories := make([]*gtsmodel.EmojiCategory, 0, len(uncached))
+
+ // Perform database query scanning
+ // the remaining (uncached) IDs.
+ if err := e.db.NewSelect().
+ Model(&categories).
+ Where("? IN (?)", bun.Ident("id"), bun.In(uncached)).
+ Scan(ctx); err != nil {
+ return nil, err
+ }
-// dropIDs drops given ID string from IDs slice.
-func dropID(ids []string, id string) []string {
- for i := 0; i < len(ids); {
- if ids[i] == id {
- // Remove this reference.
- copy(ids[i:], ids[i+1:])
- ids = ids[:len(ids)-1]
- continue
- }
- i++
+ return categories, nil
+ },
+ )
+ if err != nil {
+ return nil, err
}
- return ids
+
+ // Reorder the categories by their
+ // IDs to ensure in correct order.
+ getID := func(c *gtsmodel.EmojiCategory) string { return c.ID }
+ util.OrderBy(categories, ids, getID)
+
+ return categories, nil
}
diff --git a/internal/db/bundb/instance.go b/internal/db/bundb/instance.go
index 567a44ee2..d506e0a31 100644
--- a/internal/db/bundb/instance.go
+++ b/internal/db/bundb/instance.go
@@ -143,7 +143,7 @@ func (i *instanceDB) GetInstanceByID(ctx context.Context, id string) (*gtsmodel.
func (i *instanceDB) getInstance(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Instance) error, keyParts ...any) (*gtsmodel.Instance, error) {
// Fetch instance from database cache with loader callback
- instance, err := i.state.Caches.GTS.Instance().Load(lookup, func() (*gtsmodel.Instance, error) {
+ instance, err := i.state.Caches.GTS.Instance.LoadOne(lookup, func() (*gtsmodel.Instance, error) {
var instance gtsmodel.Instance
// Not cached! Perform database query.
@@ -219,7 +219,7 @@ func (i *instanceDB) PutInstance(ctx context.Context, instance *gtsmodel.Instanc
return gtserror.Newf("error punifying domain %s: %w", instance.Domain, err)
}
- return i.state.Caches.GTS.Instance().Store(instance, func() error {
+ return i.state.Caches.GTS.Instance.Store(instance, func() error {
_, err := i.db.NewInsert().Model(instance).Exec(ctx)
return err
})
@@ -239,7 +239,7 @@ func (i *instanceDB) UpdateInstance(ctx context.Context, instance *gtsmodel.Inst
columns = append(columns, "updated_at")
}
- return i.state.Caches.GTS.Instance().Store(instance, func() error {
+ return i.state.Caches.GTS.Instance.Store(instance, func() error {
_, err := i.db.
NewUpdate().
Model(instance).
diff --git a/internal/db/bundb/list.go b/internal/db/bundb/list.go
index 7a117670a..5f95d3c24 100644
--- a/internal/db/bundb/list.go
+++ b/internal/db/bundb/list.go
@@ -21,6 +21,7 @@ import (
"context"
"errors"
"fmt"
+ "slices"
"time"
"github.com/superseriousbusiness/gotosocial/internal/db"
@@ -29,6 +30,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state"
+ "github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun"
)
@@ -56,7 +58,7 @@ func (l *listDB) GetListByID(ctx context.Context, id string) (*gtsmodel.List, er
}
func (l *listDB) getList(ctx context.Context, lookup string, dbQuery func(*gtsmodel.List) error, keyParts ...any) (*gtsmodel.List, error) {
- list, err := l.state.Caches.GTS.List().Load(lookup, func() (*gtsmodel.List, error) {
+ list, err := l.state.Caches.GTS.List.LoadOne(lookup, func() (*gtsmodel.List, error) {
var list gtsmodel.List
// Not cached! Perform database query.
@@ -100,18 +102,8 @@ func (l *listDB) GetListsForAccountID(ctx context.Context, accountID string) ([]
return nil, nil
}
- // Select each list using its ID to ensure cache used.
- lists := make([]*gtsmodel.List, 0, len(listIDs))
- for _, id := range listIDs {
- list, err := l.state.DB.GetListByID(ctx, id)
- if err != nil {
- log.Errorf(ctx, "error fetching list %q: %v", id, err)
- continue
- }
- lists = append(lists, list)
- }
-
- return lists, nil
+ // Return lists by their IDs.
+ return l.GetListsByIDs(ctx, listIDs)
}
func (l *listDB) PopulateList(ctx context.Context, list *gtsmodel.List) error {
@@ -147,7 +139,7 @@ func (l *listDB) PopulateList(ctx context.Context, list *gtsmodel.List) error {
}
func (l *listDB) PutList(ctx context.Context, list *gtsmodel.List) error {
- return l.state.Caches.GTS.List().Store(list, func() error {
+ return l.state.Caches.GTS.List.Store(list, func() error {
_, err := l.db.NewInsert().Model(list).Exec(ctx)
return err
})
@@ -162,7 +154,7 @@ func (l *listDB) UpdateList(ctx context.Context, list *gtsmodel.List, columns ..
defer func() {
// Invalidate all entries for this list ID.
- l.state.Caches.GTS.ListEntry().Invalidate("ListID", list.ID)
+ l.state.Caches.GTS.ListEntry.Invalidate("ListID", list.ID)
// Invalidate this entire list's timeline.
if err := l.state.Timelines.List.RemoveTimeline(ctx, list.ID); err != nil {
@@ -170,7 +162,7 @@ func (l *listDB) UpdateList(ctx context.Context, list *gtsmodel.List, columns ..
}
}()
- return l.state.Caches.GTS.List().Store(list, func() error {
+ return l.state.Caches.GTS.List.Store(list, func() error {
_, err := l.db.NewUpdate().
Model(list).
Where("? = ?", bun.Ident("list.id"), list.ID).
@@ -198,7 +190,7 @@ func (l *listDB) DeleteListByID(ctx context.Context, id string) error {
defer func() {
// Invalidate this list from cache.
- l.state.Caches.GTS.List().Invalidate("ID", id)
+ l.state.Caches.GTS.List.Invalidate("ID", id)
// Invalidate this entire list's timeline.
if err := l.state.Timelines.List.RemoveTimeline(ctx, id); err != nil {
@@ -243,7 +235,7 @@ func (l *listDB) GetListEntryByID(ctx context.Context, id string) (*gtsmodel.Lis
}
func (l *listDB) getListEntry(ctx context.Context, lookup string, dbQuery func(*gtsmodel.ListEntry) error, keyParts ...any) (*gtsmodel.ListEntry, error) {
- listEntry, err := l.state.Caches.GTS.ListEntry().Load(lookup, func() (*gtsmodel.ListEntry, error) {
+ listEntry, err := l.state.Caches.GTS.ListEntry.LoadOne(lookup, func() (*gtsmodel.ListEntry, error) {
var listEntry gtsmodel.ListEntry
// Not cached! Perform database query.
@@ -344,18 +336,128 @@ func (l *listDB) GetListEntries(ctx context.Context,
}
}
- // Select each list entry using its ID to ensure cache used.
- listEntries := make([]*gtsmodel.ListEntry, 0, len(entryIDs))
- for _, id := range entryIDs {
- listEntry, err := l.state.DB.GetListEntryByID(ctx, id)
- if err != nil {
- log.Errorf(ctx, "error fetching list entry %q: %v", id, err)
- continue
+ // Return list entries by their IDs.
+ return l.GetListEntriesByIDs(ctx, entryIDs)
+}
+
+func (l *listDB) GetListsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.List, error) {
+ // Preallocate at-worst possible length.
+ uncached := make([]string, 0, len(ids))
+
+ // Load all list IDs via cache loader callbacks.
+ lists, err := l.state.Caches.GTS.List.Load("ID",
+
+ // Load cached + check for uncached.
+ func(load func(keyParts ...any) bool) {
+ for _, id := range ids {
+ if !load(id) {
+ uncached = append(uncached, id)
+ }
+ }
+ },
+
+ // Uncached list loader function.
+ func() ([]*gtsmodel.List, error) {
+ // Preallocate expected length of uncached lists.
+ lists := make([]*gtsmodel.List, 0, len(uncached))
+
+ // Perform database query scanning
+ // the remaining (uncached) IDs.
+ if err := l.db.NewSelect().
+ Model(&lists).
+ Where("? IN (?)", bun.Ident("id"), bun.In(uncached)).
+ Scan(ctx); err != nil {
+ return nil, err
+ }
+
+ return lists, nil
+ },
+ )
+ if err != nil {
+ return nil, err
+ }
+
+ // Reorder the lists by their
+ // IDs to ensure in correct order.
+ getID := func(l *gtsmodel.List) string { return l.ID }
+ util.OrderBy(lists, ids, getID)
+
+ if gtscontext.Barebones(ctx) {
+ // no need to fully populate.
+ return lists, nil
+ }
+
+ // Populate all loaded lists, removing those we fail to
+ // populate (removes needing so many nil checks everywhere).
+ lists = slices.DeleteFunc(lists, func(list *gtsmodel.List) bool {
+ if err := l.PopulateList(ctx, list); err != nil {
+ log.Errorf(ctx, "error populating list %s: %v", list.ID, err)
+ return true
}
- listEntries = append(listEntries, listEntry)
+ return false
+ })
+
+ return lists, nil
+}
+
+func (l *listDB) GetListEntriesByIDs(ctx context.Context, ids []string) ([]*gtsmodel.ListEntry, error) {
+ // Preallocate at-worst possible length.
+ uncached := make([]string, 0, len(ids))
+
+ // Load all entry IDs via cache loader callbacks.
+ entries, err := l.state.Caches.GTS.ListEntry.Load("ID",
+
+ // Load cached + check for uncached.
+ func(load func(keyParts ...any) bool) {
+ for _, id := range ids {
+ if !load(id) {
+ uncached = append(uncached, id)
+ }
+ }
+ },
+
+ // Uncached entry loader function.
+ func() ([]*gtsmodel.ListEntry, error) {
+ // Preallocate expected length of uncached entries.
+ entries := make([]*gtsmodel.ListEntry, 0, len(uncached))
+
+ // Perform database query scanning
+ // the remaining (uncached) IDs.
+ if err := l.db.NewSelect().
+ Model(&entries).
+ Where("? IN (?)", bun.Ident("id"), bun.In(uncached)).
+ Scan(ctx); err != nil {
+ return nil, err
+ }
+
+ return entries, nil
+ },
+ )
+ if err != nil {
+ return nil, err
}
- return listEntries, nil
+ // Reorder the entries by their
+ // IDs to ensure in correct order.
+ getID := func(e *gtsmodel.ListEntry) string { return e.ID }
+ util.OrderBy(entries, ids, getID)
+
+ if gtscontext.Barebones(ctx) {
+ // no need to fully populate.
+ return entries, nil
+ }
+
+ // Populate all loaded entries, removing those we fail to
+ // populate (removes needing so many nil checks everywhere).
+ entries = slices.DeleteFunc(entries, func(entry *gtsmodel.ListEntry) bool {
+ if err := l.PopulateListEntry(ctx, entry); err != nil {
+ log.Errorf(ctx, "error populating entry %s: %v", entry.ID, err)
+ return true
+ }
+ return false
+ })
+
+ return entries, nil
}
func (l *listDB) GetListEntriesForFollowID(ctx context.Context, followID string) ([]*gtsmodel.ListEntry, error) {
@@ -376,18 +478,8 @@ func (l *listDB) GetListEntriesForFollowID(ctx context.Context, followID string)
return nil, nil
}
- // Select each list entry using its ID to ensure cache used.
- listEntries := make([]*gtsmodel.ListEntry, 0, len(entryIDs))
- for _, id := range entryIDs {
- listEntry, err := l.state.DB.GetListEntryByID(ctx, id)
- if err != nil {
- log.Errorf(ctx, "error fetching list entry %q: %v", id, err)
- continue
- }
- listEntries = append(listEntries, listEntry)
- }
-
- return listEntries, nil
+ // Return list entries by their IDs.
+ return l.GetListEntriesByIDs(ctx, entryIDs)
}
func (l *listDB) PopulateListEntry(ctx context.Context, listEntry *gtsmodel.ListEntry) error {
@@ -409,10 +501,10 @@ func (l *listDB) PopulateListEntry(ctx context.Context, listEntry *gtsmodel.List
func (l *listDB) PutListEntries(ctx context.Context, entries []*gtsmodel.ListEntry) error {
defer func() {
- // Collect unique list IDs from the entries.
- listIDs := collate(func(i int) string {
- return entries[i].ListID
- }, len(entries))
+ // Collect unique list IDs from the provided entries.
+ listIDs := util.Collate(entries, func(e *gtsmodel.ListEntry) string {
+ return e.ListID
+ })
for _, id := range listIDs {
// Invalidate the timeline for the list this entry belongs to.
@@ -426,7 +518,7 @@ func (l *listDB) PutListEntries(ctx context.Context, entries []*gtsmodel.ListEnt
return l.db.RunInTx(ctx, func(tx Tx) error {
for _, entry := range entries {
entry := entry // rescope
- if err := l.state.Caches.GTS.ListEntry().Store(entry, func() error {
+ if err := l.state.Caches.GTS.ListEntry.Store(entry, func() error {
_, err := tx.
NewInsert().
Model(entry).
@@ -459,7 +551,7 @@ func (l *listDB) DeleteListEntry(ctx context.Context, id string) error {
defer func() {
// Invalidate this list entry upon delete.
- l.state.Caches.GTS.ListEntry().Invalidate("ID", id)
+ l.state.Caches.GTS.ListEntry.Invalidate("ID", id)
// Invalidate the timeline for the list this entry belongs to.
if err := l.state.Timelines.List.RemoveTimeline(ctx, entry.ListID); err != nil {
@@ -514,24 +606,3 @@ func (l *listDB) ListIncludesAccount(ctx context.Context, listID string, account
return exists, err
}
-
-// collate will collect the values of type T from an expected slice of length 'len',
-// passing the expected index to each call of 'get' and deduplicating the end result.
-func collate[T comparable](get func(int) T, len int) []T {
- ts := make([]T, 0, len)
- tm := make(map[T]struct{}, len)
-
- for i := 0; i < len; i++ {
- // Get next.
- t := get(i)
-
- if _, ok := tm[t]; !ok {
- // New value, add
- // to map + slice.
- ts = append(ts, t)
- tm[t] = struct{}{}
- }
- }
-
- return ts
-}
diff --git a/internal/db/bundb/marker.go b/internal/db/bundb/marker.go
index 5d365e08a..b1dedb4f1 100644
--- a/internal/db/bundb/marker.go
+++ b/internal/db/bundb/marker.go
@@ -39,8 +39,8 @@ type markerDB struct {
*/
func (m *markerDB) GetMarker(ctx context.Context, accountID string, name gtsmodel.MarkerName) (*gtsmodel.Marker, error) {
- marker, err := m.state.Caches.GTS.Marker().Load(
- "AccountID.Name",
+ marker, err := m.state.Caches.GTS.Marker.LoadOne(
+ "AccountID,Name",
func() (*gtsmodel.Marker, error) {
var marker gtsmodel.Marker
@@ -52,9 +52,7 @@ func (m *markerDB) GetMarker(ctx context.Context, accountID string, name gtsmode
}
return &marker, nil
- },
- accountID,
- name,
+ }, accountID, name,
)
if err != nil {
return nil, err // already processed
@@ -74,7 +72,7 @@ func (m *markerDB) UpdateMarker(ctx context.Context, marker *gtsmodel.Marker) er
marker.Version = prevMarker.Version + 1
}
- return m.state.Caches.GTS.Marker().Store(marker, func() error {
+ return m.state.Caches.GTS.Marker.Store(marker, func() error {
if prevMarker == nil {
if _, err := m.db.NewInsert().
Model(marker).
diff --git a/internal/db/bundb/media.go b/internal/db/bundb/media.go
index a2603eacc..ce3c90083 100644
--- a/internal/db/bundb/media.go
+++ b/internal/db/bundb/media.go
@@ -20,14 +20,15 @@ package bundb
import (
"context"
"errors"
+ "slices"
"time"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
- "github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state"
+ "github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun"
)
@@ -51,25 +52,52 @@ func (m *mediaDB) GetAttachmentByID(ctx context.Context, id string) (*gtsmodel.M
}
func (m *mediaDB) GetAttachmentsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.MediaAttachment, error) {
- attachments := make([]*gtsmodel.MediaAttachment, 0, len(ids))
-
- for _, id := range ids {
- // Attempt fetch from DB
- attachment, err := m.GetAttachmentByID(ctx, id)
- if err != nil {
- log.Errorf(ctx, "error getting attachment %q: %v", id, err)
- continue
- }
+ // Preallocate at-worst possible length.
+ uncached := make([]string, 0, len(ids))
+
+ // Load all media IDs via cache loader callbacks.
+ media, err := m.state.Caches.GTS.Media.Load("ID",
+
+ // Load cached + check for uncached.
+ func(load func(keyParts ...any) bool) {
+ for _, id := range ids {
+ if !load(id) {
+ uncached = append(uncached, id)
+ }
+ }
+ },
+
+ // Uncached media loader function.
+ func() ([]*gtsmodel.MediaAttachment, error) {
+ // Preallocate expected length of uncached media attachments.
+ media := make([]*gtsmodel.MediaAttachment, 0, len(uncached))
+
+ // Perform database query scanning
+ // the remaining (uncached) IDs.
+ if err := m.db.NewSelect().
+ Model(&media).
+ Where("? IN (?)", bun.Ident("id"), bun.In(uncached)).
+ Scan(ctx); err != nil {
+ return nil, err
+ }
- // Append attachment
- attachments = append(attachments, attachment)
+ return media, nil
+ },
+ )
+ if err != nil {
+ return nil, err
}
- return attachments, nil
+ // Reorder the media by their
+ // IDs to ensure in correct order.
+ getID := func(m *gtsmodel.MediaAttachment) string { return m.ID }
+ util.OrderBy(media, ids, getID)
+
+ return media, nil
}
func (m *mediaDB) getAttachment(ctx context.Context, lookup string, dbQuery func(*gtsmodel.MediaAttachment) error, keyParts ...any) (*gtsmodel.MediaAttachment, error) {
- return m.state.Caches.GTS.Media().Load(lookup, func() (*gtsmodel.MediaAttachment, error) {
+ return m.state.Caches.GTS.Media.LoadOne(lookup, func() (*gtsmodel.MediaAttachment, error) {
var attachment gtsmodel.MediaAttachment
// Not cached! Perform database query
@@ -82,7 +110,7 @@ func (m *mediaDB) getAttachment(ctx context.Context, lookup string, dbQuery func
}
func (m *mediaDB) PutAttachment(ctx context.Context, media *gtsmodel.MediaAttachment) error {
- return m.state.Caches.GTS.Media().Store(media, func() error {
+ return m.state.Caches.GTS.Media.Store(media, func() error {
_, err := m.db.NewInsert().Model(media).Exec(ctx)
return err
})
@@ -95,7 +123,7 @@ func (m *mediaDB) UpdateAttachment(ctx context.Context, media *gtsmodel.MediaAtt
columns = append(columns, "updated_at")
}
- return m.state.Caches.GTS.Media().Store(media, func() error {
+ return m.state.Caches.GTS.Media.Store(media, func() error {
_, err := m.db.NewUpdate().
Model(media).
Where("? = ?", bun.Ident("media_attachment.id"), media.ID).
@@ -119,7 +147,7 @@ func (m *mediaDB) DeleteAttachment(ctx context.Context, id string) error {
}
// On return, ensure that media with ID is invalidated.
- defer m.state.Caches.GTS.Media().Invalidate("ID", id)
+ defer m.state.Caches.GTS.Media.Invalidate("ID", id)
// Delete media attachment in new transaction.
err = m.db.RunInTx(ctx, func(tx Tx) error {
@@ -171,8 +199,12 @@ func (m *mediaDB) DeleteAttachment(ctx context.Context, id string) error {
return gtserror.Newf("error selecting status: %w", err)
}
- if updatedIDs := dropID(status.AttachmentIDs, id); // nocollapse
- len(updatedIDs) != len(status.AttachmentIDs) {
+ // Delete all instances of this deleted media ID from status attachments.
+ updatedIDs := slices.DeleteFunc(status.AttachmentIDs, func(s string) bool {
+ return s == id
+ })
+
+ if len(updatedIDs) != len(status.AttachmentIDs) {
// Note: this handles not found.
//
// Attachments changed, update the status.
diff --git a/internal/db/bundb/mention.go b/internal/db/bundb/mention.go
index 30a20b0c1..b069423bb 100644
--- a/internal/db/bundb/mention.go
+++ b/internal/db/bundb/mention.go
@@ -20,6 +20,7 @@ package bundb
import (
"context"
"errors"
+ "slices"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
@@ -27,6 +28,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state"
+ "github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun"
)
@@ -36,7 +38,7 @@ type mentionDB struct {
}
func (m *mentionDB) GetMention(ctx context.Context, id string) (*gtsmodel.Mention, error) {
- mention, err := m.state.Caches.GTS.Mention().Load("ID", func() (*gtsmodel.Mention, error) {
+ mention, err := m.state.Caches.GTS.Mention.LoadOne("ID", func() (*gtsmodel.Mention, error) {
var mention gtsmodel.Mention
q := m.db.
@@ -63,21 +65,64 @@ func (m *mentionDB) GetMention(ctx context.Context, id string) (*gtsmodel.Mentio
}
func (m *mentionDB) GetMentions(ctx context.Context, ids []string) ([]*gtsmodel.Mention, error) {
- mentions := make([]*gtsmodel.Mention, 0, len(ids))
+ // Preallocate at-worst possible length.
+ uncached := make([]string, 0, len(ids))
+
+ // Load all mention IDs via cache loader callbacks.
+ mentions, err := m.state.Caches.GTS.Mention.Load("ID",
+
+ // Load cached + check for uncached.
+ func(load func(keyParts ...any) bool) {
+ for _, id := range ids {
+ if !load(id) {
+ uncached = append(uncached, id)
+ }
+ }
+ },
+
+ // Uncached mention loader function.
+ func() ([]*gtsmodel.Mention, error) {
+ // Preallocate expected length of uncached mentions.
+ mentions := make([]*gtsmodel.Mention, 0, len(uncached))
+
+ // Perform database query scanning
+ // the remaining (uncached) IDs.
+ if err := m.db.NewSelect().
+ Model(&mentions).
+ Where("? IN (?)", bun.Ident("id"), bun.In(uncached)).
+ Scan(ctx); err != nil {
+ return nil, err
+ }
+
+ return mentions, nil
+ },
+ )
+ if err != nil {
+ return nil, err
+ }
- for _, id := range ids {
- // Attempt fetch from DB
- mention, err := m.GetMention(ctx, id)
- if err != nil {
- log.Errorf(ctx, "error getting mention %q: %v", id, err)
- continue
- }
+ // Reorder the mentions by their
+ // IDs to ensure in correct order.
+ getID := func(m *gtsmodel.Mention) string { return m.ID }
+ util.OrderBy(mentions, ids, getID)
- // Append mention
- mentions = append(mentions, mention)
+ if gtscontext.Barebones(ctx) {
+ // no need to fully populate.
+ return mentions, nil
}
+ // Populate all loaded mentions, removing those we fail to
+ // populate (removes needing so many nil checks everywhere).
+ mentions = slices.DeleteFunc(mentions, func(mention *gtsmodel.Mention) bool {
+ if err := m.PopulateMention(ctx, mention); err != nil {
+ log.Errorf(ctx, "error populating mention %s: %v", mention.ID, err)
+ return true
+ }
+ return false
+ })
+
return mentions, nil
+
}
func (m *mentionDB) PopulateMention(ctx context.Context, mention *gtsmodel.Mention) (err error) {
@@ -120,14 +165,14 @@ func (m *mentionDB) PopulateMention(ctx context.Context, mention *gtsmodel.Menti
}
func (m *mentionDB) PutMention(ctx context.Context, mention *gtsmodel.Mention) error {
- return m.state.Caches.GTS.Mention().Store(mention, func() error {
+ return m.state.Caches.GTS.Mention.Store(mention, func() error {
_, err := m.db.NewInsert().Model(mention).Exec(ctx)
return err
})
}
func (m *mentionDB) DeleteMentionByID(ctx context.Context, id string) error {
- defer m.state.Caches.GTS.Mention().Invalidate("ID", id)
+ defer m.state.Caches.GTS.Mention.Invalidate("ID", id)
// Load mention into cache before attempting a delete,
// as we need it cached in order to trigger the invalidate
diff --git a/internal/db/bundb/notification.go b/internal/db/bundb/notification.go
index 7532b9993..ed34222fb 100644
--- a/internal/db/bundb/notification.go
+++ b/internal/db/bundb/notification.go
@@ -20,6 +20,7 @@ package bundb
import (
"context"
"errors"
+ "slices"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
@@ -28,6 +29,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state"
+ "github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun"
)
@@ -37,18 +39,17 @@ type notificationDB struct {
}
func (n *notificationDB) GetNotificationByID(ctx context.Context, id string) (*gtsmodel.Notification, error) {
- return n.state.Caches.GTS.Notification().Load("ID", func() (*gtsmodel.Notification, error) {
- var notif gtsmodel.Notification
-
- q := n.db.NewSelect().
- Model(&notif).
- Where("? = ?", bun.Ident("notification.id"), id)
- if err := q.Scan(ctx); err != nil {
- return nil, err
- }
-
- return &notif, nil
- }, id)
+ return n.getNotification(
+ ctx,
+ "ID",
+ func(notif *gtsmodel.Notification) error {
+ return n.db.NewSelect().
+ Model(notif).
+ Where("? = ?", bun.Ident("id"), id).
+ Scan(ctx)
+ },
+ id,
+ )
}
func (n *notificationDB) GetNotification(
@@ -58,42 +59,113 @@ func (n *notificationDB) GetNotification(
originAccountID string,
statusID string,
) (*gtsmodel.Notification, error) {
- notif, err := n.state.Caches.GTS.Notification().Load("NotificationType.TargetAccountID.OriginAccountID.StatusID", func() (*gtsmodel.Notification, error) {
- var notif gtsmodel.Notification
+ return n.getNotification(
+ ctx,
+ "NotificationType,TargetAccountID,OriginAccountID,StatusID",
+ func(notif *gtsmodel.Notification) error {
+ return n.db.NewSelect().
+ Model(notif).
+ Where("? = ?", bun.Ident("notification_type"), notificationType).
+ Where("? = ?", bun.Ident("target_account_id"), targetAccountID).
+ Where("? = ?", bun.Ident("origin_account_id"), originAccountID).
+ Where("? = ?", bun.Ident("status_id"), statusID).
+ Scan(ctx)
+ },
+ notificationType, targetAccountID, originAccountID, statusID,
+ )
+}
- q := n.db.NewSelect().
- Model(&notif).
- Where("? = ?", bun.Ident("notification_type"), notificationType).
- Where("? = ?", bun.Ident("target_account_id"), targetAccountID).
- Where("? = ?", bun.Ident("origin_account_id"), originAccountID).
- Where("? = ?", bun.Ident("status_id"), statusID)
+func (n *notificationDB) getNotification(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Notification) error, keyParts ...any) (*gtsmodel.Notification, error) {
+ // Fetch notification from cache with loader callback
+ notif, err := n.state.Caches.GTS.Notification.LoadOne(lookup, func() (*gtsmodel.Notification, error) {
+ var notif gtsmodel.Notification
- if err := q.Scan(ctx); err != nil {
+ // Not cached! Perform database query
+ if err := dbQuery(&notif); err != nil {
return nil, err
}
return &notif, nil
- }, notificationType, targetAccountID, originAccountID, statusID)
+ }, keyParts...)
if err != nil {
return nil, err
}
if gtscontext.Barebones(ctx) {
- // no need to fully populate.
+ // Only a barebones model was requested.
return notif, nil
}
- // Further populate the notif fields where applicable.
- if err := n.PopulateNotification(ctx, notif); err != nil {
+ if err := n.state.DB.PopulateNotification(ctx, notif); err != nil {
return nil, err
}
return notif, nil
}
+func (n *notificationDB) GetNotificationsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.Notification, error) {
+ // Preallocate at-worst possible length.
+ uncached := make([]string, 0, len(ids))
+
+ // Load all notif IDs via cache loader callbacks.
+ notifs, err := n.state.Caches.GTS.Notification.Load("ID",
+
+ // Load cached + check for uncached.
+ func(load func(keyParts ...any) bool) {
+ for _, id := range ids {
+ if !load(id) {
+ uncached = append(uncached, id)
+ }
+ }
+ },
+
+ // Uncached notification loader function.
+ func() ([]*gtsmodel.Notification, error) {
+ // Preallocate expected length of uncached notifications.
+ notifs := make([]*gtsmodel.Notification, 0, len(uncached))
+
+ // Perform database query scanning
+ // the remaining (uncached) IDs.
+ if err := n.db.NewSelect().
+ Model(&notifs).
+ Where("? IN (?)", bun.Ident("id"), bun.In(uncached)).
+ Scan(ctx); err != nil {
+ return nil, err
+ }
+
+ return notifs, nil
+ },
+ )
+ if err != nil {
+ return nil, err
+ }
+
+ // Reorder the notifs by their
+ // IDs to ensure in correct order.
+ getID := func(n *gtsmodel.Notification) string { return n.ID }
+ util.OrderBy(notifs, ids, getID)
+
+ if gtscontext.Barebones(ctx) {
+ // no need to fully populate.
+ return notifs, nil
+ }
+
+ // Populate all loaded notifs, removing those we fail to
+ // populate (removes needing so many nil checks everywhere).
+ notifs = slices.DeleteFunc(notifs, func(notif *gtsmodel.Notification) bool {
+ if err := n.PopulateNotification(ctx, notif); err != nil {
+ log.Errorf(ctx, "error populating notif %s: %v", notif.ID, err)
+ return true
+ }
+ return false
+ })
+
+ return notifs, nil
+}
+
func (n *notificationDB) PopulateNotification(ctx context.Context, notif *gtsmodel.Notification) error {
var (
- errs = gtserror.NewMultiError(2)
+ errs gtserror.MultiError
err error
)
@@ -211,31 +283,19 @@ func (n *notificationDB) GetAccountNotifications(
}
}
- notifs := make([]*gtsmodel.Notification, 0, len(notifIDs))
- for _, id := range notifIDs {
- // Attempt fetch from DB
- notif, err := n.GetNotificationByID(ctx, id)
- if err != nil {
- log.Errorf(ctx, "error fetching notification %q: %v", id, err)
- continue
- }
-
- // Append notification
- notifs = append(notifs, notif)
- }
-
- return notifs, nil
+ // Fetch notification models by their IDs.
+ return n.GetNotificationsByIDs(ctx, notifIDs)
}
func (n *notificationDB) PutNotification(ctx context.Context, notif *gtsmodel.Notification) error {
- return n.state.Caches.GTS.Notification().Store(notif, func() error {
+ return n.state.Caches.GTS.Notification.Store(notif, func() error {
_, err := n.db.NewInsert().Model(notif).Exec(ctx)
return err
})
}
func (n *notificationDB) DeleteNotificationByID(ctx context.Context, id string) error {
- defer n.state.Caches.GTS.Notification().Invalidate("ID", id)
+ defer n.state.Caches.GTS.Notification.Invalidate("ID", id)
// Load notif into cache before attempting a delete,
// as we need it cached in order to trigger the invalidate
@@ -288,7 +348,7 @@ func (n *notificationDB) DeleteNotifications(ctx context.Context, types []string
defer func() {
// Invalidate all IDs on return.
for _, id := range notifIDs {
- n.state.Caches.GTS.Notification().Invalidate("ID", id)
+ n.state.Caches.GTS.Notification.Invalidate("ID", id)
}
}()
@@ -326,7 +386,7 @@ func (n *notificationDB) DeleteNotificationsForStatus(ctx context.Context, statu
defer func() {
// Invalidate all IDs on return.
for _, id := range notifIDs {
- n.state.Caches.GTS.Notification().Invalidate("ID", id)
+ n.state.Caches.GTS.Notification.Invalidate("ID", id)
}
}()
diff --git a/internal/db/bundb/poll.go b/internal/db/bundb/poll.go
index 3e77fb6c5..0dfb15621 100644
--- a/internal/db/bundb/poll.go
+++ b/internal/db/bundb/poll.go
@@ -20,6 +20,7 @@ package bundb
import (
"context"
"errors"
+ "slices"
"time"
"github.com/superseriousbusiness/gotosocial/internal/db"
@@ -28,6 +29,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state"
+ "github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun"
)
@@ -52,7 +54,7 @@ func (p *pollDB) GetPollByID(ctx context.Context, id string) (*gtsmodel.Poll, er
func (p *pollDB) getPoll(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Poll) error, keyParts ...any) (*gtsmodel.Poll, error) {
// Fetch poll from database cache with loader callback
- poll, err := p.state.Caches.GTS.Poll().Load(lookup, func() (*gtsmodel.Poll, error) {
+ poll, err := p.state.Caches.GTS.Poll.LoadOne(lookup, func() (*gtsmodel.Poll, error) {
var poll gtsmodel.Poll
// Not cached! Perform database query.
@@ -140,7 +142,7 @@ func (p *pollDB) PutPoll(ctx context.Context, poll *gtsmodel.Poll) error {
// is non nil and set.
poll.CheckVotes()
- return p.state.Caches.GTS.Poll().Store(poll, func() error {
+ return p.state.Caches.GTS.Poll.Store(poll, func() error {
_, err := p.db.NewInsert().Model(poll).Exec(ctx)
return err
})
@@ -151,7 +153,7 @@ func (p *pollDB) UpdatePoll(ctx context.Context, poll *gtsmodel.Poll, cols ...st
// is non nil and set.
poll.CheckVotes()
- return p.state.Caches.GTS.Poll().Store(poll, func() error {
+ return p.state.Caches.GTS.Poll.Store(poll, func() error {
return p.db.RunInTx(ctx, func(tx Tx) error {
// Update the status' "updated_at" field.
if _, err := tx.NewUpdate().
@@ -184,8 +186,8 @@ func (p *pollDB) DeletePollByID(ctx context.Context, id string) error {
}
// Invalidate poll by ID from cache.
- p.state.Caches.GTS.Poll().Invalidate("ID", id)
- p.state.Caches.GTS.PollVoteIDs().Invalidate(id)
+ p.state.Caches.GTS.Poll.Invalidate("ID", id)
+ p.state.Caches.GTS.PollVoteIDs.Invalidate(id)
return nil
}
@@ -207,7 +209,7 @@ func (p *pollDB) GetPollVoteByID(ctx context.Context, id string) (*gtsmodel.Poll
func (p *pollDB) GetPollVoteBy(ctx context.Context, pollID string, accountID string) (*gtsmodel.PollVote, error) {
return p.getPollVote(
ctx,
- "PollID.AccountID",
+ "PollID,AccountID",
func(vote *gtsmodel.PollVote) error {
return p.db.NewSelect().
Model(vote).
@@ -222,7 +224,7 @@ func (p *pollDB) GetPollVoteBy(ctx context.Context, pollID string, accountID str
func (p *pollDB) getPollVote(ctx context.Context, lookup string, dbQuery func(*gtsmodel.PollVote) error, keyParts ...any) (*gtsmodel.PollVote, error) {
// Fetch vote from database cache with loader callback
- vote, err := p.state.Caches.GTS.PollVote().Load(lookup, func() (*gtsmodel.PollVote, error) {
+ vote, err := p.state.Caches.GTS.PollVote.LoadOne(lookup, func() (*gtsmodel.PollVote, error) {
var vote gtsmodel.PollVote
// Not cached! Perform database query.
@@ -250,7 +252,9 @@ func (p *pollDB) getPollVote(ctx context.Context, lookup string, dbQuery func(*g
}
func (p *pollDB) GetPollVotes(ctx context.Context, pollID string) ([]*gtsmodel.PollVote, error) {
- voteIDs, err := p.state.Caches.GTS.PollVoteIDs().Load(pollID, func() ([]string, error) {
+
+ // Load vote IDs known for given poll ID using loader callback.
+ voteIDs, err := p.state.Caches.GTS.PollVoteIDs.Load(pollID, func() ([]string, error) {
var voteIDs []string
// Vote IDs not in cache, perform DB query!
@@ -266,21 +270,62 @@ func (p *pollDB) GetPollVotes(ctx context.Context, pollID string) ([]*gtsmodel.P
return nil, err
}
- // Preallocate slice of expected length.
- votes := make([]*gtsmodel.PollVote, 0, len(voteIDs))
+ // Preallocate at-worst possible length.
+ uncached := make([]string, 0, len(voteIDs))
- for _, id := range voteIDs {
- // Fetch poll vote model for this ID.
- vote, err := p.GetPollVoteByID(ctx, id)
- if err != nil {
- log.Errorf(ctx, "error getting poll vote %s: %v", id, err)
- continue
- }
+ // Load all votes from IDs via cache loader callbacks.
+ votes, err := p.state.Caches.GTS.PollVote.Load("ID",
+
+ // Load cached + check for uncached.
+ func(load func(keyParts ...any) bool) {
+ for _, id := range voteIDs {
+ if !load(id) {
+ uncached = append(uncached, id)
+ }
+ }
+ },
+
+ // Uncached poll vote loader function.
+ func() ([]*gtsmodel.PollVote, error) {
+ // Preallocate expected length of uncached votes.
+ votes := make([]*gtsmodel.PollVote, 0, len(uncached))
+
+ // Perform database query scanning
+ // the remaining (uncached) IDs.
+ if err := p.db.NewSelect().
+ Model(&votes).
+ Where("? IN (?)", bun.Ident("id"), bun.In(uncached)).
+ Scan(ctx); err != nil {
+ return nil, err
+ }
+
+ return votes, nil
+ },
+ )
+ if err != nil {
+ return nil, err
+ }
+
+ // Reorder the poll votes by their
+ // IDs to ensure in correct order.
+ getID := func(v *gtsmodel.PollVote) string { return v.ID }
+ util.OrderBy(votes, voteIDs, getID)
- // Append to return slice.
- votes = append(votes, vote)
+ if gtscontext.Barebones(ctx) {
+ // no need to fully populate.
+ return votes, nil
}
+ // Populate all loaded votes, removing those we fail to
+ // populate (removes needing so many nil checks everywhere).
+ votes = slices.DeleteFunc(votes, func(vote *gtsmodel.PollVote) bool {
+ if err := p.PopulatePollVote(ctx, vote); err != nil {
+ log.Errorf(ctx, "error populating vote %s: %v", vote.ID, err)
+ return true
+ }
+ return false
+ })
+
return votes, nil
}
@@ -316,7 +361,7 @@ func (p *pollDB) PopulatePollVote(ctx context.Context, vote *gtsmodel.PollVote)
}
func (p *pollDB) PutPollVote(ctx context.Context, vote *gtsmodel.PollVote) error {
- return p.state.Caches.GTS.PollVote().Store(vote, func() error {
+ return p.state.Caches.GTS.PollVote.Store(vote, func() error {
return p.db.RunInTx(ctx, func(tx Tx) error {
// Try insert vote into database.
if _, err := tx.NewInsert().
@@ -416,9 +461,9 @@ func (p *pollDB) DeletePollVotes(ctx context.Context, pollID string) error {
}
// Invalidate poll vote and poll entry from caches.
- p.state.Caches.GTS.Poll().Invalidate("ID", pollID)
- p.state.Caches.GTS.PollVote().Invalidate("PollID", pollID)
- p.state.Caches.GTS.PollVoteIDs().Invalidate(pollID)
+ p.state.Caches.GTS.Poll.Invalidate("ID", pollID)
+ p.state.Caches.GTS.PollVote.Invalidate("PollID", pollID)
+ p.state.Caches.GTS.PollVoteIDs.Invalidate(pollID)
return nil
}
@@ -428,7 +473,7 @@ func (p *pollDB) DeletePollVoteBy(ctx context.Context, pollID string, accountID
// Slice should only ever be of length
// 0 or 1; it's a slice of slices only
// because we can't LIMIT deletes to 1.
- var choicesSl [][]int
+ var choicesSlice [][]int
// Delete vote in poll by account,
// returning the ID + choices of the vote.
@@ -437,17 +482,19 @@ func (p *pollDB) DeletePollVoteBy(ctx context.Context, pollID string, accountID
Where("? = ?", bun.Ident("poll_id"), pollID).
Where("? = ?", bun.Ident("account_id"), accountID).
Returning("?", bun.Ident("choices")).
- Scan(ctx, &choicesSl); err != nil {
+ Scan(ctx, &choicesSlice); err != nil {
// irrecoverable.
return err
}
- if len(choicesSl) != 1 {
+ if len(choicesSlice) != 1 {
// No poll votes by this
// acct on this poll.
return nil
}
- choices := choicesSl[0]
+
+ // Extract the *actual* choices.
+ choices := choicesSlice[0]
// Select current poll counts from DB,
// taking minimal columns needed to
@@ -489,9 +536,9 @@ func (p *pollDB) DeletePollVoteBy(ctx context.Context, pollID string, accountID
}
// Invalidate poll vote and poll entry from caches.
- p.state.Caches.GTS.Poll().Invalidate("ID", pollID)
- p.state.Caches.GTS.PollVote().Invalidate("PollID.AccountID", pollID, accountID)
- p.state.Caches.GTS.PollVoteIDs().Invalidate(pollID)
+ p.state.Caches.GTS.Poll.Invalidate("ID", pollID)
+ p.state.Caches.GTS.PollVote.Invalidate("PollID,AccountID", pollID, accountID)
+ p.state.Caches.GTS.PollVoteIDs.Invalidate(pollID)
return nil
}
diff --git a/internal/db/bundb/relationship.go b/internal/db/bundb/relationship.go
index 138a5aa17..4c50862a1 100644
--- a/internal/db/bundb/relationship.go
+++ b/internal/db/bundb/relationship.go
@@ -194,7 +194,7 @@ func (r *relationshipDB) CountAccountBlocks(ctx context.Context, accountID strin
}
func (r *relationshipDB) getAccountFollowIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) {
- return loadPagedIDs(r.state.Caches.GTS.FollowIDs(), ">"+accountID, page, func() ([]string, error) {
+ return loadPagedIDs(r.state.Caches.GTS.FollowIDs, ">"+accountID, page, func() ([]string, error) {
var followIDs []string
// Follow IDs not in cache, perform DB query!
@@ -209,7 +209,7 @@ func (r *relationshipDB) getAccountFollowIDs(ctx context.Context, accountID stri
}
func (r *relationshipDB) getAccountLocalFollowIDs(ctx context.Context, accountID string) ([]string, error) {
- return r.state.Caches.GTS.FollowIDs().Load("l>"+accountID, func() ([]string, error) {
+ return r.state.Caches.GTS.FollowIDs.Load("l>"+accountID, func() ([]string, error) {
var followIDs []string
// Follow IDs not in cache, perform DB query!
@@ -224,7 +224,7 @@ func (r *relationshipDB) getAccountLocalFollowIDs(ctx context.Context, accountID
}
func (r *relationshipDB) getAccountFollowerIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) {
- return loadPagedIDs(r.state.Caches.GTS.FollowIDs(), "<"+accountID, page, func() ([]string, error) {
+ return loadPagedIDs(r.state.Caches.GTS.FollowIDs, "<"+accountID, page, func() ([]string, error) {
var followIDs []string
// Follow IDs not in cache, perform DB query!
@@ -239,7 +239,7 @@ func (r *relationshipDB) getAccountFollowerIDs(ctx context.Context, accountID st
}
func (r *relationshipDB) getAccountLocalFollowerIDs(ctx context.Context, accountID string) ([]string, error) {
- return r.state.Caches.GTS.FollowIDs().Load("l<"+accountID, func() ([]string, error) {
+ return r.state.Caches.GTS.FollowIDs.Load("l<"+accountID, func() ([]string, error) {
var followIDs []string
// Follow IDs not in cache, perform DB query!
@@ -254,7 +254,7 @@ func (r *relationshipDB) getAccountLocalFollowerIDs(ctx context.Context, account
}
func (r *relationshipDB) getAccountFollowRequestIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) {
- return loadPagedIDs(r.state.Caches.GTS.FollowRequestIDs(), ">"+accountID, page, func() ([]string, error) {
+ return loadPagedIDs(r.state.Caches.GTS.FollowRequestIDs, ">"+accountID, page, func() ([]string, error) {
var followReqIDs []string
// Follow request IDs not in cache, perform DB query!
@@ -269,7 +269,7 @@ func (r *relationshipDB) getAccountFollowRequestIDs(ctx context.Context, account
}
func (r *relationshipDB) getAccountFollowRequestingIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) {
- return loadPagedIDs(r.state.Caches.GTS.FollowRequestIDs(), "<"+accountID, page, func() ([]string, error) {
+ return loadPagedIDs(r.state.Caches.GTS.FollowRequestIDs, "<"+accountID, page, func() ([]string, error) {
var followReqIDs []string
// Follow request IDs not in cache, perform DB query!
@@ -284,7 +284,7 @@ func (r *relationshipDB) getAccountFollowRequestingIDs(ctx context.Context, acco
}
func (r *relationshipDB) getAccountBlockIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) {
- return loadPagedIDs(r.state.Caches.GTS.BlockIDs(), accountID, page, func() ([]string, error) {
+ return loadPagedIDs(r.state.Caches.GTS.BlockIDs, accountID, page, func() ([]string, error) {
var blockIDs []string
// Block IDs not in cache, perform DB query!
diff --git a/internal/db/bundb/relationship_block.go b/internal/db/bundb/relationship_block.go
index efaa6d1a9..178de6aa7 100644
--- a/internal/db/bundb/relationship_block.go
+++ b/internal/db/bundb/relationship_block.go
@@ -20,12 +20,14 @@ package bundb
import (
"context"
"errors"
+ "slices"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
+ "github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun"
)
@@ -86,7 +88,7 @@ func (r *relationshipDB) GetBlockByURI(ctx context.Context, uri string) (*gtsmod
func (r *relationshipDB) GetBlock(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.Block, error) {
return r.getBlock(
ctx,
- "AccountID.TargetAccountID",
+ "AccountID,TargetAccountID",
func(block *gtsmodel.Block) error {
return r.db.NewSelect().Model(block).
Where("? = ?", bun.Ident("block.account_id"), sourceAccountID).
@@ -99,27 +101,68 @@ func (r *relationshipDB) GetBlock(ctx context.Context, sourceAccountID string, t
}
func (r *relationshipDB) GetBlocksByIDs(ctx context.Context, ids []string) ([]*gtsmodel.Block, error) {
- // Preallocate slice of expected length.
- blocks := make([]*gtsmodel.Block, 0, len(ids))
+ // Preallocate at-worst possible length.
+ uncached := make([]string, 0, len(ids))
+
+ // Load all blocks IDs via cache loader callbacks.
+ blocks, err := r.state.Caches.GTS.Block.Load("ID",
+
+ // Load cached + check for uncached.
+ func(load func(keyParts ...any) bool) {
+ for _, id := range ids {
+ if !load(id) {
+ uncached = append(uncached, id)
+ }
+ }
+ },
- for _, id := range ids {
- // Fetch block model for this ID.
- block, err := r.GetBlockByID(ctx, id)
- if err != nil {
- log.Errorf(ctx, "error getting block %q: %v", id, err)
- continue
- }
+ // Uncached block loader function.
+ func() ([]*gtsmodel.Block, error) {
+ // Preallocate expected length of uncached blocks.
+ blocks := make([]*gtsmodel.Block, 0, len(uncached))
+
+ // Perform database query scanning
+ // the remaining (uncached) IDs.
+ if err := r.db.NewSelect().
+ Model(&blocks).
+ Where("? IN (?)", bun.Ident("id"), bun.In(uncached)).
+ Scan(ctx); err != nil {
+ return nil, err
+ }
+
+ return blocks, nil
+ },
+ )
+ if err != nil {
+ return nil, err
+ }
+
+ // Reorder the blocks by their
+ // IDs to ensure in correct order.
+ getID := func(b *gtsmodel.Block) string { return b.ID }
+ util.OrderBy(blocks, ids, getID)
- // Append to return slice.
- blocks = append(blocks, block)
+ if gtscontext.Barebones(ctx) {
+ // no need to fully populate.
+ return blocks, nil
}
+ // Populate all loaded blocks, removing those we fail to
+ // populate (removes needing so many nil checks everywhere).
+ blocks = slices.DeleteFunc(blocks, func(block *gtsmodel.Block) bool {
+ if err := r.PopulateBlock(ctx, block); err != nil {
+ log.Errorf(ctx, "error populating block %s: %v", block.ID, err)
+ return true
+ }
+ return false
+ })
+
return blocks, nil
}
func (r *relationshipDB) getBlock(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Block) error, keyParts ...any) (*gtsmodel.Block, error) {
// Fetch block from cache with loader callback
- block, err := r.state.Caches.GTS.Block().Load(lookup, func() (*gtsmodel.Block, error) {
+ block, err := r.state.Caches.GTS.Block.LoadOne(lookup, func() (*gtsmodel.Block, error) {
var block gtsmodel.Block
// Not cached! Perform database query
@@ -148,8 +191,8 @@ func (r *relationshipDB) getBlock(ctx context.Context, lookup string, dbQuery fu
func (r *relationshipDB) PopulateBlock(ctx context.Context, block *gtsmodel.Block) error {
var (
+ errs gtserror.MultiError
err error
- errs = gtserror.NewMultiError(2)
)
if block.Account == nil {
@@ -178,7 +221,7 @@ func (r *relationshipDB) PopulateBlock(ctx context.Context, block *gtsmodel.Bloc
}
func (r *relationshipDB) PutBlock(ctx context.Context, block *gtsmodel.Block) error {
- return r.state.Caches.GTS.Block().Store(block, func() error {
+ return r.state.Caches.GTS.Block.Store(block, func() error {
_, err := r.db.NewInsert().Model(block).Exec(ctx)
return err
})
@@ -198,7 +241,7 @@ func (r *relationshipDB) DeleteBlockByID(ctx context.Context, id string) error {
}
// Drop this now-cached block on return after delete.
- defer r.state.Caches.GTS.Block().Invalidate("ID", id)
+ defer r.state.Caches.GTS.Block.Invalidate("ID", id)
// Finally delete block from DB.
_, err = r.db.NewDelete().
@@ -222,7 +265,7 @@ func (r *relationshipDB) DeleteBlockByURI(ctx context.Context, uri string) error
}
// Drop this now-cached block on return after delete.
- defer r.state.Caches.GTS.Block().Invalidate("URI", uri)
+ defer r.state.Caches.GTS.Block.Invalidate("URI", uri)
// Finally delete block from DB.
_, err = r.db.NewDelete().
@@ -251,22 +294,20 @@ func (r *relationshipDB) DeleteAccountBlocks(ctx context.Context, accountID stri
defer func() {
// Invalidate all account's incoming / outoing blocks on return.
- r.state.Caches.GTS.Block().Invalidate("AccountID", accountID)
- r.state.Caches.GTS.Block().Invalidate("TargetAccountID", accountID)
+ r.state.Caches.GTS.Block.Invalidate("AccountID", accountID)
+ r.state.Caches.GTS.Block.Invalidate("TargetAccountID", accountID)
}()
// Load all blocks into cache, this *really* isn't great
// but it is the only way we can ensure we invalidate all
// related caches correctly (e.g. visibility).
- for _, id := range blockIDs {
- _, err := r.GetBlockByID(ctx, id)
- if err != nil && !errors.Is(err, db.ErrNoEntries) {
- return err
- }
+ _, err := r.GetAccountBlocks(ctx, accountID, nil)
+ if err != nil && !errors.Is(err, db.ErrNoEntries) {
+ return err
}
// Finally delete all from DB.
- _, err := r.db.NewDelete().
+ _, err = r.db.NewDelete().
Table("blocks").
Where("? IN (?)", bun.Ident("id"), bun.In(blockIDs)).
Exec(ctx)
diff --git a/internal/db/bundb/relationship_follow.go b/internal/db/bundb/relationship_follow.go
index 6c5a75e4c..93ee69bd7 100644
--- a/internal/db/bundb/relationship_follow.go
+++ b/internal/db/bundb/relationship_follow.go
@@ -21,6 +21,7 @@ import (
"context"
"errors"
"fmt"
+ "slices"
"time"
"github.com/superseriousbusiness/gotosocial/internal/db"
@@ -28,6 +29,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
+ "github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun"
)
@@ -62,7 +64,7 @@ func (r *relationshipDB) GetFollowByURI(ctx context.Context, uri string) (*gtsmo
func (r *relationshipDB) GetFollow(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.Follow, error) {
return r.getFollow(
ctx,
- "AccountID.TargetAccountID",
+ "AccountID,TargetAccountID",
func(follow *gtsmodel.Follow) error {
return r.db.NewSelect().
Model(follow).
@@ -76,21 +78,62 @@ func (r *relationshipDB) GetFollow(ctx context.Context, sourceAccountID string,
}
func (r *relationshipDB) GetFollowsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.Follow, error) {
- // Preallocate slice of expected length.
- follows := make([]*gtsmodel.Follow, 0, len(ids))
+ // Preallocate at-worst possible length.
+ uncached := make([]string, 0, len(ids))
+
+ // Load all follow IDs via cache loader callbacks.
+ follows, err := r.state.Caches.GTS.Follow.Load("ID",
+
+ // Load cached + check for uncached.
+ func(load func(keyParts ...any) bool) {
+ for _, id := range ids {
+ if !load(id) {
+ uncached = append(uncached, id)
+ }
+ }
+ },
- for _, id := range ids {
- // Fetch follow model for this ID.
- follow, err := r.GetFollowByID(ctx, id)
- if err != nil {
- log.Errorf(ctx, "error getting follow %q: %v", id, err)
- continue
- }
+ // Uncached follow loader function.
+ func() ([]*gtsmodel.Follow, error) {
+ // Preallocate expected length of uncached follows.
+ follows := make([]*gtsmodel.Follow, 0, len(uncached))
+
+ // Perform database query scanning
+ // the remaining (uncached) IDs.
+ if err := r.db.NewSelect().
+ Model(&follows).
+ Where("? IN (?)", bun.Ident("id"), bun.In(uncached)).
+ Scan(ctx); err != nil {
+ return nil, err
+ }
+
+ return follows, nil
+ },
+ )
+ if err != nil {
+ return nil, err
+ }
+
+ // Reorder the follows by their
+ // IDs to ensure in correct order.
+ getID := func(f *gtsmodel.Follow) string { return f.ID }
+ util.OrderBy(follows, ids, getID)
- // Append to return slice.
- follows = append(follows, follow)
+ if gtscontext.Barebones(ctx) {
+ // no need to fully populate.
+ return follows, nil
}
+ // Populate all loaded follows, removing those we fail to
+ // populate (removes needing so many nil checks everywhere).
+ follows = slices.DeleteFunc(follows, func(follow *gtsmodel.Follow) bool {
+ if err := r.PopulateFollow(ctx, follow); err != nil {
+ log.Errorf(ctx, "error populating follow %s: %v", follow.ID, err)
+ return true
+ }
+ return false
+ })
+
return follows, nil
}
@@ -130,7 +173,7 @@ func (r *relationshipDB) IsMutualFollowing(ctx context.Context, accountID1 strin
func (r *relationshipDB) getFollow(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Follow) error, keyParts ...any) (*gtsmodel.Follow, error) {
// Fetch follow from database cache with loader callback
- follow, err := r.state.Caches.GTS.Follow().Load(lookup, func() (*gtsmodel.Follow, error) {
+ follow, err := r.state.Caches.GTS.Follow.LoadOne(lookup, func() (*gtsmodel.Follow, error) {
var follow gtsmodel.Follow
// Not cached! Perform database query
@@ -189,7 +232,7 @@ func (r *relationshipDB) PopulateFollow(ctx context.Context, follow *gtsmodel.Fo
}
func (r *relationshipDB) PutFollow(ctx context.Context, follow *gtsmodel.Follow) error {
- return r.state.Caches.GTS.Follow().Store(follow, func() error {
+ return r.state.Caches.GTS.Follow.Store(follow, func() error {
_, err := r.db.NewInsert().Model(follow).Exec(ctx)
return err
})
@@ -202,7 +245,7 @@ func (r *relationshipDB) UpdateFollow(ctx context.Context, follow *gtsmodel.Foll
columns = append(columns, "updated_at")
}
- return r.state.Caches.GTS.Follow().Store(follow, func() error {
+ return r.state.Caches.GTS.Follow.Store(follow, func() error {
if _, err := r.db.NewUpdate().
Model(follow).
Where("? = ?", bun.Ident("follow.id"), follow.ID).
@@ -250,7 +293,7 @@ func (r *relationshipDB) DeleteFollow(ctx context.Context, sourceAccountID strin
}
// Drop this now-cached follow on return after delete.
- defer r.state.Caches.GTS.Follow().Invalidate("AccountID.TargetAccountID", sourceAccountID, targetAccountID)
+ defer r.state.Caches.GTS.Follow.Invalidate("AccountID,TargetAccountID", sourceAccountID, targetAccountID)
// Finally delete follow from DB.
return r.deleteFollow(ctx, follow.ID)
@@ -270,7 +313,7 @@ func (r *relationshipDB) DeleteFollowByID(ctx context.Context, id string) error
}
// Drop this now-cached follow on return after delete.
- defer r.state.Caches.GTS.Follow().Invalidate("ID", id)
+ defer r.state.Caches.GTS.Follow.Invalidate("ID", id)
// Finally delete follow from DB.
return r.deleteFollow(ctx, follow.ID)
@@ -290,7 +333,7 @@ func (r *relationshipDB) DeleteFollowByURI(ctx context.Context, uri string) erro
}
// Drop this now-cached follow on return after delete.
- defer r.state.Caches.GTS.Follow().Invalidate("URI", uri)
+ defer r.state.Caches.GTS.Follow.Invalidate("URI", uri)
// Finally delete follow from DB.
return r.deleteFollow(ctx, follow.ID)
@@ -316,22 +359,30 @@ func (r *relationshipDB) DeleteAccountFollows(ctx context.Context, accountID str
defer func() {
// Invalidate all account's incoming / outoing follows on return.
- r.state.Caches.GTS.Follow().Invalidate("AccountID", accountID)
- r.state.Caches.GTS.Follow().Invalidate("TargetAccountID", accountID)
+ r.state.Caches.GTS.Follow.Invalidate("AccountID", accountID)
+ r.state.Caches.GTS.Follow.Invalidate("TargetAccountID", accountID)
}()
// Load all follows into cache, this *really* isn't great
// but it is the only way we can ensure we invalidate all
// related caches correctly (e.g. visibility).
- for _, id := range followIDs {
- follow, err := r.GetFollowByID(ctx, id)
- if err != nil && !errors.Is(err, db.ErrNoEntries) {
- return err
- }
+ _, err := r.GetAccountFollows(ctx, accountID, nil)
+ if err != nil && !errors.Is(err, db.ErrNoEntries) {
+ return err
+ }
- // Delete each follow from DB.
- if err := r.deleteFollow(ctx, follow.ID); err != nil &&
- !errors.Is(err, db.ErrNoEntries) {
+ // Delete all follows from DB.
+ _, err = r.db.NewDelete().
+ Table("follows").
+ Where("? IN (?)", bun.Ident("id"), bun.In(followIDs)).
+ Exec(ctx)
+ if err != nil {
+ return err
+ }
+
+ for _, id := range followIDs {
+ // Finally, delete all list entries associated with each follow ID.
+ if err := r.state.DB.DeleteListEntriesForFollowID(ctx, id); err != nil {
return err
}
}
diff --git a/internal/db/bundb/relationship_follow_req.go b/internal/db/bundb/relationship_follow_req.go
index 51aceafe1..690b97cf0 100644
--- a/internal/db/bundb/relationship_follow_req.go
+++ b/internal/db/bundb/relationship_follow_req.go
@@ -20,6 +20,7 @@ package bundb
import (
"context"
"errors"
+ "slices"
"time"
"github.com/superseriousbusiness/gotosocial/internal/db"
@@ -27,6 +28,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
+ "github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun"
)
@@ -61,7 +63,7 @@ func (r *relationshipDB) GetFollowRequestByURI(ctx context.Context, uri string)
func (r *relationshipDB) GetFollowRequest(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.FollowRequest, error) {
return r.getFollowRequest(
ctx,
- "AccountID.TargetAccountID",
+ "AccountID,TargetAccountID",
func(followReq *gtsmodel.FollowRequest) error {
return r.db.NewSelect().
Model(followReq).
@@ -75,22 +77,63 @@ func (r *relationshipDB) GetFollowRequest(ctx context.Context, sourceAccountID s
}
func (r *relationshipDB) GetFollowRequestsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.FollowRequest, error) {
- // Preallocate slice of expected length.
- followReqs := make([]*gtsmodel.FollowRequest, 0, len(ids))
+ // Preallocate at-worst possible length.
+ uncached := make([]string, 0, len(ids))
+
+ // Load all follow IDs via cache loader callbacks.
+ follows, err := r.state.Caches.GTS.FollowRequest.Load("ID",
+
+ // Load cached + check for uncached.
+ func(load func(keyParts ...any) bool) {
+ for _, id := range ids {
+ if !load(id) {
+ uncached = append(uncached, id)
+ }
+ }
+ },
- for _, id := range ids {
- // Fetch follow request model for this ID.
- followReq, err := r.GetFollowRequestByID(ctx, id)
- if err != nil {
- log.Errorf(ctx, "error getting follow request %q: %v", id, err)
- continue
- }
+ // Uncached follow req loader function.
+ func() ([]*gtsmodel.FollowRequest, error) {
+ // Preallocate expected length of uncached followReqs.
+ follows := make([]*gtsmodel.FollowRequest, 0, len(uncached))
+
+ // Perform database query scanning
+ // the remaining (uncached) IDs.
+ if err := r.db.NewSelect().
+ Model(&follows).
+ Where("? IN (?)", bun.Ident("id"), bun.In(uncached)).
+ Scan(ctx); err != nil {
+ return nil, err
+ }
+
+ return follows, nil
+ },
+ )
+ if err != nil {
+ return nil, err
+ }
+
+ // Reorder the requests by their
+ // IDs to ensure in correct order.
+ getID := func(f *gtsmodel.FollowRequest) string { return f.ID }
+ util.OrderBy(follows, ids, getID)
- // Append to return slice.
- followReqs = append(followReqs, followReq)
+ if gtscontext.Barebones(ctx) {
+ // no need to fully populate.
+ return follows, nil
}
- return followReqs, nil
+ // Populate all loaded followreqs, removing those we fail to
+ // populate (removes needing so many nil checks everywhere).
+ follows = slices.DeleteFunc(follows, func(follow *gtsmodel.FollowRequest) bool {
+ if err := r.PopulateFollowRequest(ctx, follow); err != nil {
+ log.Errorf(ctx, "error populating follow request %s: %v", follow.ID, err)
+ return true
+ }
+ return false
+ })
+
+ return follows, nil
}
func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, error) {
@@ -107,7 +150,7 @@ func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccountID
func (r *relationshipDB) getFollowRequest(ctx context.Context, lookup string, dbQuery func(*gtsmodel.FollowRequest) error, keyParts ...any) (*gtsmodel.FollowRequest, error) {
// Fetch follow request from database cache with loader callback
- followReq, err := r.state.Caches.GTS.FollowRequest().Load(lookup, func() (*gtsmodel.FollowRequest, error) {
+ followReq, err := r.state.Caches.GTS.FollowRequest.LoadOne(lookup, func() (*gtsmodel.FollowRequest, error) {
var followReq gtsmodel.FollowRequest
// Not cached! Perform database query
@@ -166,7 +209,7 @@ func (r *relationshipDB) PopulateFollowRequest(ctx context.Context, follow *gtsm
}
func (r *relationshipDB) PutFollowRequest(ctx context.Context, follow *gtsmodel.FollowRequest) error {
- return r.state.Caches.GTS.FollowRequest().Store(follow, func() error {
+ return r.state.Caches.GTS.FollowRequest.Store(follow, func() error {
_, err := r.db.NewInsert().Model(follow).Exec(ctx)
return err
})
@@ -179,7 +222,7 @@ func (r *relationshipDB) UpdateFollowRequest(ctx context.Context, followRequest
columns = append(columns, "updated_at")
}
- return r.state.Caches.GTS.FollowRequest().Store(followRequest, func() error {
+ return r.state.Caches.GTS.FollowRequest.Store(followRequest, func() error {
if _, err := r.db.NewUpdate().
Model(followRequest).
Where("? = ?", bun.Ident("follow_request.id"), followRequest.ID).
@@ -212,7 +255,7 @@ func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, sourceAccountI
Notify: followReq.Notify,
}
- if err := r.state.Caches.GTS.Follow().Store(follow, func() error {
+ if err := r.state.Caches.GTS.Follow.Store(follow, func() error {
// If the follow already exists, just
// replace the URI with the new one.
_, err := r.db.
@@ -274,7 +317,7 @@ func (r *relationshipDB) DeleteFollowRequest(ctx context.Context, sourceAccountI
}
// Drop this now-cached follow request on return after delete.
- defer r.state.Caches.GTS.FollowRequest().Invalidate("AccountID.TargetAccountID", sourceAccountID, targetAccountID)
+ defer r.state.Caches.GTS.FollowRequest.Invalidate("AccountID,TargetAccountID", sourceAccountID, targetAccountID)
// Finally delete followreq from DB.
_, err = r.db.NewDelete().
@@ -298,7 +341,7 @@ func (r *relationshipDB) DeleteFollowRequestByID(ctx context.Context, id string)
}
// Drop this now-cached follow request on return after delete.
- defer r.state.Caches.GTS.FollowRequest().Invalidate("ID", id)
+ defer r.state.Caches.GTS.FollowRequest.Invalidate("ID", id)
// Finally delete followreq from DB.
_, err = r.db.NewDelete().
@@ -322,7 +365,7 @@ func (r *relationshipDB) DeleteFollowRequestByURI(ctx context.Context, uri strin
}
// Drop this now-cached follow request on return after delete.
- defer r.state.Caches.GTS.FollowRequest().Invalidate("URI", uri)
+ defer r.state.Caches.GTS.FollowRequest.Invalidate("URI", uri)
// Finally delete followreq from DB.
_, err = r.db.NewDelete().
@@ -352,22 +395,20 @@ func (r *relationshipDB) DeleteAccountFollowRequests(ctx context.Context, accoun
defer func() {
// Invalidate all account's incoming / outoing follow requests on return.
- r.state.Caches.GTS.FollowRequest().Invalidate("AccountID", accountID)
- r.state.Caches.GTS.FollowRequest().Invalidate("TargetAccountID", accountID)
+ r.state.Caches.GTS.FollowRequest.Invalidate("AccountID", accountID)
+ r.state.Caches.GTS.FollowRequest.Invalidate("TargetAccountID", accountID)
}()
// Load all followreqs into cache, this *really* isn't
// great but it is the only way we can ensure we invalidate
// all related caches correctly (e.g. visibility).
- for _, id := range followReqIDs {
- _, err := r.GetFollowRequestByID(ctx, id)
- if err != nil && !errors.Is(err, db.ErrNoEntries) {
- return err
- }
+ _, err := r.GetAccountFollowRequests(ctx, accountID, nil)
+ if err != nil && !errors.Is(err, db.ErrNoEntries) {
+ return err
}
// Finally delete all from DB.
- _, err := r.db.NewDelete().
+ _, err = r.db.NewDelete().
Table("follow_requests").
Where("? IN (?)", bun.Ident("id"), bun.In(followReqIDs)).
Exec(ctx)
diff --git a/internal/db/bundb/relationship_note.go b/internal/db/bundb/relationship_note.go
index f7d15f8b7..126ea0cd1 100644
--- a/internal/db/bundb/relationship_note.go
+++ b/internal/db/bundb/relationship_note.go
@@ -30,7 +30,7 @@ import (
func (r *relationshipDB) GetNote(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.AccountNote, error) {
return r.getNote(
ctx,
- "AccountID.TargetAccountID",
+ "AccountID,TargetAccountID",
func(note *gtsmodel.AccountNote) error {
return r.db.NewSelect().Model(note).
Where("? = ?", bun.Ident("account_id"), sourceAccountID).
@@ -44,7 +44,7 @@ func (r *relationshipDB) GetNote(ctx context.Context, sourceAccountID string, ta
func (r *relationshipDB) getNote(ctx context.Context, lookup string, dbQuery func(*gtsmodel.AccountNote) error, keyParts ...any) (*gtsmodel.AccountNote, error) {
// Fetch note from cache with loader callback
- note, err := r.state.Caches.GTS.AccountNote().Load(lookup, func() (*gtsmodel.AccountNote, error) {
+ note, err := r.state.Caches.GTS.AccountNote.LoadOne(lookup, func() (*gtsmodel.AccountNote, error) {
var note gtsmodel.AccountNote
// Not cached! Perform database query
@@ -105,7 +105,7 @@ func (r *relationshipDB) PopulateNote(ctx context.Context, note *gtsmodel.Accoun
func (r *relationshipDB) PutNote(ctx context.Context, note *gtsmodel.AccountNote) error {
note.UpdatedAt = time.Now()
- return r.state.Caches.GTS.AccountNote().Store(note, func() error {
+ return r.state.Caches.GTS.AccountNote.Store(note, func() error {
_, err := r.db.
NewInsert().
Model(note).
diff --git a/internal/db/bundb/report.go b/internal/db/bundb/report.go
index 9e4ba5b29..5b0ae17f3 100644
--- a/internal/db/bundb/report.go
+++ b/internal/db/bundb/report.go
@@ -120,7 +120,7 @@ func (r *reportDB) GetReports(ctx context.Context, resolved *bool, accountID str
func (r *reportDB) getReport(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Report) error, keyParts ...any) (*gtsmodel.Report, error) {
// Fetch report from database cache with loader callback
- report, err := r.state.Caches.GTS.Report().Load(lookup, func() (*gtsmodel.Report, error) {
+ report, err := r.state.Caches.GTS.Report.LoadOne(lookup, func() (*gtsmodel.Report, error) {
var report gtsmodel.Report
// Not cached! Perform database query
@@ -215,7 +215,7 @@ func (r *reportDB) PopulateReport(ctx context.Context, report *gtsmodel.Report)
}
func (r *reportDB) PutReport(ctx context.Context, report *gtsmodel.Report) error {
- return r.state.Caches.GTS.Report().Store(report, func() error {
+ return r.state.Caches.GTS.Report.Store(report, func() error {
_, err := r.db.NewInsert().Model(report).Exec(ctx)
return err
})
@@ -237,12 +237,12 @@ func (r *reportDB) UpdateReport(ctx context.Context, report *gtsmodel.Report, co
return nil, err
}
- r.state.Caches.GTS.Report().Invalidate("ID", report.ID)
+ r.state.Caches.GTS.Report.Invalidate("ID", report.ID)
return report, nil
}
func (r *reportDB) DeleteReportByID(ctx context.Context, id string) error {
- defer r.state.Caches.GTS.Report().Invalidate("ID", id)
+ defer r.state.Caches.GTS.Report.Invalidate("ID", id)
// Load status into cache before attempting a delete,
// as we need it cached in order to trigger the invalidate
diff --git a/internal/db/bundb/rule.go b/internal/db/bundb/rule.go
index 79825923b..ebfa89d15 100644
--- a/internal/db/bundb/rule.go
+++ b/internal/db/bundb/rule.go
@@ -125,7 +125,7 @@ func (r *ruleDB) PutRule(ctx context.Context, rule *gtsmodel.Rule) error {
}
// invalidate cached local instance response, so it gets updated with the new rules
- r.state.Caches.GTS.Instance().Invalidate("Domain", config.GetHost())
+ r.state.Caches.GTS.Instance.Invalidate("Domain", config.GetHost())
return nil
}
@@ -143,7 +143,7 @@ func (r *ruleDB) UpdateRule(ctx context.Context, rule *gtsmodel.Rule) (*gtsmodel
}
// invalidate cached local instance response, so it gets updated with the new rules
- r.state.Caches.GTS.Instance().Invalidate("Domain", config.GetHost())
+ r.state.Caches.GTS.Instance.Invalidate("Domain", config.GetHost())
return rule, nil
}
diff --git a/internal/db/bundb/status.go b/internal/db/bundb/status.go
index da252c7f7..07a09050a 100644
--- a/internal/db/bundb/status.go
+++ b/internal/db/bundb/status.go
@@ -20,6 +20,7 @@ package bundb
import (
"context"
"errors"
+ "slices"
"time"
"github.com/superseriousbusiness/gotosocial/internal/db"
@@ -28,6 +29,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state"
+ "github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun"
)
@@ -48,20 +50,62 @@ func (s *statusDB) GetStatusByID(ctx context.Context, id string) (*gtsmodel.Stat
}
func (s *statusDB) GetStatusesByIDs(ctx context.Context, ids []string) ([]*gtsmodel.Status, error) {
- statuses := make([]*gtsmodel.Status, 0, len(ids))
+ // Preallocate at-worst possible length.
+ uncached := make([]string, 0, len(ids))
- for _, id := range ids {
- // Attempt to fetch status from DB.
- status, err := s.GetStatusByID(ctx, id)
- if err != nil {
- log.Errorf(ctx, "error getting status %q: %v", id, err)
- continue
- }
+ // Load all status IDs via cache loader callbacks.
+ statuses, err := s.state.Caches.GTS.Status.Load("ID",
+
+ // Load cached + check for uncached.
+ func(load func(keyParts ...any) bool) {
+ for _, id := range ids {
+ if !load(id) {
+ uncached = append(uncached, id)
+ }
+ }
+ },
+
+ // Uncached statuses loader function.
+ func() ([]*gtsmodel.Status, error) {
+ // Preallocate expected length of uncached statuses.
+ statuses := make([]*gtsmodel.Status, 0, len(uncached))
+
+ // Perform database query scanning
+ // the remaining (uncached) status IDs.
+ if err := s.db.NewSelect().
+ Model(&statuses).
+ Where("? IN (?)", bun.Ident("id"), bun.In(uncached)).
+ Scan(ctx); err != nil {
+ return nil, err
+ }
+
+ return statuses, nil
+ },
+ )
+ if err != nil {
+ return nil, err
+ }
+
+ // Reorder the statuses by their
+ // IDs to ensure in correct order.
+ getID := func(s *gtsmodel.Status) string { return s.ID }
+ util.OrderBy(statuses, ids, getID)
- // Append status to return slice.
- statuses = append(statuses, status)
+ if gtscontext.Barebones(ctx) {
+ // no need to fully populate.
+ return statuses, nil
}
+ // Populate all loaded statuses, removing those we fail to
+ // populate (removes needing so many nil checks everywhere).
+ statuses = slices.DeleteFunc(statuses, func(status *gtsmodel.Status) bool {
+ if err := s.PopulateStatus(ctx, status); err != nil {
+ log.Errorf(ctx, "error populating status %s: %v", status.ID, err)
+ return true
+ }
+ return false
+ })
+
return statuses, nil
}
@@ -101,7 +145,7 @@ func (s *statusDB) GetStatusByPollID(ctx context.Context, pollID string) (*gtsmo
func (s *statusDB) GetStatusBoost(ctx context.Context, boostOfID string, byAccountID string) (*gtsmodel.Status, error) {
return s.getStatus(
ctx,
- "BoostOfID.AccountID",
+ "BoostOfID,AccountID",
func(status *gtsmodel.Status) error {
return s.db.NewSelect().Model(status).
Where("status.boost_of_id = ?", boostOfID).
@@ -120,7 +164,7 @@ func (s *statusDB) GetStatusBoost(ctx context.Context, boostOfID string, byAccou
func (s *statusDB) getStatus(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Status) error, keyParts ...any) (*gtsmodel.Status, error) {
// Fetch status from database cache with loader callback
- status, err := s.state.Caches.GTS.Status().Load(lookup, func() (*gtsmodel.Status, error) {
+ status, err := s.state.Caches.GTS.Status.LoadOne(lookup, func() (*gtsmodel.Status, error) {
var status gtsmodel.Status
// Not cached! Perform database query.
@@ -282,7 +326,7 @@ func (s *statusDB) PopulateStatus(ctx context.Context, status *gtsmodel.Status)
}
func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) error {
- return s.state.Caches.GTS.Status().Store(status, func() error {
+ return s.state.Caches.GTS.Status.Store(status, func() error {
// It is safe to run this database transaction within cache.Store
// as the cache does not attempt a mutex lock until AFTER hook.
//
@@ -366,7 +410,7 @@ func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status, co
columns = append(columns, "updated_at")
}
- return s.state.Caches.GTS.Status().Store(status, func() error {
+ return s.state.Caches.GTS.Status.Store(status, func() error {
// It is safe to run this database transaction within cache.Store
// as the cache does not attempt a mutex lock until AFTER hook.
//
@@ -463,7 +507,7 @@ func (s *statusDB) DeleteStatusByID(ctx context.Context, id string) error {
}
// On return ensure status invalidated from cache.
- defer s.state.Caches.GTS.Status().Invalidate("ID", id)
+ defer s.state.Caches.GTS.Status.Invalidate("ID", id)
return s.db.RunInTx(ctx, func(tx Tx) error {
// delete links between this status and any emojis it uses
@@ -585,7 +629,7 @@ func (s *statusDB) CountStatusReplies(ctx context.Context, statusID string) (int
}
func (s *statusDB) getStatusReplyIDs(ctx context.Context, statusID string) ([]string, error) {
- return s.state.Caches.GTS.InReplyToIDs().Load(statusID, func() ([]string, error) {
+ return s.state.Caches.GTS.InReplyToIDs.Load(statusID, func() ([]string, error) {
var statusIDs []string
// Status reply IDs not in cache, perform DB query!
@@ -629,7 +673,7 @@ func (s *statusDB) CountStatusBoosts(ctx context.Context, statusID string) (int,
}
func (s *statusDB) getStatusBoostIDs(ctx context.Context, statusID string) ([]string, error) {
- return s.state.Caches.GTS.BoostOfIDs().Load(statusID, func() ([]string, error) {
+ return s.state.Caches.GTS.BoostOfIDs.Load(statusID, func() ([]string, error) {
var statusIDs []string
// Status boost IDs not in cache, perform DB query!
diff --git a/internal/db/bundb/statusfave.go b/internal/db/bundb/statusfave.go
index 73ac62fe7..e0f018b68 100644
--- a/internal/db/bundb/statusfave.go
+++ b/internal/db/bundb/statusfave.go
@@ -22,6 +22,7 @@ import (
"database/sql"
"errors"
"fmt"
+ "slices"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
@@ -29,6 +30,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state"
+ "github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun"
)
@@ -40,7 +42,7 @@ type statusFaveDB struct {
func (s *statusFaveDB) GetStatusFave(ctx context.Context, accountID string, statusID string) (*gtsmodel.StatusFave, error) {
return s.getStatusFave(
ctx,
- "AccountID.StatusID",
+ "AccountID,StatusID",
func(fave *gtsmodel.StatusFave) error {
return s.db.
NewSelect().
@@ -77,7 +79,7 @@ func (s *statusFaveDB) GetStatusFaveByID(ctx context.Context, id string) (*gtsmo
func (s *statusFaveDB) getStatusFave(ctx context.Context, lookup string, dbQuery func(*gtsmodel.StatusFave) error, keyParts ...any) (*gtsmodel.StatusFave, error) {
// Fetch status fave from database cache with loader callback
- fave, err := s.state.Caches.GTS.StatusFave().Load(lookup, func() (*gtsmodel.StatusFave, error) {
+ fave, err := s.state.Caches.GTS.StatusFave.LoadOne(lookup, func() (*gtsmodel.StatusFave, error) {
var fave gtsmodel.StatusFave
// Not cached! Perform database query.
@@ -111,19 +113,62 @@ func (s *statusFaveDB) GetStatusFaves(ctx context.Context, statusID string) ([]*
return nil, err
}
- // Preallocate a slice of expected status fave capacity.
- faves := make([]*gtsmodel.StatusFave, 0, len(faveIDs))
+ // Preallocate at-worst possible length.
+ uncached := make([]string, 0, len(faveIDs))
- for _, id := range faveIDs {
- // Fetch status fave model for each ID.
- fave, err := s.GetStatusFaveByID(ctx, id)
- if err != nil {
- log.Errorf(ctx, "error getting status fave %q: %v", id, err)
- continue
- }
- faves = append(faves, fave)
+ // Load all fave IDs via cache loader callbacks.
+ faves, err := s.state.Caches.GTS.StatusFave.Load("ID",
+
+ // Load cached + check for uncached.
+ func(load func(keyParts ...any) bool) {
+ for _, id := range faveIDs {
+ if !load(id) {
+ uncached = append(uncached, id)
+ }
+ }
+ },
+
+ // Uncached status faves loader function.
+ func() ([]*gtsmodel.StatusFave, error) {
+ // Preallocate expected length of uncached faves.
+ faves := make([]*gtsmodel.StatusFave, 0, len(uncached))
+
+ // Perform database query scanning
+ // the remaining (uncached) fave IDs.
+ if err := s.db.NewSelect().
+ Model(&faves).
+ Where("? IN (?)", bun.Ident("id"), bun.In(uncached)).
+ Scan(ctx); err != nil {
+ return nil, err
+ }
+
+ return faves, nil
+ },
+ )
+ if err != nil {
+ return nil, err
}
+ // Reorder the statuses by their
+ // IDs to ensure in correct order.
+ getID := func(f *gtsmodel.StatusFave) string { return f.ID }
+ util.OrderBy(faves, faveIDs, getID)
+
+ if gtscontext.Barebones(ctx) {
+ // no need to fully populate.
+ return faves, nil
+ }
+
+ // Populate all loaded faves, removing those we fail to
+ // populate (removes needing so many nil checks everywhere).
+ faves = slices.DeleteFunc(faves, func(fave *gtsmodel.StatusFave) bool {
+ if err := s.PopulateStatusFave(ctx, fave); err != nil {
+ log.Errorf(ctx, "error populating fave %s: %v", fave.ID, err)
+ return true
+ }
+ return false
+ })
+
return faves, nil
}
@@ -141,7 +186,7 @@ func (s *statusFaveDB) CountStatusFaves(ctx context.Context, statusID string) (i
}
func (s *statusFaveDB) getStatusFaveIDs(ctx context.Context, statusID string) ([]string, error) {
- return s.state.Caches.GTS.StatusFaveIDs().Load(statusID, func() ([]string, error) {
+ return s.state.Caches.GTS.StatusFaveIDs.Load(statusID, func() ([]string, error) {
var faveIDs []string
// Status fave IDs not in cache, perform DB query!
@@ -201,7 +246,7 @@ func (s *statusFaveDB) PopulateStatusFave(ctx context.Context, statusFave *gtsmo
}
func (s *statusFaveDB) PutStatusFave(ctx context.Context, fave *gtsmodel.StatusFave) error {
- return s.state.Caches.GTS.StatusFave().Store(fave, func() error {
+ return s.state.Caches.GTS.StatusFave.Store(fave, func() error {
_, err := s.db.
NewInsert().
Model(fave).
@@ -230,10 +275,10 @@ func (s *statusFaveDB) DeleteStatusFaveByID(ctx context.Context, id string) erro
if statusID != "" {
// Invalidate any cached status faves for this status.
- s.state.Caches.GTS.StatusFave().Invalidate("ID", id)
+ s.state.Caches.GTS.StatusFave.Invalidate("ID", id)
// Invalidate any cached status fave IDs for this status.
- s.state.Caches.GTS.StatusFaveIDs().Invalidate(statusID)
+ s.state.Caches.GTS.StatusFaveIDs.Invalidate(statusID)
}
return nil
@@ -270,17 +315,15 @@ func (s *statusFaveDB) DeleteStatusFaves(ctx context.Context, targetAccountID st
return err
}
- // Collate (deduplicating) status IDs.
- statusIDs = collate(func(i int) string {
- return statusIDs[i]
- }, len(statusIDs))
+ // Deduplicate determined status IDs.
+ statusIDs = util.Deduplicate(statusIDs)
for _, id := range statusIDs {
// Invalidate any cached status faves for this status.
- s.state.Caches.GTS.StatusFave().Invalidate("ID", id)
+ s.state.Caches.GTS.StatusFave.Invalidate("ID", id)
// Invalidate any cached status fave IDs for this status.
- s.state.Caches.GTS.StatusFaveIDs().Invalidate(id)
+ s.state.Caches.GTS.StatusFaveIDs.Invalidate(id)
}
return nil
@@ -296,10 +339,10 @@ func (s *statusFaveDB) DeleteStatusFavesForStatus(ctx context.Context, statusID
}
// Invalidate any cached status faves for this status.
- s.state.Caches.GTS.StatusFave().Invalidate("ID", statusID)
+ s.state.Caches.GTS.StatusFave.Invalidate("ID", statusID)
// Invalidate any cached status fave IDs for this status.
- s.state.Caches.GTS.StatusFaveIDs().Invalidate(statusID)
+ s.state.Caches.GTS.StatusFaveIDs.Invalidate(statusID)
return nil
}
diff --git a/internal/db/bundb/tag.go b/internal/db/bundb/tag.go
index fac621f0a..66ee8cb3a 100644
--- a/internal/db/bundb/tag.go
+++ b/internal/db/bundb/tag.go
@@ -22,21 +22,21 @@ import (
"strings"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
- "github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state"
+ "github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun"
)
type tagDB struct {
- conn *DB
+ db *DB
state *state.State
}
-func (m *tagDB) GetTag(ctx context.Context, id string) (*gtsmodel.Tag, error) {
- return m.state.Caches.GTS.Tag().Load("ID", func() (*gtsmodel.Tag, error) {
+func (t *tagDB) GetTag(ctx context.Context, id string) (*gtsmodel.Tag, error) {
+ return t.state.Caches.GTS.Tag.LoadOne("ID", func() (*gtsmodel.Tag, error) {
var tag gtsmodel.Tag
- q := m.conn.
+ q := t.db.
NewSelect().
Model(&tag).
Where("? = ?", bun.Ident("tag.id"), id)
@@ -49,15 +49,15 @@ func (m *tagDB) GetTag(ctx context.Context, id string) (*gtsmodel.Tag, error) {
}, id)
}
-func (m *tagDB) GetTagByName(ctx context.Context, name string) (*gtsmodel.Tag, error) {
+func (t *tagDB) GetTagByName(ctx context.Context, name string) (*gtsmodel.Tag, error) {
// Normalize 'name' string.
name = strings.TrimSpace(name)
name = strings.ToLower(name)
- return m.state.Caches.GTS.Tag().Load("Name", func() (*gtsmodel.Tag, error) {
+ return t.state.Caches.GTS.Tag.LoadOne("Name", func() (*gtsmodel.Tag, error) {
var tag gtsmodel.Tag
- q := m.conn.
+ q := t.db.
NewSelect().
Model(&tag).
Where("? = ?", bun.Ident("tag.name"), name)
@@ -70,25 +70,52 @@ func (m *tagDB) GetTagByName(ctx context.Context, name string) (*gtsmodel.Tag, e
}, name)
}
-func (m *tagDB) GetTags(ctx context.Context, ids []string) ([]*gtsmodel.Tag, error) {
- tags := make([]*gtsmodel.Tag, 0, len(ids))
-
- for _, id := range ids {
- // Attempt fetch from DB
- tag, err := m.GetTag(ctx, id)
- if err != nil {
- log.Errorf(ctx, "error getting tag %q: %v", id, err)
- continue
- }
-
- // Append tag
- tags = append(tags, tag)
+func (t *tagDB) GetTags(ctx context.Context, ids []string) ([]*gtsmodel.Tag, error) {
+ // Preallocate at-worst possible length.
+ uncached := make([]string, 0, len(ids))
+
+ // Load all tag IDs via cache loader callbacks.
+ tags, err := t.state.Caches.GTS.Tag.Load("ID",
+
+ // Load cached + check for uncached.
+ func(load func(keyParts ...any) bool) {
+ for _, id := range ids {
+ if !load(id) {
+ uncached = append(uncached, id)
+ }
+ }
+ },
+
+ // Uncached tag loader function.
+ func() ([]*gtsmodel.Tag, error) {
+ // Preallocate expected length of uncached tags.
+ tags := make([]*gtsmodel.Tag, 0, len(uncached))
+
+ // Perform database query scanning
+ // the remaining (uncached) IDs.
+ if err := t.db.NewSelect().
+ Model(&tags).
+ Where("? IN (?)", bun.Ident("id"), bun.In(uncached)).
+ Scan(ctx); err != nil {
+ return nil, err
+ }
+
+ return tags, nil
+ },
+ )
+ if err != nil {
+ return nil, err
}
+ // Reorder the tags by their
+ // IDs to ensure in correct order.
+ getID := func(t *gtsmodel.Tag) string { return t.ID }
+ util.OrderBy(tags, ids, getID)
+
return tags, nil
}
-func (m *tagDB) PutTag(ctx context.Context, tag *gtsmodel.Tag) error {
+func (t *tagDB) PutTag(ctx context.Context, tag *gtsmodel.Tag) error {
// Normalize 'name' string before it enters
// the db, without changing tag we were given.
//
@@ -101,8 +128,8 @@ func (m *tagDB) PutTag(ctx context.Context, tag *gtsmodel.Tag) error {
t2.Name = strings.ToLower(t2.Name)
// Insert the copy.
- if err := m.state.Caches.GTS.Tag().Store(t2, func() error {
- _, err := m.conn.NewInsert().Model(t2).Exec(ctx)
+ if err := t.state.Caches.GTS.Tag.Store(t2, func() error {
+ _, err := t.db.NewInsert().Model(t2).Exec(ctx)
return err
}); err != nil {
return err // err already processed
diff --git a/internal/db/bundb/thread.go b/internal/db/bundb/thread.go
index e6d6154d4..34c5f783a 100644
--- a/internal/db/bundb/thread.go
+++ b/internal/db/bundb/thread.go
@@ -42,7 +42,7 @@ func (t *threadDB) PutThread(ctx context.Context, thread *gtsmodel.Thread) error
}
func (t *threadDB) GetThreadMute(ctx context.Context, id string) (*gtsmodel.ThreadMute, error) {
- return t.state.Caches.GTS.ThreadMute().Load("ID", func() (*gtsmodel.ThreadMute, error) {
+ return t.state.Caches.GTS.ThreadMute.LoadOne("ID", func() (*gtsmodel.ThreadMute, error) {
var threadMute gtsmodel.ThreadMute
q := t.db.
@@ -63,7 +63,7 @@ func (t *threadDB) GetThreadMutedByAccount(
threadID string,
accountID string,
) (*gtsmodel.ThreadMute, error) {
- return t.state.Caches.GTS.ThreadMute().Load("ThreadID.AccountID", func() (*gtsmodel.ThreadMute, error) {
+ return t.state.Caches.GTS.ThreadMute.LoadOne("ThreadID,AccountID", func() (*gtsmodel.ThreadMute, error) {
var threadMute gtsmodel.ThreadMute
q := t.db.
@@ -98,7 +98,7 @@ func (t *threadDB) IsThreadMutedByAccount(
}
func (t *threadDB) PutThreadMute(ctx context.Context, threadMute *gtsmodel.ThreadMute) error {
- return t.state.Caches.GTS.ThreadMute().Store(threadMute, func() error {
+ return t.state.Caches.GTS.ThreadMute.Store(threadMute, func() error {
_, err := t.db.NewInsert().Model(threadMute).Exec(ctx)
return err
})
@@ -112,6 +112,6 @@ func (t *threadDB) DeleteThreadMute(ctx context.Context, id string) error {
return err
}
- t.state.Caches.GTS.ThreadMute().Invalidate("ID", id)
+ t.state.Caches.GTS.ThreadMute.Invalidate("ID", id)
return nil
}
diff --git a/internal/db/bundb/timeline.go b/internal/db/bundb/timeline.go
index 4af17fb7f..f2ba2a9d1 100644
--- a/internal/db/bundb/timeline.go
+++ b/internal/db/bundb/timeline.go
@@ -29,7 +29,6 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id"
- "github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/uptrace/bun"
)
@@ -155,20 +154,8 @@ func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxI
}
}
- statuses := make([]*gtsmodel.Status, 0, len(statusIDs))
- for _, id := range statusIDs {
- // Fetch status from db for ID
- status, err := t.state.DB.GetStatusByID(ctx, id)
- if err != nil {
- log.Errorf(ctx, "error fetching status %q: %v", id, err)
- continue
- }
-
- // Append status to slice
- statuses = append(statuses, status)
- }
-
- return statuses, nil
+ // Return status IDs loaded from cache + db.
+ return t.state.DB.GetStatusesByIDs(ctx, statusIDs)
}
func (t *timelineDB) GetPublicTimeline(ctx context.Context, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, error) {
@@ -256,20 +243,8 @@ func (t *timelineDB) GetPublicTimeline(ctx context.Context, maxID string, sinceI
}
}
- statuses := make([]*gtsmodel.Status, 0, len(statusIDs))
- for _, id := range statusIDs {
- // Fetch status from db for ID
- status, err := t.state.DB.GetStatusByID(ctx, id)
- if err != nil {
- log.Errorf(ctx, "error fetching status %q: %v", id, err)
- continue
- }
-
- // Append status to slice
- statuses = append(statuses, status)
- }
-
- return statuses, nil
+ // Return status IDs loaded from cache + db.
+ return t.state.DB.GetStatusesByIDs(ctx, statusIDs)
}
// TODO optimize this query and the logic here, because it's slow as balls -- it takes like a literal second to return with a limit of 20!
@@ -323,18 +298,15 @@ func (t *timelineDB) GetFavedTimeline(ctx context.Context, accountID string, max
}
})
- statuses := make([]*gtsmodel.Status, 0, len(faves))
-
- for _, fave := range faves {
- // Fetch status from db for corresponding favourite
- status, err := t.state.DB.GetStatusByID(ctx, fave.StatusID)
- if err != nil {
- log.Errorf(ctx, "error fetching status for fave %q: %v", fave.ID, err)
- continue
- }
+ // Convert fave IDs to status IDs.
+ statusIDs := make([]string, len(faves))
+ for i, fave := range faves {
+ statusIDs[i] = fave.StatusID
+ }
- // Append status to slice
- statuses = append(statuses, status)
+ statuses, err := t.state.DB.GetStatusesByIDs(ctx, statusIDs)
+ if err != nil {
+ return nil, "", "", err
}
nextMaxID := faves[len(faves)-1].ID
@@ -453,20 +425,8 @@ func (t *timelineDB) GetListTimeline(
}
}
- statuses := make([]*gtsmodel.Status, 0, len(statusIDs))
- for _, id := range statusIDs {
- // Fetch status from db for ID
- status, err := t.state.DB.GetStatusByID(ctx, id)
- if err != nil {
- log.Errorf(ctx, "error fetching status %q: %v", id, err)
- continue
- }
-
- // Append status to slice
- statuses = append(statuses, status)
- }
-
- return statuses, nil
+ // Return status IDs loaded from cache + db.
+ return t.state.DB.GetStatusesByIDs(ctx, statusIDs)
}
func (t *timelineDB) GetTagTimeline(
@@ -561,18 +521,6 @@ func (t *timelineDB) GetTagTimeline(
}
}
- statuses := make([]*gtsmodel.Status, 0, len(statusIDs))
- for _, id := range statusIDs {
- // Fetch status from db for ID
- status, err := t.state.DB.GetStatusByID(ctx, id)
- if err != nil {
- log.Errorf(ctx, "error fetching status %q: %v", id, err)
- continue
- }
-
- // Append status to slice
- statuses = append(statuses, status)
- }
-
- return statuses, nil
+ // Return status IDs loaded from cache + db.
+ return t.state.DB.GetStatusesByIDs(ctx, statusIDs)
}
diff --git a/internal/db/bundb/tombstone.go b/internal/db/bundb/tombstone.go
index f9882d1c6..c0e439720 100644
--- a/internal/db/bundb/tombstone.go
+++ b/internal/db/bundb/tombstone.go
@@ -32,7 +32,7 @@ type tombstoneDB struct {
}
func (t *tombstoneDB) GetTombstoneByURI(ctx context.Context, uri string) (*gtsmodel.Tombstone, error) {
- return t.state.Caches.GTS.Tombstone().Load("URI", func() (*gtsmodel.Tombstone, error) {
+ return t.state.Caches.GTS.Tombstone.LoadOne("URI", func() (*gtsmodel.Tombstone, error) {
var tomb gtsmodel.Tombstone
q := t.db.
@@ -57,7 +57,7 @@ func (t *tombstoneDB) TombstoneExistsWithURI(ctx context.Context, uri string) (b
}
func (t *tombstoneDB) PutTombstone(ctx context.Context, tombstone *gtsmodel.Tombstone) error {
- return t.state.Caches.GTS.Tombstone().Store(tombstone, func() error {
+ return t.state.Caches.GTS.Tombstone.Store(tombstone, func() error {
_, err := t.db.
NewInsert().
Model(tombstone).
@@ -67,7 +67,7 @@ func (t *tombstoneDB) PutTombstone(ctx context.Context, tombstone *gtsmodel.Tomb
}
func (t *tombstoneDB) DeleteTombstone(ctx context.Context, id string) error {
- defer t.state.Caches.GTS.Tombstone().Invalidate("ID", id)
+ defer t.state.Caches.GTS.Tombstone.Invalidate("ID", id)
// Delete tombstone from DB.
_, err := t.db.NewDelete().
diff --git a/internal/db/bundb/user.go b/internal/db/bundb/user.go
index 46b3c568f..a6fa142f2 100644
--- a/internal/db/bundb/user.go
+++ b/internal/db/bundb/user.go
@@ -116,7 +116,7 @@ func (u *userDB) GetUserByConfirmationToken(ctx context.Context, token string) (
func (u *userDB) getUser(ctx context.Context, lookup string, dbQuery func(*gtsmodel.User) error, keyParts ...any) (*gtsmodel.User, error) {
// Fetch user from database cache with loader callback.
- user, err := u.state.Caches.GTS.User().Load(lookup, func() (*gtsmodel.User, error) {
+ user, err := u.state.Caches.GTS.User.LoadOne(lookup, func() (*gtsmodel.User, error) {
var user gtsmodel.User
// Not cached! perform database query.
@@ -179,7 +179,7 @@ func (u *userDB) GetAllUsers(ctx context.Context) ([]*gtsmodel.User, error) {
}
func (u *userDB) PutUser(ctx context.Context, user *gtsmodel.User) error {
- return u.state.Caches.GTS.User().Store(user, func() error {
+ return u.state.Caches.GTS.User.Store(user, func() error {
_, err := u.db.
NewInsert().
Model(user).
@@ -197,7 +197,7 @@ func (u *userDB) UpdateUser(ctx context.Context, user *gtsmodel.User, columns ..
columns = append(columns, "updated_at")
}
- return u.state.Caches.GTS.User().Store(user, func() error {
+ return u.state.Caches.GTS.User.Store(user, func() error {
_, err := u.db.
NewUpdate().
Model(user).
@@ -209,7 +209,7 @@ func (u *userDB) UpdateUser(ctx context.Context, user *gtsmodel.User, columns ..
}
func (u *userDB) DeleteUserByID(ctx context.Context, userID string) error {
- defer u.state.Caches.GTS.User().Invalidate("ID", userID)
+ defer u.state.Caches.GTS.User.Invalidate("ID", userID)
// Load user into cache before attempting a delete,
// as we need it cached in order to trigger the invalidate
diff --git a/internal/db/list.go b/internal/db/list.go
index 91a540486..16a0207de 100644
--- a/internal/db/list.go
+++ b/internal/db/list.go
@@ -27,6 +27,9 @@ type List interface {
// GetListByID gets one list with the given id.
GetListByID(ctx context.Context, id string) (*gtsmodel.List, error)
+ // GetListsByIDs fetches all lists with the provided IDs.
+ GetListsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.List, error)
+
// GetListsForAccountID gets all lists owned by the given accountID.
GetListsForAccountID(ctx context.Context, accountID string) ([]*gtsmodel.List, error)
@@ -46,6 +49,9 @@ type List interface {
// GetListEntryByID gets one list entry with the given ID.
GetListEntryByID(ctx context.Context, id string) (*gtsmodel.ListEntry, error)
+ // GetListEntriesyIDs fetches all list entries with the provided IDs.
+ GetListEntriesByIDs(ctx context.Context, ids []string) ([]*gtsmodel.ListEntry, error)
+
// GetListEntries gets list entries from the given listID, using the given parameters.
GetListEntries(ctx context.Context, listID string, maxID string, sinceID string, minID string, limit int) ([]*gtsmodel.ListEntry, error)
diff --git a/internal/db/notification.go b/internal/db/notification.go
index ab8b5cc6d..9ff459b9c 100644
--- a/internal/db/notification.go
+++ b/internal/db/notification.go
@@ -33,6 +33,9 @@ type Notification interface {
// GetNotification returns one notification according to its id.
GetNotificationByID(ctx context.Context, id string) (*gtsmodel.Notification, error)
+ // GetNotificationsByIDs returns a slice of notifications of the the provided IDs.
+ GetNotificationsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.Notification, error)
+
// GetNotification gets one notification according to the provided parameters, if it exists.
// Since not all notifications are about a status, statusID can be an empty string.
GetNotification(ctx context.Context, notificationType gtsmodel.NotificationType, targetAccountID string, originAccountID string, statusID string) (*gtsmodel.Notification, error)