diff options
Diffstat (limited to 'internal/db/bundb/poll.go')
-rw-r--r-- | internal/db/bundb/poll.go | 107 |
1 files changed, 77 insertions, 30 deletions
diff --git a/internal/db/bundb/poll.go b/internal/db/bundb/poll.go index 3e77fb6c5..0dfb15621 100644 --- a/internal/db/bundb/poll.go +++ b/internal/db/bundb/poll.go @@ -20,6 +20,7 @@ package bundb import ( "context" "errors" + "slices" "time" "github.com/superseriousbusiness/gotosocial/internal/db" @@ -28,6 +29,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/state" + "github.com/superseriousbusiness/gotosocial/internal/util" "github.com/uptrace/bun" ) @@ -52,7 +54,7 @@ func (p *pollDB) GetPollByID(ctx context.Context, id string) (*gtsmodel.Poll, er func (p *pollDB) getPoll(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Poll) error, keyParts ...any) (*gtsmodel.Poll, error) { // Fetch poll from database cache with loader callback - poll, err := p.state.Caches.GTS.Poll().Load(lookup, func() (*gtsmodel.Poll, error) { + poll, err := p.state.Caches.GTS.Poll.LoadOne(lookup, func() (*gtsmodel.Poll, error) { var poll gtsmodel.Poll // Not cached! Perform database query. @@ -140,7 +142,7 @@ func (p *pollDB) PutPoll(ctx context.Context, poll *gtsmodel.Poll) error { // is non nil and set. poll.CheckVotes() - return p.state.Caches.GTS.Poll().Store(poll, func() error { + return p.state.Caches.GTS.Poll.Store(poll, func() error { _, err := p.db.NewInsert().Model(poll).Exec(ctx) return err }) @@ -151,7 +153,7 @@ func (p *pollDB) UpdatePoll(ctx context.Context, poll *gtsmodel.Poll, cols ...st // is non nil and set. poll.CheckVotes() - return p.state.Caches.GTS.Poll().Store(poll, func() error { + return p.state.Caches.GTS.Poll.Store(poll, func() error { return p.db.RunInTx(ctx, func(tx Tx) error { // Update the status' "updated_at" field. if _, err := tx.NewUpdate(). @@ -184,8 +186,8 @@ func (p *pollDB) DeletePollByID(ctx context.Context, id string) error { } // Invalidate poll by ID from cache. - p.state.Caches.GTS.Poll().Invalidate("ID", id) - p.state.Caches.GTS.PollVoteIDs().Invalidate(id) + p.state.Caches.GTS.Poll.Invalidate("ID", id) + p.state.Caches.GTS.PollVoteIDs.Invalidate(id) return nil } @@ -207,7 +209,7 @@ func (p *pollDB) GetPollVoteByID(ctx context.Context, id string) (*gtsmodel.Poll func (p *pollDB) GetPollVoteBy(ctx context.Context, pollID string, accountID string) (*gtsmodel.PollVote, error) { return p.getPollVote( ctx, - "PollID.AccountID", + "PollID,AccountID", func(vote *gtsmodel.PollVote) error { return p.db.NewSelect(). Model(vote). @@ -222,7 +224,7 @@ func (p *pollDB) GetPollVoteBy(ctx context.Context, pollID string, accountID str func (p *pollDB) getPollVote(ctx context.Context, lookup string, dbQuery func(*gtsmodel.PollVote) error, keyParts ...any) (*gtsmodel.PollVote, error) { // Fetch vote from database cache with loader callback - vote, err := p.state.Caches.GTS.PollVote().Load(lookup, func() (*gtsmodel.PollVote, error) { + vote, err := p.state.Caches.GTS.PollVote.LoadOne(lookup, func() (*gtsmodel.PollVote, error) { var vote gtsmodel.PollVote // Not cached! Perform database query. @@ -250,7 +252,9 @@ func (p *pollDB) getPollVote(ctx context.Context, lookup string, dbQuery func(*g } func (p *pollDB) GetPollVotes(ctx context.Context, pollID string) ([]*gtsmodel.PollVote, error) { - voteIDs, err := p.state.Caches.GTS.PollVoteIDs().Load(pollID, func() ([]string, error) { + + // Load vote IDs known for given poll ID using loader callback. + voteIDs, err := p.state.Caches.GTS.PollVoteIDs.Load(pollID, func() ([]string, error) { var voteIDs []string // Vote IDs not in cache, perform DB query! @@ -266,21 +270,62 @@ func (p *pollDB) GetPollVotes(ctx context.Context, pollID string) ([]*gtsmodel.P return nil, err } - // Preallocate slice of expected length. - votes := make([]*gtsmodel.PollVote, 0, len(voteIDs)) + // Preallocate at-worst possible length. + uncached := make([]string, 0, len(voteIDs)) - for _, id := range voteIDs { - // Fetch poll vote model for this ID. - vote, err := p.GetPollVoteByID(ctx, id) - if err != nil { - log.Errorf(ctx, "error getting poll vote %s: %v", id, err) - continue - } + // Load all votes from IDs via cache loader callbacks. + votes, err := p.state.Caches.GTS.PollVote.Load("ID", + + // Load cached + check for uncached. + func(load func(keyParts ...any) bool) { + for _, id := range voteIDs { + if !load(id) { + uncached = append(uncached, id) + } + } + }, + + // Uncached poll vote loader function. + func() ([]*gtsmodel.PollVote, error) { + // Preallocate expected length of uncached votes. + votes := make([]*gtsmodel.PollVote, 0, len(uncached)) + + // Perform database query scanning + // the remaining (uncached) IDs. + if err := p.db.NewSelect(). + Model(&votes). + Where("? IN (?)", bun.Ident("id"), bun.In(uncached)). + Scan(ctx); err != nil { + return nil, err + } + + return votes, nil + }, + ) + if err != nil { + return nil, err + } + + // Reorder the poll votes by their + // IDs to ensure in correct order. + getID := func(v *gtsmodel.PollVote) string { return v.ID } + util.OrderBy(votes, voteIDs, getID) - // Append to return slice. - votes = append(votes, vote) + if gtscontext.Barebones(ctx) { + // no need to fully populate. + return votes, nil } + // Populate all loaded votes, removing those we fail to + // populate (removes needing so many nil checks everywhere). + votes = slices.DeleteFunc(votes, func(vote *gtsmodel.PollVote) bool { + if err := p.PopulatePollVote(ctx, vote); err != nil { + log.Errorf(ctx, "error populating vote %s: %v", vote.ID, err) + return true + } + return false + }) + return votes, nil } @@ -316,7 +361,7 @@ func (p *pollDB) PopulatePollVote(ctx context.Context, vote *gtsmodel.PollVote) } func (p *pollDB) PutPollVote(ctx context.Context, vote *gtsmodel.PollVote) error { - return p.state.Caches.GTS.PollVote().Store(vote, func() error { + return p.state.Caches.GTS.PollVote.Store(vote, func() error { return p.db.RunInTx(ctx, func(tx Tx) error { // Try insert vote into database. if _, err := tx.NewInsert(). @@ -416,9 +461,9 @@ func (p *pollDB) DeletePollVotes(ctx context.Context, pollID string) error { } // Invalidate poll vote and poll entry from caches. - p.state.Caches.GTS.Poll().Invalidate("ID", pollID) - p.state.Caches.GTS.PollVote().Invalidate("PollID", pollID) - p.state.Caches.GTS.PollVoteIDs().Invalidate(pollID) + p.state.Caches.GTS.Poll.Invalidate("ID", pollID) + p.state.Caches.GTS.PollVote.Invalidate("PollID", pollID) + p.state.Caches.GTS.PollVoteIDs.Invalidate(pollID) return nil } @@ -428,7 +473,7 @@ func (p *pollDB) DeletePollVoteBy(ctx context.Context, pollID string, accountID // Slice should only ever be of length // 0 or 1; it's a slice of slices only // because we can't LIMIT deletes to 1. - var choicesSl [][]int + var choicesSlice [][]int // Delete vote in poll by account, // returning the ID + choices of the vote. @@ -437,17 +482,19 @@ func (p *pollDB) DeletePollVoteBy(ctx context.Context, pollID string, accountID Where("? = ?", bun.Ident("poll_id"), pollID). Where("? = ?", bun.Ident("account_id"), accountID). Returning("?", bun.Ident("choices")). - Scan(ctx, &choicesSl); err != nil { + Scan(ctx, &choicesSlice); err != nil { // irrecoverable. return err } - if len(choicesSl) != 1 { + if len(choicesSlice) != 1 { // No poll votes by this // acct on this poll. return nil } - choices := choicesSl[0] + + // Extract the *actual* choices. + choices := choicesSlice[0] // Select current poll counts from DB, // taking minimal columns needed to @@ -489,9 +536,9 @@ func (p *pollDB) DeletePollVoteBy(ctx context.Context, pollID string, accountID } // Invalidate poll vote and poll entry from caches. - p.state.Caches.GTS.Poll().Invalidate("ID", pollID) - p.state.Caches.GTS.PollVote().Invalidate("PollID.AccountID", pollID, accountID) - p.state.Caches.GTS.PollVoteIDs().Invalidate(pollID) + p.state.Caches.GTS.Poll.Invalidate("ID", pollID) + p.state.Caches.GTS.PollVote.Invalidate("PollID,AccountID", pollID, accountID) + p.state.Caches.GTS.PollVoteIDs.Invalidate(pollID) return nil } |