diff options
Diffstat (limited to 'internal/db/bundb/relationship_follow.go')
-rw-r--r-- | internal/db/bundb/relationship_follow.go | 107 |
1 files changed, 79 insertions, 28 deletions
diff --git a/internal/db/bundb/relationship_follow.go b/internal/db/bundb/relationship_follow.go index 6c5a75e4c..93ee69bd7 100644 --- a/internal/db/bundb/relationship_follow.go +++ b/internal/db/bundb/relationship_follow.go @@ -21,6 +21,7 @@ import ( "context" "errors" "fmt" + "slices" "time" "github.com/superseriousbusiness/gotosocial/internal/db" @@ -28,6 +29,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/log" + "github.com/superseriousbusiness/gotosocial/internal/util" "github.com/uptrace/bun" ) @@ -62,7 +64,7 @@ func (r *relationshipDB) GetFollowByURI(ctx context.Context, uri string) (*gtsmo func (r *relationshipDB) GetFollow(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.Follow, error) { return r.getFollow( ctx, - "AccountID.TargetAccountID", + "AccountID,TargetAccountID", func(follow *gtsmodel.Follow) error { return r.db.NewSelect(). Model(follow). @@ -76,21 +78,62 @@ func (r *relationshipDB) GetFollow(ctx context.Context, sourceAccountID string, } func (r *relationshipDB) GetFollowsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.Follow, error) { - // Preallocate slice of expected length. - follows := make([]*gtsmodel.Follow, 0, len(ids)) + // Preallocate at-worst possible length. + uncached := make([]string, 0, len(ids)) + + // Load all follow IDs via cache loader callbacks. + follows, err := r.state.Caches.GTS.Follow.Load("ID", + + // Load cached + check for uncached. + func(load func(keyParts ...any) bool) { + for _, id := range ids { + if !load(id) { + uncached = append(uncached, id) + } + } + }, - for _, id := range ids { - // Fetch follow model for this ID. - follow, err := r.GetFollowByID(ctx, id) - if err != nil { - log.Errorf(ctx, "error getting follow %q: %v", id, err) - continue - } + // Uncached follow loader function. + func() ([]*gtsmodel.Follow, error) { + // Preallocate expected length of uncached follows. + follows := make([]*gtsmodel.Follow, 0, len(uncached)) + + // Perform database query scanning + // the remaining (uncached) IDs. + if err := r.db.NewSelect(). + Model(&follows). + Where("? IN (?)", bun.Ident("id"), bun.In(uncached)). + Scan(ctx); err != nil { + return nil, err + } + + return follows, nil + }, + ) + if err != nil { + return nil, err + } + + // Reorder the follows by their + // IDs to ensure in correct order. + getID := func(f *gtsmodel.Follow) string { return f.ID } + util.OrderBy(follows, ids, getID) - // Append to return slice. - follows = append(follows, follow) + if gtscontext.Barebones(ctx) { + // no need to fully populate. + return follows, nil } + // Populate all loaded follows, removing those we fail to + // populate (removes needing so many nil checks everywhere). + follows = slices.DeleteFunc(follows, func(follow *gtsmodel.Follow) bool { + if err := r.PopulateFollow(ctx, follow); err != nil { + log.Errorf(ctx, "error populating follow %s: %v", follow.ID, err) + return true + } + return false + }) + return follows, nil } @@ -130,7 +173,7 @@ func (r *relationshipDB) IsMutualFollowing(ctx context.Context, accountID1 strin func (r *relationshipDB) getFollow(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Follow) error, keyParts ...any) (*gtsmodel.Follow, error) { // Fetch follow from database cache with loader callback - follow, err := r.state.Caches.GTS.Follow().Load(lookup, func() (*gtsmodel.Follow, error) { + follow, err := r.state.Caches.GTS.Follow.LoadOne(lookup, func() (*gtsmodel.Follow, error) { var follow gtsmodel.Follow // Not cached! Perform database query @@ -189,7 +232,7 @@ func (r *relationshipDB) PopulateFollow(ctx context.Context, follow *gtsmodel.Fo } func (r *relationshipDB) PutFollow(ctx context.Context, follow *gtsmodel.Follow) error { - return r.state.Caches.GTS.Follow().Store(follow, func() error { + return r.state.Caches.GTS.Follow.Store(follow, func() error { _, err := r.db.NewInsert().Model(follow).Exec(ctx) return err }) @@ -202,7 +245,7 @@ func (r *relationshipDB) UpdateFollow(ctx context.Context, follow *gtsmodel.Foll columns = append(columns, "updated_at") } - return r.state.Caches.GTS.Follow().Store(follow, func() error { + return r.state.Caches.GTS.Follow.Store(follow, func() error { if _, err := r.db.NewUpdate(). Model(follow). Where("? = ?", bun.Ident("follow.id"), follow.ID). @@ -250,7 +293,7 @@ func (r *relationshipDB) DeleteFollow(ctx context.Context, sourceAccountID strin } // Drop this now-cached follow on return after delete. - defer r.state.Caches.GTS.Follow().Invalidate("AccountID.TargetAccountID", sourceAccountID, targetAccountID) + defer r.state.Caches.GTS.Follow.Invalidate("AccountID,TargetAccountID", sourceAccountID, targetAccountID) // Finally delete follow from DB. return r.deleteFollow(ctx, follow.ID) @@ -270,7 +313,7 @@ func (r *relationshipDB) DeleteFollowByID(ctx context.Context, id string) error } // Drop this now-cached follow on return after delete. - defer r.state.Caches.GTS.Follow().Invalidate("ID", id) + defer r.state.Caches.GTS.Follow.Invalidate("ID", id) // Finally delete follow from DB. return r.deleteFollow(ctx, follow.ID) @@ -290,7 +333,7 @@ func (r *relationshipDB) DeleteFollowByURI(ctx context.Context, uri string) erro } // Drop this now-cached follow on return after delete. - defer r.state.Caches.GTS.Follow().Invalidate("URI", uri) + defer r.state.Caches.GTS.Follow.Invalidate("URI", uri) // Finally delete follow from DB. return r.deleteFollow(ctx, follow.ID) @@ -316,22 +359,30 @@ func (r *relationshipDB) DeleteAccountFollows(ctx context.Context, accountID str defer func() { // Invalidate all account's incoming / outoing follows on return. - r.state.Caches.GTS.Follow().Invalidate("AccountID", accountID) - r.state.Caches.GTS.Follow().Invalidate("TargetAccountID", accountID) + r.state.Caches.GTS.Follow.Invalidate("AccountID", accountID) + r.state.Caches.GTS.Follow.Invalidate("TargetAccountID", accountID) }() // Load all follows into cache, this *really* isn't great // but it is the only way we can ensure we invalidate all // related caches correctly (e.g. visibility). - for _, id := range followIDs { - follow, err := r.GetFollowByID(ctx, id) - if err != nil && !errors.Is(err, db.ErrNoEntries) { - return err - } + _, err := r.GetAccountFollows(ctx, accountID, nil) + if err != nil && !errors.Is(err, db.ErrNoEntries) { + return err + } - // Delete each follow from DB. - if err := r.deleteFollow(ctx, follow.ID); err != nil && - !errors.Is(err, db.ErrNoEntries) { + // Delete all follows from DB. + _, err = r.db.NewDelete(). + Table("follows"). + Where("? IN (?)", bun.Ident("id"), bun.In(followIDs)). + Exec(ctx) + if err != nil { + return err + } + + for _, id := range followIDs { + // Finally, delete all list entries associated with each follow ID. + if err := r.state.DB.DeleteListEntriesForFollowID(ctx, id); err != nil { return err } } |