summaryrefslogtreecommitdiff
path: root/internal/db/bundb/relationship_block.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/db/bundb/relationship_block.go')
-rw-r--r--internal/db/bundb/relationship_block.go91
1 files changed, 66 insertions, 25 deletions
diff --git a/internal/db/bundb/relationship_block.go b/internal/db/bundb/relationship_block.go
index efaa6d1a9..178de6aa7 100644
--- a/internal/db/bundb/relationship_block.go
+++ b/internal/db/bundb/relationship_block.go
@@ -20,12 +20,14 @@ package bundb
import (
"context"
"errors"
+ "slices"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
+ "github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun"
)
@@ -86,7 +88,7 @@ func (r *relationshipDB) GetBlockByURI(ctx context.Context, uri string) (*gtsmod
func (r *relationshipDB) GetBlock(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.Block, error) {
return r.getBlock(
ctx,
- "AccountID.TargetAccountID",
+ "AccountID,TargetAccountID",
func(block *gtsmodel.Block) error {
return r.db.NewSelect().Model(block).
Where("? = ?", bun.Ident("block.account_id"), sourceAccountID).
@@ -99,27 +101,68 @@ func (r *relationshipDB) GetBlock(ctx context.Context, sourceAccountID string, t
}
func (r *relationshipDB) GetBlocksByIDs(ctx context.Context, ids []string) ([]*gtsmodel.Block, error) {
- // Preallocate slice of expected length.
- blocks := make([]*gtsmodel.Block, 0, len(ids))
+ // Preallocate at-worst possible length.
+ uncached := make([]string, 0, len(ids))
+
+ // Load all blocks IDs via cache loader callbacks.
+ blocks, err := r.state.Caches.GTS.Block.Load("ID",
+
+ // Load cached + check for uncached.
+ func(load func(keyParts ...any) bool) {
+ for _, id := range ids {
+ if !load(id) {
+ uncached = append(uncached, id)
+ }
+ }
+ },
- for _, id := range ids {
- // Fetch block model for this ID.
- block, err := r.GetBlockByID(ctx, id)
- if err != nil {
- log.Errorf(ctx, "error getting block %q: %v", id, err)
- continue
- }
+ // Uncached block loader function.
+ func() ([]*gtsmodel.Block, error) {
+ // Preallocate expected length of uncached blocks.
+ blocks := make([]*gtsmodel.Block, 0, len(uncached))
+
+ // Perform database query scanning
+ // the remaining (uncached) IDs.
+ if err := r.db.NewSelect().
+ Model(&blocks).
+ Where("? IN (?)", bun.Ident("id"), bun.In(uncached)).
+ Scan(ctx); err != nil {
+ return nil, err
+ }
+
+ return blocks, nil
+ },
+ )
+ if err != nil {
+ return nil, err
+ }
+
+ // Reorder the blocks by their
+ // IDs to ensure in correct order.
+ getID := func(b *gtsmodel.Block) string { return b.ID }
+ util.OrderBy(blocks, ids, getID)
- // Append to return slice.
- blocks = append(blocks, block)
+ if gtscontext.Barebones(ctx) {
+ // no need to fully populate.
+ return blocks, nil
}
+ // Populate all loaded blocks, removing those we fail to
+ // populate (removes needing so many nil checks everywhere).
+ blocks = slices.DeleteFunc(blocks, func(block *gtsmodel.Block) bool {
+ if err := r.PopulateBlock(ctx, block); err != nil {
+ log.Errorf(ctx, "error populating block %s: %v", block.ID, err)
+ return true
+ }
+ return false
+ })
+
return blocks, nil
}
func (r *relationshipDB) getBlock(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Block) error, keyParts ...any) (*gtsmodel.Block, error) {
// Fetch block from cache with loader callback
- block, err := r.state.Caches.GTS.Block().Load(lookup, func() (*gtsmodel.Block, error) {
+ block, err := r.state.Caches.GTS.Block.LoadOne(lookup, func() (*gtsmodel.Block, error) {
var block gtsmodel.Block
// Not cached! Perform database query
@@ -148,8 +191,8 @@ func (r *relationshipDB) getBlock(ctx context.Context, lookup string, dbQuery fu
func (r *relationshipDB) PopulateBlock(ctx context.Context, block *gtsmodel.Block) error {
var (
+ errs gtserror.MultiError
err error
- errs = gtserror.NewMultiError(2)
)
if block.Account == nil {
@@ -178,7 +221,7 @@ func (r *relationshipDB) PopulateBlock(ctx context.Context, block *gtsmodel.Bloc
}
func (r *relationshipDB) PutBlock(ctx context.Context, block *gtsmodel.Block) error {
- return r.state.Caches.GTS.Block().Store(block, func() error {
+ return r.state.Caches.GTS.Block.Store(block, func() error {
_, err := r.db.NewInsert().Model(block).Exec(ctx)
return err
})
@@ -198,7 +241,7 @@ func (r *relationshipDB) DeleteBlockByID(ctx context.Context, id string) error {
}
// Drop this now-cached block on return after delete.
- defer r.state.Caches.GTS.Block().Invalidate("ID", id)
+ defer r.state.Caches.GTS.Block.Invalidate("ID", id)
// Finally delete block from DB.
_, err = r.db.NewDelete().
@@ -222,7 +265,7 @@ func (r *relationshipDB) DeleteBlockByURI(ctx context.Context, uri string) error
}
// Drop this now-cached block on return after delete.
- defer r.state.Caches.GTS.Block().Invalidate("URI", uri)
+ defer r.state.Caches.GTS.Block.Invalidate("URI", uri)
// Finally delete block from DB.
_, err = r.db.NewDelete().
@@ -251,22 +294,20 @@ func (r *relationshipDB) DeleteAccountBlocks(ctx context.Context, accountID stri
defer func() {
// Invalidate all account's incoming / outoing blocks on return.
- r.state.Caches.GTS.Block().Invalidate("AccountID", accountID)
- r.state.Caches.GTS.Block().Invalidate("TargetAccountID", accountID)
+ r.state.Caches.GTS.Block.Invalidate("AccountID", accountID)
+ r.state.Caches.GTS.Block.Invalidate("TargetAccountID", accountID)
}()
// Load all blocks into cache, this *really* isn't great
// but it is the only way we can ensure we invalidate all
// related caches correctly (e.g. visibility).
- for _, id := range blockIDs {
- _, err := r.GetBlockByID(ctx, id)
- if err != nil && !errors.Is(err, db.ErrNoEntries) {
- return err
- }
+ _, err := r.GetAccountBlocks(ctx, accountID, nil)
+ if err != nil && !errors.Is(err, db.ErrNoEntries) {
+ return err
}
// Finally delete all from DB.
- _, err := r.db.NewDelete().
+ _, err = r.db.NewDelete().
Table("blocks").
Where("? IN (?)", bun.Ident("id"), bun.In(blockIDs)).
Exec(ctx)