summaryrefslogtreecommitdiff
path: root/internal/db/bundb/poll.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/db/bundb/poll.go')
-rw-r--r--internal/db/bundb/poll.go107
1 files changed, 77 insertions, 30 deletions
diff --git a/internal/db/bundb/poll.go b/internal/db/bundb/poll.go
index 3e77fb6c5..0dfb15621 100644
--- a/internal/db/bundb/poll.go
+++ b/internal/db/bundb/poll.go
@@ -20,6 +20,7 @@ package bundb
import (
"context"
"errors"
+ "slices"
"time"
"github.com/superseriousbusiness/gotosocial/internal/db"
@@ -28,6 +29,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state"
+ "github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun"
)
@@ -52,7 +54,7 @@ func (p *pollDB) GetPollByID(ctx context.Context, id string) (*gtsmodel.Poll, er
func (p *pollDB) getPoll(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Poll) error, keyParts ...any) (*gtsmodel.Poll, error) {
// Fetch poll from database cache with loader callback
- poll, err := p.state.Caches.GTS.Poll().Load(lookup, func() (*gtsmodel.Poll, error) {
+ poll, err := p.state.Caches.GTS.Poll.LoadOne(lookup, func() (*gtsmodel.Poll, error) {
var poll gtsmodel.Poll
// Not cached! Perform database query.
@@ -140,7 +142,7 @@ func (p *pollDB) PutPoll(ctx context.Context, poll *gtsmodel.Poll) error {
// is non nil and set.
poll.CheckVotes()
- return p.state.Caches.GTS.Poll().Store(poll, func() error {
+ return p.state.Caches.GTS.Poll.Store(poll, func() error {
_, err := p.db.NewInsert().Model(poll).Exec(ctx)
return err
})
@@ -151,7 +153,7 @@ func (p *pollDB) UpdatePoll(ctx context.Context, poll *gtsmodel.Poll, cols ...st
// is non nil and set.
poll.CheckVotes()
- return p.state.Caches.GTS.Poll().Store(poll, func() error {
+ return p.state.Caches.GTS.Poll.Store(poll, func() error {
return p.db.RunInTx(ctx, func(tx Tx) error {
// Update the status' "updated_at" field.
if _, err := tx.NewUpdate().
@@ -184,8 +186,8 @@ func (p *pollDB) DeletePollByID(ctx context.Context, id string) error {
}
// Invalidate poll by ID from cache.
- p.state.Caches.GTS.Poll().Invalidate("ID", id)
- p.state.Caches.GTS.PollVoteIDs().Invalidate(id)
+ p.state.Caches.GTS.Poll.Invalidate("ID", id)
+ p.state.Caches.GTS.PollVoteIDs.Invalidate(id)
return nil
}
@@ -207,7 +209,7 @@ func (p *pollDB) GetPollVoteByID(ctx context.Context, id string) (*gtsmodel.Poll
func (p *pollDB) GetPollVoteBy(ctx context.Context, pollID string, accountID string) (*gtsmodel.PollVote, error) {
return p.getPollVote(
ctx,
- "PollID.AccountID",
+ "PollID,AccountID",
func(vote *gtsmodel.PollVote) error {
return p.db.NewSelect().
Model(vote).
@@ -222,7 +224,7 @@ func (p *pollDB) GetPollVoteBy(ctx context.Context, pollID string, accountID str
func (p *pollDB) getPollVote(ctx context.Context, lookup string, dbQuery func(*gtsmodel.PollVote) error, keyParts ...any) (*gtsmodel.PollVote, error) {
// Fetch vote from database cache with loader callback
- vote, err := p.state.Caches.GTS.PollVote().Load(lookup, func() (*gtsmodel.PollVote, error) {
+ vote, err := p.state.Caches.GTS.PollVote.LoadOne(lookup, func() (*gtsmodel.PollVote, error) {
var vote gtsmodel.PollVote
// Not cached! Perform database query.
@@ -250,7 +252,9 @@ func (p *pollDB) getPollVote(ctx context.Context, lookup string, dbQuery func(*g
}
func (p *pollDB) GetPollVotes(ctx context.Context, pollID string) ([]*gtsmodel.PollVote, error) {
- voteIDs, err := p.state.Caches.GTS.PollVoteIDs().Load(pollID, func() ([]string, error) {
+
+ // Load vote IDs known for given poll ID using loader callback.
+ voteIDs, err := p.state.Caches.GTS.PollVoteIDs.Load(pollID, func() ([]string, error) {
var voteIDs []string
// Vote IDs not in cache, perform DB query!
@@ -266,21 +270,62 @@ func (p *pollDB) GetPollVotes(ctx context.Context, pollID string) ([]*gtsmodel.P
return nil, err
}
- // Preallocate slice of expected length.
- votes := make([]*gtsmodel.PollVote, 0, len(voteIDs))
+ // Preallocate at-worst possible length.
+ uncached := make([]string, 0, len(voteIDs))
- for _, id := range voteIDs {
- // Fetch poll vote model for this ID.
- vote, err := p.GetPollVoteByID(ctx, id)
- if err != nil {
- log.Errorf(ctx, "error getting poll vote %s: %v", id, err)
- continue
- }
+ // Load all votes from IDs via cache loader callbacks.
+ votes, err := p.state.Caches.GTS.PollVote.Load("ID",
+
+ // Load cached + check for uncached.
+ func(load func(keyParts ...any) bool) {
+ for _, id := range voteIDs {
+ if !load(id) {
+ uncached = append(uncached, id)
+ }
+ }
+ },
+
+ // Uncached poll vote loader function.
+ func() ([]*gtsmodel.PollVote, error) {
+ // Preallocate expected length of uncached votes.
+ votes := make([]*gtsmodel.PollVote, 0, len(uncached))
+
+ // Perform database query scanning
+ // the remaining (uncached) IDs.
+ if err := p.db.NewSelect().
+ Model(&votes).
+ Where("? IN (?)", bun.Ident("id"), bun.In(uncached)).
+ Scan(ctx); err != nil {
+ return nil, err
+ }
+
+ return votes, nil
+ },
+ )
+ if err != nil {
+ return nil, err
+ }
+
+ // Reorder the poll votes by their
+ // IDs to ensure in correct order.
+ getID := func(v *gtsmodel.PollVote) string { return v.ID }
+ util.OrderBy(votes, voteIDs, getID)
- // Append to return slice.
- votes = append(votes, vote)
+ if gtscontext.Barebones(ctx) {
+ // no need to fully populate.
+ return votes, nil
}
+ // Populate all loaded votes, removing those we fail to
+ // populate (removes needing so many nil checks everywhere).
+ votes = slices.DeleteFunc(votes, func(vote *gtsmodel.PollVote) bool {
+ if err := p.PopulatePollVote(ctx, vote); err != nil {
+ log.Errorf(ctx, "error populating vote %s: %v", vote.ID, err)
+ return true
+ }
+ return false
+ })
+
return votes, nil
}
@@ -316,7 +361,7 @@ func (p *pollDB) PopulatePollVote(ctx context.Context, vote *gtsmodel.PollVote)
}
func (p *pollDB) PutPollVote(ctx context.Context, vote *gtsmodel.PollVote) error {
- return p.state.Caches.GTS.PollVote().Store(vote, func() error {
+ return p.state.Caches.GTS.PollVote.Store(vote, func() error {
return p.db.RunInTx(ctx, func(tx Tx) error {
// Try insert vote into database.
if _, err := tx.NewInsert().
@@ -416,9 +461,9 @@ func (p *pollDB) DeletePollVotes(ctx context.Context, pollID string) error {
}
// Invalidate poll vote and poll entry from caches.
- p.state.Caches.GTS.Poll().Invalidate("ID", pollID)
- p.state.Caches.GTS.PollVote().Invalidate("PollID", pollID)
- p.state.Caches.GTS.PollVoteIDs().Invalidate(pollID)
+ p.state.Caches.GTS.Poll.Invalidate("ID", pollID)
+ p.state.Caches.GTS.PollVote.Invalidate("PollID", pollID)
+ p.state.Caches.GTS.PollVoteIDs.Invalidate(pollID)
return nil
}
@@ -428,7 +473,7 @@ func (p *pollDB) DeletePollVoteBy(ctx context.Context, pollID string, accountID
// Slice should only ever be of length
// 0 or 1; it's a slice of slices only
// because we can't LIMIT deletes to 1.
- var choicesSl [][]int
+ var choicesSlice [][]int
// Delete vote in poll by account,
// returning the ID + choices of the vote.
@@ -437,17 +482,19 @@ func (p *pollDB) DeletePollVoteBy(ctx context.Context, pollID string, accountID
Where("? = ?", bun.Ident("poll_id"), pollID).
Where("? = ?", bun.Ident("account_id"), accountID).
Returning("?", bun.Ident("choices")).
- Scan(ctx, &choicesSl); err != nil {
+ Scan(ctx, &choicesSlice); err != nil {
// irrecoverable.
return err
}
- if len(choicesSl) != 1 {
+ if len(choicesSlice) != 1 {
// No poll votes by this
// acct on this poll.
return nil
}
- choices := choicesSl[0]
+
+ // Extract the *actual* choices.
+ choices := choicesSlice[0]
// Select current poll counts from DB,
// taking minimal columns needed to
@@ -489,9 +536,9 @@ func (p *pollDB) DeletePollVoteBy(ctx context.Context, pollID string, accountID
}
// Invalidate poll vote and poll entry from caches.
- p.state.Caches.GTS.Poll().Invalidate("ID", pollID)
- p.state.Caches.GTS.PollVote().Invalidate("PollID.AccountID", pollID, accountID)
- p.state.Caches.GTS.PollVoteIDs().Invalidate(pollID)
+ p.state.Caches.GTS.Poll.Invalidate("ID", pollID)
+ p.state.Caches.GTS.PollVote.Invalidate("PollID,AccountID", pollID, accountID)
+ p.state.Caches.GTS.PollVoteIDs.Invalidate(pollID)
return nil
}