diff options
Diffstat (limited to 'internal/db')
| -rw-r--r-- | internal/db/bundb/migrations/20231110142330_small_poll_table_tweaks.go | 2 | ||||
| -rw-r--r-- | internal/db/bundb/poll.go | 77 | ||||
| -rw-r--r-- | internal/db/bundb/poll_test.go | 54 | 
3 files changed, 96 insertions, 37 deletions
| diff --git a/internal/db/bundb/migrations/20231110142330_small_poll_table_tweaks.go b/internal/db/bundb/migrations/20231110142330_small_poll_table_tweaks.go index c9f2b3d0f..dad943efa 100644 --- a/internal/db/bundb/migrations/20231110142330_small_poll_table_tweaks.go +++ b/internal/db/bundb/migrations/20231110142330_small_poll_table_tweaks.go @@ -44,7 +44,7 @@ func init() {  				Table("polls").  				Column("expires_at_new").  				Set("? = ?", bun.Ident("expires_at_new"), bun.Ident("expires_at")). -				Where("1"). // bun gets angry performing update over all rows +				Where("TRUE"). // bun gets angry performing update over all rows  				Exec(ctx); err != nil {  				return err  			} diff --git a/internal/db/bundb/poll.go b/internal/db/bundb/poll.go index ab6edb4b9..830fb88ec 100644 --- a/internal/db/bundb/poll.go +++ b/internal/db/bundb/poll.go @@ -341,9 +341,12 @@ func (p *pollDB) PutPollVote(ctx context.Context, vote *gtsmodel.PollVote) error  			var poll gtsmodel.Poll -			// Select poll counts from DB. +			// Select current poll counts from DB, +			// taking minimal columns needed to +			// increment/decrement votes.  			if err := tx.NewSelect().  				Model(&poll). +				Column("options", "votes", "voters").  				Where("? = ?", bun.Ident("id"), vote.PollID).  				Scan(ctx); err != nil {  				return err @@ -365,31 +368,35 @@ func (p *pollDB) PutPollVote(ctx context.Context, vote *gtsmodel.PollVote) error  func (p *pollDB) DeletePollVotes(ctx context.Context, pollID string) error {  	err := p.db.RunInTx(ctx, func(tx Tx) error { -		// Delete all vote in poll, -		// returning all vote choices. -		switch _, err := tx.NewDelete(). +		// Delete all votes in poll. +		res, err := tx.NewDelete().  			Table("poll_votes").  			Where("? = ?", bun.Ident("poll_id"), pollID). -			Exec(ctx); { +			Exec(ctx) +		if err != nil { +			// irrecoverable +			return err +		} -		case err == nil: -			// no issue. +		ra, err := res.RowsAffected() +		if err != nil { +			// irrecoverable +			return err +		} -		case errors.Is(err, db.ErrNoEntries): -			// no votes found, -			// return here. +		if ra == 0 { +			// No poll votes deleted, +			// nothing to update.  			return nil - -		default: -			// irrecoverable. -			return err  		} +		// Select current poll counts from DB, +		// taking minimal columns needed to +		// increment/decrement votes.  		var poll gtsmodel.Poll - -		// Select poll counts from DB.  		switch err := tx.NewSelect().  			Model(&poll). +			Column("options", "votes", "voters").  			Where("? = ?", bun.Ident("id"), pollID).  			Scan(ctx); { @@ -410,7 +417,7 @@ func (p *pollDB) DeletePollVotes(ctx context.Context, pollID string) error {  		poll.ResetVotes()  		// Finally, update the poll entry. -		_, err := tx.NewUpdate(). +		_, err = tx.NewUpdate().  			Model(&poll).  			Column("votes", "voters").  			Where("? = ?", bun.Ident("id"), pollID). @@ -432,35 +439,37 @@ func (p *pollDB) DeletePollVotes(ctx context.Context, pollID string) error {  func (p *pollDB) DeletePollVoteBy(ctx context.Context, pollID string, accountID string) error {  	err := p.db.RunInTx(ctx, func(tx Tx) error { -		var choices []int +		// Slice should only ever be of length +		// 0 or 1; it's a slice of slices only +		// because we can't LIMIT deletes to 1. +		var choicesSl [][]int  		// Delete vote in poll by account,  		// returning the ID + choices of the vote. -		switch err := tx.NewDelete(). +		if err := tx.NewDelete().  			Table("poll_votes").  			Where("? = ?", bun.Ident("poll_id"), pollID).  			Where("? = ?", bun.Ident("account_id"), accountID). -			Returning("choices"). -			Scan(ctx, &choices); { - -		case err == nil: -			// no issue. - -		case errors.Is(err, db.ErrNoEntries): -			// no votes found, -			// return here. -			return nil - -		default: +			Returning("?", bun.Ident("choices")). +			Scan(ctx, &choicesSl); err != nil {  			// irrecoverable.  			return err  		} -		var poll gtsmodel.Poll +		if len(choicesSl) != 1 { +			// No poll votes by this +			// acct on this poll. +			return nil +		} +		choices := choicesSl[0] -		// Select poll counts from DB. +		// Select current poll counts from DB, +		// taking minimal columns needed to +		// increment/decrement votes. +		var poll gtsmodel.Poll  		switch err := tx.NewSelect().  			Model(&poll). +			Column("options", "votes", "voters").  			Where("? = ?", bun.Ident("id"), pollID).  			Scan(ctx); { @@ -468,7 +477,7 @@ func (p *pollDB) DeletePollVoteBy(ctx context.Context, pollID string, accountID  			// no issue.  		case errors.Is(err, db.ErrNoEntries): -			// no votes found, +			// no poll found,  			// return here.  			return nil diff --git a/internal/db/bundb/poll_test.go b/internal/db/bundb/poll_test.go index 53da2514b..479557c55 100644 --- a/internal/db/bundb/poll_test.go +++ b/internal/db/bundb/poll_test.go @@ -26,6 +26,7 @@ import (  	"github.com/stretchr/testify/suite"  	"github.com/superseriousbusiness/gotosocial/internal/db" +	"github.com/superseriousbusiness/gotosocial/internal/gtscontext"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"  	"github.com/superseriousbusiness/gotosocial/internal/id"  	"github.com/superseriousbusiness/gotosocial/internal/util" @@ -304,15 +305,64 @@ func (suite *PollTestSuite) TestDeletePollVotes() {  		suite.NoError(err)  		// Fetch latest version of poll from database. -		poll, err = suite.db.GetPollByID(ctx, poll.ID) +		poll, err = suite.db.GetPollByID( +			gtscontext.SetBarebones(ctx), +			poll.ID, +		)  		suite.NoError(err)  		// Check that poll counts are all zero.  		suite.Equal(*poll.Voters, 0) -		suite.Equal(poll.Votes, make([]int, len(poll.Options))) +		suite.Equal(make([]int, len(poll.Options)), poll.Votes)  	}  } +func (suite *PollTestSuite) TestDeletePollVotesNoPoll() { +	// Create a new context for this test. +	ctx, cncl := context.WithCancel(context.Background()) +	defer cncl() + +	// Try to delete votes of nonexistent poll. +	nonPollID := "01HF6V4XWTSZWJ80JNPPDTD4DB" + +	err := suite.db.DeletePollVotes(ctx, nonPollID) +	suite.NoError(err) +} + +func (suite *PollTestSuite) TestDeletePollVotesBy() { +	ctx, cncl := context.WithCancel(context.Background()) +	defer cncl() + +	for _, vote := range suite.testPollVotes { +		// Fetch before version of pollBefore from database. +		pollBefore, err := suite.db.GetPollByID(ctx, vote.PollID) +		suite.NoError(err) + +		// Delete this poll vote. +		err = suite.db.DeletePollVoteBy(ctx, vote.PollID, vote.AccountID) +		suite.NoError(err) + +		// Fetch after version of poll from database. +		pollAfter, err := suite.db.GetPollByID(ctx, vote.PollID) +		suite.NoError(err) + +		// Voters count should be reduced by 1. +		suite.Equal(*pollBefore.Voters-1, *pollAfter.Voters) +	} +} + +func (suite *PollTestSuite) TestDeletePollVotesByNoAccount() { +	ctx, cncl := context.WithCancel(context.Background()) +	defer cncl() + +	// Try to delete a poll by nonexisting account. +	pollID := suite.testPolls["local_account_1_status_6_poll"].ID +	nonAccountID := "01HF6T545G1G8ZNMY1S3ZXJ608" + +	err := suite.db.DeletePollVoteBy(ctx, pollID, nonAccountID) +	suite.NoError(err) +} +  func TestPollTestSuite(t *testing.T) {  	suite.Run(t, new(PollTestSuite))  } | 
