diff options
Diffstat (limited to 'internal/db/bundb/interaction.go')
-rw-r--r-- | internal/db/bundb/interaction.go | 81 |
1 files changed, 62 insertions, 19 deletions
diff --git a/internal/db/bundb/interaction.go b/internal/db/bundb/interaction.go index 78abcc763..88a044b6f 100644 --- a/internal/db/bundb/interaction.go +++ b/internal/db/bundb/interaction.go @@ -26,8 +26,10 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/gtscontext" "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/paging" "github.com/superseriousbusiness/gotosocial/internal/state" + "github.com/superseriousbusiness/gotosocial/internal/util" "github.com/uptrace/bun" ) @@ -84,6 +86,53 @@ func (i *interactionDB) GetInteractionRequestByURI(ctx context.Context, uri stri ) } +func (i *interactionDB) GetInteractionRequestsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.InteractionRequest, error) { + // Load all interaction request IDs via cache loader callbacks. + requests, err := i.state.Caches.DB.InteractionRequest.LoadIDs("ID", + ids, + func(uncached []string) ([]*gtsmodel.InteractionRequest, error) { + // Preallocate expected length of uncached interaction requests. + requests := make([]*gtsmodel.InteractionRequest, 0, len(uncached)) + + // Perform database query scanning + // the remaining (uncached) IDs. + if err := i.db.NewSelect(). + Model(&requests). + Where("? IN (?)", bun.Ident("id"), bun.In(uncached)). + Scan(ctx); err != nil { + return nil, err + } + + return requests, nil + }, + ) + if err != nil { + return nil, err + } + + // Reorder the requests by their + // IDs to ensure in correct order. + getID := func(r *gtsmodel.InteractionRequest) string { return r.ID } + util.OrderBy(requests, ids, getID) + + if gtscontext.Barebones(ctx) { + // no need to fully populate. + return requests, nil + } + + // Populate all loaded interaction requests, removing those we + // fail to populate (removes needing so many nil checks everywhere). + requests = slices.DeleteFunc(requests, func(request *gtsmodel.InteractionRequest) bool { + if err := i.PopulateInteractionRequest(ctx, request); err != nil { + log.Errorf(ctx, "error populating %s: %v", request.ID, err) + return true + } + return false + }) + + return requests, nil +} + func (i *interactionDB) getInteractionRequest( ctx context.Context, lookup string, @@ -205,13 +254,18 @@ func (i *interactionDB) UpdateInteractionRequest(ctx context.Context, request *g } func (i *interactionDB) DeleteInteractionRequestByID(ctx context.Context, id string) error { - defer i.state.Caches.DB.InteractionRequest.Invalidate("ID", id) + // Delete interaction request by ID. + if _, err := i.db.NewDelete(). + Table("interaction_requests"). + Where("? = ?", bun.Ident("id"), id). + Exec(ctx); err != nil { + return err + } + + // Invalidate cached interaction request with ID. + i.state.Caches.DB.InteractionRequest.Invalidate("ID", id) - _, err := i.db.NewDelete(). - TableExpr("? AS ?", bun.Ident("interaction_requests"), bun.Ident("interaction_request")). - Where("? = ?", bun.Ident("interaction_request.id"), id). - Exec(ctx) - return err + return nil } func (i *interactionDB) GetInteractionsRequestsForAcct( @@ -317,19 +371,8 @@ func (i *interactionDB) GetInteractionsRequestsForAcct( slices.Reverse(reqIDs) } - // For each interaction request ID, - // select the interaction request. - reqs := make([]*gtsmodel.InteractionRequest, 0, len(reqIDs)) - for _, id := range reqIDs { - req, err := i.GetInteractionRequestByID(ctx, id) - if err != nil { - return nil, err - } - - reqs = append(reqs, req) - } - - return reqs, nil + // Load all interaction requests by their IDs. + return i.GetInteractionRequestsByIDs(ctx, reqIDs) } func (i *interactionDB) IsInteractionRejected(ctx context.Context, interactionURI string) (bool, error) { |