diff options
Diffstat (limited to 'internal/db/bundb/poll.go')
-rw-r--r-- | internal/db/bundb/poll.go | 226 |
1 files changed, 88 insertions, 138 deletions
diff --git a/internal/db/bundb/poll.go b/internal/db/bundb/poll.go index 5c1d9c6dd..f5c33ce9b 100644 --- a/internal/db/bundb/poll.go +++ b/internal/db/bundb/poll.go @@ -177,17 +177,36 @@ func (p *pollDB) UpdatePoll(ctx context.Context, poll *gtsmodel.Poll, cols ...st } func (p *pollDB) DeletePollByID(ctx context.Context, id string) error { - // Delete poll by ID from database. - if _, err := p.db.NewDelete(). - Table("polls"). - Where("? = ?", bun.Ident("id"), id). - Exec(ctx); err != nil { + // Delete poll vote with ID, and its associated votes from the database. + if err := p.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + + // Delete poll from database. + if _, err := tx.NewDelete(). + Table("polls"). + Where("? = ?", bun.Ident("id"), id). + Exec(ctx); err != nil { + return err + } + + // Delete the poll votes. + _, err := tx.NewDelete(). + Table("poll_votes"). + Where("? = ?", bun.Ident("poll_id"), id). + Exec(ctx) + return err + }); err != nil { return err } - // Invalidate poll by ID from cache. + // Wrap provided ID in a poll + // model for calling cache hook. + var deleted gtsmodel.Poll + deleted.ID = id + + // Invalidate cached poll with ID, manually + // call invalidate hook in case not cached. p.state.Caches.DB.Poll.Invalidate("ID", id) - p.state.Caches.DB.PollVoteIDs.Invalidate(id) + p.state.Caches.OnInvalidatePoll(&deleted) return nil } @@ -274,15 +293,8 @@ func (p *pollDB) GetPollVotes(ctx context.Context, pollID string) ([]*gtsmodel.P votes, err := p.state.Caches.DB.PollVote.LoadIDs("ID", voteIDs, func(uncached []string) ([]*gtsmodel.PollVote, error) { - // Avoid querying - // if none uncached. - count := len(uncached) - if count == 0 { - return nil, nil - } - // Preallocate expected length of uncached votes. - votes := make([]*gtsmodel.PollVote, 0, count) + votes := make([]*gtsmodel.PollVote, 0, len(uncached)) // Perform database query scanning // the remaining (uncached) IDs. @@ -391,148 +403,44 @@ func (p *pollDB) PutPollVote(ctx context.Context, vote *gtsmodel.PollVote) error }) } -func (p *pollDB) DeletePollVotes(ctx context.Context, pollID string) error { - err := p.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { - // Delete all votes in poll. - res, err := tx.NewDelete(). - Table("poll_votes"). - Where("? = ?", bun.Ident("poll_id"), pollID). - Exec(ctx) - if err != nil { - // irrecoverable - return err - } - - ra, err := res.RowsAffected() - if err != nil { - // irrecoverable - return err - } - - if ra == 0 { - // No poll votes deleted, - // nothing to update. - return nil - } - - // Select current poll counts from DB, - // taking minimal columns needed to - // increment/decrement votes. - var poll gtsmodel.Poll - switch err := tx.NewSelect(). - Model(&poll). - Column("options", "votes", "voters"). - Where("? = ?", bun.Ident("id"), pollID). - Scan(ctx); { - - case err == nil: - // no issue. - - case errors.Is(err, db.ErrNoEntries): - // no votes found, - // return here. - return nil - - default: - // irrecoverable. - return err - } - - // Zero all counts. - poll.ResetVotes() - - // Finally, update the poll entry. - _, err = tx.NewUpdate(). - Model(&poll). - Column("votes", "voters"). - Where("? = ?", bun.Ident("id"), pollID). - Exec(ctx) - return err - }) - - if err != nil { - return err - } - - // Invalidate poll vote and poll entry from caches. - p.state.Caches.DB.Poll.Invalidate("ID", pollID) - p.state.Caches.DB.PollVote.Invalidate("PollID", pollID) - p.state.Caches.DB.PollVoteIDs.Invalidate(pollID) - - return nil -} - func (p *pollDB) DeletePollVoteBy(ctx context.Context, pollID string, accountID string) error { - err := p.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { - // Slice should only ever be of length - // 0 or 1; it's a slice of slices only - // because we can't LIMIT deletes to 1. - var choicesSlice [][]int + // Gather necessary fields from + // deleted for cache invaliation. + var deleted gtsmodel.PollVote + deleted.AccountID = accountID + deleted.PollID = pollID + + // Delete the poll vote with given poll and account IDs, and update vote counts. + if err := p.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { // Delete vote in poll by account, - // returning the ID + choices of the vote. - if err := tx.NewDelete(). - Table("poll_votes"). + // returning deleted model info. + switch _, err := tx.NewDelete(). + Model(&deleted). Where("? = ?", bun.Ident("poll_id"), pollID). Where("? = ?", bun.Ident("account_id"), accountID). Returning("?", bun.Ident("choices")). - Scan(ctx, &choicesSlice); err != nil { - // irrecoverable. - return err - } - - if len(choicesSlice) != 1 { - // No poll votes by this - // acct on this poll. - return nil - } - - // Extract the *actual* choices. - choices := choicesSlice[0] - - // Select current poll counts from DB, - // taking minimal columns needed to - // increment/decrement votes. - var poll gtsmodel.Poll - switch err := tx.NewSelect(). - Model(&poll). - Column("options", "votes", "voters"). - Where("? = ?", bun.Ident("id"), pollID). - Scan(ctx); { + Exec(ctx); { case err == nil: - // no issue. - + // no issue case errors.Is(err, db.ErrNoEntries): - // no poll found, - // return here. return nil - default: - // irrecoverable. return err } - // Decrement votes for choices. - poll.DecrementVotes(choices) - - // Finally, update the poll entry. - _, err := tx.NewUpdate(). - Model(&poll). - Column("votes", "voters"). - Where("? = ?", bun.Ident("id"), pollID). - Exec(ctx) + // Update the votes for this deleted poll. + err := updatePollCounts(ctx, tx, &deleted) return err - }) - - if err != nil { + }); err != nil { return err } - // Invalidate poll vote and poll entry from caches. - p.state.Caches.DB.Poll.Invalidate("ID", pollID) + // Invalidate the poll vote cache by given poll + account IDs, also + // manually call invalidation hook in case not actually stored in cache. p.state.Caches.DB.PollVote.Invalidate("PollID,AccountID", pollID, accountID) - p.state.Caches.DB.PollVoteIDs.Invalidate(pollID) + p.state.Caches.OnInvalidatePollVote(&deleted) return nil } @@ -562,6 +470,48 @@ func (p *pollDB) DeletePollVotesByAccountID(ctx context.Context, accountID strin return nil } +// updatePollCounts updates the vote counts on a poll for the given deleted PollVote model. +func updatePollCounts(ctx context.Context, tx bun.Tx, deleted *gtsmodel.PollVote) error { + + // Select current poll counts from DB, + // taking minimal columns needed to + // increment/decrement votes. + var poll gtsmodel.Poll + switch err := tx.NewSelect(). + Model(&poll). + Column("options", "votes", "voters"). + Where("? = ?", bun.Ident("id"), deleted.PollID). + Scan(ctx); { + + case err == nil: + // no issue. + + case errors.Is(err, db.ErrNoEntries): + // no poll found, + // return here. + return nil + + default: + // irrecoverable. + return err + } + + // Decrement votes for these choices. + poll.DecrementVotes(deleted.Choices) + + // Finally, update the poll entry. + if _, err := tx.NewUpdate(). + Model(&poll). + Column("votes", "voters"). + Where("? = ?", bun.Ident("id"), deleted.PollID). + Exec(ctx); err != nil && + !errors.Is(err, db.ErrNoEntries) { + return err + } + + return nil +} + // newSelectPollVotes returns a new select query for all rows in the poll_votes table with poll_id = pollID. func newSelectPollVotes(db *bun.DB, pollID string) *bun.SelectQuery { return db.NewSelect(). |