diff options
Diffstat (limited to 'internal/db')
| -rw-r--r-- | internal/db/bundb/account_test.go | 8 | ||||
| -rw-r--r-- | internal/db/bundb/basic_test.go | 2 | ||||
| -rw-r--r-- | internal/db/bundb/bundb.go | 5 | ||||
| -rw-r--r-- | internal/db/bundb/bundb_test.go | 4 | ||||
| -rw-r--r-- | internal/db/bundb/instance_test.go | 4 | ||||
| -rw-r--r-- | internal/db/bundb/mention.go | 69 | ||||
| -rw-r--r-- | internal/db/bundb/migrations/20231002153327_add_status_polls.go | 65 | ||||
| -rw-r--r-- | internal/db/bundb/poll.go | 536 | ||||
| -rw-r--r-- | internal/db/bundb/poll_test.go | 318 | ||||
| -rw-r--r-- | internal/db/bundb/relationship.go | 21 | ||||
| -rw-r--r-- | internal/db/bundb/status.go | 22 | ||||
| -rw-r--r-- | internal/db/bundb/timeline_test.go | 31 | ||||
| -rw-r--r-- | internal/db/db.go | 1 | ||||
| -rw-r--r-- | internal/db/mention.go | 3 | ||||
| -rw-r--r-- | internal/db/poll.go | 71 | 
15 files changed, 1095 insertions, 65 deletions
| diff --git a/internal/db/bundb/account_test.go b/internal/db/bundb/account_test.go index b410bb3ed..8c2de5519 100644 --- a/internal/db/bundb/account_test.go +++ b/internal/db/bundb/account_test.go @@ -42,7 +42,7 @@ type AccountTestSuite struct {  func (suite *AccountTestSuite) TestGetAccountStatuses() {  	statuses, err := suite.db.GetAccountStatuses(context.Background(), suite.testAccounts["local_account_1"].ID, 20, false, false, "", "", false, false)  	suite.NoError(err) -	suite.Len(statuses, 5) +	suite.Len(statuses, 6)  }  func (suite *AccountTestSuite) TestGetAccountStatusesPageDown() { @@ -65,7 +65,7 @@ func (suite *AccountTestSuite) TestGetAccountStatusesPageDown() {  	if err != nil {  		suite.FailNow(err.Error())  	} -	suite.Len(statuses, 1) +	suite.Len(statuses, 2)  	// try to get the last page (should be empty)  	statuses, err = suite.db.GetAccountStatuses(context.Background(), suite.testAccounts["local_account_1"].ID, 2, false, false, statuses[len(statuses)-1].ID, "", false, false) @@ -76,7 +76,7 @@ func (suite *AccountTestSuite) TestGetAccountStatusesPageDown() {  func (suite *AccountTestSuite) TestGetAccountStatusesExcludeRepliesAndReblogs() {  	statuses, err := suite.db.GetAccountStatuses(context.Background(), suite.testAccounts["local_account_1"].ID, 20, true, true, "", "", false, false)  	suite.NoError(err) -	suite.Len(statuses, 5) +	suite.Len(statuses, 6)  }  func (suite *AccountTestSuite) TestGetAccountStatusesExcludeRepliesAndReblogsPublicOnly() { @@ -306,7 +306,7 @@ func (suite *AccountTestSuite) TestUpdateAccount() {  func (suite *AccountTestSuite) TestGetAccountLastPosted() {  	lastPosted, err := suite.db.GetAccountLastPosted(context.Background(), suite.testAccounts["local_account_1"].ID, false)  	suite.NoError(err) -	suite.EqualValues(1653046675, lastPosted.Unix()) +	suite.EqualValues(1653046870, lastPosted.Unix())  }  func (suite *AccountTestSuite) TestGetAccountLastPostedWebOnly() { diff --git a/internal/db/bundb/basic_test.go b/internal/db/bundb/basic_test.go index a24deac9e..cef0617b7 100644 --- a/internal/db/bundb/basic_test.go +++ b/internal/db/bundb/basic_test.go @@ -121,7 +121,7 @@ func (suite *BasicTestSuite) TestGetAllStatuses() {  	s := []*gtsmodel.Status{}  	err := suite.db.GetAll(context.Background(), &s)  	suite.NoError(err) -	suite.Len(s, 17) +	suite.Len(s, 20)  }  func (suite *BasicTestSuite) TestGetAllNotNull() { diff --git a/internal/db/bundb/bundb.go b/internal/db/bundb/bundb.go index 393f32eec..a86a20274 100644 --- a/internal/db/bundb/bundb.go +++ b/internal/db/bundb/bundb.go @@ -71,6 +71,7 @@ type DBService struct {  	db.Media  	db.Mention  	db.Notification +	db.Poll  	db.Relationship  	db.Report  	db.Rule @@ -203,6 +204,10 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) {  			db:    db,  			state: state,  		}, +		Poll: &pollDB{ +			db:    db, +			state: state, +		},  		Relationship: &relationshipDB{  			db:    db,  			state: state, diff --git a/internal/db/bundb/bundb_test.go b/internal/db/bundb/bundb_test.go index 8245937b9..037727090 100644 --- a/internal/db/bundb/bundb_test.go +++ b/internal/db/bundb/bundb_test.go @@ -54,6 +54,8 @@ type BunDBStandardTestSuite struct {  	testMarkers      map[string]*gtsmodel.Marker  	testRules        map[string]*gtsmodel.Rule  	testThreads      map[string]*gtsmodel.Thread +	testPolls        map[string]*gtsmodel.Poll +	testPollVotes    map[string]*gtsmodel.PollVote  }  func (suite *BunDBStandardTestSuite) SetupSuite() { @@ -77,6 +79,8 @@ func (suite *BunDBStandardTestSuite) SetupSuite() {  	suite.testMarkers = testrig.NewTestMarkers()  	suite.testRules = testrig.NewTestRules()  	suite.testThreads = testrig.NewTestThreads() +	suite.testPolls = testrig.NewTestPolls() +	suite.testPollVotes = testrig.NewTestPollVotes()  }  func (suite *BunDBStandardTestSuite) SetupTest() { diff --git a/internal/db/bundb/instance_test.go b/internal/db/bundb/instance_test.go index a825a3341..d88825a33 100644 --- a/internal/db/bundb/instance_test.go +++ b/internal/db/bundb/instance_test.go @@ -47,13 +47,13 @@ func (suite *InstanceTestSuite) TestCountInstanceUsersRemote() {  func (suite *InstanceTestSuite) TestCountInstanceStatuses() {  	count, err := suite.db.CountInstanceStatuses(context.Background(), config.GetHost())  	suite.NoError(err) -	suite.Equal(16, count) +	suite.Equal(18, count)  }  func (suite *InstanceTestSuite) TestCountInstanceStatusesRemote() {  	count, err := suite.db.CountInstanceStatuses(context.Background(), "fossbros-anonymous.io")  	suite.NoError(err) -	suite.Equal(1, count) +	suite.Equal(2, count)  }  func (suite *InstanceTestSuite) TestCountInstanceDomains() { diff --git a/internal/db/bundb/mention.go b/internal/db/bundb/mention.go index 547d8d0a8..30a20b0c1 100644 --- a/internal/db/bundb/mention.go +++ b/internal/db/bundb/mention.go @@ -20,10 +20,10 @@ package bundb  import (  	"context"  	"errors" -	"fmt"  	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/gtscontext" +	"github.com/superseriousbusiness/gotosocial/internal/gtserror"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"  	"github.com/superseriousbusiness/gotosocial/internal/log"  	"github.com/superseriousbusiness/gotosocial/internal/state" @@ -54,31 +54,9 @@ func (m *mentionDB) GetMention(ctx context.Context, id string) (*gtsmodel.Mentio  		return nil, err  	} -	// Set the mention originating status. -	mention.Status, err = m.state.DB.GetStatusByID( -		gtscontext.SetBarebones(ctx), -		mention.StatusID, -	) -	if err != nil { -		return nil, fmt.Errorf("error populating mention status: %w", err) -	} - -	// Set the mention origin account model. -	mention.OriginAccount, err = m.state.DB.GetAccountByID( -		gtscontext.SetBarebones(ctx), -		mention.OriginAccountID, -	) -	if err != nil { -		return nil, fmt.Errorf("error populating mention origin account: %w", err) -	} - -	// Set the mention target account model. -	mention.TargetAccount, err = m.state.DB.GetAccountByID( -		gtscontext.SetBarebones(ctx), -		mention.TargetAccountID, -	) -	if err != nil { -		return nil, fmt.Errorf("error populating mention target account: %w", err) +	// Further populate the mention fields where applicable. +	if err := m.PopulateMention(ctx, mention); err != nil { +		return nil, err  	}  	return mention, nil @@ -102,6 +80,45 @@ func (m *mentionDB) GetMentions(ctx context.Context, ids []string) ([]*gtsmodel.  	return mentions, nil  } +func (m *mentionDB) PopulateMention(ctx context.Context, mention *gtsmodel.Mention) (err error) { +	var errs gtserror.MultiError + +	if mention.Status == nil { +		// Set the mention originating status. +		mention.Status, err = m.state.DB.GetStatusByID( +			gtscontext.SetBarebones(ctx), +			mention.StatusID, +		) +		if err != nil { +			return gtserror.Newf("error populating mention status: %w", err) +		} +	} + +	if mention.OriginAccount == nil { +		// Set the mention origin account model. +		mention.OriginAccount, err = m.state.DB.GetAccountByID( +			gtscontext.SetBarebones(ctx), +			mention.OriginAccountID, +		) +		if err != nil { +			return gtserror.Newf("error populating mention origin account: %w", err) +		} +	} + +	if mention.TargetAccount == nil { +		// Set the mention target account model. +		mention.TargetAccount, err = m.state.DB.GetAccountByID( +			gtscontext.SetBarebones(ctx), +			mention.TargetAccountID, +		) +		if err != nil { +			return gtserror.Newf("error populating mention target account: %w", err) +		} +	} + +	return errs.Combine() +} +  func (m *mentionDB) PutMention(ctx context.Context, mention *gtsmodel.Mention) error {  	return m.state.Caches.GTS.Mention().Store(mention, func() error {  		_, err := m.db.NewInsert().Model(mention).Exec(ctx) diff --git a/internal/db/bundb/migrations/20231002153327_add_status_polls.go b/internal/db/bundb/migrations/20231002153327_add_status_polls.go new file mode 100644 index 000000000..5e525cc27 --- /dev/null +++ b/internal/db/bundb/migrations/20231002153327_add_status_polls.go @@ -0,0 +1,65 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program.  If not, see <http://www.gnu.org/licenses/>. + +package migrations + +import ( +	"context" +	"strings" + +	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +	"github.com/uptrace/bun" +) + +func init() { +	up := func(ctx context.Context, db *bun.DB) error { +		// Create `polls` + `poll_votes` tables. +		for _, model := range []any{ +			>smodel.Poll{}, +			>smodel.PollVote{}, +		} { +			_, err := db.NewCreateTable(). +				IfNotExists(). +				Model(model). +				Exec(ctx) +			if err != nil && !(strings.Contains(err.Error(), "already exists") || strings.Contains(err.Error(), "duplicate column name") || strings.Contains(err.Error(), "SQLSTATE 42701")) { +				return err +			} +		} + +		// Add the new status `poll_id` column. +		_, err := db.NewAddColumn(). +			Model(>smodel.Status{}). +			ColumnExpr("? CHAR(26)", bun.Ident("poll_id")). +			Exec(ctx) +		if err != nil && !(strings.Contains(err.Error(), "already exists") || strings.Contains(err.Error(), "duplicate column name") || strings.Contains(err.Error(), "SQLSTATE 42701")) { +			return err +		} + +		return nil +	} + +	down := func(ctx context.Context, db *bun.DB) error { +		return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { +			return nil +		}) +	} + +	if err := Migrations.Register(up, down); err != nil { +		panic(err) +	} +} diff --git a/internal/db/bundb/poll.go b/internal/db/bundb/poll.go new file mode 100644 index 000000000..84f160987 --- /dev/null +++ b/internal/db/bundb/poll.go @@ -0,0 +1,536 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program.  If not, see <http://www.gnu.org/licenses/>. + +package bundb + +import ( +	"context" +	"errors" +	"time" + +	"github.com/superseriousbusiness/gotosocial/internal/db" +	"github.com/superseriousbusiness/gotosocial/internal/gtscontext" +	"github.com/superseriousbusiness/gotosocial/internal/gtserror" +	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +	"github.com/superseriousbusiness/gotosocial/internal/log" +	"github.com/superseriousbusiness/gotosocial/internal/state" +	"github.com/uptrace/bun" +) + +type pollDB struct { +	db    *DB +	state *state.State +} + +func (p *pollDB) GetPollByID(ctx context.Context, id string) (*gtsmodel.Poll, error) { +	return p.getPoll( +		ctx, +		"ID", +		func(poll *gtsmodel.Poll) error { +			return p.db.NewSelect(). +				Model(poll). +				Where("? = ?", bun.Ident("poll.id"), id). +				Scan(ctx) +		}, +		id, +	) +} + +func (p *pollDB) GetPollByStatusID(ctx context.Context, statusID string) (*gtsmodel.Poll, error) { +	return p.getPoll( +		ctx, +		"StatusID", +		func(poll *gtsmodel.Poll) error { +			return p.db.NewSelect(). +				Model(poll). +				Where("? = ?", bun.Ident("poll.status_id"), statusID). +				Scan(ctx) +		}, +		statusID, +	) +} + +func (p *pollDB) getPoll(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Poll) error, keyParts ...any) (*gtsmodel.Poll, error) { +	// Fetch poll from database cache with loader callback +	poll, err := p.state.Caches.GTS.Poll().Load(lookup, func() (*gtsmodel.Poll, error) { +		var poll gtsmodel.Poll + +		// Not cached! Perform database query. +		if err := dbQuery(&poll); err != nil { +			return nil, err +		} + +		// Ensure vote slice +		// is non nil and set. +		poll.CheckVotes() + +		return &poll, nil +	}, keyParts...) +	if err != nil { +		return nil, err +	} + +	if gtscontext.Barebones(ctx) { +		// no need to fully populate. +		return poll, nil +	} + +	// Further populate the poll fields where applicable. +	if err := p.PopulatePoll(ctx, poll); err != nil { +		return nil, err +	} + +	return poll, nil +} + +func (p *pollDB) GetOpenPolls(ctx context.Context) ([]*gtsmodel.Poll, error) { +	var pollIDs []string + +	// Select all polls with unset `closed_at` time. +	if err := p.db.NewSelect(). +		Table("polls"). +		Column("polls.id"). +		Join("JOIN ? ON ? = ?", bun.Ident("statuses"), bun.Ident("polls.id"), bun.Ident("statuses.poll_id")). +		Where("? = true", bun.Ident("statuses.local")). +		Where("? IS NULL", bun.Ident("polls.closed_at")). +		Scan(ctx, &pollIDs); err != nil { +		return nil, err +	} + +	// Preallocate a slice to contain the poll models. +	polls := make([]*gtsmodel.Poll, 0, len(pollIDs)) + +	for _, id := range pollIDs { +		// Attempt to fetch poll from DB. +		poll, err := p.GetPollByID(ctx, id) +		if err != nil { +			log.Errorf(ctx, "error getting poll %s: %v", id, err) +			continue +		} + +		// Append poll to return slice. +		polls = append(polls, poll) +	} + +	return polls, nil +} + +func (p *pollDB) PopulatePoll(ctx context.Context, poll *gtsmodel.Poll) error { +	var ( +		err  error +		errs gtserror.MultiError +	) + +	if poll.Status == nil { +		// Vote account is not set, fetch from database. +		poll.Status, err = p.state.DB.GetStatusByID( +			gtscontext.SetBarebones(ctx), +			poll.StatusID, +		) +		if err != nil { +			errs.Appendf("error populating poll status: %w", err) +		} +	} + +	return errs.Combine() +} + +func (p *pollDB) PutPoll(ctx context.Context, poll *gtsmodel.Poll) error { +	// Ensure vote slice +	// is non nil and set. +	poll.CheckVotes() + +	return p.state.Caches.GTS.Poll().Store(poll, func() error { +		_, err := p.db.NewInsert().Model(poll).Exec(ctx) +		return err +	}) +} + +func (p *pollDB) UpdatePoll(ctx context.Context, poll *gtsmodel.Poll, cols ...string) error { +	// Ensure vote slice +	// is non nil and set. +	poll.CheckVotes() + +	return p.state.Caches.GTS.Poll().Store(poll, func() error { +		return p.db.RunInTx(ctx, func(tx Tx) error { +			// Update the status' "updated_at" field. +			if _, err := tx.NewUpdate(). +				Table("statuses"). +				Where("? = ?", bun.Ident("id"), poll.StatusID). +				SetColumn("updated_at", "?", time.Now()). +				Exec(ctx); err != nil { +				return err +			} + +			// Finally, update poll +			// columns in database. +			_, err := tx.NewUpdate(). +				Model(poll). +				Column(cols...). +				Where("? = ?", bun.Ident("id"), poll.ID). +				Exec(ctx) +			return err +		}) +	}) +} + +func (p *pollDB) DeletePollByID(ctx context.Context, id string) error { +	// Delete poll by ID from database. +	if _, err := p.db.NewDelete(). +		Table("polls"). +		Where("? = ?", bun.Ident("id"), id). +		Exec(ctx); err != nil { +		return err +	} + +	// Invalidate poll by ID from cache. +	p.state.Caches.GTS.Poll().Invalidate("ID", id) +	p.state.Caches.GTS.PollVoteIDs().Invalidate(id) + +	return nil +} + +func (p *pollDB) GetPollVoteByID(ctx context.Context, id string) (*gtsmodel.PollVote, error) { +	return p.getPollVote( +		ctx, +		"ID", +		func(vote *gtsmodel.PollVote) error { +			return p.db.NewSelect(). +				Model(vote). +				Where("? = ?", bun.Ident("poll_vote.id"), id). +				Scan(ctx) +		}, +		id, +	) +} + +func (p *pollDB) GetPollVoteBy(ctx context.Context, pollID string, accountID string) (*gtsmodel.PollVote, error) { +	return p.getPollVote( +		ctx, +		"PollID.AccountID", +		func(vote *gtsmodel.PollVote) error { +			return p.db.NewSelect(). +				Model(vote). +				Where("? = ?", bun.Ident("poll_vote.account_id"), accountID). +				Where("? = ?", bun.Ident("poll_vote.poll_id"), pollID). +				Scan(ctx) +		}, +		pollID, +		accountID, +	) +} + +func (p *pollDB) getPollVote(ctx context.Context, lookup string, dbQuery func(*gtsmodel.PollVote) error, keyParts ...any) (*gtsmodel.PollVote, error) { +	// Fetch vote from database cache with loader callback +	vote, err := p.state.Caches.GTS.PollVote().Load(lookup, func() (*gtsmodel.PollVote, error) { +		var vote gtsmodel.PollVote + +		// Not cached! Perform database query. +		if err := dbQuery(&vote); err != nil { +			return nil, err +		} + +		return &vote, nil +	}, keyParts...) +	if err != nil { +		return nil, err +	} + +	if gtscontext.Barebones(ctx) { +		// no need to fully populate. +		return vote, nil +	} + +	// Further populate the vote fields where applicable. +	if err := p.PopulatePollVote(ctx, vote); err != nil { +		return nil, err +	} + +	return vote, nil +} + +func (p *pollDB) GetPollVotes(ctx context.Context, pollID string) ([]*gtsmodel.PollVote, error) { +	voteIDs, err := p.state.Caches.GTS.PollVoteIDs().Load(pollID, func() ([]string, error) { +		var voteIDs []string + +		// Vote IDs not in cache, perform DB query! +		q := newSelectPollVotes(p.db, pollID) +		if _, err := q.Exec(ctx, &voteIDs); // nocollapse +		err != nil && !errors.Is(err, db.ErrNoEntries) { +			return nil, err +		} + +		return voteIDs, nil +	}) +	if err != nil { +		return nil, err +	} + +	// Preallocate slice of expected length. +	votes := make([]*gtsmodel.PollVote, 0, len(voteIDs)) + +	for _, id := range voteIDs { +		// Fetch poll vote model for this ID. +		vote, err := p.GetPollVoteByID(ctx, id) +		if err != nil { +			log.Errorf(ctx, "error getting poll vote %s: %v", id, err) +			continue +		} + +		// Append to return slice. +		votes = append(votes, vote) +	} + +	return votes, nil +} + +func (p *pollDB) PopulatePollVote(ctx context.Context, vote *gtsmodel.PollVote) error { +	var ( +		err  error +		errs gtserror.MultiError +	) + +	if vote.Account == nil { +		// Vote account is not set, fetch from database. +		vote.Account, err = p.state.DB.GetAccountByID( +			gtscontext.SetBarebones(ctx), +			vote.AccountID, +		) +		if err != nil { +			errs.Appendf("error populating vote account: %w", err) +		} +	} + +	if vote.Poll == nil { +		// Vote poll is not set, fetch from database. +		vote.Poll, err = p.GetPollByID( +			gtscontext.SetBarebones(ctx), +			vote.PollID, +		) +		if err != nil { +			errs.Appendf("error populating vote poll: %w", err) +		} +	} + +	return errs.Combine() +} + +func (p *pollDB) PutPollVote(ctx context.Context, vote *gtsmodel.PollVote) error { +	return p.state.Caches.GTS.PollVote().Store(vote, func() error { +		return p.db.RunInTx(ctx, func(tx Tx) error { +			// Try insert vote into database. +			if _, err := tx.NewInsert(). +				Model(vote). +				Exec(ctx); err != nil { +				return err +			} + +			var poll gtsmodel.Poll + +			// Select poll counts from DB. +			if err := tx.NewSelect(). +				Model(&poll). +				Where("? = ?", bun.Ident("id"), vote.PollID). +				Scan(ctx); err != nil { +				return err +			} + +			// Increment poll votes for choices. +			poll.IncrementVotes(vote.Choices) + +			// Finally, update the poll entry. +			_, err := tx.NewUpdate(). +				Model(&poll). +				Column("votes", "voters"). +				Where("? = ?", bun.Ident("id"), vote.PollID). +				Exec(ctx) +			return err +		}) +	}) +} + +func (p *pollDB) DeletePollVotes(ctx context.Context, pollID string) error { +	err := p.db.RunInTx(ctx, func(tx Tx) error { +		// Delete all vote in poll, +		// returning all vote choices. +		switch _, err := tx.NewDelete(). +			Table("poll_votes"). +			Where("? = ?", bun.Ident("poll_id"), pollID). +			Exec(ctx); { + +		case err == nil: +			// no issue. + +		case errors.Is(err, db.ErrNoEntries): +			// no votes found, +			// return here. +			return nil + +		default: +			// irrecoverable. +			return err +		} + +		var poll gtsmodel.Poll + +		// Select poll counts from DB. +		switch err := tx.NewSelect(). +			Model(&poll). +			Where("? = ?", bun.Ident("id"), pollID). +			Scan(ctx); { + +		case err == nil: +			// no issue. + +		case errors.Is(err, db.ErrNoEntries): +			// no votes found, +			// return here. +			return nil + +		default: +			// irrecoverable. +			return err +		} + +		// Zero all counts. +		poll.ResetVotes() + +		// Finally, update the poll entry. +		_, err := tx.NewUpdate(). +			Model(&poll). +			Column("votes", "voters"). +			Where("? = ?", bun.Ident("id"), pollID). +			Exec(ctx) +		return err +	}) + +	if err != nil { +		return err +	} + +	// Invalidate poll vote and poll entry from caches. +	p.state.Caches.GTS.Poll().Invalidate("ID", pollID) +	p.state.Caches.GTS.PollVote().Invalidate("PollID", pollID) +	p.state.Caches.GTS.PollVoteIDs().Invalidate(pollID) + +	return nil +} + +func (p *pollDB) DeletePollVoteBy(ctx context.Context, pollID string, accountID string) error { +	err := p.db.RunInTx(ctx, func(tx Tx) error { +		var choices []int + +		// Delete vote in poll by account, +		// returning the ID + choices of the vote. +		switch err := tx.NewDelete(). +			Table("poll_votes"). +			Where("? = ?", bun.Ident("poll_id"), pollID). +			Where("? = ?", bun.Ident("account_id"), accountID). +			Returning("choices"). +			Scan(ctx, &choices); { + +		case err == nil: +			// no issue. + +		case errors.Is(err, db.ErrNoEntries): +			// no votes found, +			// return here. +			return nil + +		default: +			// irrecoverable. +			return err +		} + +		var poll gtsmodel.Poll + +		// Select poll counts from DB. +		switch err := tx.NewSelect(). +			Model(&poll). +			Where("? = ?", bun.Ident("id"), pollID). +			Scan(ctx); { + +		case err == nil: +			// no issue. + +		case errors.Is(err, db.ErrNoEntries): +			// no votes found, +			// return here. +			return nil + +		default: +			// irrecoverable. +			return err +		} + +		// Decrement votes for choices. +		poll.IncrementVotes(choices) + +		// Finally, update the poll entry. +		_, err := tx.NewUpdate(). +			Model(&poll). +			Column("votes", "voters"). +			Where("? = ?", bun.Ident("id"), pollID). +			Exec(ctx) +		return err +	}) + +	if err != nil { +		return err +	} + +	// Invalidate poll vote and poll entry from caches. +	p.state.Caches.GTS.Poll().Invalidate("ID", pollID) +	p.state.Caches.GTS.PollVote().Invalidate("PollID.AccountID", pollID, accountID) +	p.state.Caches.GTS.PollVoteIDs().Invalidate(pollID) + +	return nil +} + +func (p *pollDB) DeletePollVotesByAccountID(ctx context.Context, accountID string) error { +	var pollIDs []string + +	// Select all polls this account +	// has registered a poll vote in. +	if err := p.db.NewSelect(). +		Table("poll_votes"). +		Column("poll_id"). +		Where("? = ?", bun.Ident("account_id"), accountID). +		Scan(ctx, &pollIDs); err != nil && +		!errors.Is(err, db.ErrNoEntries) { +		return err +	} + +	for _, id := range pollIDs { +		// Delete all votes by this account in each of the polls, +		// this way ensures that all necessary caches are invalidated. +		if err := p.DeletePollVoteBy(ctx, id, accountID); err != nil { +			log.Errorf(ctx, "error deleting vote by %s in %s: %v", accountID, id, err) +		} +	} + +	return nil +} + +// newSelectPollVotes returns a new select query for all rows in the poll_votes table with poll_id = pollID. +func newSelectPollVotes(db *DB, pollID string) *bun.SelectQuery { +	return db.NewSelect(). +		TableExpr("?", bun.Ident("poll_votes")). +		ColumnExpr("?", bun.Ident("id")). +		Where("? = ?", bun.Ident("poll_id"), pollID). +		OrderExpr("? DESC", bun.Ident("id")) +} diff --git a/internal/db/bundb/poll_test.go b/internal/db/bundb/poll_test.go new file mode 100644 index 000000000..53da2514b --- /dev/null +++ b/internal/db/bundb/poll_test.go @@ -0,0 +1,318 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program.  If not, see <http://www.gnu.org/licenses/>. + +package bundb_test + +import ( +	"context" +	"errors" +	"math/rand" +	"testing" +	"time" + +	"github.com/stretchr/testify/suite" +	"github.com/superseriousbusiness/gotosocial/internal/db" +	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +	"github.com/superseriousbusiness/gotosocial/internal/id" +	"github.com/superseriousbusiness/gotosocial/internal/util" +) + +type PollTestSuite struct { +	BunDBStandardTestSuite +} + +func (suite *PollTestSuite) TestGetPollBy() { +	t := suite.T() + +	// Create a new context for this test. +	ctx, cncl := context.WithCancel(context.Background()) +	defer cncl() + +	// Sentinel error to mark avoiding a test case. +	sentinelErr := errors.New("sentinel") + +	// isEqual checks if 2 poll models are equal. +	isEqual := func(p1, p2 gtsmodel.Poll) bool { +		// Clear populated sub-models. +		p1.Status = nil +		p2.Status = nil + +		// Localize all of the time fields. +		p1.ExpiresAt = p1.ExpiresAt.Local() +		p2.ExpiresAt = p2.ExpiresAt.Local() +		p1.ClosedAt = p1.ClosedAt.Local() +		p2.ClosedAt = p2.ClosedAt.Local() + +		// Perform the comparison. +		return suite.Equal(p1, p2) +	} + +	for _, poll := range suite.testPolls { +		for lookup, dbfunc := range map[string]func() (*gtsmodel.Poll, error){ +			"id": func() (*gtsmodel.Poll, error) { +				return suite.db.GetPollByID(ctx, poll.ID) +			}, + +			"status_id": func() (*gtsmodel.Poll, error) { +				return suite.db.GetPollByStatusID(ctx, poll.StatusID) +			}, +		} { + +			// Clear database caches. +			suite.state.Caches.Init() + +			t.Logf("checking database lookup %q", lookup) + +			// Perform database function. +			checkPoll, err := dbfunc() +			if err != nil { +				if err == sentinelErr { +					continue +				} + +				t.Errorf("error encountered for database lookup %q: %v", lookup, err) +				continue +			} + +			// Check received account data. +			if !isEqual(*checkPoll, *poll) { +				t.Errorf("poll does not contain expected data: %+v", checkPoll) +				continue +			} + +			// Check that poll source status populated. +			if poll.StatusID != (*checkPoll).Status.ID { +				t.Errorf("poll source status not correctly populated for: %+v", poll) +				continue +			} +		} +	} +} + +func (suite *PollTestSuite) TestGetPollVoteBy() { +	t := suite.T() + +	// Create a new context for this test. +	ctx, cncl := context.WithCancel(context.Background()) +	defer cncl() + +	// Sentinel error to mark avoiding a test case. +	sentinelErr := errors.New("sentinel") + +	// isEqual checks if 2 poll vote models are equal. +	isEqual := func(v1, v2 gtsmodel.PollVote) bool { +		// Clear populated sub-models. +		v1.Poll = nil +		v2.Poll = nil +		v1.Account = nil +		v2.Account = nil + +		// Localize all of the time fields. +		v1.CreatedAt = v1.CreatedAt.Local() +		v2.CreatedAt = v2.CreatedAt.Local() + +		// Perform the comparison. +		return suite.Equal(v1, v2) +	} + +	for _, vote := range suite.testPollVotes { +		for lookup, dbfunc := range map[string]func() (*gtsmodel.PollVote, error){ +			"id": func() (*gtsmodel.PollVote, error) { +				return suite.db.GetPollVoteByID(ctx, vote.ID) +			}, + +			"poll_id_account_id": func() (*gtsmodel.PollVote, error) { +				return suite.db.GetPollVoteBy(ctx, vote.PollID, vote.AccountID) +			}, +		} { + +			// Clear database caches. +			suite.state.Caches.Init() + +			t.Logf("checking database lookup %q", lookup) + +			// Perform database function. +			checkVote, err := dbfunc() +			if err != nil { +				if err == sentinelErr { +					continue +				} + +				t.Errorf("error encountered for database lookup %q: %v", lookup, err) +				continue +			} + +			// Check received account data. +			if !isEqual(*checkVote, *vote) { +				t.Errorf("poll vote does not contain expected data: %+v", checkVote) +				continue +			} + +			// Check that vote source poll populated. +			if checkVote.PollID != (*checkVote).Poll.ID { +				t.Errorf("vote source poll not correctly populated for: %+v", vote) +				continue +			} + +			// Check that vote author account populated. +			if checkVote.AccountID != (*checkVote).Account.ID { +				t.Errorf("vote author account not correctly populated for: %+v", vote) +				continue +			} +		} +	} +} + +func (suite *PollTestSuite) TestUpdatePoll() { +	// Create a new context for this test. +	ctx, cncl := context.WithCancel(context.Background()) +	defer cncl() + +	for _, poll := range suite.testPolls { +		// Take copy of poll. +		poll := util.Ptr(*poll) + +		// Update the poll closed field. +		poll.ClosedAt = time.Now() + +		// Update poll model in the database. +		err := suite.db.UpdatePoll(ctx, poll) +		suite.NoError(err) + +		// Refetch poll from database to get latest. +		latest, err := suite.db.GetPollByID(ctx, poll.ID) +		suite.NoError(err) + +		// The latest poll should have updated closedAt. +		suite.Equal(poll.ClosedAt, latest.ClosedAt) +	} +} + +func (suite *PollTestSuite) TestPutPoll() { +	// Create a new context for this test. +	ctx, cncl := context.WithCancel(context.Background()) +	defer cncl() + +	for _, poll := range suite.testPolls { +		// Delete this poll from the database. +		err := suite.db.DeletePollByID(ctx, poll.ID) +		suite.NoError(err) + +		// Ensure that afterwards we can +		// enter it again into database. +		err = suite.db.PutPoll(ctx, poll) + +		// Ensure that afterwards we can fetch poll. +		_, err = suite.db.GetPollByID(ctx, poll.ID) +		suite.NoError(err) +	} +} + +func (suite *PollTestSuite) TestPutPollVote() { +	// Create a new context for this test. +	ctx, cncl := context.WithCancel(context.Background()) +	defer cncl() + +	// randomChoices generates random vote choices in poll. +	randomChoices := func(poll *gtsmodel.Poll) []int { +		var max int +		if *poll.Multiple { +			max = len(poll.Options) +		} else { +			max = 1 +		} +		count := 1 + rand.Intn(max) +		choices := make([]int, count) +		for i := range choices { +			choices[i] = rand.Intn(len(poll.Options)) +		} +		return choices +	} + +	for _, poll := range suite.testPolls { +		// Create a new vote to insert for poll. +		vote := >smodel.PollVote{ +			ID:        id.NewULID(), +			Choices:   randomChoices(poll), +			PollID:    poll.ID, +			AccountID: id.NewULID(), // random account, doesn't matter +		} + +		// Insert this new vote into database. +		err := suite.db.PutPollVote(ctx, vote) +		suite.NoError(err) + +		// Fetch latest version of poll from database. +		latest, err := suite.db.GetPollByID(ctx, poll.ID) +		suite.NoError(err) + +		// Decr latest version choices by new vote's. +		for _, choice := range vote.Choices { +			latest.Votes[choice]-- +		} +		(*latest.Voters)-- + +		// Old poll and latest model after decr +		// should have equal vote + voter counts. +		suite.Equal(poll.Voters, latest.Voters) +		suite.Equal(poll.Votes, latest.Votes) +	} +} + +func (suite *PollTestSuite) TestDeletePoll() { +	// Create a new context for this test. +	ctx, cncl := context.WithCancel(context.Background()) +	defer cncl() + +	for _, poll := range suite.testPolls { +		// Delete this poll from the database. +		err := suite.db.DeletePollByID(ctx, poll.ID) +		suite.NoError(err) + +		// Ensure that afterwards we cannot fetch poll. +		_, err = suite.db.GetPollByID(ctx, poll.ID) +		suite.ErrorIs(err, db.ErrNoEntries) + +		// Or again by the status it's attached to. +		_, err = suite.db.GetPollByStatusID(ctx, poll.StatusID) +		suite.ErrorIs(err, db.ErrNoEntries) +	} +} + +func (suite *PollTestSuite) TestDeletePollVotes() { +	// Create a new context for this test. +	ctx, cncl := context.WithCancel(context.Background()) +	defer cncl() + +	for _, poll := range suite.testPolls { +		// Delete votes associated with poll from database. +		err := suite.db.DeletePollVotes(ctx, poll.ID) +		suite.NoError(err) + +		// Fetch latest version of poll from database. +		poll, err = suite.db.GetPollByID(ctx, poll.ID) +		suite.NoError(err) + +		// Check that poll counts are all zero. +		suite.Equal(*poll.Voters, 0) +		suite.Equal(poll.Votes, make([]int, len(poll.Options))) +	} +} + +func TestPollTestSuite(t *testing.T) { +	suite.Run(t, new(PollTestSuite)) +} diff --git a/internal/db/bundb/relationship.go b/internal/db/bundb/relationship.go index 822e697c1..138a5aa17 100644 --- a/internal/db/bundb/relationship.go +++ b/internal/db/bundb/relationship.go @@ -199,7 +199,8 @@ func (r *relationshipDB) getAccountFollowIDs(ctx context.Context, accountID stri  		// Follow IDs not in cache, perform DB query!  		q := newSelectFollows(r.db, accountID) -		if _, err := q.Exec(ctx, &followIDs); err != nil { +		if _, err := q.Exec(ctx, &followIDs); // nocollapse +		err != nil && !errors.Is(err, db.ErrNoEntries) {  			return nil, err  		} @@ -213,7 +214,8 @@ func (r *relationshipDB) getAccountLocalFollowIDs(ctx context.Context, accountID  		// Follow IDs not in cache, perform DB query!  		q := newSelectLocalFollows(r.db, accountID) -		if _, err := q.Exec(ctx, &followIDs); err != nil { +		if _, err := q.Exec(ctx, &followIDs); // nocollapse +		err != nil && !errors.Is(err, db.ErrNoEntries) {  			return nil, err  		} @@ -227,7 +229,8 @@ func (r *relationshipDB) getAccountFollowerIDs(ctx context.Context, accountID st  		// Follow IDs not in cache, perform DB query!  		q := newSelectFollowers(r.db, accountID) -		if _, err := q.Exec(ctx, &followIDs); err != nil { +		if _, err := q.Exec(ctx, &followIDs); // nocollapse +		err != nil && !errors.Is(err, db.ErrNoEntries) {  			return nil, err  		} @@ -241,7 +244,8 @@ func (r *relationshipDB) getAccountLocalFollowerIDs(ctx context.Context, account  		// Follow IDs not in cache, perform DB query!  		q := newSelectLocalFollowers(r.db, accountID) -		if _, err := q.Exec(ctx, &followIDs); err != nil { +		if _, err := q.Exec(ctx, &followIDs); // nocollapse +		err != nil && !errors.Is(err, db.ErrNoEntries) {  			return nil, err  		} @@ -255,7 +259,8 @@ func (r *relationshipDB) getAccountFollowRequestIDs(ctx context.Context, account  		// Follow request IDs not in cache, perform DB query!  		q := newSelectFollowRequests(r.db, accountID) -		if _, err := q.Exec(ctx, &followReqIDs); err != nil { +		if _, err := q.Exec(ctx, &followReqIDs); // nocollapse +		err != nil && !errors.Is(err, db.ErrNoEntries) {  			return nil, err  		} @@ -269,7 +274,8 @@ func (r *relationshipDB) getAccountFollowRequestingIDs(ctx context.Context, acco  		// Follow request IDs not in cache, perform DB query!  		q := newSelectFollowRequesting(r.db, accountID) -		if _, err := q.Exec(ctx, &followReqIDs); err != nil { +		if _, err := q.Exec(ctx, &followReqIDs); // nocollapse +		err != nil && !errors.Is(err, db.ErrNoEntries) {  			return nil, err  		} @@ -283,7 +289,8 @@ func (r *relationshipDB) getAccountBlockIDs(ctx context.Context, accountID strin  		// Block IDs not in cache, perform DB query!  		q := newSelectBlocks(r.db, accountID) -		if _, err := q.Exec(ctx, &blockIDs); err != nil { +		if _, err := q.Exec(ctx, &blockIDs); // nocollapse +		err != nil && !errors.Is(err, db.ErrNoEntries) {  			return nil, err  		} diff --git a/internal/db/bundb/status.go b/internal/db/bundb/status.go index 0bd4ba1a9..7f274d693 100644 --- a/internal/db/bundb/status.go +++ b/internal/db/bundb/status.go @@ -154,17 +154,6 @@ func (s *statusDB) PopulateStatus(ctx context.Context, status *gtsmodel.Status)  		}  	} -	if status.InReplyToID != "" && status.InReplyTo == nil { -		// Status parent is not set, fetch from database. -		status.InReplyTo, err = s.GetStatusByID( -			gtscontext.SetBarebones(ctx), -			status.InReplyToID, -		) -		if err != nil { -			errs.Appendf("error populating status parent: %w", err) -		} -	} -  	if status.InReplyToID != "" {  		if status.InReplyTo == nil {  			// Status parent is not set, fetch from database. @@ -213,6 +202,17 @@ func (s *statusDB) PopulateStatus(ctx context.Context, status *gtsmodel.Status)  		}  	} +	if status.PollID != "" && status.Poll == nil { +		// Status poll is not set, fetch from database. +		status.Poll, err = s.state.DB.GetPollByID( +			gtscontext.SetBarebones(ctx), +			status.PollID, +		) +		if err != nil { +			errs.Appendf("error populating status poll: %w", err) +		} +	} +  	if !status.AttachmentsPopulated() {  		// Status attachments are out-of-date with IDs, repopulate.  		status.Attachments, err = s.state.DB.GetAttachmentsByIDs( diff --git a/internal/db/bundb/timeline_test.go b/internal/db/bundb/timeline_test.go index ac169ec4a..b3ce91755 100644 --- a/internal/db/bundb/timeline_test.go +++ b/internal/db/bundb/timeline_test.go @@ -22,6 +22,7 @@ import (  	"testing"  	"time" +	"codeberg.org/gruf/go-kv"  	"github.com/stretchr/testify/suite"  	"github.com/superseriousbusiness/gotosocial/internal/ap"  	"github.com/superseriousbusiness/gotosocial/internal/gtscontext" @@ -73,20 +74,18 @@ func getFutureStatus() *gtsmodel.Status {  func (suite *TimelineTestSuite) publicCount() int {  	var publicCount int -  	for _, status := range suite.testStatuses {  		if status.Visibility == gtsmodel.VisibilityPublic &&  			status.BoostOfID == "" {  			publicCount++  		}  	} -  	return publicCount  }  func (suite *TimelineTestSuite) checkStatuses(statuses []*gtsmodel.Status, maxID string, minID string, expectedLength int) {  	if l := len(statuses); l != expectedLength { -		suite.FailNow("", "expected %d statuses in slice, got %d", expectedLength, l) +		suite.FailNowf("", "expected %d statuses in slice, got %d", expectedLength, l)  	} else if l == 0 {  		// Can't test empty slice.  		return @@ -98,15 +97,15 @@ func (suite *TimelineTestSuite) checkStatuses(statuses []*gtsmodel.Status, maxID  		id := status.ID  		if id >= maxID { -			suite.FailNow("", "%s greater than maxID %s", id, maxID) +			suite.FailNowf("", "%s greater than maxID %s", id, maxID)  		}  		if id <= minID { -			suite.FailNow("", "%s smaller than minID %s", id, minID) +			suite.FailNowf("", "%s smaller than minID %s", id, minID)  		}  		if id > highest { -			suite.FailNow("", "statuses in slice were not ordered highest -> lowest ID") +			suite.FailNowf("", "statuses in slice were not ordered highest -> lowest ID")  		}  		highest = id @@ -121,6 +120,10 @@ func (suite *TimelineTestSuite) TestGetPublicTimeline() {  		suite.FailNow(err.Error())  	} +	suite.T().Log(kv.Field{ +		K: "statuses", V: s, +	}) +  	suite.checkStatuses(s, id.Highest, id.Lowest, suite.publicCount())  } @@ -154,7 +157,7 @@ func (suite *TimelineTestSuite) TestGetHomeTimeline() {  		suite.FailNow(err.Error())  	} -	suite.checkStatuses(s, id.Highest, id.Lowest, 16) +	suite.checkStatuses(s, id.Highest, id.Lowest, 18)  }  func (suite *TimelineTestSuite) TestGetHomeTimelineNoFollowing() { @@ -186,7 +189,7 @@ func (suite *TimelineTestSuite) TestGetHomeTimelineNoFollowing() {  		suite.FailNow(err.Error())  	} -	suite.checkStatuses(s, id.Highest, id.Lowest, 5) +	suite.checkStatuses(s, id.Highest, id.Lowest, 6)  }  func (suite *TimelineTestSuite) TestGetHomeTimelineWithFutureStatus() { @@ -208,7 +211,7 @@ func (suite *TimelineTestSuite) TestGetHomeTimelineWithFutureStatus() {  	}  	suite.NotContains(s, futureStatus) -	suite.checkStatuses(s, id.Highest, id.Lowest, 16) +	suite.checkStatuses(s, id.Highest, id.Lowest, 18)  }  func (suite *TimelineTestSuite) TestGetHomeTimelineBackToFront() { @@ -239,8 +242,8 @@ func (suite *TimelineTestSuite) TestGetHomeTimelineFromHighest() {  	}  	suite.checkStatuses(s, id.Highest, id.Lowest, 5) -	suite.Equal("01G36SF3V6Y6V5BF9P4R7PQG7G", s[0].ID) -	suite.Equal("01FCTA44PW9H1TB328S9AQXKDS", s[len(s)-1].ID) +	suite.Equal("01HEN2RZ8BG29Y5Z9VJC73HZW7", s[0].ID) +	suite.Equal("01FN3VJGFH10KR7S2PB0GFJZYG", s[len(s)-1].ID)  }  func (suite *TimelineTestSuite) TestGetListTimelineNoParams() { @@ -254,7 +257,7 @@ func (suite *TimelineTestSuite) TestGetListTimelineNoParams() {  		suite.FailNow(err.Error())  	} -	suite.checkStatuses(s, id.Highest, id.Lowest, 11) +	suite.checkStatuses(s, id.Highest, id.Lowest, 12)  }  func (suite *TimelineTestSuite) TestGetListTimelineMaxID() { @@ -269,8 +272,8 @@ func (suite *TimelineTestSuite) TestGetListTimelineMaxID() {  	}  	suite.checkStatuses(s, id.Highest, id.Lowest, 5) -	suite.Equal("01G36SF3V6Y6V5BF9P4R7PQG7G", s[0].ID) -	suite.Equal("01FCQSQ667XHJ9AV9T27SJJSX5", s[len(s)-1].ID) +	suite.Equal("01HEN2PRXT0TF4YDRA64FZZRN7", s[0].ID) +	suite.Equal("01FF25D5Q0DH7CHD57CTRS6WK0", s[len(s)-1].ID)  }  func (suite *TimelineTestSuite) TestGetListTimelineMinID() { diff --git a/internal/db/db.go b/internal/db/db.go index 41b253834..2914d9b59 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -36,6 +36,7 @@ type DB interface {  	Media  	Mention  	Notification +	Poll  	Relationship  	Report  	Rule diff --git a/internal/db/mention.go b/internal/db/mention.go index d4125031e..994ec04b5 100644 --- a/internal/db/mention.go +++ b/internal/db/mention.go @@ -31,6 +31,9 @@ type Mention interface {  	// GetMentions gets multiple mentions.  	GetMentions(ctx context.Context, ids []string) ([]*gtsmodel.Mention, error) +	// PopulateMention ensures that all sub-models of a mention are populated (e.g. accounts). +	PopulateMention(ctx context.Context, mention *gtsmodel.Mention) error +  	// PutMention will insert the given mention into the database.  	PutMention(ctx context.Context, mention *gtsmodel.Mention) error diff --git a/internal/db/poll.go b/internal/db/poll.go new file mode 100644 index 000000000..b59d27c73 --- /dev/null +++ b/internal/db/poll.go @@ -0,0 +1,71 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program.  If not, see <http://www.gnu.org/licenses/>. + +package db + +import ( +	"context" + +	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) + +type Poll interface { +	// GetPollByID fetches the Poll with given ID from the database. +	GetPollByID(ctx context.Context, id string) (*gtsmodel.Poll, error) + +	// GetPollByStatusID fetches the Poll with given status ID column value from the database. +	GetPollByStatusID(ctx context.Context, statusID string) (*gtsmodel.Poll, error) + +	// GetOpenPolls fetches all local Polls in the database with an unset `closed_at` column. +	GetOpenPolls(ctx context.Context) ([]*gtsmodel.Poll, error) + +	// PopulatePoll ensures the given Poll is fully populated with all other related database models. +	PopulatePoll(ctx context.Context, poll *gtsmodel.Poll) error + +	// PutPoll puts the given Poll in the database. +	PutPoll(ctx context.Context, poll *gtsmodel.Poll) error + +	// UpdatePoll updates the Poll in the database, only on selected columns if provided (else, all). +	UpdatePoll(ctx context.Context, poll *gtsmodel.Poll, cols ...string) error + +	// DeletePollByID deletes the Poll with given ID from the database. +	DeletePollByID(ctx context.Context, id string) error + +	// GetPollVoteByID gets the PollVote with given ID from the database. +	GetPollVoteByID(ctx context.Context, id string) (*gtsmodel.PollVote, error) + +	// GetPollVotesBy fetches the PollVote in Poll with ID, by account ID, from the database. +	GetPollVoteBy(ctx context.Context, pollID string, accountID string) (*gtsmodel.PollVote, error) + +	// GetPollVotes fetches all PollVotes in Poll with ID, from the database. +	GetPollVotes(ctx context.Context, pollID string) ([]*gtsmodel.PollVote, error) + +	// PopulatePollVote ensures the given PollVote is fully populated with all other related database models. +	PopulatePollVote(ctx context.Context, votes *gtsmodel.PollVote) error + +	// PutPollVote puts the given PollVote in the database. +	PutPollVote(ctx context.Context, vote *gtsmodel.PollVote) error + +	// DeletePollVotes deletes all PollVotes in Poll with given ID from the database. +	DeletePollVotes(ctx context.Context, pollID string) error + +	// DeletePollVoteBy deletes the PollVote in Poll with ID, by account ID, from the database. +	DeletePollVoteBy(ctx context.Context, pollID string, accountID string) error + +	// DeletePollVotesByAccountID deletes all PollVotes in all Polls, by account ID, from the database. +	DeletePollVotesByAccountID(ctx context.Context, accountID string) error +} | 
