diff options
Diffstat (limited to 'internal/db/bundb/relationship_block.go')
-rw-r--r-- | internal/db/bundb/relationship_block.go | 91 |
1 files changed, 66 insertions, 25 deletions
diff --git a/internal/db/bundb/relationship_block.go b/internal/db/bundb/relationship_block.go index efaa6d1a9..178de6aa7 100644 --- a/internal/db/bundb/relationship_block.go +++ b/internal/db/bundb/relationship_block.go @@ -20,12 +20,14 @@ package bundb import ( "context" "errors" + "slices" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtscontext" "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/log" + "github.com/superseriousbusiness/gotosocial/internal/util" "github.com/uptrace/bun" ) @@ -86,7 +88,7 @@ func (r *relationshipDB) GetBlockByURI(ctx context.Context, uri string) (*gtsmod func (r *relationshipDB) GetBlock(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.Block, error) { return r.getBlock( ctx, - "AccountID.TargetAccountID", + "AccountID,TargetAccountID", func(block *gtsmodel.Block) error { return r.db.NewSelect().Model(block). Where("? = ?", bun.Ident("block.account_id"), sourceAccountID). @@ -99,27 +101,68 @@ func (r *relationshipDB) GetBlock(ctx context.Context, sourceAccountID string, t } func (r *relationshipDB) GetBlocksByIDs(ctx context.Context, ids []string) ([]*gtsmodel.Block, error) { - // Preallocate slice of expected length. - blocks := make([]*gtsmodel.Block, 0, len(ids)) + // Preallocate at-worst possible length. + uncached := make([]string, 0, len(ids)) + + // Load all blocks IDs via cache loader callbacks. + blocks, err := r.state.Caches.GTS.Block.Load("ID", + + // Load cached + check for uncached. + func(load func(keyParts ...any) bool) { + for _, id := range ids { + if !load(id) { + uncached = append(uncached, id) + } + } + }, - for _, id := range ids { - // Fetch block model for this ID. - block, err := r.GetBlockByID(ctx, id) - if err != nil { - log.Errorf(ctx, "error getting block %q: %v", id, err) - continue - } + // Uncached block loader function. + func() ([]*gtsmodel.Block, error) { + // Preallocate expected length of uncached blocks. + blocks := make([]*gtsmodel.Block, 0, len(uncached)) + + // Perform database query scanning + // the remaining (uncached) IDs. + if err := r.db.NewSelect(). + Model(&blocks). + Where("? IN (?)", bun.Ident("id"), bun.In(uncached)). + Scan(ctx); err != nil { + return nil, err + } + + return blocks, nil + }, + ) + if err != nil { + return nil, err + } + + // Reorder the blocks by their + // IDs to ensure in correct order. + getID := func(b *gtsmodel.Block) string { return b.ID } + util.OrderBy(blocks, ids, getID) - // Append to return slice. - blocks = append(blocks, block) + if gtscontext.Barebones(ctx) { + // no need to fully populate. + return blocks, nil } + // Populate all loaded blocks, removing those we fail to + // populate (removes needing so many nil checks everywhere). + blocks = slices.DeleteFunc(blocks, func(block *gtsmodel.Block) bool { + if err := r.PopulateBlock(ctx, block); err != nil { + log.Errorf(ctx, "error populating block %s: %v", block.ID, err) + return true + } + return false + }) + return blocks, nil } func (r *relationshipDB) getBlock(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Block) error, keyParts ...any) (*gtsmodel.Block, error) { // Fetch block from cache with loader callback - block, err := r.state.Caches.GTS.Block().Load(lookup, func() (*gtsmodel.Block, error) { + block, err := r.state.Caches.GTS.Block.LoadOne(lookup, func() (*gtsmodel.Block, error) { var block gtsmodel.Block // Not cached! Perform database query @@ -148,8 +191,8 @@ func (r *relationshipDB) getBlock(ctx context.Context, lookup string, dbQuery fu func (r *relationshipDB) PopulateBlock(ctx context.Context, block *gtsmodel.Block) error { var ( + errs gtserror.MultiError err error - errs = gtserror.NewMultiError(2) ) if block.Account == nil { @@ -178,7 +221,7 @@ func (r *relationshipDB) PopulateBlock(ctx context.Context, block *gtsmodel.Bloc } func (r *relationshipDB) PutBlock(ctx context.Context, block *gtsmodel.Block) error { - return r.state.Caches.GTS.Block().Store(block, func() error { + return r.state.Caches.GTS.Block.Store(block, func() error { _, err := r.db.NewInsert().Model(block).Exec(ctx) return err }) @@ -198,7 +241,7 @@ func (r *relationshipDB) DeleteBlockByID(ctx context.Context, id string) error { } // Drop this now-cached block on return after delete. - defer r.state.Caches.GTS.Block().Invalidate("ID", id) + defer r.state.Caches.GTS.Block.Invalidate("ID", id) // Finally delete block from DB. _, err = r.db.NewDelete(). @@ -222,7 +265,7 @@ func (r *relationshipDB) DeleteBlockByURI(ctx context.Context, uri string) error } // Drop this now-cached block on return after delete. - defer r.state.Caches.GTS.Block().Invalidate("URI", uri) + defer r.state.Caches.GTS.Block.Invalidate("URI", uri) // Finally delete block from DB. _, err = r.db.NewDelete(). @@ -251,22 +294,20 @@ func (r *relationshipDB) DeleteAccountBlocks(ctx context.Context, accountID stri defer func() { // Invalidate all account's incoming / outoing blocks on return. - r.state.Caches.GTS.Block().Invalidate("AccountID", accountID) - r.state.Caches.GTS.Block().Invalidate("TargetAccountID", accountID) + r.state.Caches.GTS.Block.Invalidate("AccountID", accountID) + r.state.Caches.GTS.Block.Invalidate("TargetAccountID", accountID) }() // Load all blocks into cache, this *really* isn't great // but it is the only way we can ensure we invalidate all // related caches correctly (e.g. visibility). - for _, id := range blockIDs { - _, err := r.GetBlockByID(ctx, id) - if err != nil && !errors.Is(err, db.ErrNoEntries) { - return err - } + _, err := r.GetAccountBlocks(ctx, accountID, nil) + if err != nil && !errors.Is(err, db.ErrNoEntries) { + return err } // Finally delete all from DB. - _, err := r.db.NewDelete(). + _, err = r.db.NewDelete(). Table("blocks"). Where("? IN (?)", bun.Ident("id"), bun.In(blockIDs)). Exec(ctx) |