summaryrefslogtreecommitdiff
path: root/internal/db
diff options
context:
space:
mode:
authorLibravatar kim <89579420+NyaaaWhatsUpDoc@users.noreply.github.com>2023-11-08 14:32:17 +0000
committerLibravatar GitHub <noreply@github.com>2023-11-08 14:32:17 +0000
commite9e5dc5a40926e5320cb131b035c46b1e1b0bd59 (patch)
tree52edc9fa5742f28e1e5223f51cda628ec1c35a24 /internal/db
parent[chore]: Bump github.com/spf13/cobra from 1.7.0 to 1.8.0 (#2338) (diff)
downloadgotosocial-e9e5dc5a40926e5320cb131b035c46b1e1b0bd59.tar.xz
[feature] add support for polls + receiving federated status edits (#2330)
Diffstat (limited to 'internal/db')
-rw-r--r--internal/db/bundb/account_test.go8
-rw-r--r--internal/db/bundb/basic_test.go2
-rw-r--r--internal/db/bundb/bundb.go5
-rw-r--r--internal/db/bundb/bundb_test.go4
-rw-r--r--internal/db/bundb/instance_test.go4
-rw-r--r--internal/db/bundb/mention.go69
-rw-r--r--internal/db/bundb/migrations/20231002153327_add_status_polls.go65
-rw-r--r--internal/db/bundb/poll.go536
-rw-r--r--internal/db/bundb/poll_test.go318
-rw-r--r--internal/db/bundb/relationship.go21
-rw-r--r--internal/db/bundb/status.go22
-rw-r--r--internal/db/bundb/timeline_test.go31
-rw-r--r--internal/db/db.go1
-rw-r--r--internal/db/mention.go3
-rw-r--r--internal/db/poll.go71
15 files changed, 1095 insertions, 65 deletions
diff --git a/internal/db/bundb/account_test.go b/internal/db/bundb/account_test.go
index b410bb3ed..8c2de5519 100644
--- a/internal/db/bundb/account_test.go
+++ b/internal/db/bundb/account_test.go
@@ -42,7 +42,7 @@ type AccountTestSuite struct {
func (suite *AccountTestSuite) TestGetAccountStatuses() {
statuses, err := suite.db.GetAccountStatuses(context.Background(), suite.testAccounts["local_account_1"].ID, 20, false, false, "", "", false, false)
suite.NoError(err)
- suite.Len(statuses, 5)
+ suite.Len(statuses, 6)
}
func (suite *AccountTestSuite) TestGetAccountStatusesPageDown() {
@@ -65,7 +65,7 @@ func (suite *AccountTestSuite) TestGetAccountStatusesPageDown() {
if err != nil {
suite.FailNow(err.Error())
}
- suite.Len(statuses, 1)
+ suite.Len(statuses, 2)
// try to get the last page (should be empty)
statuses, err = suite.db.GetAccountStatuses(context.Background(), suite.testAccounts["local_account_1"].ID, 2, false, false, statuses[len(statuses)-1].ID, "", false, false)
@@ -76,7 +76,7 @@ func (suite *AccountTestSuite) TestGetAccountStatusesPageDown() {
func (suite *AccountTestSuite) TestGetAccountStatusesExcludeRepliesAndReblogs() {
statuses, err := suite.db.GetAccountStatuses(context.Background(), suite.testAccounts["local_account_1"].ID, 20, true, true, "", "", false, false)
suite.NoError(err)
- suite.Len(statuses, 5)
+ suite.Len(statuses, 6)
}
func (suite *AccountTestSuite) TestGetAccountStatusesExcludeRepliesAndReblogsPublicOnly() {
@@ -306,7 +306,7 @@ func (suite *AccountTestSuite) TestUpdateAccount() {
func (suite *AccountTestSuite) TestGetAccountLastPosted() {
lastPosted, err := suite.db.GetAccountLastPosted(context.Background(), suite.testAccounts["local_account_1"].ID, false)
suite.NoError(err)
- suite.EqualValues(1653046675, lastPosted.Unix())
+ suite.EqualValues(1653046870, lastPosted.Unix())
}
func (suite *AccountTestSuite) TestGetAccountLastPostedWebOnly() {
diff --git a/internal/db/bundb/basic_test.go b/internal/db/bundb/basic_test.go
index a24deac9e..cef0617b7 100644
--- a/internal/db/bundb/basic_test.go
+++ b/internal/db/bundb/basic_test.go
@@ -121,7 +121,7 @@ func (suite *BasicTestSuite) TestGetAllStatuses() {
s := []*gtsmodel.Status{}
err := suite.db.GetAll(context.Background(), &s)
suite.NoError(err)
- suite.Len(s, 17)
+ suite.Len(s, 20)
}
func (suite *BasicTestSuite) TestGetAllNotNull() {
diff --git a/internal/db/bundb/bundb.go b/internal/db/bundb/bundb.go
index 393f32eec..a86a20274 100644
--- a/internal/db/bundb/bundb.go
+++ b/internal/db/bundb/bundb.go
@@ -71,6 +71,7 @@ type DBService struct {
db.Media
db.Mention
db.Notification
+ db.Poll
db.Relationship
db.Report
db.Rule
@@ -203,6 +204,10 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) {
db: db,
state: state,
},
+ Poll: &pollDB{
+ db: db,
+ state: state,
+ },
Relationship: &relationshipDB{
db: db,
state: state,
diff --git a/internal/db/bundb/bundb_test.go b/internal/db/bundb/bundb_test.go
index 8245937b9..037727090 100644
--- a/internal/db/bundb/bundb_test.go
+++ b/internal/db/bundb/bundb_test.go
@@ -54,6 +54,8 @@ type BunDBStandardTestSuite struct {
testMarkers map[string]*gtsmodel.Marker
testRules map[string]*gtsmodel.Rule
testThreads map[string]*gtsmodel.Thread
+ testPolls map[string]*gtsmodel.Poll
+ testPollVotes map[string]*gtsmodel.PollVote
}
func (suite *BunDBStandardTestSuite) SetupSuite() {
@@ -77,6 +79,8 @@ func (suite *BunDBStandardTestSuite) SetupSuite() {
suite.testMarkers = testrig.NewTestMarkers()
suite.testRules = testrig.NewTestRules()
suite.testThreads = testrig.NewTestThreads()
+ suite.testPolls = testrig.NewTestPolls()
+ suite.testPollVotes = testrig.NewTestPollVotes()
}
func (suite *BunDBStandardTestSuite) SetupTest() {
diff --git a/internal/db/bundb/instance_test.go b/internal/db/bundb/instance_test.go
index a825a3341..d88825a33 100644
--- a/internal/db/bundb/instance_test.go
+++ b/internal/db/bundb/instance_test.go
@@ -47,13 +47,13 @@ func (suite *InstanceTestSuite) TestCountInstanceUsersRemote() {
func (suite *InstanceTestSuite) TestCountInstanceStatuses() {
count, err := suite.db.CountInstanceStatuses(context.Background(), config.GetHost())
suite.NoError(err)
- suite.Equal(16, count)
+ suite.Equal(18, count)
}
func (suite *InstanceTestSuite) TestCountInstanceStatusesRemote() {
count, err := suite.db.CountInstanceStatuses(context.Background(), "fossbros-anonymous.io")
suite.NoError(err)
- suite.Equal(1, count)
+ suite.Equal(2, count)
}
func (suite *InstanceTestSuite) TestCountInstanceDomains() {
diff --git a/internal/db/bundb/mention.go b/internal/db/bundb/mention.go
index 547d8d0a8..30a20b0c1 100644
--- a/internal/db/bundb/mention.go
+++ b/internal/db/bundb/mention.go
@@ -20,10 +20,10 @@ package bundb
import (
"context"
"errors"
- "fmt"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
+ "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state"
@@ -54,31 +54,9 @@ func (m *mentionDB) GetMention(ctx context.Context, id string) (*gtsmodel.Mentio
return nil, err
}
- // Set the mention originating status.
- mention.Status, err = m.state.DB.GetStatusByID(
- gtscontext.SetBarebones(ctx),
- mention.StatusID,
- )
- if err != nil {
- return nil, fmt.Errorf("error populating mention status: %w", err)
- }
-
- // Set the mention origin account model.
- mention.OriginAccount, err = m.state.DB.GetAccountByID(
- gtscontext.SetBarebones(ctx),
- mention.OriginAccountID,
- )
- if err != nil {
- return nil, fmt.Errorf("error populating mention origin account: %w", err)
- }
-
- // Set the mention target account model.
- mention.TargetAccount, err = m.state.DB.GetAccountByID(
- gtscontext.SetBarebones(ctx),
- mention.TargetAccountID,
- )
- if err != nil {
- return nil, fmt.Errorf("error populating mention target account: %w", err)
+ // Further populate the mention fields where applicable.
+ if err := m.PopulateMention(ctx, mention); err != nil {
+ return nil, err
}
return mention, nil
@@ -102,6 +80,45 @@ func (m *mentionDB) GetMentions(ctx context.Context, ids []string) ([]*gtsmodel.
return mentions, nil
}
+func (m *mentionDB) PopulateMention(ctx context.Context, mention *gtsmodel.Mention) (err error) {
+ var errs gtserror.MultiError
+
+ if mention.Status == nil {
+ // Set the mention originating status.
+ mention.Status, err = m.state.DB.GetStatusByID(
+ gtscontext.SetBarebones(ctx),
+ mention.StatusID,
+ )
+ if err != nil {
+ return gtserror.Newf("error populating mention status: %w", err)
+ }
+ }
+
+ if mention.OriginAccount == nil {
+ // Set the mention origin account model.
+ mention.OriginAccount, err = m.state.DB.GetAccountByID(
+ gtscontext.SetBarebones(ctx),
+ mention.OriginAccountID,
+ )
+ if err != nil {
+ return gtserror.Newf("error populating mention origin account: %w", err)
+ }
+ }
+
+ if mention.TargetAccount == nil {
+ // Set the mention target account model.
+ mention.TargetAccount, err = m.state.DB.GetAccountByID(
+ gtscontext.SetBarebones(ctx),
+ mention.TargetAccountID,
+ )
+ if err != nil {
+ return gtserror.Newf("error populating mention target account: %w", err)
+ }
+ }
+
+ return errs.Combine()
+}
+
func (m *mentionDB) PutMention(ctx context.Context, mention *gtsmodel.Mention) error {
return m.state.Caches.GTS.Mention().Store(mention, func() error {
_, err := m.db.NewInsert().Model(mention).Exec(ctx)
diff --git a/internal/db/bundb/migrations/20231002153327_add_status_polls.go b/internal/db/bundb/migrations/20231002153327_add_status_polls.go
new file mode 100644
index 000000000..5e525cc27
--- /dev/null
+++ b/internal/db/bundb/migrations/20231002153327_add_status_polls.go
@@ -0,0 +1,65 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+package migrations
+
+import (
+ "context"
+ "strings"
+
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/uptrace/bun"
+)
+
+func init() {
+ up := func(ctx context.Context, db *bun.DB) error {
+ // Create `polls` + `poll_votes` tables.
+ for _, model := range []any{
+ &gtsmodel.Poll{},
+ &gtsmodel.PollVote{},
+ } {
+ _, err := db.NewCreateTable().
+ IfNotExists().
+ Model(model).
+ Exec(ctx)
+ if err != nil && !(strings.Contains(err.Error(), "already exists") || strings.Contains(err.Error(), "duplicate column name") || strings.Contains(err.Error(), "SQLSTATE 42701")) {
+ return err
+ }
+ }
+
+ // Add the new status `poll_id` column.
+ _, err := db.NewAddColumn().
+ Model(&gtsmodel.Status{}).
+ ColumnExpr("? CHAR(26)", bun.Ident("poll_id")).
+ Exec(ctx)
+ if err != nil && !(strings.Contains(err.Error(), "already exists") || strings.Contains(err.Error(), "duplicate column name") || strings.Contains(err.Error(), "SQLSTATE 42701")) {
+ return err
+ }
+
+ return nil
+ }
+
+ down := func(ctx context.Context, db *bun.DB) error {
+ return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
+ return nil
+ })
+ }
+
+ if err := Migrations.Register(up, down); err != nil {
+ panic(err)
+ }
+}
diff --git a/internal/db/bundb/poll.go b/internal/db/bundb/poll.go
new file mode 100644
index 000000000..84f160987
--- /dev/null
+++ b/internal/db/bundb/poll.go
@@ -0,0 +1,536 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+package bundb
+
+import (
+ "context"
+ "errors"
+ "time"
+
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtscontext"
+ "github.com/superseriousbusiness/gotosocial/internal/gtserror"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/log"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
+ "github.com/uptrace/bun"
+)
+
+type pollDB struct {
+ db *DB
+ state *state.State
+}
+
+func (p *pollDB) GetPollByID(ctx context.Context, id string) (*gtsmodel.Poll, error) {
+ return p.getPoll(
+ ctx,
+ "ID",
+ func(poll *gtsmodel.Poll) error {
+ return p.db.NewSelect().
+ Model(poll).
+ Where("? = ?", bun.Ident("poll.id"), id).
+ Scan(ctx)
+ },
+ id,
+ )
+}
+
+func (p *pollDB) GetPollByStatusID(ctx context.Context, statusID string) (*gtsmodel.Poll, error) {
+ return p.getPoll(
+ ctx,
+ "StatusID",
+ func(poll *gtsmodel.Poll) error {
+ return p.db.NewSelect().
+ Model(poll).
+ Where("? = ?", bun.Ident("poll.status_id"), statusID).
+ Scan(ctx)
+ },
+ statusID,
+ )
+}
+
+func (p *pollDB) getPoll(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Poll) error, keyParts ...any) (*gtsmodel.Poll, error) {
+ // Fetch poll from database cache with loader callback
+ poll, err := p.state.Caches.GTS.Poll().Load(lookup, func() (*gtsmodel.Poll, error) {
+ var poll gtsmodel.Poll
+
+ // Not cached! Perform database query.
+ if err := dbQuery(&poll); err != nil {
+ return nil, err
+ }
+
+ // Ensure vote slice
+ // is non nil and set.
+ poll.CheckVotes()
+
+ return &poll, nil
+ }, keyParts...)
+ if err != nil {
+ return nil, err
+ }
+
+ if gtscontext.Barebones(ctx) {
+ // no need to fully populate.
+ return poll, nil
+ }
+
+ // Further populate the poll fields where applicable.
+ if err := p.PopulatePoll(ctx, poll); err != nil {
+ return nil, err
+ }
+
+ return poll, nil
+}
+
+func (p *pollDB) GetOpenPolls(ctx context.Context) ([]*gtsmodel.Poll, error) {
+ var pollIDs []string
+
+ // Select all polls with unset `closed_at` time.
+ if err := p.db.NewSelect().
+ Table("polls").
+ Column("polls.id").
+ Join("JOIN ? ON ? = ?", bun.Ident("statuses"), bun.Ident("polls.id"), bun.Ident("statuses.poll_id")).
+ Where("? = true", bun.Ident("statuses.local")).
+ Where("? IS NULL", bun.Ident("polls.closed_at")).
+ Scan(ctx, &pollIDs); err != nil {
+ return nil, err
+ }
+
+ // Preallocate a slice to contain the poll models.
+ polls := make([]*gtsmodel.Poll, 0, len(pollIDs))
+
+ for _, id := range pollIDs {
+ // Attempt to fetch poll from DB.
+ poll, err := p.GetPollByID(ctx, id)
+ if err != nil {
+ log.Errorf(ctx, "error getting poll %s: %v", id, err)
+ continue
+ }
+
+ // Append poll to return slice.
+ polls = append(polls, poll)
+ }
+
+ return polls, nil
+}
+
+func (p *pollDB) PopulatePoll(ctx context.Context, poll *gtsmodel.Poll) error {
+ var (
+ err error
+ errs gtserror.MultiError
+ )
+
+ if poll.Status == nil {
+ // Vote account is not set, fetch from database.
+ poll.Status, err = p.state.DB.GetStatusByID(
+ gtscontext.SetBarebones(ctx),
+ poll.StatusID,
+ )
+ if err != nil {
+ errs.Appendf("error populating poll status: %w", err)
+ }
+ }
+
+ return errs.Combine()
+}
+
+func (p *pollDB) PutPoll(ctx context.Context, poll *gtsmodel.Poll) error {
+ // Ensure vote slice
+ // is non nil and set.
+ poll.CheckVotes()
+
+ return p.state.Caches.GTS.Poll().Store(poll, func() error {
+ _, err := p.db.NewInsert().Model(poll).Exec(ctx)
+ return err
+ })
+}
+
+func (p *pollDB) UpdatePoll(ctx context.Context, poll *gtsmodel.Poll, cols ...string) error {
+ // Ensure vote slice
+ // is non nil and set.
+ poll.CheckVotes()
+
+ return p.state.Caches.GTS.Poll().Store(poll, func() error {
+ return p.db.RunInTx(ctx, func(tx Tx) error {
+ // Update the status' "updated_at" field.
+ if _, err := tx.NewUpdate().
+ Table("statuses").
+ Where("? = ?", bun.Ident("id"), poll.StatusID).
+ SetColumn("updated_at", "?", time.Now()).
+ Exec(ctx); err != nil {
+ return err
+ }
+
+ // Finally, update poll
+ // columns in database.
+ _, err := tx.NewUpdate().
+ Model(poll).
+ Column(cols...).
+ Where("? = ?", bun.Ident("id"), poll.ID).
+ Exec(ctx)
+ return err
+ })
+ })
+}
+
+func (p *pollDB) DeletePollByID(ctx context.Context, id string) error {
+ // Delete poll by ID from database.
+ if _, err := p.db.NewDelete().
+ Table("polls").
+ Where("? = ?", bun.Ident("id"), id).
+ Exec(ctx); err != nil {
+ return err
+ }
+
+ // Invalidate poll by ID from cache.
+ p.state.Caches.GTS.Poll().Invalidate("ID", id)
+ p.state.Caches.GTS.PollVoteIDs().Invalidate(id)
+
+ return nil
+}
+
+func (p *pollDB) GetPollVoteByID(ctx context.Context, id string) (*gtsmodel.PollVote, error) {
+ return p.getPollVote(
+ ctx,
+ "ID",
+ func(vote *gtsmodel.PollVote) error {
+ return p.db.NewSelect().
+ Model(vote).
+ Where("? = ?", bun.Ident("poll_vote.id"), id).
+ Scan(ctx)
+ },
+ id,
+ )
+}
+
+func (p *pollDB) GetPollVoteBy(ctx context.Context, pollID string, accountID string) (*gtsmodel.PollVote, error) {
+ return p.getPollVote(
+ ctx,
+ "PollID.AccountID",
+ func(vote *gtsmodel.PollVote) error {
+ return p.db.NewSelect().
+ Model(vote).
+ Where("? = ?", bun.Ident("poll_vote.account_id"), accountID).
+ Where("? = ?", bun.Ident("poll_vote.poll_id"), pollID).
+ Scan(ctx)
+ },
+ pollID,
+ accountID,
+ )
+}
+
+func (p *pollDB) getPollVote(ctx context.Context, lookup string, dbQuery func(*gtsmodel.PollVote) error, keyParts ...any) (*gtsmodel.PollVote, error) {
+ // Fetch vote from database cache with loader callback
+ vote, err := p.state.Caches.GTS.PollVote().Load(lookup, func() (*gtsmodel.PollVote, error) {
+ var vote gtsmodel.PollVote
+
+ // Not cached! Perform database query.
+ if err := dbQuery(&vote); err != nil {
+ return nil, err
+ }
+
+ return &vote, nil
+ }, keyParts...)
+ if err != nil {
+ return nil, err
+ }
+
+ if gtscontext.Barebones(ctx) {
+ // no need to fully populate.
+ return vote, nil
+ }
+
+ // Further populate the vote fields where applicable.
+ if err := p.PopulatePollVote(ctx, vote); err != nil {
+ return nil, err
+ }
+
+ return vote, nil
+}
+
+func (p *pollDB) GetPollVotes(ctx context.Context, pollID string) ([]*gtsmodel.PollVote, error) {
+ voteIDs, err := p.state.Caches.GTS.PollVoteIDs().Load(pollID, func() ([]string, error) {
+ var voteIDs []string
+
+ // Vote IDs not in cache, perform DB query!
+ q := newSelectPollVotes(p.db, pollID)
+ if _, err := q.Exec(ctx, &voteIDs); // nocollapse
+ err != nil && !errors.Is(err, db.ErrNoEntries) {
+ return nil, err
+ }
+
+ return voteIDs, nil
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ // Preallocate slice of expected length.
+ votes := make([]*gtsmodel.PollVote, 0, len(voteIDs))
+
+ for _, id := range voteIDs {
+ // Fetch poll vote model for this ID.
+ vote, err := p.GetPollVoteByID(ctx, id)
+ if err != nil {
+ log.Errorf(ctx, "error getting poll vote %s: %v", id, err)
+ continue
+ }
+
+ // Append to return slice.
+ votes = append(votes, vote)
+ }
+
+ return votes, nil
+}
+
+func (p *pollDB) PopulatePollVote(ctx context.Context, vote *gtsmodel.PollVote) error {
+ var (
+ err error
+ errs gtserror.MultiError
+ )
+
+ if vote.Account == nil {
+ // Vote account is not set, fetch from database.
+ vote.Account, err = p.state.DB.GetAccountByID(
+ gtscontext.SetBarebones(ctx),
+ vote.AccountID,
+ )
+ if err != nil {
+ errs.Appendf("error populating vote account: %w", err)
+ }
+ }
+
+ if vote.Poll == nil {
+ // Vote poll is not set, fetch from database.
+ vote.Poll, err = p.GetPollByID(
+ gtscontext.SetBarebones(ctx),
+ vote.PollID,
+ )
+ if err != nil {
+ errs.Appendf("error populating vote poll: %w", err)
+ }
+ }
+
+ return errs.Combine()
+}
+
+func (p *pollDB) PutPollVote(ctx context.Context, vote *gtsmodel.PollVote) error {
+ return p.state.Caches.GTS.PollVote().Store(vote, func() error {
+ return p.db.RunInTx(ctx, func(tx Tx) error {
+ // Try insert vote into database.
+ if _, err := tx.NewInsert().
+ Model(vote).
+ Exec(ctx); err != nil {
+ return err
+ }
+
+ var poll gtsmodel.Poll
+
+ // Select poll counts from DB.
+ if err := tx.NewSelect().
+ Model(&poll).
+ Where("? = ?", bun.Ident("id"), vote.PollID).
+ Scan(ctx); err != nil {
+ return err
+ }
+
+ // Increment poll votes for choices.
+ poll.IncrementVotes(vote.Choices)
+
+ // Finally, update the poll entry.
+ _, err := tx.NewUpdate().
+ Model(&poll).
+ Column("votes", "voters").
+ Where("? = ?", bun.Ident("id"), vote.PollID).
+ Exec(ctx)
+ return err
+ })
+ })
+}
+
+func (p *pollDB) DeletePollVotes(ctx context.Context, pollID string) error {
+ err := p.db.RunInTx(ctx, func(tx Tx) error {
+ // Delete all vote in poll,
+ // returning all vote choices.
+ switch _, err := tx.NewDelete().
+ Table("poll_votes").
+ Where("? = ?", bun.Ident("poll_id"), pollID).
+ Exec(ctx); {
+
+ case err == nil:
+ // no issue.
+
+ case errors.Is(err, db.ErrNoEntries):
+ // no votes found,
+ // return here.
+ return nil
+
+ default:
+ // irrecoverable.
+ return err
+ }
+
+ var poll gtsmodel.Poll
+
+ // Select poll counts from DB.
+ switch err := tx.NewSelect().
+ Model(&poll).
+ Where("? = ?", bun.Ident("id"), pollID).
+ Scan(ctx); {
+
+ case err == nil:
+ // no issue.
+
+ case errors.Is(err, db.ErrNoEntries):
+ // no votes found,
+ // return here.
+ return nil
+
+ default:
+ // irrecoverable.
+ return err
+ }
+
+ // Zero all counts.
+ poll.ResetVotes()
+
+ // Finally, update the poll entry.
+ _, err := tx.NewUpdate().
+ Model(&poll).
+ Column("votes", "voters").
+ Where("? = ?", bun.Ident("id"), pollID).
+ Exec(ctx)
+ return err
+ })
+
+ if err != nil {
+ return err
+ }
+
+ // Invalidate poll vote and poll entry from caches.
+ p.state.Caches.GTS.Poll().Invalidate("ID", pollID)
+ p.state.Caches.GTS.PollVote().Invalidate("PollID", pollID)
+ p.state.Caches.GTS.PollVoteIDs().Invalidate(pollID)
+
+ return nil
+}
+
+func (p *pollDB) DeletePollVoteBy(ctx context.Context, pollID string, accountID string) error {
+ err := p.db.RunInTx(ctx, func(tx Tx) error {
+ var choices []int
+
+ // Delete vote in poll by account,
+ // returning the ID + choices of the vote.
+ switch err := tx.NewDelete().
+ Table("poll_votes").
+ Where("? = ?", bun.Ident("poll_id"), pollID).
+ Where("? = ?", bun.Ident("account_id"), accountID).
+ Returning("choices").
+ Scan(ctx, &choices); {
+
+ case err == nil:
+ // no issue.
+
+ case errors.Is(err, db.ErrNoEntries):
+ // no votes found,
+ // return here.
+ return nil
+
+ default:
+ // irrecoverable.
+ return err
+ }
+
+ var poll gtsmodel.Poll
+
+ // Select poll counts from DB.
+ switch err := tx.NewSelect().
+ Model(&poll).
+ Where("? = ?", bun.Ident("id"), pollID).
+ Scan(ctx); {
+
+ case err == nil:
+ // no issue.
+
+ case errors.Is(err, db.ErrNoEntries):
+ // no votes found,
+ // return here.
+ return nil
+
+ default:
+ // irrecoverable.
+ return err
+ }
+
+ // Decrement votes for choices.
+ poll.IncrementVotes(choices)
+
+ // Finally, update the poll entry.
+ _, err := tx.NewUpdate().
+ Model(&poll).
+ Column("votes", "voters").
+ Where("? = ?", bun.Ident("id"), pollID).
+ Exec(ctx)
+ return err
+ })
+
+ if err != nil {
+ return err
+ }
+
+ // Invalidate poll vote and poll entry from caches.
+ p.state.Caches.GTS.Poll().Invalidate("ID", pollID)
+ p.state.Caches.GTS.PollVote().Invalidate("PollID.AccountID", pollID, accountID)
+ p.state.Caches.GTS.PollVoteIDs().Invalidate(pollID)
+
+ return nil
+}
+
+func (p *pollDB) DeletePollVotesByAccountID(ctx context.Context, accountID string) error {
+ var pollIDs []string
+
+ // Select all polls this account
+ // has registered a poll vote in.
+ if err := p.db.NewSelect().
+ Table("poll_votes").
+ Column("poll_id").
+ Where("? = ?", bun.Ident("account_id"), accountID).
+ Scan(ctx, &pollIDs); err != nil &&
+ !errors.Is(err, db.ErrNoEntries) {
+ return err
+ }
+
+ for _, id := range pollIDs {
+ // Delete all votes by this account in each of the polls,
+ // this way ensures that all necessary caches are invalidated.
+ if err := p.DeletePollVoteBy(ctx, id, accountID); err != nil {
+ log.Errorf(ctx, "error deleting vote by %s in %s: %v", accountID, id, err)
+ }
+ }
+
+ return nil
+}
+
+// newSelectPollVotes returns a new select query for all rows in the poll_votes table with poll_id = pollID.
+func newSelectPollVotes(db *DB, pollID string) *bun.SelectQuery {
+ return db.NewSelect().
+ TableExpr("?", bun.Ident("poll_votes")).
+ ColumnExpr("?", bun.Ident("id")).
+ Where("? = ?", bun.Ident("poll_id"), pollID).
+ OrderExpr("? DESC", bun.Ident("id"))
+}
diff --git a/internal/db/bundb/poll_test.go b/internal/db/bundb/poll_test.go
new file mode 100644
index 000000000..53da2514b
--- /dev/null
+++ b/internal/db/bundb/poll_test.go
@@ -0,0 +1,318 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+package bundb_test
+
+import (
+ "context"
+ "errors"
+ "math/rand"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/suite"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/id"
+ "github.com/superseriousbusiness/gotosocial/internal/util"
+)
+
+type PollTestSuite struct {
+ BunDBStandardTestSuite
+}
+
+func (suite *PollTestSuite) TestGetPollBy() {
+ t := suite.T()
+
+ // Create a new context for this test.
+ ctx, cncl := context.WithCancel(context.Background())
+ defer cncl()
+
+ // Sentinel error to mark avoiding a test case.
+ sentinelErr := errors.New("sentinel")
+
+ // isEqual checks if 2 poll models are equal.
+ isEqual := func(p1, p2 gtsmodel.Poll) bool {
+ // Clear populated sub-models.
+ p1.Status = nil
+ p2.Status = nil
+
+ // Localize all of the time fields.
+ p1.ExpiresAt = p1.ExpiresAt.Local()
+ p2.ExpiresAt = p2.ExpiresAt.Local()
+ p1.ClosedAt = p1.ClosedAt.Local()
+ p2.ClosedAt = p2.ClosedAt.Local()
+
+ // Perform the comparison.
+ return suite.Equal(p1, p2)
+ }
+
+ for _, poll := range suite.testPolls {
+ for lookup, dbfunc := range map[string]func() (*gtsmodel.Poll, error){
+ "id": func() (*gtsmodel.Poll, error) {
+ return suite.db.GetPollByID(ctx, poll.ID)
+ },
+
+ "status_id": func() (*gtsmodel.Poll, error) {
+ return suite.db.GetPollByStatusID(ctx, poll.StatusID)
+ },
+ } {
+
+ // Clear database caches.
+ suite.state.Caches.Init()
+
+ t.Logf("checking database lookup %q", lookup)
+
+ // Perform database function.
+ checkPoll, err := dbfunc()
+ if err != nil {
+ if err == sentinelErr {
+ continue
+ }
+
+ t.Errorf("error encountered for database lookup %q: %v", lookup, err)
+ continue
+ }
+
+ // Check received account data.
+ if !isEqual(*checkPoll, *poll) {
+ t.Errorf("poll does not contain expected data: %+v", checkPoll)
+ continue
+ }
+
+ // Check that poll source status populated.
+ if poll.StatusID != (*checkPoll).Status.ID {
+ t.Errorf("poll source status not correctly populated for: %+v", poll)
+ continue
+ }
+ }
+ }
+}
+
+func (suite *PollTestSuite) TestGetPollVoteBy() {
+ t := suite.T()
+
+ // Create a new context for this test.
+ ctx, cncl := context.WithCancel(context.Background())
+ defer cncl()
+
+ // Sentinel error to mark avoiding a test case.
+ sentinelErr := errors.New("sentinel")
+
+ // isEqual checks if 2 poll vote models are equal.
+ isEqual := func(v1, v2 gtsmodel.PollVote) bool {
+ // Clear populated sub-models.
+ v1.Poll = nil
+ v2.Poll = nil
+ v1.Account = nil
+ v2.Account = nil
+
+ // Localize all of the time fields.
+ v1.CreatedAt = v1.CreatedAt.Local()
+ v2.CreatedAt = v2.CreatedAt.Local()
+
+ // Perform the comparison.
+ return suite.Equal(v1, v2)
+ }
+
+ for _, vote := range suite.testPollVotes {
+ for lookup, dbfunc := range map[string]func() (*gtsmodel.PollVote, error){
+ "id": func() (*gtsmodel.PollVote, error) {
+ return suite.db.GetPollVoteByID(ctx, vote.ID)
+ },
+
+ "poll_id_account_id": func() (*gtsmodel.PollVote, error) {
+ return suite.db.GetPollVoteBy(ctx, vote.PollID, vote.AccountID)
+ },
+ } {
+
+ // Clear database caches.
+ suite.state.Caches.Init()
+
+ t.Logf("checking database lookup %q", lookup)
+
+ // Perform database function.
+ checkVote, err := dbfunc()
+ if err != nil {
+ if err == sentinelErr {
+ continue
+ }
+
+ t.Errorf("error encountered for database lookup %q: %v", lookup, err)
+ continue
+ }
+
+ // Check received account data.
+ if !isEqual(*checkVote, *vote) {
+ t.Errorf("poll vote does not contain expected data: %+v", checkVote)
+ continue
+ }
+
+ // Check that vote source poll populated.
+ if checkVote.PollID != (*checkVote).Poll.ID {
+ t.Errorf("vote source poll not correctly populated for: %+v", vote)
+ continue
+ }
+
+ // Check that vote author account populated.
+ if checkVote.AccountID != (*checkVote).Account.ID {
+ t.Errorf("vote author account not correctly populated for: %+v", vote)
+ continue
+ }
+ }
+ }
+}
+
+func (suite *PollTestSuite) TestUpdatePoll() {
+ // Create a new context for this test.
+ ctx, cncl := context.WithCancel(context.Background())
+ defer cncl()
+
+ for _, poll := range suite.testPolls {
+ // Take copy of poll.
+ poll := util.Ptr(*poll)
+
+ // Update the poll closed field.
+ poll.ClosedAt = time.Now()
+
+ // Update poll model in the database.
+ err := suite.db.UpdatePoll(ctx, poll)
+ suite.NoError(err)
+
+ // Refetch poll from database to get latest.
+ latest, err := suite.db.GetPollByID(ctx, poll.ID)
+ suite.NoError(err)
+
+ // The latest poll should have updated closedAt.
+ suite.Equal(poll.ClosedAt, latest.ClosedAt)
+ }
+}
+
+func (suite *PollTestSuite) TestPutPoll() {
+ // Create a new context for this test.
+ ctx, cncl := context.WithCancel(context.Background())
+ defer cncl()
+
+ for _, poll := range suite.testPolls {
+ // Delete this poll from the database.
+ err := suite.db.DeletePollByID(ctx, poll.ID)
+ suite.NoError(err)
+
+ // Ensure that afterwards we can
+ // enter it again into database.
+ err = suite.db.PutPoll(ctx, poll)
+
+ // Ensure that afterwards we can fetch poll.
+ _, err = suite.db.GetPollByID(ctx, poll.ID)
+ suite.NoError(err)
+ }
+}
+
+func (suite *PollTestSuite) TestPutPollVote() {
+ // Create a new context for this test.
+ ctx, cncl := context.WithCancel(context.Background())
+ defer cncl()
+
+ // randomChoices generates random vote choices in poll.
+ randomChoices := func(poll *gtsmodel.Poll) []int {
+ var max int
+ if *poll.Multiple {
+ max = len(poll.Options)
+ } else {
+ max = 1
+ }
+ count := 1 + rand.Intn(max)
+ choices := make([]int, count)
+ for i := range choices {
+ choices[i] = rand.Intn(len(poll.Options))
+ }
+ return choices
+ }
+
+ for _, poll := range suite.testPolls {
+ // Create a new vote to insert for poll.
+ vote := &gtsmodel.PollVote{
+ ID: id.NewULID(),
+ Choices: randomChoices(poll),
+ PollID: poll.ID,
+ AccountID: id.NewULID(), // random account, doesn't matter
+ }
+
+ // Insert this new vote into database.
+ err := suite.db.PutPollVote(ctx, vote)
+ suite.NoError(err)
+
+ // Fetch latest version of poll from database.
+ latest, err := suite.db.GetPollByID(ctx, poll.ID)
+ suite.NoError(err)
+
+ // Decr latest version choices by new vote's.
+ for _, choice := range vote.Choices {
+ latest.Votes[choice]--
+ }
+ (*latest.Voters)--
+
+ // Old poll and latest model after decr
+ // should have equal vote + voter counts.
+ suite.Equal(poll.Voters, latest.Voters)
+ suite.Equal(poll.Votes, latest.Votes)
+ }
+}
+
+func (suite *PollTestSuite) TestDeletePoll() {
+ // Create a new context for this test.
+ ctx, cncl := context.WithCancel(context.Background())
+ defer cncl()
+
+ for _, poll := range suite.testPolls {
+ // Delete this poll from the database.
+ err := suite.db.DeletePollByID(ctx, poll.ID)
+ suite.NoError(err)
+
+ // Ensure that afterwards we cannot fetch poll.
+ _, err = suite.db.GetPollByID(ctx, poll.ID)
+ suite.ErrorIs(err, db.ErrNoEntries)
+
+ // Or again by the status it's attached to.
+ _, err = suite.db.GetPollByStatusID(ctx, poll.StatusID)
+ suite.ErrorIs(err, db.ErrNoEntries)
+ }
+}
+
+func (suite *PollTestSuite) TestDeletePollVotes() {
+ // Create a new context for this test.
+ ctx, cncl := context.WithCancel(context.Background())
+ defer cncl()
+
+ for _, poll := range suite.testPolls {
+ // Delete votes associated with poll from database.
+ err := suite.db.DeletePollVotes(ctx, poll.ID)
+ suite.NoError(err)
+
+ // Fetch latest version of poll from database.
+ poll, err = suite.db.GetPollByID(ctx, poll.ID)
+ suite.NoError(err)
+
+ // Check that poll counts are all zero.
+ suite.Equal(*poll.Voters, 0)
+ suite.Equal(poll.Votes, make([]int, len(poll.Options)))
+ }
+}
+
+func TestPollTestSuite(t *testing.T) {
+ suite.Run(t, new(PollTestSuite))
+}
diff --git a/internal/db/bundb/relationship.go b/internal/db/bundb/relationship.go
index 822e697c1..138a5aa17 100644
--- a/internal/db/bundb/relationship.go
+++ b/internal/db/bundb/relationship.go
@@ -199,7 +199,8 @@ func (r *relationshipDB) getAccountFollowIDs(ctx context.Context, accountID stri
// Follow IDs not in cache, perform DB query!
q := newSelectFollows(r.db, accountID)
- if _, err := q.Exec(ctx, &followIDs); err != nil {
+ if _, err := q.Exec(ctx, &followIDs); // nocollapse
+ err != nil && !errors.Is(err, db.ErrNoEntries) {
return nil, err
}
@@ -213,7 +214,8 @@ func (r *relationshipDB) getAccountLocalFollowIDs(ctx context.Context, accountID
// Follow IDs not in cache, perform DB query!
q := newSelectLocalFollows(r.db, accountID)
- if _, err := q.Exec(ctx, &followIDs); err != nil {
+ if _, err := q.Exec(ctx, &followIDs); // nocollapse
+ err != nil && !errors.Is(err, db.ErrNoEntries) {
return nil, err
}
@@ -227,7 +229,8 @@ func (r *relationshipDB) getAccountFollowerIDs(ctx context.Context, accountID st
// Follow IDs not in cache, perform DB query!
q := newSelectFollowers(r.db, accountID)
- if _, err := q.Exec(ctx, &followIDs); err != nil {
+ if _, err := q.Exec(ctx, &followIDs); // nocollapse
+ err != nil && !errors.Is(err, db.ErrNoEntries) {
return nil, err
}
@@ -241,7 +244,8 @@ func (r *relationshipDB) getAccountLocalFollowerIDs(ctx context.Context, account
// Follow IDs not in cache, perform DB query!
q := newSelectLocalFollowers(r.db, accountID)
- if _, err := q.Exec(ctx, &followIDs); err != nil {
+ if _, err := q.Exec(ctx, &followIDs); // nocollapse
+ err != nil && !errors.Is(err, db.ErrNoEntries) {
return nil, err
}
@@ -255,7 +259,8 @@ func (r *relationshipDB) getAccountFollowRequestIDs(ctx context.Context, account
// Follow request IDs not in cache, perform DB query!
q := newSelectFollowRequests(r.db, accountID)
- if _, err := q.Exec(ctx, &followReqIDs); err != nil {
+ if _, err := q.Exec(ctx, &followReqIDs); // nocollapse
+ err != nil && !errors.Is(err, db.ErrNoEntries) {
return nil, err
}
@@ -269,7 +274,8 @@ func (r *relationshipDB) getAccountFollowRequestingIDs(ctx context.Context, acco
// Follow request IDs not in cache, perform DB query!
q := newSelectFollowRequesting(r.db, accountID)
- if _, err := q.Exec(ctx, &followReqIDs); err != nil {
+ if _, err := q.Exec(ctx, &followReqIDs); // nocollapse
+ err != nil && !errors.Is(err, db.ErrNoEntries) {
return nil, err
}
@@ -283,7 +289,8 @@ func (r *relationshipDB) getAccountBlockIDs(ctx context.Context, accountID strin
// Block IDs not in cache, perform DB query!
q := newSelectBlocks(r.db, accountID)
- if _, err := q.Exec(ctx, &blockIDs); err != nil {
+ if _, err := q.Exec(ctx, &blockIDs); // nocollapse
+ err != nil && !errors.Is(err, db.ErrNoEntries) {
return nil, err
}
diff --git a/internal/db/bundb/status.go b/internal/db/bundb/status.go
index 0bd4ba1a9..7f274d693 100644
--- a/internal/db/bundb/status.go
+++ b/internal/db/bundb/status.go
@@ -154,17 +154,6 @@ func (s *statusDB) PopulateStatus(ctx context.Context, status *gtsmodel.Status)
}
}
- if status.InReplyToID != "" && status.InReplyTo == nil {
- // Status parent is not set, fetch from database.
- status.InReplyTo, err = s.GetStatusByID(
- gtscontext.SetBarebones(ctx),
- status.InReplyToID,
- )
- if err != nil {
- errs.Appendf("error populating status parent: %w", err)
- }
- }
-
if status.InReplyToID != "" {
if status.InReplyTo == nil {
// Status parent is not set, fetch from database.
@@ -213,6 +202,17 @@ func (s *statusDB) PopulateStatus(ctx context.Context, status *gtsmodel.Status)
}
}
+ if status.PollID != "" && status.Poll == nil {
+ // Status poll is not set, fetch from database.
+ status.Poll, err = s.state.DB.GetPollByID(
+ gtscontext.SetBarebones(ctx),
+ status.PollID,
+ )
+ if err != nil {
+ errs.Appendf("error populating status poll: %w", err)
+ }
+ }
+
if !status.AttachmentsPopulated() {
// Status attachments are out-of-date with IDs, repopulate.
status.Attachments, err = s.state.DB.GetAttachmentsByIDs(
diff --git a/internal/db/bundb/timeline_test.go b/internal/db/bundb/timeline_test.go
index ac169ec4a..b3ce91755 100644
--- a/internal/db/bundb/timeline_test.go
+++ b/internal/db/bundb/timeline_test.go
@@ -22,6 +22,7 @@ import (
"testing"
"time"
+ "codeberg.org/gruf/go-kv"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/ap"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
@@ -73,20 +74,18 @@ func getFutureStatus() *gtsmodel.Status {
func (suite *TimelineTestSuite) publicCount() int {
var publicCount int
-
for _, status := range suite.testStatuses {
if status.Visibility == gtsmodel.VisibilityPublic &&
status.BoostOfID == "" {
publicCount++
}
}
-
return publicCount
}
func (suite *TimelineTestSuite) checkStatuses(statuses []*gtsmodel.Status, maxID string, minID string, expectedLength int) {
if l := len(statuses); l != expectedLength {
- suite.FailNow("", "expected %d statuses in slice, got %d", expectedLength, l)
+ suite.FailNowf("", "expected %d statuses in slice, got %d", expectedLength, l)
} else if l == 0 {
// Can't test empty slice.
return
@@ -98,15 +97,15 @@ func (suite *TimelineTestSuite) checkStatuses(statuses []*gtsmodel.Status, maxID
id := status.ID
if id >= maxID {
- suite.FailNow("", "%s greater than maxID %s", id, maxID)
+ suite.FailNowf("", "%s greater than maxID %s", id, maxID)
}
if id <= minID {
- suite.FailNow("", "%s smaller than minID %s", id, minID)
+ suite.FailNowf("", "%s smaller than minID %s", id, minID)
}
if id > highest {
- suite.FailNow("", "statuses in slice were not ordered highest -> lowest ID")
+ suite.FailNowf("", "statuses in slice were not ordered highest -> lowest ID")
}
highest = id
@@ -121,6 +120,10 @@ func (suite *TimelineTestSuite) TestGetPublicTimeline() {
suite.FailNow(err.Error())
}
+ suite.T().Log(kv.Field{
+ K: "statuses", V: s,
+ })
+
suite.checkStatuses(s, id.Highest, id.Lowest, suite.publicCount())
}
@@ -154,7 +157,7 @@ func (suite *TimelineTestSuite) TestGetHomeTimeline() {
suite.FailNow(err.Error())
}
- suite.checkStatuses(s, id.Highest, id.Lowest, 16)
+ suite.checkStatuses(s, id.Highest, id.Lowest, 18)
}
func (suite *TimelineTestSuite) TestGetHomeTimelineNoFollowing() {
@@ -186,7 +189,7 @@ func (suite *TimelineTestSuite) TestGetHomeTimelineNoFollowing() {
suite.FailNow(err.Error())
}
- suite.checkStatuses(s, id.Highest, id.Lowest, 5)
+ suite.checkStatuses(s, id.Highest, id.Lowest, 6)
}
func (suite *TimelineTestSuite) TestGetHomeTimelineWithFutureStatus() {
@@ -208,7 +211,7 @@ func (suite *TimelineTestSuite) TestGetHomeTimelineWithFutureStatus() {
}
suite.NotContains(s, futureStatus)
- suite.checkStatuses(s, id.Highest, id.Lowest, 16)
+ suite.checkStatuses(s, id.Highest, id.Lowest, 18)
}
func (suite *TimelineTestSuite) TestGetHomeTimelineBackToFront() {
@@ -239,8 +242,8 @@ func (suite *TimelineTestSuite) TestGetHomeTimelineFromHighest() {
}
suite.checkStatuses(s, id.Highest, id.Lowest, 5)
- suite.Equal("01G36SF3V6Y6V5BF9P4R7PQG7G", s[0].ID)
- suite.Equal("01FCTA44PW9H1TB328S9AQXKDS", s[len(s)-1].ID)
+ suite.Equal("01HEN2RZ8BG29Y5Z9VJC73HZW7", s[0].ID)
+ suite.Equal("01FN3VJGFH10KR7S2PB0GFJZYG", s[len(s)-1].ID)
}
func (suite *TimelineTestSuite) TestGetListTimelineNoParams() {
@@ -254,7 +257,7 @@ func (suite *TimelineTestSuite) TestGetListTimelineNoParams() {
suite.FailNow(err.Error())
}
- suite.checkStatuses(s, id.Highest, id.Lowest, 11)
+ suite.checkStatuses(s, id.Highest, id.Lowest, 12)
}
func (suite *TimelineTestSuite) TestGetListTimelineMaxID() {
@@ -269,8 +272,8 @@ func (suite *TimelineTestSuite) TestGetListTimelineMaxID() {
}
suite.checkStatuses(s, id.Highest, id.Lowest, 5)
- suite.Equal("01G36SF3V6Y6V5BF9P4R7PQG7G", s[0].ID)
- suite.Equal("01FCQSQ667XHJ9AV9T27SJJSX5", s[len(s)-1].ID)
+ suite.Equal("01HEN2PRXT0TF4YDRA64FZZRN7", s[0].ID)
+ suite.Equal("01FF25D5Q0DH7CHD57CTRS6WK0", s[len(s)-1].ID)
}
func (suite *TimelineTestSuite) TestGetListTimelineMinID() {
diff --git a/internal/db/db.go b/internal/db/db.go
index 41b253834..2914d9b59 100644
--- a/internal/db/db.go
+++ b/internal/db/db.go
@@ -36,6 +36,7 @@ type DB interface {
Media
Mention
Notification
+ Poll
Relationship
Report
Rule
diff --git a/internal/db/mention.go b/internal/db/mention.go
index d4125031e..994ec04b5 100644
--- a/internal/db/mention.go
+++ b/internal/db/mention.go
@@ -31,6 +31,9 @@ type Mention interface {
// GetMentions gets multiple mentions.
GetMentions(ctx context.Context, ids []string) ([]*gtsmodel.Mention, error)
+ // PopulateMention ensures that all sub-models of a mention are populated (e.g. accounts).
+ PopulateMention(ctx context.Context, mention *gtsmodel.Mention) error
+
// PutMention will insert the given mention into the database.
PutMention(ctx context.Context, mention *gtsmodel.Mention) error
diff --git a/internal/db/poll.go b/internal/db/poll.go
new file mode 100644
index 000000000..b59d27c73
--- /dev/null
+++ b/internal/db/poll.go
@@ -0,0 +1,71 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+package db
+
+import (
+ "context"
+
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+)
+
+type Poll interface {
+ // GetPollByID fetches the Poll with given ID from the database.
+ GetPollByID(ctx context.Context, id string) (*gtsmodel.Poll, error)
+
+ // GetPollByStatusID fetches the Poll with given status ID column value from the database.
+ GetPollByStatusID(ctx context.Context, statusID string) (*gtsmodel.Poll, error)
+
+ // GetOpenPolls fetches all local Polls in the database with an unset `closed_at` column.
+ GetOpenPolls(ctx context.Context) ([]*gtsmodel.Poll, error)
+
+ // PopulatePoll ensures the given Poll is fully populated with all other related database models.
+ PopulatePoll(ctx context.Context, poll *gtsmodel.Poll) error
+
+ // PutPoll puts the given Poll in the database.
+ PutPoll(ctx context.Context, poll *gtsmodel.Poll) error
+
+ // UpdatePoll updates the Poll in the database, only on selected columns if provided (else, all).
+ UpdatePoll(ctx context.Context, poll *gtsmodel.Poll, cols ...string) error
+
+ // DeletePollByID deletes the Poll with given ID from the database.
+ DeletePollByID(ctx context.Context, id string) error
+
+ // GetPollVoteByID gets the PollVote with given ID from the database.
+ GetPollVoteByID(ctx context.Context, id string) (*gtsmodel.PollVote, error)
+
+ // GetPollVotesBy fetches the PollVote in Poll with ID, by account ID, from the database.
+ GetPollVoteBy(ctx context.Context, pollID string, accountID string) (*gtsmodel.PollVote, error)
+
+ // GetPollVotes fetches all PollVotes in Poll with ID, from the database.
+ GetPollVotes(ctx context.Context, pollID string) ([]*gtsmodel.PollVote, error)
+
+ // PopulatePollVote ensures the given PollVote is fully populated with all other related database models.
+ PopulatePollVote(ctx context.Context, votes *gtsmodel.PollVote) error
+
+ // PutPollVote puts the given PollVote in the database.
+ PutPollVote(ctx context.Context, vote *gtsmodel.PollVote) error
+
+ // DeletePollVotes deletes all PollVotes in Poll with given ID from the database.
+ DeletePollVotes(ctx context.Context, pollID string) error
+
+ // DeletePollVoteBy deletes the PollVote in Poll with ID, by account ID, from the database.
+ DeletePollVoteBy(ctx context.Context, pollID string, accountID string) error
+
+ // DeletePollVotesByAccountID deletes all PollVotes in all Polls, by account ID, from the database.
+ DeletePollVotesByAccountID(ctx context.Context, accountID string) error
+}