summaryrefslogtreecommitdiff
path: root/internal/db/bundb/poll.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/db/bundb/poll.go')
-rw-r--r--internal/db/bundb/poll.go226
1 files changed, 88 insertions, 138 deletions
diff --git a/internal/db/bundb/poll.go b/internal/db/bundb/poll.go
index 5c1d9c6dd..f5c33ce9b 100644
--- a/internal/db/bundb/poll.go
+++ b/internal/db/bundb/poll.go
@@ -177,17 +177,36 @@ func (p *pollDB) UpdatePoll(ctx context.Context, poll *gtsmodel.Poll, cols ...st
}
func (p *pollDB) DeletePollByID(ctx context.Context, id string) error {
- // Delete poll by ID from database.
- if _, err := p.db.NewDelete().
- Table("polls").
- Where("? = ?", bun.Ident("id"), id).
- Exec(ctx); err != nil {
+ // Delete poll vote with ID, and its associated votes from the database.
+ if err := p.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
+
+ // Delete poll from database.
+ if _, err := tx.NewDelete().
+ Table("polls").
+ Where("? = ?", bun.Ident("id"), id).
+ Exec(ctx); err != nil {
+ return err
+ }
+
+ // Delete the poll votes.
+ _, err := tx.NewDelete().
+ Table("poll_votes").
+ Where("? = ?", bun.Ident("poll_id"), id).
+ Exec(ctx)
+ return err
+ }); err != nil {
return err
}
- // Invalidate poll by ID from cache.
+ // Wrap provided ID in a poll
+ // model for calling cache hook.
+ var deleted gtsmodel.Poll
+ deleted.ID = id
+
+ // Invalidate cached poll with ID, manually
+ // call invalidate hook in case not cached.
p.state.Caches.DB.Poll.Invalidate("ID", id)
- p.state.Caches.DB.PollVoteIDs.Invalidate(id)
+ p.state.Caches.OnInvalidatePoll(&deleted)
return nil
}
@@ -274,15 +293,8 @@ func (p *pollDB) GetPollVotes(ctx context.Context, pollID string) ([]*gtsmodel.P
votes, err := p.state.Caches.DB.PollVote.LoadIDs("ID",
voteIDs,
func(uncached []string) ([]*gtsmodel.PollVote, error) {
- // Avoid querying
- // if none uncached.
- count := len(uncached)
- if count == 0 {
- return nil, nil
- }
-
// Preallocate expected length of uncached votes.
- votes := make([]*gtsmodel.PollVote, 0, count)
+ votes := make([]*gtsmodel.PollVote, 0, len(uncached))
// Perform database query scanning
// the remaining (uncached) IDs.
@@ -391,148 +403,44 @@ func (p *pollDB) PutPollVote(ctx context.Context, vote *gtsmodel.PollVote) error
})
}
-func (p *pollDB) DeletePollVotes(ctx context.Context, pollID string) error {
- err := p.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
- // Delete all votes in poll.
- res, err := tx.NewDelete().
- Table("poll_votes").
- Where("? = ?", bun.Ident("poll_id"), pollID).
- Exec(ctx)
- if err != nil {
- // irrecoverable
- return err
- }
-
- ra, err := res.RowsAffected()
- if err != nil {
- // irrecoverable
- return err
- }
-
- if ra == 0 {
- // No poll votes deleted,
- // nothing to update.
- return nil
- }
-
- // Select current poll counts from DB,
- // taking minimal columns needed to
- // increment/decrement votes.
- var poll gtsmodel.Poll
- switch err := tx.NewSelect().
- Model(&poll).
- Column("options", "votes", "voters").
- Where("? = ?", bun.Ident("id"), pollID).
- Scan(ctx); {
-
- case err == nil:
- // no issue.
-
- case errors.Is(err, db.ErrNoEntries):
- // no votes found,
- // return here.
- return nil
-
- default:
- // irrecoverable.
- return err
- }
-
- // Zero all counts.
- poll.ResetVotes()
-
- // Finally, update the poll entry.
- _, err = tx.NewUpdate().
- Model(&poll).
- Column("votes", "voters").
- Where("? = ?", bun.Ident("id"), pollID).
- Exec(ctx)
- return err
- })
-
- if err != nil {
- return err
- }
-
- // Invalidate poll vote and poll entry from caches.
- p.state.Caches.DB.Poll.Invalidate("ID", pollID)
- p.state.Caches.DB.PollVote.Invalidate("PollID", pollID)
- p.state.Caches.DB.PollVoteIDs.Invalidate(pollID)
-
- return nil
-}
-
func (p *pollDB) DeletePollVoteBy(ctx context.Context, pollID string, accountID string) error {
- err := p.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
- // Slice should only ever be of length
- // 0 or 1; it's a slice of slices only
- // because we can't LIMIT deletes to 1.
- var choicesSlice [][]int
+ // Gather necessary fields from
+ // deleted for cache invaliation.
+ var deleted gtsmodel.PollVote
+ deleted.AccountID = accountID
+ deleted.PollID = pollID
+
+ // Delete the poll vote with given poll and account IDs, and update vote counts.
+ if err := p.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
// Delete vote in poll by account,
- // returning the ID + choices of the vote.
- if err := tx.NewDelete().
- Table("poll_votes").
+ // returning deleted model info.
+ switch _, err := tx.NewDelete().
+ Model(&deleted).
Where("? = ?", bun.Ident("poll_id"), pollID).
Where("? = ?", bun.Ident("account_id"), accountID).
Returning("?", bun.Ident("choices")).
- Scan(ctx, &choicesSlice); err != nil {
- // irrecoverable.
- return err
- }
-
- if len(choicesSlice) != 1 {
- // No poll votes by this
- // acct on this poll.
- return nil
- }
-
- // Extract the *actual* choices.
- choices := choicesSlice[0]
-
- // Select current poll counts from DB,
- // taking minimal columns needed to
- // increment/decrement votes.
- var poll gtsmodel.Poll
- switch err := tx.NewSelect().
- Model(&poll).
- Column("options", "votes", "voters").
- Where("? = ?", bun.Ident("id"), pollID).
- Scan(ctx); {
+ Exec(ctx); {
case err == nil:
- // no issue.
-
+ // no issue
case errors.Is(err, db.ErrNoEntries):
- // no poll found,
- // return here.
return nil
-
default:
- // irrecoverable.
return err
}
- // Decrement votes for choices.
- poll.DecrementVotes(choices)
-
- // Finally, update the poll entry.
- _, err := tx.NewUpdate().
- Model(&poll).
- Column("votes", "voters").
- Where("? = ?", bun.Ident("id"), pollID).
- Exec(ctx)
+ // Update the votes for this deleted poll.
+ err := updatePollCounts(ctx, tx, &deleted)
return err
- })
-
- if err != nil {
+ }); err != nil {
return err
}
- // Invalidate poll vote and poll entry from caches.
- p.state.Caches.DB.Poll.Invalidate("ID", pollID)
+ // Invalidate the poll vote cache by given poll + account IDs, also
+ // manually call invalidation hook in case not actually stored in cache.
p.state.Caches.DB.PollVote.Invalidate("PollID,AccountID", pollID, accountID)
- p.state.Caches.DB.PollVoteIDs.Invalidate(pollID)
+ p.state.Caches.OnInvalidatePollVote(&deleted)
return nil
}
@@ -562,6 +470,48 @@ func (p *pollDB) DeletePollVotesByAccountID(ctx context.Context, accountID strin
return nil
}
+// updatePollCounts updates the vote counts on a poll for the given deleted PollVote model.
+func updatePollCounts(ctx context.Context, tx bun.Tx, deleted *gtsmodel.PollVote) error {
+
+ // Select current poll counts from DB,
+ // taking minimal columns needed to
+ // increment/decrement votes.
+ var poll gtsmodel.Poll
+ switch err := tx.NewSelect().
+ Model(&poll).
+ Column("options", "votes", "voters").
+ Where("? = ?", bun.Ident("id"), deleted.PollID).
+ Scan(ctx); {
+
+ case err == nil:
+ // no issue.
+
+ case errors.Is(err, db.ErrNoEntries):
+ // no poll found,
+ // return here.
+ return nil
+
+ default:
+ // irrecoverable.
+ return err
+ }
+
+ // Decrement votes for these choices.
+ poll.DecrementVotes(deleted.Choices)
+
+ // Finally, update the poll entry.
+ if _, err := tx.NewUpdate().
+ Model(&poll).
+ Column("votes", "voters").
+ Where("? = ?", bun.Ident("id"), deleted.PollID).
+ Exec(ctx); err != nil &&
+ !errors.Is(err, db.ErrNoEntries) {
+ return err
+ }
+
+ return nil
+}
+
// newSelectPollVotes returns a new select query for all rows in the poll_votes table with poll_id = pollID.
func newSelectPollVotes(db *bun.DB, pollID string) *bun.SelectQuery {
return db.NewSelect().