summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLibravatar tobi <31960611+tsmethurst@users.noreply.github.com>2023-11-14 13:43:27 +0100
committerLibravatar GitHub <noreply@github.com>2023-11-14 12:43:27 +0000
commit0b99f14d64d5d372824c4d7602543610f5c006a1 (patch)
tree953e9e12fb312184b85dda88e278075ac260dc9f
parent[feature/performance] Wrap incoming HTTP requests in timeout handler (#2353) (diff)
downloadgotosocial-0b99f14d64d5d372824c4d7602543610f5c006a1.tar.xz
[bugfix] Update poll delete/update db queries (#2361)
-rw-r--r--internal/db/bundb/migrations/20231110142330_small_poll_table_tweaks.go2
-rw-r--r--internal/db/bundb/poll.go77
-rw-r--r--internal/db/bundb/poll_test.go54
3 files changed, 96 insertions, 37 deletions
diff --git a/internal/db/bundb/migrations/20231110142330_small_poll_table_tweaks.go b/internal/db/bundb/migrations/20231110142330_small_poll_table_tweaks.go
index c9f2b3d0f..dad943efa 100644
--- a/internal/db/bundb/migrations/20231110142330_small_poll_table_tweaks.go
+++ b/internal/db/bundb/migrations/20231110142330_small_poll_table_tweaks.go
@@ -44,7 +44,7 @@ func init() {
Table("polls").
Column("expires_at_new").
Set("? = ?", bun.Ident("expires_at_new"), bun.Ident("expires_at")).
- Where("1"). // bun gets angry performing update over all rows
+ Where("TRUE"). // bun gets angry performing update over all rows
Exec(ctx); err != nil {
return err
}
diff --git a/internal/db/bundb/poll.go b/internal/db/bundb/poll.go
index ab6edb4b9..830fb88ec 100644
--- a/internal/db/bundb/poll.go
+++ b/internal/db/bundb/poll.go
@@ -341,9 +341,12 @@ func (p *pollDB) PutPollVote(ctx context.Context, vote *gtsmodel.PollVote) error
var poll gtsmodel.Poll
- // Select poll counts from DB.
+ // Select current poll counts from DB,
+ // taking minimal columns needed to
+ // increment/decrement votes.
if err := tx.NewSelect().
Model(&poll).
+ Column("options", "votes", "voters").
Where("? = ?", bun.Ident("id"), vote.PollID).
Scan(ctx); err != nil {
return err
@@ -365,31 +368,35 @@ func (p *pollDB) PutPollVote(ctx context.Context, vote *gtsmodel.PollVote) error
func (p *pollDB) DeletePollVotes(ctx context.Context, pollID string) error {
err := p.db.RunInTx(ctx, func(tx Tx) error {
- // Delete all vote in poll,
- // returning all vote choices.
- switch _, err := tx.NewDelete().
+ // Delete all votes in poll.
+ res, err := tx.NewDelete().
Table("poll_votes").
Where("? = ?", bun.Ident("poll_id"), pollID).
- Exec(ctx); {
+ Exec(ctx)
+ if err != nil {
+ // irrecoverable
+ return err
+ }
- case err == nil:
- // no issue.
+ ra, err := res.RowsAffected()
+ if err != nil {
+ // irrecoverable
+ return err
+ }
- case errors.Is(err, db.ErrNoEntries):
- // no votes found,
- // return here.
+ if ra == 0 {
+ // No poll votes deleted,
+ // nothing to update.
return nil
-
- default:
- // irrecoverable.
- return err
}
+ // Select current poll counts from DB,
+ // taking minimal columns needed to
+ // increment/decrement votes.
var poll gtsmodel.Poll
-
- // Select poll counts from DB.
switch err := tx.NewSelect().
Model(&poll).
+ Column("options", "votes", "voters").
Where("? = ?", bun.Ident("id"), pollID).
Scan(ctx); {
@@ -410,7 +417,7 @@ func (p *pollDB) DeletePollVotes(ctx context.Context, pollID string) error {
poll.ResetVotes()
// Finally, update the poll entry.
- _, err := tx.NewUpdate().
+ _, err = tx.NewUpdate().
Model(&poll).
Column("votes", "voters").
Where("? = ?", bun.Ident("id"), pollID).
@@ -432,35 +439,37 @@ func (p *pollDB) DeletePollVotes(ctx context.Context, pollID string) error {
func (p *pollDB) DeletePollVoteBy(ctx context.Context, pollID string, accountID string) error {
err := p.db.RunInTx(ctx, func(tx Tx) error {
- var choices []int
+ // Slice should only ever be of length
+ // 0 or 1; it's a slice of slices only
+ // because we can't LIMIT deletes to 1.
+ var choicesSl [][]int
// Delete vote in poll by account,
// returning the ID + choices of the vote.
- switch err := tx.NewDelete().
+ if err := tx.NewDelete().
Table("poll_votes").
Where("? = ?", bun.Ident("poll_id"), pollID).
Where("? = ?", bun.Ident("account_id"), accountID).
- Returning("choices").
- Scan(ctx, &choices); {
-
- case err == nil:
- // no issue.
-
- case errors.Is(err, db.ErrNoEntries):
- // no votes found,
- // return here.
- return nil
-
- default:
+ Returning("?", bun.Ident("choices")).
+ Scan(ctx, &choicesSl); err != nil {
// irrecoverable.
return err
}
- var poll gtsmodel.Poll
+ if len(choicesSl) != 1 {
+ // No poll votes by this
+ // acct on this poll.
+ return nil
+ }
+ choices := choicesSl[0]
- // Select poll counts from DB.
+ // Select current poll counts from DB,
+ // taking minimal columns needed to
+ // increment/decrement votes.
+ var poll gtsmodel.Poll
switch err := tx.NewSelect().
Model(&poll).
+ Column("options", "votes", "voters").
Where("? = ?", bun.Ident("id"), pollID).
Scan(ctx); {
@@ -468,7 +477,7 @@ func (p *pollDB) DeletePollVoteBy(ctx context.Context, pollID string, accountID
// no issue.
case errors.Is(err, db.ErrNoEntries):
- // no votes found,
+ // no poll found,
// return here.
return nil
diff --git a/internal/db/bundb/poll_test.go b/internal/db/bundb/poll_test.go
index 53da2514b..479557c55 100644
--- a/internal/db/bundb/poll_test.go
+++ b/internal/db/bundb/poll_test.go
@@ -26,6 +26,7 @@ import (
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/superseriousbusiness/gotosocial/internal/util"
@@ -304,15 +305,64 @@ func (suite *PollTestSuite) TestDeletePollVotes() {
suite.NoError(err)
// Fetch latest version of poll from database.
- poll, err = suite.db.GetPollByID(ctx, poll.ID)
+ poll, err = suite.db.GetPollByID(
+ gtscontext.SetBarebones(ctx),
+ poll.ID,
+ )
suite.NoError(err)
// Check that poll counts are all zero.
suite.Equal(*poll.Voters, 0)
- suite.Equal(poll.Votes, make([]int, len(poll.Options)))
+ suite.Equal(make([]int, len(poll.Options)), poll.Votes)
}
}
+func (suite *PollTestSuite) TestDeletePollVotesNoPoll() {
+ // Create a new context for this test.
+ ctx, cncl := context.WithCancel(context.Background())
+ defer cncl()
+
+ // Try to delete votes of nonexistent poll.
+ nonPollID := "01HF6V4XWTSZWJ80JNPPDTD4DB"
+
+ err := suite.db.DeletePollVotes(ctx, nonPollID)
+ suite.NoError(err)
+}
+
+func (suite *PollTestSuite) TestDeletePollVotesBy() {
+ ctx, cncl := context.WithCancel(context.Background())
+ defer cncl()
+
+ for _, vote := range suite.testPollVotes {
+ // Fetch before version of pollBefore from database.
+ pollBefore, err := suite.db.GetPollByID(ctx, vote.PollID)
+ suite.NoError(err)
+
+ // Delete this poll vote.
+ err = suite.db.DeletePollVoteBy(ctx, vote.PollID, vote.AccountID)
+ suite.NoError(err)
+
+ // Fetch after version of poll from database.
+ pollAfter, err := suite.db.GetPollByID(ctx, vote.PollID)
+ suite.NoError(err)
+
+ // Voters count should be reduced by 1.
+ suite.Equal(*pollBefore.Voters-1, *pollAfter.Voters)
+ }
+}
+
+func (suite *PollTestSuite) TestDeletePollVotesByNoAccount() {
+ ctx, cncl := context.WithCancel(context.Background())
+ defer cncl()
+
+ // Try to delete a poll by nonexisting account.
+ pollID := suite.testPolls["local_account_1_status_6_poll"].ID
+ nonAccountID := "01HF6T545G1G8ZNMY1S3ZXJ608"
+
+ err := suite.db.DeletePollVoteBy(ctx, pollID, nonAccountID)
+ suite.NoError(err)
+}
+
func TestPollTestSuite(t *testing.T) {
suite.Run(t, new(PollTestSuite))
}