diff options
Diffstat (limited to 'internal/db/bundb/relationship_follow_req.go')
-rw-r--r-- | internal/db/bundb/relationship_follow_req.go | 97 |
1 files changed, 69 insertions, 28 deletions
diff --git a/internal/db/bundb/relationship_follow_req.go b/internal/db/bundb/relationship_follow_req.go index 51aceafe1..690b97cf0 100644 --- a/internal/db/bundb/relationship_follow_req.go +++ b/internal/db/bundb/relationship_follow_req.go @@ -20,6 +20,7 @@ package bundb import ( "context" "errors" + "slices" "time" "github.com/superseriousbusiness/gotosocial/internal/db" @@ -27,6 +28,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/log" + "github.com/superseriousbusiness/gotosocial/internal/util" "github.com/uptrace/bun" ) @@ -61,7 +63,7 @@ func (r *relationshipDB) GetFollowRequestByURI(ctx context.Context, uri string) func (r *relationshipDB) GetFollowRequest(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.FollowRequest, error) { return r.getFollowRequest( ctx, - "AccountID.TargetAccountID", + "AccountID,TargetAccountID", func(followReq *gtsmodel.FollowRequest) error { return r.db.NewSelect(). Model(followReq). @@ -75,22 +77,63 @@ func (r *relationshipDB) GetFollowRequest(ctx context.Context, sourceAccountID s } func (r *relationshipDB) GetFollowRequestsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.FollowRequest, error) { - // Preallocate slice of expected length. - followReqs := make([]*gtsmodel.FollowRequest, 0, len(ids)) + // Preallocate at-worst possible length. + uncached := make([]string, 0, len(ids)) + + // Load all follow IDs via cache loader callbacks. + follows, err := r.state.Caches.GTS.FollowRequest.Load("ID", + + // Load cached + check for uncached. + func(load func(keyParts ...any) bool) { + for _, id := range ids { + if !load(id) { + uncached = append(uncached, id) + } + } + }, - for _, id := range ids { - // Fetch follow request model for this ID. - followReq, err := r.GetFollowRequestByID(ctx, id) - if err != nil { - log.Errorf(ctx, "error getting follow request %q: %v", id, err) - continue - } + // Uncached follow req loader function. + func() ([]*gtsmodel.FollowRequest, error) { + // Preallocate expected length of uncached followReqs. + follows := make([]*gtsmodel.FollowRequest, 0, len(uncached)) + + // Perform database query scanning + // the remaining (uncached) IDs. + if err := r.db.NewSelect(). + Model(&follows). + Where("? IN (?)", bun.Ident("id"), bun.In(uncached)). + Scan(ctx); err != nil { + return nil, err + } + + return follows, nil + }, + ) + if err != nil { + return nil, err + } + + // Reorder the requests by their + // IDs to ensure in correct order. + getID := func(f *gtsmodel.FollowRequest) string { return f.ID } + util.OrderBy(follows, ids, getID) - // Append to return slice. - followReqs = append(followReqs, followReq) + if gtscontext.Barebones(ctx) { + // no need to fully populate. + return follows, nil } - return followReqs, nil + // Populate all loaded followreqs, removing those we fail to + // populate (removes needing so many nil checks everywhere). + follows = slices.DeleteFunc(follows, func(follow *gtsmodel.FollowRequest) bool { + if err := r.PopulateFollowRequest(ctx, follow); err != nil { + log.Errorf(ctx, "error populating follow request %s: %v", follow.ID, err) + return true + } + return false + }) + + return follows, nil } func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, error) { @@ -107,7 +150,7 @@ func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccountID func (r *relationshipDB) getFollowRequest(ctx context.Context, lookup string, dbQuery func(*gtsmodel.FollowRequest) error, keyParts ...any) (*gtsmodel.FollowRequest, error) { // Fetch follow request from database cache with loader callback - followReq, err := r.state.Caches.GTS.FollowRequest().Load(lookup, func() (*gtsmodel.FollowRequest, error) { + followReq, err := r.state.Caches.GTS.FollowRequest.LoadOne(lookup, func() (*gtsmodel.FollowRequest, error) { var followReq gtsmodel.FollowRequest // Not cached! Perform database query @@ -166,7 +209,7 @@ func (r *relationshipDB) PopulateFollowRequest(ctx context.Context, follow *gtsm } func (r *relationshipDB) PutFollowRequest(ctx context.Context, follow *gtsmodel.FollowRequest) error { - return r.state.Caches.GTS.FollowRequest().Store(follow, func() error { + return r.state.Caches.GTS.FollowRequest.Store(follow, func() error { _, err := r.db.NewInsert().Model(follow).Exec(ctx) return err }) @@ -179,7 +222,7 @@ func (r *relationshipDB) UpdateFollowRequest(ctx context.Context, followRequest columns = append(columns, "updated_at") } - return r.state.Caches.GTS.FollowRequest().Store(followRequest, func() error { + return r.state.Caches.GTS.FollowRequest.Store(followRequest, func() error { if _, err := r.db.NewUpdate(). Model(followRequest). Where("? = ?", bun.Ident("follow_request.id"), followRequest.ID). @@ -212,7 +255,7 @@ func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, sourceAccountI Notify: followReq.Notify, } - if err := r.state.Caches.GTS.Follow().Store(follow, func() error { + if err := r.state.Caches.GTS.Follow.Store(follow, func() error { // If the follow already exists, just // replace the URI with the new one. _, err := r.db. @@ -274,7 +317,7 @@ func (r *relationshipDB) DeleteFollowRequest(ctx context.Context, sourceAccountI } // Drop this now-cached follow request on return after delete. - defer r.state.Caches.GTS.FollowRequest().Invalidate("AccountID.TargetAccountID", sourceAccountID, targetAccountID) + defer r.state.Caches.GTS.FollowRequest.Invalidate("AccountID,TargetAccountID", sourceAccountID, targetAccountID) // Finally delete followreq from DB. _, err = r.db.NewDelete(). @@ -298,7 +341,7 @@ func (r *relationshipDB) DeleteFollowRequestByID(ctx context.Context, id string) } // Drop this now-cached follow request on return after delete. - defer r.state.Caches.GTS.FollowRequest().Invalidate("ID", id) + defer r.state.Caches.GTS.FollowRequest.Invalidate("ID", id) // Finally delete followreq from DB. _, err = r.db.NewDelete(). @@ -322,7 +365,7 @@ func (r *relationshipDB) DeleteFollowRequestByURI(ctx context.Context, uri strin } // Drop this now-cached follow request on return after delete. - defer r.state.Caches.GTS.FollowRequest().Invalidate("URI", uri) + defer r.state.Caches.GTS.FollowRequest.Invalidate("URI", uri) // Finally delete followreq from DB. _, err = r.db.NewDelete(). @@ -352,22 +395,20 @@ func (r *relationshipDB) DeleteAccountFollowRequests(ctx context.Context, accoun defer func() { // Invalidate all account's incoming / outoing follow requests on return. - r.state.Caches.GTS.FollowRequest().Invalidate("AccountID", accountID) - r.state.Caches.GTS.FollowRequest().Invalidate("TargetAccountID", accountID) + r.state.Caches.GTS.FollowRequest.Invalidate("AccountID", accountID) + r.state.Caches.GTS.FollowRequest.Invalidate("TargetAccountID", accountID) }() // Load all followreqs into cache, this *really* isn't // great but it is the only way we can ensure we invalidate // all related caches correctly (e.g. visibility). - for _, id := range followReqIDs { - _, err := r.GetFollowRequestByID(ctx, id) - if err != nil && !errors.Is(err, db.ErrNoEntries) { - return err - } + _, err := r.GetAccountFollowRequests(ctx, accountID, nil) + if err != nil && !errors.Is(err, db.ErrNoEntries) { + return err } // Finally delete all from DB. - _, err := r.db.NewDelete(). + _, err = r.db.NewDelete(). Table("follow_requests"). Where("? IN (?)", bun.Ident("id"), bun.In(followReqIDs)). Exec(ctx) |