summaryrefslogtreecommitdiff
path: root/internal/db/bundb/relationship_follow.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/db/bundb/relationship_follow.go')
-rw-r--r--internal/db/bundb/relationship_follow.go107
1 files changed, 79 insertions, 28 deletions
diff --git a/internal/db/bundb/relationship_follow.go b/internal/db/bundb/relationship_follow.go
index 6c5a75e4c..93ee69bd7 100644
--- a/internal/db/bundb/relationship_follow.go
+++ b/internal/db/bundb/relationship_follow.go
@@ -21,6 +21,7 @@ import (
"context"
"errors"
"fmt"
+ "slices"
"time"
"github.com/superseriousbusiness/gotosocial/internal/db"
@@ -28,6 +29,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
+ "github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun"
)
@@ -62,7 +64,7 @@ func (r *relationshipDB) GetFollowByURI(ctx context.Context, uri string) (*gtsmo
func (r *relationshipDB) GetFollow(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.Follow, error) {
return r.getFollow(
ctx,
- "AccountID.TargetAccountID",
+ "AccountID,TargetAccountID",
func(follow *gtsmodel.Follow) error {
return r.db.NewSelect().
Model(follow).
@@ -76,21 +78,62 @@ func (r *relationshipDB) GetFollow(ctx context.Context, sourceAccountID string,
}
func (r *relationshipDB) GetFollowsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.Follow, error) {
- // Preallocate slice of expected length.
- follows := make([]*gtsmodel.Follow, 0, len(ids))
+ // Preallocate at-worst possible length.
+ uncached := make([]string, 0, len(ids))
+
+ // Load all follow IDs via cache loader callbacks.
+ follows, err := r.state.Caches.GTS.Follow.Load("ID",
+
+ // Load cached + check for uncached.
+ func(load func(keyParts ...any) bool) {
+ for _, id := range ids {
+ if !load(id) {
+ uncached = append(uncached, id)
+ }
+ }
+ },
- for _, id := range ids {
- // Fetch follow model for this ID.
- follow, err := r.GetFollowByID(ctx, id)
- if err != nil {
- log.Errorf(ctx, "error getting follow %q: %v", id, err)
- continue
- }
+ // Uncached follow loader function.
+ func() ([]*gtsmodel.Follow, error) {
+ // Preallocate expected length of uncached follows.
+ follows := make([]*gtsmodel.Follow, 0, len(uncached))
+
+ // Perform database query scanning
+ // the remaining (uncached) IDs.
+ if err := r.db.NewSelect().
+ Model(&follows).
+ Where("? IN (?)", bun.Ident("id"), bun.In(uncached)).
+ Scan(ctx); err != nil {
+ return nil, err
+ }
+
+ return follows, nil
+ },
+ )
+ if err != nil {
+ return nil, err
+ }
+
+ // Reorder the follows by their
+ // IDs to ensure in correct order.
+ getID := func(f *gtsmodel.Follow) string { return f.ID }
+ util.OrderBy(follows, ids, getID)
- // Append to return slice.
- follows = append(follows, follow)
+ if gtscontext.Barebones(ctx) {
+ // no need to fully populate.
+ return follows, nil
}
+ // Populate all loaded follows, removing those we fail to
+ // populate (removes needing so many nil checks everywhere).
+ follows = slices.DeleteFunc(follows, func(follow *gtsmodel.Follow) bool {
+ if err := r.PopulateFollow(ctx, follow); err != nil {
+ log.Errorf(ctx, "error populating follow %s: %v", follow.ID, err)
+ return true
+ }
+ return false
+ })
+
return follows, nil
}
@@ -130,7 +173,7 @@ func (r *relationshipDB) IsMutualFollowing(ctx context.Context, accountID1 strin
func (r *relationshipDB) getFollow(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Follow) error, keyParts ...any) (*gtsmodel.Follow, error) {
// Fetch follow from database cache with loader callback
- follow, err := r.state.Caches.GTS.Follow().Load(lookup, func() (*gtsmodel.Follow, error) {
+ follow, err := r.state.Caches.GTS.Follow.LoadOne(lookup, func() (*gtsmodel.Follow, error) {
var follow gtsmodel.Follow
// Not cached! Perform database query
@@ -189,7 +232,7 @@ func (r *relationshipDB) PopulateFollow(ctx context.Context, follow *gtsmodel.Fo
}
func (r *relationshipDB) PutFollow(ctx context.Context, follow *gtsmodel.Follow) error {
- return r.state.Caches.GTS.Follow().Store(follow, func() error {
+ return r.state.Caches.GTS.Follow.Store(follow, func() error {
_, err := r.db.NewInsert().Model(follow).Exec(ctx)
return err
})
@@ -202,7 +245,7 @@ func (r *relationshipDB) UpdateFollow(ctx context.Context, follow *gtsmodel.Foll
columns = append(columns, "updated_at")
}
- return r.state.Caches.GTS.Follow().Store(follow, func() error {
+ return r.state.Caches.GTS.Follow.Store(follow, func() error {
if _, err := r.db.NewUpdate().
Model(follow).
Where("? = ?", bun.Ident("follow.id"), follow.ID).
@@ -250,7 +293,7 @@ func (r *relationshipDB) DeleteFollow(ctx context.Context, sourceAccountID strin
}
// Drop this now-cached follow on return after delete.
- defer r.state.Caches.GTS.Follow().Invalidate("AccountID.TargetAccountID", sourceAccountID, targetAccountID)
+ defer r.state.Caches.GTS.Follow.Invalidate("AccountID,TargetAccountID", sourceAccountID, targetAccountID)
// Finally delete follow from DB.
return r.deleteFollow(ctx, follow.ID)
@@ -270,7 +313,7 @@ func (r *relationshipDB) DeleteFollowByID(ctx context.Context, id string) error
}
// Drop this now-cached follow on return after delete.
- defer r.state.Caches.GTS.Follow().Invalidate("ID", id)
+ defer r.state.Caches.GTS.Follow.Invalidate("ID", id)
// Finally delete follow from DB.
return r.deleteFollow(ctx, follow.ID)
@@ -290,7 +333,7 @@ func (r *relationshipDB) DeleteFollowByURI(ctx context.Context, uri string) erro
}
// Drop this now-cached follow on return after delete.
- defer r.state.Caches.GTS.Follow().Invalidate("URI", uri)
+ defer r.state.Caches.GTS.Follow.Invalidate("URI", uri)
// Finally delete follow from DB.
return r.deleteFollow(ctx, follow.ID)
@@ -316,22 +359,30 @@ func (r *relationshipDB) DeleteAccountFollows(ctx context.Context, accountID str
defer func() {
// Invalidate all account's incoming / outoing follows on return.
- r.state.Caches.GTS.Follow().Invalidate("AccountID", accountID)
- r.state.Caches.GTS.Follow().Invalidate("TargetAccountID", accountID)
+ r.state.Caches.GTS.Follow.Invalidate("AccountID", accountID)
+ r.state.Caches.GTS.Follow.Invalidate("TargetAccountID", accountID)
}()
// Load all follows into cache, this *really* isn't great
// but it is the only way we can ensure we invalidate all
// related caches correctly (e.g. visibility).
- for _, id := range followIDs {
- follow, err := r.GetFollowByID(ctx, id)
- if err != nil && !errors.Is(err, db.ErrNoEntries) {
- return err
- }
+ _, err := r.GetAccountFollows(ctx, accountID, nil)
+ if err != nil && !errors.Is(err, db.ErrNoEntries) {
+ return err
+ }
- // Delete each follow from DB.
- if err := r.deleteFollow(ctx, follow.ID); err != nil &&
- !errors.Is(err, db.ErrNoEntries) {
+ // Delete all follows from DB.
+ _, err = r.db.NewDelete().
+ Table("follows").
+ Where("? IN (?)", bun.Ident("id"), bun.In(followIDs)).
+ Exec(ctx)
+ if err != nil {
+ return err
+ }
+
+ for _, id := range followIDs {
+ // Finally, delete all list entries associated with each follow ID.
+ if err := r.state.DB.DeleteListEntriesForFollowID(ctx, id); err != nil {
return err
}
}