summaryrefslogtreecommitdiff
path: root/internal/db
diff options
context:
space:
mode:
Diffstat (limited to 'internal/db')
-rw-r--r--internal/db/bundb/bundb.go5
-rw-r--r--internal/db/bundb/bundb_test.go6
-rw-r--r--internal/db/bundb/list.go467
-rw-r--r--internal/db/bundb/list_test.go315
-rw-r--r--internal/db/bundb/migrations/20230515173919_lists.go92
-rw-r--r--internal/db/bundb/relationship_follow.go101
-rw-r--r--internal/db/bundb/relationship_test.go15
-rw-r--r--internal/db/bundb/timeline.go129
-rw-r--r--internal/db/bundb/timeline_test.go253
-rw-r--r--internal/db/db.go1
-rw-r--r--internal/db/list.go67
-rw-r--r--internal/db/relationship.go3
-rw-r--r--internal/db/timeline.go4
13 files changed, 1347 insertions, 111 deletions
diff --git a/internal/db/bundb/bundb.go b/internal/db/bundb/bundb.go
index f095d1728..f0329e898 100644
--- a/internal/db/bundb/bundb.go
+++ b/internal/db/bundb/bundb.go
@@ -65,6 +65,7 @@ type DBService struct {
db.Domain
db.Emoji
db.Instance
+ db.List
db.Media
db.Mention
db.Notification
@@ -179,6 +180,10 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) {
Instance: &instanceDB{
conn: conn,
},
+ List: &listDB{
+ conn: conn,
+ state: state,
+ },
Media: &mediaDB{
conn: conn,
state: state,
diff --git a/internal/db/bundb/bundb_test.go b/internal/db/bundb/bundb_test.go
index 2566be2ba..84e11447a 100644
--- a/internal/db/bundb/bundb_test.go
+++ b/internal/db/bundb/bundb_test.go
@@ -22,6 +22,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/state"
+ "github.com/superseriousbusiness/gotosocial/internal/visibility"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@@ -46,6 +47,8 @@ type BunDBStandardTestSuite struct {
testReports map[string]*gtsmodel.Report
testBookmarks map[string]*gtsmodel.StatusBookmark
testFaves map[string]*gtsmodel.StatusFave
+ testLists map[string]*gtsmodel.List
+ testListEntries map[string]*gtsmodel.ListEntry
}
func (suite *BunDBStandardTestSuite) SetupSuite() {
@@ -63,6 +66,8 @@ func (suite *BunDBStandardTestSuite) SetupSuite() {
suite.testReports = testrig.NewTestReports()
suite.testBookmarks = testrig.NewTestBookmarks()
suite.testFaves = testrig.NewTestFaves()
+ suite.testLists = testrig.NewTestLists()
+ suite.testListEntries = testrig.NewTestListEntries()
}
func (suite *BunDBStandardTestSuite) SetupTest() {
@@ -70,6 +75,7 @@ func (suite *BunDBStandardTestSuite) SetupTest() {
testrig.InitTestLog()
suite.state.Caches.Init()
suite.db = testrig.NewTestDB(&suite.state)
+ testrig.StartTimelines(&suite.state, visibility.NewFilter(&suite.state), testrig.NewTestTypeConverter(suite.db))
testrig.StandardDBSetup(suite.db, suite.testAccounts)
}
diff --git a/internal/db/bundb/list.go b/internal/db/bundb/list.go
new file mode 100644
index 000000000..38701cc07
--- /dev/null
+++ b/internal/db/bundb/list.go
@@ -0,0 +1,467 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+package bundb
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtscontext"
+ "github.com/superseriousbusiness/gotosocial/internal/gtserror"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/log"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
+ "github.com/uptrace/bun"
+)
+
+type listDB struct {
+ conn *DBConn
+ state *state.State
+}
+
+/*
+ LIST FUNCTIONS
+*/
+
+func (l *listDB) getList(ctx context.Context, lookup string, dbQuery func(*gtsmodel.List) error, keyParts ...any) (*gtsmodel.List, error) {
+ list, err := l.state.Caches.GTS.List().Load(lookup, func() (*gtsmodel.List, error) {
+ var list gtsmodel.List
+
+ // Not cached! Perform database query.
+ if err := dbQuery(&list); err != nil {
+ return nil, l.conn.ProcessError(err)
+ }
+
+ return &list, nil
+ }, keyParts...)
+ if err != nil {
+ return nil, err // already processed
+ }
+
+ if gtscontext.Barebones(ctx) {
+ // Only a barebones model was requested.
+ return list, nil
+ }
+
+ if err := l.state.DB.PopulateList(ctx, list); err != nil {
+ return nil, err
+ }
+
+ return list, nil
+}
+
+func (l *listDB) GetListByID(ctx context.Context, id string) (*gtsmodel.List, error) {
+ return l.getList(
+ ctx,
+ "ID",
+ func(list *gtsmodel.List) error {
+ return l.conn.NewSelect().
+ Model(list).
+ Where("? = ?", bun.Ident("list.id"), id).
+ Scan(ctx)
+ },
+ id,
+ )
+}
+
+func (l *listDB) GetListsForAccountID(ctx context.Context, accountID string) ([]*gtsmodel.List, error) {
+ // Fetch IDs of all lists owned by this account.
+ var listIDs []string
+ if err := l.conn.
+ NewSelect().
+ TableExpr("? AS ?", bun.Ident("lists"), bun.Ident("list")).
+ Column("list.id").
+ Where("? = ?", bun.Ident("list.account_id"), accountID).
+ Order("list.id DESC").
+ Scan(ctx, &listIDs); err != nil {
+ return nil, l.conn.ProcessError(err)
+ }
+
+ if len(listIDs) == 0 {
+ return nil, nil
+ }
+
+ // Select each list using its ID to ensure cache used.
+ lists := make([]*gtsmodel.List, 0, len(listIDs))
+ for _, id := range listIDs {
+ list, err := l.state.DB.GetListByID(ctx, id)
+ if err != nil {
+ log.Errorf(ctx, "error fetching list %q: %v", id, err)
+ continue
+ }
+
+ // Append list.
+ lists = append(lists, list)
+ }
+
+ return lists, nil
+}
+
+func (l *listDB) PopulateList(ctx context.Context, list *gtsmodel.List) error {
+ var (
+ err error
+ errs = make(gtserror.MultiError, 0, 2)
+ )
+
+ if list.Account == nil {
+ // List account is not set, fetch from the database.
+ list.Account, err = l.state.DB.GetAccountByID(
+ gtscontext.SetBarebones(ctx),
+ list.AccountID,
+ )
+ if err != nil {
+ errs.Append(fmt.Errorf("error populating list account: %w", err))
+ }
+ }
+
+ if list.ListEntries == nil {
+ // List entries are not set, fetch from the database.
+ list.ListEntries, err = l.state.DB.GetListEntries(
+ gtscontext.SetBarebones(ctx),
+ list.ID,
+ "", "", "", 0,
+ )
+ if err != nil {
+ errs.Append(fmt.Errorf("error populating list entries: %w", err))
+ }
+ }
+
+ return errs.Combine()
+}
+
+func (l *listDB) PutList(ctx context.Context, list *gtsmodel.List) error {
+ return l.state.Caches.GTS.List().Store(list, func() error {
+ _, err := l.conn.NewInsert().Model(list).Exec(ctx)
+ return l.conn.ProcessError(err)
+ })
+}
+
+func (l *listDB) UpdateList(ctx context.Context, list *gtsmodel.List, columns ...string) error {
+ list.UpdatedAt = time.Now()
+ if len(columns) > 0 {
+ // If we're updating by column, ensure "updated_at" is included.
+ columns = append(columns, "updated_at")
+ }
+
+ return l.state.Caches.GTS.List().Store(list, func() error {
+ if _, err := l.conn.NewUpdate().
+ Model(list).
+ Where("? = ?", bun.Ident("list.id"), list.ID).
+ Column(columns...).
+ Exec(ctx); err != nil {
+ return l.conn.ProcessError(err)
+ }
+
+ return nil
+ })
+}
+
+func (l *listDB) DeleteListByID(ctx context.Context, id string) error {
+ defer l.state.Caches.GTS.List().Invalidate("ID", id)
+
+ // Select all entries that belong to this list.
+ listEntries, err := l.state.DB.GetListEntries(ctx, id, "", "", "", 0)
+ if err != nil {
+ return fmt.Errorf("error selecting entries from list %q: %w", id, err)
+ }
+
+ // Delete each list entry. This will
+ // invalidate the list timeline too.
+ for _, listEntry := range listEntries {
+ err := l.state.DB.DeleteListEntry(ctx, listEntry.ID)
+ if err != nil && !errors.Is(err, db.ErrNoEntries) {
+ return err
+ }
+ }
+
+ // Finally delete list itself from DB.
+ _, err = l.conn.NewDelete().
+ Table("lists").
+ Where("? = ?", bun.Ident("id"), id).
+ Exec(ctx)
+ return l.conn.ProcessError(err)
+}
+
+/*
+ LIST ENTRY functions
+*/
+
+func (l *listDB) getListEntry(ctx context.Context, lookup string, dbQuery func(*gtsmodel.ListEntry) error, keyParts ...any) (*gtsmodel.ListEntry, error) {
+ listEntry, err := l.state.Caches.GTS.ListEntry().Load(lookup, func() (*gtsmodel.ListEntry, error) {
+ var listEntry gtsmodel.ListEntry
+
+ // Not cached! Perform database query.
+ if err := dbQuery(&listEntry); err != nil {
+ return nil, l.conn.ProcessError(err)
+ }
+
+ return &listEntry, nil
+ }, keyParts...)
+ if err != nil {
+ return nil, err // already processed
+ }
+
+ if gtscontext.Barebones(ctx) {
+ // Only a barebones model was requested.
+ return listEntry, nil
+ }
+
+ // Further populate the list entry fields where applicable.
+ if err := l.state.DB.PopulateListEntry(ctx, listEntry); err != nil {
+ return nil, err
+ }
+
+ return listEntry, nil
+}
+
+func (l *listDB) GetListEntryByID(ctx context.Context, id string) (*gtsmodel.ListEntry, error) {
+ return l.getListEntry(
+ ctx,
+ "ID",
+ func(listEntry *gtsmodel.ListEntry) error {
+ return l.conn.NewSelect().
+ Model(listEntry).
+ Where("? = ?", bun.Ident("list_entry.id"), id).
+ Scan(ctx)
+ },
+ id,
+ )
+}
+
+func (l *listDB) GetListEntries(ctx context.Context,
+ listID string,
+ maxID string,
+ sinceID string,
+ minID string,
+ limit int,
+) ([]*gtsmodel.ListEntry, error) {
+ // Ensure reasonable
+ if limit < 0 {
+ limit = 0
+ }
+
+ // Make educated guess for slice size
+ var (
+ entryIDs = make([]string, 0, limit)
+ frontToBack = true
+ )
+
+ q := l.conn.
+ NewSelect().
+ TableExpr("? AS ?", bun.Ident("list_entries"), bun.Ident("entry")).
+ // Select only IDs from table
+ Column("entry.id").
+ // Select only entries belonging to listID.
+ Where("? = ?", bun.Ident("entry.list_id"), listID)
+
+ if maxID != "" {
+ // return only entries LOWER (ie., older) than maxID
+ q = q.Where("? < ?", bun.Ident("entry.id"), maxID)
+ }
+
+ if sinceID != "" {
+ // return only entries HIGHER (ie., newer) than sinceID
+ q = q.Where("? > ?", bun.Ident("entry.id"), sinceID)
+ }
+
+ if minID != "" {
+ // return only entries HIGHER (ie., newer) than minID
+ q = q.Where("? > ?", bun.Ident("entry.id"), minID)
+
+ // page up
+ frontToBack = false
+ }
+
+ if limit > 0 {
+ // limit amount of entries returned
+ q = q.Limit(limit)
+ }
+
+ if frontToBack {
+ // Page down.
+ q = q.Order("entry.id DESC")
+ } else {
+ // Page up.
+ q = q.Order("entry.id ASC")
+ }
+
+ if err := q.Scan(ctx, &entryIDs); err != nil {
+ return nil, l.conn.ProcessError(err)
+ }
+
+ if len(entryIDs) == 0 {
+ return nil, nil
+ }
+
+ // If we're paging up, we still want entries
+ // to be sorted by ID desc, so reverse ids slice.
+ // https://zchee.github.io/golang-wiki/SliceTricks/#reversing
+ if !frontToBack {
+ for l, r := 0, len(entryIDs)-1; l < r; l, r = l+1, r-1 {
+ entryIDs[l], entryIDs[r] = entryIDs[r], entryIDs[l]
+ }
+ }
+
+ // Select each list entry using its ID to ensure cache used.
+ listEntries := make([]*gtsmodel.ListEntry, 0, len(entryIDs))
+ for _, id := range entryIDs {
+ listEntry, err := l.state.DB.GetListEntryByID(ctx, id)
+ if err != nil {
+ log.Errorf(ctx, "error fetching list entry %q: %v", id, err)
+ continue
+ }
+
+ // Append list entries.
+ listEntries = append(listEntries, listEntry)
+ }
+
+ return listEntries, nil
+}
+
+func (l *listDB) GetListEntriesForFollowID(ctx context.Context, followID string) ([]*gtsmodel.ListEntry, error) {
+ entryIDs := []string{}
+
+ if err := l.conn.
+ NewSelect().
+ TableExpr("? AS ?", bun.Ident("list_entries"), bun.Ident("entry")).
+ // Select only IDs from table
+ Column("entry.id").
+ // Select only entries belonging with given followID.
+ Where("? = ?", bun.Ident("entry.follow_id"), followID).
+ Scan(ctx, &entryIDs); err != nil {
+ return nil, l.conn.ProcessError(err)
+ }
+
+ if len(entryIDs) == 0 {
+ return nil, nil
+ }
+
+ // Select each list entry using its ID to ensure cache used.
+ listEntries := make([]*gtsmodel.ListEntry, 0, len(entryIDs))
+ for _, id := range entryIDs {
+ listEntry, err := l.state.DB.GetListEntryByID(ctx, id)
+ if err != nil {
+ log.Errorf(ctx, "error fetching list entry %q: %v", id, err)
+ continue
+ }
+
+ // Append list entries.
+ listEntries = append(listEntries, listEntry)
+ }
+
+ return listEntries, nil
+}
+
+func (l *listDB) PopulateListEntry(ctx context.Context, listEntry *gtsmodel.ListEntry) error {
+ var err error
+
+ if listEntry.Follow == nil {
+ // ListEntry follow is not set, fetch from the database.
+ listEntry.Follow, err = l.state.DB.GetFollowByID(
+ gtscontext.SetBarebones(ctx),
+ listEntry.FollowID,
+ )
+ if err != nil {
+ return fmt.Errorf("error populating listEntry follow: %w", err)
+ }
+ }
+
+ return nil
+}
+
+func (l *listDB) PutListEntries(ctx context.Context, listEntries []*gtsmodel.ListEntry) error {
+ return l.conn.RunInTx(ctx, func(tx bun.Tx) error {
+ for _, listEntry := range listEntries {
+ if _, err := tx.
+ NewInsert().
+ Model(listEntry).
+ Exec(ctx); err != nil {
+ return err
+ }
+
+ // Invalidate the timeline for the list this entry belongs to.
+ if err := l.state.Timelines.List.RemoveTimeline(ctx, listEntry.ListID); err != nil {
+ log.Errorf(ctx, "PutListEntries: error invalidating list timeline: %q", err)
+ }
+ }
+
+ return nil
+ })
+}
+
+func (l *listDB) DeleteListEntry(ctx context.Context, id string) error {
+ defer l.state.Caches.GTS.ListEntry().Invalidate("ID", id)
+
+ // Load list entry into cache before attempting a delete,
+ // as we need the followID from it in order to trigger
+ // timeline invalidation.
+ listEntry, err := l.GetListEntryByID(
+ // Don't populate the entry;
+ // we only want the list ID.
+ gtscontext.SetBarebones(ctx),
+ id,
+ )
+ if err != nil {
+ if errors.Is(err, db.ErrNoEntries) {
+ // Already gone.
+ return nil
+ }
+ return err
+ }
+
+ defer func() {
+ // Invalidate the timeline for the list this entry belongs to.
+ if err := l.state.Timelines.List.RemoveTimeline(ctx, listEntry.ListID); err != nil {
+ log.Errorf(ctx, "DeleteListEntry: error invalidating list timeline: %q", err)
+ }
+ }()
+
+ if _, err := l.conn.NewDelete().
+ Table("list_entries").
+ Where("? = ?", bun.Ident("id"), listEntry.ID).
+ Exec(ctx); err != nil {
+ return l.conn.ProcessError(err)
+ }
+
+ return nil
+}
+
+func (l *listDB) DeleteListEntriesForFollowID(ctx context.Context, followID string) error {
+ // Fetch IDs of all entries that pertain to this follow.
+ var listEntryIDs []string
+ if err := l.conn.
+ NewSelect().
+ TableExpr("? AS ?", bun.Ident("list_entries"), bun.Ident("list_entry")).
+ Column("list_entry.id").
+ Where("? = ?", bun.Ident("list_entry.follow_id"), followID).
+ Order("list_entry.id DESC").
+ Scan(ctx, &listEntryIDs); err != nil {
+ return l.conn.ProcessError(err)
+ }
+
+ for _, id := range listEntryIDs {
+ if err := l.DeleteListEntry(ctx, id); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
diff --git a/internal/db/bundb/list_test.go b/internal/db/bundb/list_test.go
new file mode 100644
index 000000000..296ab7c1a
--- /dev/null
+++ b/internal/db/bundb/list_test.go
@@ -0,0 +1,315 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+package bundb_test
+
+import (
+ "context"
+ "testing"
+
+ "github.com/stretchr/testify/suite"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtscontext"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "golang.org/x/exp/slices"
+)
+
+type ListTestSuite struct {
+ BunDBStandardTestSuite
+}
+
+func (suite *ListTestSuite) testStructs() (*gtsmodel.List, *gtsmodel.Account) {
+ testList := &gtsmodel.List{}
+ *testList = *suite.testLists["local_account_1_list_1"]
+
+ // Populate entries on this list as we'd expect them back from the db.
+ entries := make([]*gtsmodel.ListEntry, 0, len(suite.testListEntries))
+ for _, entry := range suite.testListEntries {
+ entries = append(entries, entry)
+ }
+
+ // Sort by ID descending (again, as we'd expect from the db).
+ slices.SortFunc(entries, func(a, b *gtsmodel.ListEntry) bool {
+ return b.ID < a.ID
+ })
+
+ testList.ListEntries = entries
+
+ testAccount := &gtsmodel.Account{}
+ *testAccount = *suite.testAccounts["local_account_1"]
+
+ return testList, testAccount
+}
+
+func (suite *ListTestSuite) checkList(expected *gtsmodel.List, actual *gtsmodel.List) {
+ suite.Equal(expected.ID, actual.ID)
+ suite.Equal(expected.Title, actual.Title)
+ suite.Equal(expected.AccountID, actual.AccountID)
+ suite.Equal(expected.RepliesPolicy, actual.RepliesPolicy)
+ suite.NotNil(actual.Account)
+}
+
+func (suite *ListTestSuite) checkListEntry(expected *gtsmodel.ListEntry, actual *gtsmodel.ListEntry) {
+ suite.Equal(expected.ID, actual.ID)
+ suite.Equal(expected.ListID, actual.ListID)
+ suite.Equal(expected.FollowID, actual.FollowID)
+}
+
+func (suite *ListTestSuite) checkListEntries(expected []*gtsmodel.ListEntry, actual []*gtsmodel.ListEntry) {
+ var (
+ lExpected = len(expected)
+ lActual = len(actual)
+ )
+
+ if lExpected != lActual {
+ suite.FailNow("", "expected %d list entries, got %d", lExpected, lActual)
+ }
+
+ var topID string
+ for i, expectedEntry := range expected {
+ actualEntry := actual[i]
+
+ // Ensure ID descending.
+ if topID == "" {
+ topID = actualEntry.ID
+ } else {
+ suite.Less(actualEntry.ID, topID)
+ }
+
+ suite.checkListEntry(expectedEntry, actualEntry)
+ }
+}
+
+func (suite *ListTestSuite) TestGetListByID() {
+ testList, _ := suite.testStructs()
+
+ dbList, err := suite.db.GetListByID(context.Background(), testList.ID)
+ if err != nil {
+ suite.FailNow(err.Error())
+ }
+
+ suite.checkList(testList, dbList)
+ suite.checkListEntries(testList.ListEntries, dbList.ListEntries)
+}
+
+func (suite *ListTestSuite) TestGetListsForAccountID() {
+ testList, testAccount := suite.testStructs()
+
+ dbLists, err := suite.db.GetListsForAccountID(context.Background(), testAccount.ID)
+ if err != nil {
+ suite.FailNow(err.Error())
+ }
+
+ if l := len(dbLists); l != 1 {
+ suite.FailNow("", "expected %d lists, got %d", 1, l)
+ }
+
+ suite.checkList(testList, dbLists[0])
+}
+
+func (suite *ListTestSuite) TestGetListEntries() {
+ testList, _ := suite.testStructs()
+
+ dbListEntries, err := suite.db.GetListEntries(context.Background(), testList.ID, "", "", "", 0)
+ if err != nil {
+ suite.FailNow(err.Error())
+ }
+
+ suite.checkListEntries(testList.ListEntries, dbListEntries)
+}
+
+func (suite *ListTestSuite) TestPutList() {
+ ctx := context.Background()
+ _, testAccount := suite.testStructs()
+
+ testList := &gtsmodel.List{
+ ID: "01H0J2PMYM54618VCV8Y8QYAT4",
+ Title: "Test List!",
+ AccountID: testAccount.ID,
+ }
+
+ if err := suite.db.PutList(ctx, testList); err != nil {
+ suite.FailNow(err.Error())
+ }
+
+ dbList, err := suite.db.GetListByID(ctx, testList.ID)
+ if err != nil {
+ suite.FailNow(err.Error())
+ }
+
+ // Bodge testlist as though default had been set.
+ testList.RepliesPolicy = gtsmodel.RepliesPolicyFollowed
+ suite.checkList(testList, dbList)
+}
+
+func (suite *ListTestSuite) TestUpdateList() {
+ ctx := context.Background()
+ testList, _ := suite.testStructs()
+
+ // Get List in the cache first.
+ dbList, err := suite.db.GetListByID(ctx, testList.ID)
+ if err != nil {
+ suite.FailNow(err.Error())
+ }
+
+ // Now do the update.
+ testList.Title = "New Title!"
+ if err := suite.db.UpdateList(ctx, testList, "title"); err != nil {
+ suite.FailNow(err.Error())
+ }
+
+ // Cache should be invalidated
+ // + we should have updated list.
+ dbList, err = suite.db.GetListByID(ctx, testList.ID)
+ if err != nil {
+ suite.FailNow(err.Error())
+ }
+
+ suite.checkList(testList, dbList)
+}
+
+func (suite *ListTestSuite) TestDeleteList() {
+ ctx := context.Background()
+ testList, _ := suite.testStructs()
+
+ // Get List in the cache first.
+ if _, err := suite.db.GetListByID(ctx, testList.ID); err != nil {
+ suite.FailNow(err.Error())
+ }
+
+ // Now do the delete.
+ if err := suite.db.DeleteListByID(ctx, testList.ID); err != nil {
+ suite.FailNow(err.Error())
+ }
+
+ // Cache should be invalidated
+ // + we should have no list.
+ _, err := suite.db.GetListByID(ctx, testList.ID)
+ suite.ErrorIs(err, db.ErrNoEntries)
+
+ // All entries belonging to this
+ // list should now be deleted.
+ listEntries, err := suite.db.GetListEntries(ctx, testList.ID, "", "", "", 0)
+ if err != nil {
+ suite.FailNow(err.Error())
+ }
+ suite.Empty(listEntries)
+}
+
+func (suite *ListTestSuite) TestPutListEntries() {
+ ctx := context.Background()
+ testList, _ := suite.testStructs()
+
+ listEntries := []*gtsmodel.ListEntry{
+ {
+ ID: "01H0MKMQY69HWDSDR2SWGA17R4",
+ ListID: testList.ID,
+ FollowID: "01H0MKNFRFZS8R9WV6DBX31Y03", // random id, doesn't exist
+ },
+ {
+ ID: "01H0MKPGQF0E7QAVW5BKTHZ630",
+ ListID: testList.ID,
+ FollowID: "01H0MKP6RR8VEHN3GVWFBP2H30", // random id, doesn't exist
+ },
+ {
+ ID: "01H0MKPPP2DT68FRBMR1FJM32T",
+ ListID: testList.ID,
+ FollowID: "01H0MKQ0KA29C6NFJ27GTZD16J", // random id, doesn't exist
+ },
+ }
+
+ if err := suite.db.PutListEntries(ctx, listEntries); err != nil {
+ suite.FailNow(err.Error())
+ }
+
+ // Add these entries to the test list, sort it again
+ // to reflect what we'd expect to get from the db.
+ testList.ListEntries = append(testList.ListEntries, listEntries...)
+ slices.SortFunc(testList.ListEntries, func(a, b *gtsmodel.ListEntry) bool {
+ return b.ID < a.ID
+ })
+
+ // Now get all list entries from the db.
+ // Use barebones for this because the ones
+ // we just added will fail if we try to get
+ // the nonexistent follows.
+ dbListEntries, err := suite.db.GetListEntries(
+ gtscontext.SetBarebones(ctx),
+ testList.ID,
+ "", "", "", 0)
+ if err != nil {
+ suite.FailNow(err.Error())
+ }
+
+ suite.checkListEntries(testList.ListEntries, dbListEntries)
+}
+
+func (suite *ListTestSuite) TestDeleteListEntry() {
+ ctx := context.Background()
+ testList, _ := suite.testStructs()
+
+ // Get List in the cache first.
+ if _, err := suite.db.GetListByID(ctx, testList.ID); err != nil {
+ suite.FailNow(err.Error())
+ }
+
+ // Delete the first entry.
+ if err := suite.db.DeleteListEntry(ctx, testList.ListEntries[0].ID); err != nil {
+ suite.FailNow(err.Error())
+ }
+
+ // Get list from the db again.
+ dbList, err := suite.db.GetListByID(ctx, testList.ID)
+ if err != nil {
+ suite.FailNow(err.Error())
+ }
+
+ // Bodge the testlist as though
+ // we'd removed the first entry.
+ testList.ListEntries = testList.ListEntries[1:]
+ suite.checkList(testList, dbList)
+}
+
+func (suite *ListTestSuite) TestDeleteListEntriesForFollowID() {
+ ctx := context.Background()
+ testList, _ := suite.testStructs()
+
+ // Get List in the cache first.
+ if _, err := suite.db.GetListByID(ctx, testList.ID); err != nil {
+ suite.FailNow(err.Error())
+ }
+
+ // Delete the first entry.
+ if err := suite.db.DeleteListEntriesForFollowID(ctx, testList.ListEntries[0].FollowID); err != nil {
+ suite.FailNow(err.Error())
+ }
+
+ // Get list from the db again.
+ dbList, err := suite.db.GetListByID(ctx, testList.ID)
+ if err != nil {
+ suite.FailNow(err.Error())
+ }
+
+ // Bodge the testlist as though
+ // we'd removed the first entry.
+ testList.ListEntries = testList.ListEntries[1:]
+ suite.checkList(testList, dbList)
+}
+
+func TestListTestSuite(t *testing.T) {
+ suite.Run(t, new(ListTestSuite))
+}
diff --git a/internal/db/bundb/migrations/20230515173919_lists.go b/internal/db/bundb/migrations/20230515173919_lists.go
new file mode 100644
index 000000000..e0ea5c7b6
--- /dev/null
+++ b/internal/db/bundb/migrations/20230515173919_lists.go
@@ -0,0 +1,92 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+package migrations
+
+import (
+ "context"
+
+ gtsmodel "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/uptrace/bun"
+)
+
+func init() {
+ up := func(ctx context.Context, db *bun.DB) error {
+ return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
+ // List table.
+ if _, err := tx.
+ NewCreateTable().
+ Model(&gtsmodel.List{}).
+ IfNotExists().
+ Exec(ctx); err != nil {
+ return err
+ }
+
+ // Add indexes to the List table.
+ for index, columns := range map[string][]string{
+ "lists_id_idx": {"id"},
+ "lists_account_id_idx": {"account_id"},
+ } {
+ if _, err := tx.
+ NewCreateIndex().
+ Table("lists").
+ Index(index).
+ Column(columns...).
+ Exec(ctx); err != nil {
+ return err
+ }
+ }
+
+ // List entry table.
+ if _, err := tx.
+ NewCreateTable().
+ Model(&gtsmodel.ListEntry{}).
+ IfNotExists().
+ Exec(ctx); err != nil {
+ return err
+ }
+
+ // Add indexes to the List entry table.
+ for index, columns := range map[string][]string{
+ "list_entries_id_idx": {"id"},
+ "list_entries_list_id_idx": {"list_id"},
+ "list_entries_follow_id_idx": {"follow_id"},
+ } {
+ if _, err := tx.
+ NewCreateIndex().
+ Table("list_entries").
+ Index(index).
+ Column(columns...).
+ Exec(ctx); err != nil {
+ return err
+ }
+ }
+
+ return nil
+ })
+ }
+
+ down := func(ctx context.Context, db *bun.DB) error {
+ return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
+ return nil
+ })
+ }
+
+ if err := Migrations.Register(up, down); err != nil {
+ panic(err)
+ }
+}
diff --git a/internal/db/bundb/relationship_follow.go b/internal/db/bundb/relationship_follow.go
index fe1f26bf1..39b85075c 100644
--- a/internal/db/bundb/relationship_follow.go
+++ b/internal/db/bundb/relationship_follow.go
@@ -25,6 +25,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
+ "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/uptrace/bun"
@@ -149,25 +150,42 @@ func (r *relationshipDB) getFollow(ctx context.Context, lookup string, dbQuery f
return follow, nil
}
- // Set the follow source account
- follow.Account, err = r.state.DB.GetAccountByID(
- gtscontext.SetBarebones(ctx),
- follow.AccountID,
- )
- if err != nil {
- return nil, fmt.Errorf("error getting follow source account: %w", err)
+ if err := r.state.DB.PopulateFollow(ctx, follow); err != nil {
+ return nil, err
}
- // Set the follow target account
- follow.TargetAccount, err = r.state.DB.GetAccountByID(
- gtscontext.SetBarebones(ctx),
- follow.TargetAccountID,
+ return follow, nil
+}
+
+func (r *relationshipDB) PopulateFollow(ctx context.Context, follow *gtsmodel.Follow) error {
+ var (
+ err error
+ errs = make(gtserror.MultiError, 0, 2)
)
- if err != nil {
- return nil, fmt.Errorf("error getting follow target account: %w", err)
+
+ if follow.Account == nil {
+ // Follow account is not set, fetch from the database.
+ follow.Account, err = r.state.DB.GetAccountByID(
+ gtscontext.SetBarebones(ctx),
+ follow.AccountID,
+ )
+ if err != nil {
+ errs.Append(fmt.Errorf("error populating follow account: %w", err))
+ }
}
- return follow, nil
+ if follow.TargetAccount == nil {
+ // Follow target account is not set, fetch from the database.
+ follow.TargetAccount, err = r.state.DB.GetAccountByID(
+ gtscontext.SetBarebones(ctx),
+ follow.TargetAccountID,
+ )
+ if err != nil {
+ errs.Append(fmt.Errorf("error populating follow target account: %w", err))
+ }
+ }
+
+ return errs.Combine()
}
func (r *relationshipDB) PutFollow(ctx context.Context, follow *gtsmodel.Follow) error {
@@ -197,27 +215,40 @@ func (r *relationshipDB) UpdateFollow(ctx context.Context, follow *gtsmodel.Foll
})
}
+func (r *relationshipDB) deleteFollow(ctx context.Context, id string) error {
+ // Delete the follow itself using the given ID.
+ if _, err := r.conn.NewDelete().
+ Table("follows").
+ Where("? = ?", bun.Ident("id"), id).
+ Exec(ctx); err != nil {
+ return r.conn.ProcessError(err)
+ }
+
+ // Delete every list entry that used this followID.
+ if err := r.state.DB.DeleteListEntriesForFollowID(ctx, id); err != nil {
+ return fmt.Errorf("deleteFollow: error deleting list entries: %w", err)
+ }
+
+ return nil
+}
+
func (r *relationshipDB) DeleteFollowByID(ctx context.Context, id string) error {
defer r.state.Caches.GTS.Follow().Invalidate("ID", id)
// Load follow into cache before attempting a delete,
// as we need it cached in order to trigger the invalidate
// callback. This in turn invalidates others.
- _, err := r.GetFollowByID(gtscontext.SetBarebones(ctx), id)
+ follow, err := r.GetFollowByID(gtscontext.SetBarebones(ctx), id)
if err != nil {
if errors.Is(err, db.ErrNoEntries) {
- // not an issue.
- err = nil
+ // Already gone.
+ return nil
}
return err
}
// Finally delete follow from DB.
- _, err = r.conn.NewDelete().
- Table("follows").
- Where("? = ?", bun.Ident("id"), id).
- Exec(ctx)
- return r.conn.ProcessError(err)
+ return r.deleteFollow(ctx, follow.ID)
}
func (r *relationshipDB) DeleteFollowByURI(ctx context.Context, uri string) error {
@@ -226,21 +257,17 @@ func (r *relationshipDB) DeleteFollowByURI(ctx context.Context, uri string) erro
// Load follow into cache before attempting a delete,
// as we need it cached in order to trigger the invalidate
// callback. This in turn invalidates others.
- _, err := r.GetFollowByURI(gtscontext.SetBarebones(ctx), uri)
+ follow, err := r.GetFollowByURI(gtscontext.SetBarebones(ctx), uri)
if err != nil {
if errors.Is(err, db.ErrNoEntries) {
- // not an issue.
- err = nil
+ // Already gone.
+ return nil
}
return err
}
// Finally delete follow from DB.
- _, err = r.conn.NewDelete().
- Table("follows").
- Where("? = ?", bun.Ident("uri"), uri).
- Exec(ctx)
- return r.conn.ProcessError(err)
+ return r.deleteFollow(ctx, follow.ID)
}
func (r *relationshipDB) DeleteAccountFollows(ctx context.Context, accountID string) error {
@@ -272,16 +299,16 @@ func (r *relationshipDB) DeleteAccountFollows(ctx context.Context, accountID str
// but it is the only way we can ensure we invalidate all
// related caches correctly (e.g. visibility).
for _, id := range followIDs {
- _, err := r.GetFollowByID(ctx, id)
+ follow, err := r.GetFollowByID(ctx, id)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
return err
}
+
+ // Delete each follow from DB.
+ if err := r.deleteFollow(ctx, follow.ID); err != nil && !errors.Is(err, db.ErrNoEntries) {
+ return err
+ }
}
- // Finally delete all from DB.
- _, err := r.conn.NewDelete().
- Table("follows").
- Where("? IN (?)", bun.Ident("id"), bun.In(followIDs)).
- Exec(ctx)
- return r.conn.ProcessError(err)
+ return nil
}
diff --git a/internal/db/bundb/relationship_test.go b/internal/db/bundb/relationship_test.go
index 0e38d19fe..63fdb9632 100644
--- a/internal/db/bundb/relationship_test.go
+++ b/internal/db/bundb/relationship_test.go
@@ -807,16 +807,27 @@ func (suite *RelationshipTestSuite) TestUnfollowExisting() {
follow, err := suite.db.GetFollow(context.Background(), originAccount.ID, targetAccount.ID)
suite.NoError(err)
suite.NotNil(follow)
+ followID := follow.ID
- err = suite.db.DeleteFollowByID(context.Background(), follow.ID)
+ // We should have list entries for this follow.
+ listEntries, err := suite.db.GetListEntriesForFollowID(context.Background(), followID)
+ suite.NoError(err)
+ suite.NotEmpty(listEntries)
+
+ err = suite.db.DeleteFollowByID(context.Background(), followID)
suite.NoError(err)
follow, err = suite.db.GetFollow(context.Background(), originAccount.ID, targetAccount.ID)
suite.EqualError(err, db.ErrNoEntries.Error())
suite.Nil(follow)
+
+ // ListEntries pertaining to this follow should be deleted too.
+ listEntries, err = suite.db.GetListEntriesForFollowID(context.Background(), followID)
+ suite.NoError(err)
+ suite.Empty(listEntries)
}
-func (suite *RelationshipTestSuite) TestUnfollowNotExisting() {
+func (suite *RelationshipTestSuite) TestGetFollowNotExisting() {
originAccount := suite.testAccounts["local_account_1"]
targetAccountID := "01GTVD9N484CZ6AM90PGGNY7GQ"
diff --git a/internal/db/bundb/timeline.go b/internal/db/bundb/timeline.go
index 87e7751d2..d33840a7b 100644
--- a/internal/db/bundb/timeline.go
+++ b/internal/db/bundb/timeline.go
@@ -19,9 +19,11 @@ package bundb
import (
"context"
+ "fmt"
"time"
"github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/superseriousbusiness/gotosocial/internal/log"
@@ -281,3 +283,130 @@ func (t *timelineDB) GetFavedTimeline(ctx context.Context, accountID string, max
prevMinID := faves[0].ID
return statuses, nextMaxID, prevMinID, nil
}
+
+func (t *timelineDB) GetListTimeline(
+ ctx context.Context,
+ listID string,
+ maxID string,
+ sinceID string,
+ minID string,
+ limit int,
+) ([]*gtsmodel.Status, error) {
+ // Ensure reasonable
+ if limit < 0 {
+ limit = 0
+ }
+
+ // Make educated guess for slice size
+ var (
+ statusIDs = make([]string, 0, limit)
+ frontToBack = true
+ )
+
+ // Fetch all listEntries entries from the database.
+ listEntries, err := t.state.DB.GetListEntries(
+ // Don't need actual follows
+ // for this, just the IDs.
+ gtscontext.SetBarebones(ctx),
+ listID,
+ "", "", "", 0,
+ )
+ if err != nil {
+ return nil, fmt.Errorf("error getting entries for list %s: %w", listID, err)
+ }
+
+ // Extract just the IDs of each follow.
+ followIDs := make([]string, 0, len(listEntries))
+ for _, listEntry := range listEntries {
+ followIDs = append(followIDs, listEntry.FollowID)
+ }
+
+ // Select target account IDs from follows.
+ subQ := t.conn.
+ NewSelect().
+ TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")).
+ Column("follow.target_account_id").
+ Where("? IN (?)", bun.Ident("follow.id"), bun.In(followIDs))
+
+ // Select only status IDs created
+ // by one of the followed accounts.
+ q := t.conn.
+ NewSelect().
+ TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
+ // Select only IDs from table
+ Column("status.id").
+ Where("? IN (?)", bun.Ident("status.account_id"), subQ)
+
+ if maxID == "" || maxID >= id.Highest {
+ const future = 24 * time.Hour
+
+ var err error
+
+ // don't return statuses more than 24hr in the future
+ maxID, err = id.NewULIDFromTime(time.Now().Add(future))
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ // return only statuses LOWER (ie., older) than maxID
+ q = q.Where("? < ?", bun.Ident("status.id"), maxID)
+
+ if sinceID != "" {
+ // return only statuses HIGHER (ie., newer) than sinceID
+ q = q.Where("? > ?", bun.Ident("status.id"), sinceID)
+ }
+
+ if minID != "" {
+ // return only statuses HIGHER (ie., newer) than minID
+ q = q.Where("? > ?", bun.Ident("status.id"), minID)
+
+ // page up
+ frontToBack = false
+ }
+
+ if limit > 0 {
+ // limit amount of statuses returned
+ q = q.Limit(limit)
+ }
+
+ if frontToBack {
+ // Page down.
+ q = q.Order("status.id DESC")
+ } else {
+ // Page up.
+ q = q.Order("status.id ASC")
+ }
+
+ if err := q.Scan(ctx, &statusIDs); err != nil {
+ return nil, t.conn.ProcessError(err)
+ }
+
+ if len(statusIDs) == 0 {
+ return nil, nil
+ }
+
+ // If we're paging up, we still want statuses
+ // to be sorted by ID desc, so reverse ids slice.
+ // https://zchee.github.io/golang-wiki/SliceTricks/#reversing
+ if !frontToBack {
+ for l, r := 0, len(statusIDs)-1; l < r; l, r = l+1, r-1 {
+ statusIDs[l], statusIDs[r] = statusIDs[r], statusIDs[l]
+ }
+ }
+
+ statuses := make([]*gtsmodel.Status, 0, len(statusIDs))
+ for _, id := range statusIDs {
+ // Fetch status from db for ID
+ status, err := t.state.DB.GetStatusByID(ctx, id)
+ if err != nil {
+ log.Errorf(ctx, "error fetching status %q: %v", id, err)
+ continue
+ }
+
+ // Append status to slice
+ statuses = append(statuses, status)
+ }
+
+ return statuses, nil
+}
diff --git a/internal/db/bundb/timeline_test.go b/internal/db/bundb/timeline_test.go
index f954c78dd..7e8fd0838 100644
--- a/internal/db/bundb/timeline_test.go
+++ b/internal/db/bundb/timeline_test.go
@@ -33,134 +33,243 @@ type TimelineTestSuite struct {
BunDBStandardTestSuite
}
-func (suite *TimelineTestSuite) TestGetPublicTimeline() {
- var count int
+func getFutureStatus() *gtsmodel.Status {
+ theDistantFuture := time.Now().Add(876600 * time.Hour)
+ id, err := id.NewULIDFromTime(theDistantFuture)
+ if err != nil {
+ panic(err)
+ }
+
+ return &gtsmodel.Status{
+ ID: id,
+ URI: "http://localhost:8080/users/admin/statuses/" + id,
+ URL: "http://localhost:8080/@admin/statuses/" + id,
+ Content: "it's the future, wooooooooooooooooooooooooooooooooo",
+ Text: "it's the future, wooooooooooooooooooooooooooooooooo",
+ AttachmentIDs: []string{},
+ TagIDs: []string{},
+ MentionIDs: []string{},
+ EmojiIDs: []string{},
+ CreatedAt: theDistantFuture,
+ UpdatedAt: theDistantFuture,
+ Local: testrig.TrueBool(),
+ AccountURI: "http://localhost:8080/users/admin",
+ AccountID: "01F8MH17FWEB39HZJ76B6VXSKF",
+ InReplyToID: "",
+ BoostOfID: "",
+ ContentWarning: "",
+ Visibility: gtsmodel.VisibilityPublic,
+ Sensitive: testrig.FalseBool(),
+ Language: "en",
+ CreatedWithApplicationID: "01F8MGXQRHYF5QPMTMXP78QC2F",
+ Federated: testrig.TrueBool(),
+ Boostable: testrig.TrueBool(),
+ Replyable: testrig.TrueBool(),
+ Likeable: testrig.TrueBool(),
+ ActivityStreamsType: ap.ObjectNote,
+ }
+}
+
+func (suite *TimelineTestSuite) publicCount() int {
+ var publicCount int
for _, status := range suite.testStatuses {
if status.Visibility == gtsmodel.VisibilityPublic &&
status.BoostOfID == "" {
- count++
+ publicCount++
}
}
- ctx := context.Background()
- s, err := suite.db.GetPublicTimeline(ctx, "", "", "", 20, false)
- suite.NoError(err)
-
- suite.Len(s, count)
+ return publicCount
}
-func (suite *TimelineTestSuite) TestGetPublicTimelineWithFutureStatus() {
- var count int
+func (suite *TimelineTestSuite) checkStatuses(statuses []*gtsmodel.Status, maxID string, minID string, expectedLength int) {
+ if l := len(statuses); l != expectedLength {
+ suite.FailNow("", "expected %d statuses in slice, got %d", expectedLength, l)
+ } else if l == 0 {
+ // Can't test empty slice.
+ return
+ }
- for _, status := range suite.testStatuses {
- if status.Visibility == gtsmodel.VisibilityPublic &&
- status.BoostOfID == "" {
- count++
+ // Check ordering + bounds of statuses.
+ highest := statuses[0].ID
+ for _, status := range statuses {
+ id := status.ID
+
+ if id >= maxID {
+ suite.FailNow("", "%s greater than maxID %s", id, maxID)
+ }
+
+ if id <= minID {
+ suite.FailNow("", "%s smaller than minID %s", id, minID)
+ }
+
+ if id > highest {
+ suite.FailNow("", "statuses in slice were not ordered highest -> lowest ID")
}
+
+ highest = id
}
+}
+func (suite *TimelineTestSuite) TestGetPublicTimeline() {
ctx := context.Background()
+ s, err := suite.db.GetPublicTimeline(ctx, "", "", "", 20, false)
+ if err != nil {
+ suite.FailNow(err.Error())
+ }
+
+ suite.checkStatuses(s, id.Highest, id.Lowest, suite.publicCount())
+}
+
+func (suite *TimelineTestSuite) TestGetPublicTimelineWithFutureStatus() {
+ ctx := context.Background()
+
+ // Insert a status set far in the
+ // future, it shouldn't be retrieved.
futureStatus := getFutureStatus()
- err := suite.db.PutStatus(ctx, futureStatus)
- suite.NoError(err)
+ if err := suite.db.PutStatus(ctx, futureStatus); err != nil {
+ suite.FailNow(err.Error())
+ }
s, err := suite.db.GetPublicTimeline(ctx, "", "", "", 20, false)
- suite.NoError(err)
+ if err != nil {
+ suite.FailNow(err.Error())
+ }
suite.NotContains(s, futureStatus)
- suite.Len(s, count)
+ suite.checkStatuses(s, id.Highest, id.Lowest, suite.publicCount())
}
func (suite *TimelineTestSuite) TestGetHomeTimeline() {
- ctx := context.Background()
-
- viewingAccount := suite.testAccounts["local_account_1"]
+ var (
+ ctx = context.Background()
+ viewingAccount = suite.testAccounts["local_account_1"]
+ )
s, err := suite.db.GetHomeTimeline(ctx, viewingAccount.ID, "", "", "", 20, false)
- suite.NoError(err)
+ if err != nil {
+ suite.FailNow(err.Error())
+ }
- suite.Len(s, 16)
+ suite.checkStatuses(s, id.Highest, id.Lowest, 16)
}
func (suite *TimelineTestSuite) TestGetHomeTimelineWithFutureStatus() {
- ctx := context.Background()
-
- viewingAccount := suite.testAccounts["local_account_1"]
+ var (
+ ctx = context.Background()
+ viewingAccount = suite.testAccounts["local_account_1"]
+ )
+ // Insert a status set far in the
+ // future, it shouldn't be retrieved.
futureStatus := getFutureStatus()
- err := suite.db.PutStatus(ctx, futureStatus)
- suite.NoError(err)
+ if err := suite.db.PutStatus(ctx, futureStatus); err != nil {
+ suite.FailNow(err.Error())
+ }
- s, err := suite.db.GetHomeTimeline(context.Background(), viewingAccount.ID, "", "", "", 20, false)
- suite.NoError(err)
+ s, err := suite.db.GetHomeTimeline(ctx, viewingAccount.ID, "", "", "", 20, false)
+ if err != nil {
+ suite.FailNow(err.Error())
+ }
suite.NotContains(s, futureStatus)
- suite.Len(s, 16)
+ suite.checkStatuses(s, id.Highest, id.Lowest, 16)
}
func (suite *TimelineTestSuite) TestGetHomeTimelineBackToFront() {
- ctx := context.Background()
-
- viewingAccount := suite.testAccounts["local_account_1"]
+ var (
+ ctx = context.Background()
+ viewingAccount = suite.testAccounts["local_account_1"]
+ )
s, err := suite.db.GetHomeTimeline(ctx, viewingAccount.ID, "", "", id.Lowest, 5, false)
- suite.NoError(err)
+ if err != nil {
+ suite.FailNow(err.Error())
+ }
- suite.Len(s, 5)
+ suite.checkStatuses(s, id.Highest, id.Lowest, 5)
suite.Equal("01F8MHAYFKS4KMXF8K5Y1C0KRN", s[0].ID)
suite.Equal("01F8MH75CBF9JFX4ZAD54N0W0R", s[len(s)-1].ID)
}
func (suite *TimelineTestSuite) TestGetHomeTimelineFromHighest() {
- ctx := context.Background()
-
- viewingAccount := suite.testAccounts["local_account_1"]
+ var (
+ ctx = context.Background()
+ viewingAccount = suite.testAccounts["local_account_1"]
+ )
s, err := suite.db.GetHomeTimeline(ctx, viewingAccount.ID, id.Highest, "", "", 5, false)
- suite.NoError(err)
+ if err != nil {
+ suite.FailNow(err.Error())
+ }
- suite.Len(s, 5)
+ suite.checkStatuses(s, id.Highest, id.Lowest, 5)
suite.Equal("01G36SF3V6Y6V5BF9P4R7PQG7G", s[0].ID)
suite.Equal("01FCTA44PW9H1TB328S9AQXKDS", s[len(s)-1].ID)
}
-func getFutureStatus() *gtsmodel.Status {
- theDistantFuture := time.Now().Add(876600 * time.Hour)
- id, err := id.NewULIDFromTime(theDistantFuture)
+func (suite *TimelineTestSuite) TestGetListTimelineNoParams() {
+ var (
+ ctx = context.Background()
+ list = suite.testLists["local_account_1_list_1"]
+ )
+
+ s, err := suite.db.GetListTimeline(ctx, list.ID, "", "", "", 20)
if err != nil {
- panic(err)
+ suite.FailNow(err.Error())
}
- return &gtsmodel.Status{
- ID: id,
- URI: "http://localhost:8080/users/admin/statuses/" + id,
- URL: "http://localhost:8080/@admin/statuses/" + id,
- Content: "it's the future, wooooooooooooooooooooooooooooooooo",
- Text: "it's the future, wooooooooooooooooooooooooooooooooo",
- AttachmentIDs: []string{},
- TagIDs: []string{},
- MentionIDs: []string{},
- EmojiIDs: []string{},
- CreatedAt: theDistantFuture,
- UpdatedAt: theDistantFuture,
- Local: testrig.TrueBool(),
- AccountURI: "http://localhost:8080/users/admin",
- AccountID: "01F8MH17FWEB39HZJ76B6VXSKF",
- InReplyToID: "",
- BoostOfID: "",
- ContentWarning: "",
- Visibility: gtsmodel.VisibilityPublic,
- Sensitive: testrig.FalseBool(),
- Language: "en",
- CreatedWithApplicationID: "01F8MGXQRHYF5QPMTMXP78QC2F",
- Federated: testrig.TrueBool(),
- Boostable: testrig.TrueBool(),
- Replyable: testrig.TrueBool(),
- Likeable: testrig.TrueBool(),
- ActivityStreamsType: ap.ObjectNote,
+ suite.checkStatuses(s, id.Highest, id.Lowest, 11)
+}
+
+func (suite *TimelineTestSuite) TestGetListTimelineMaxID() {
+ var (
+ ctx = context.Background()
+ list = suite.testLists["local_account_1_list_1"]
+ )
+
+ s, err := suite.db.GetListTimeline(ctx, list.ID, id.Highest, "", "", 5)
+ if err != nil {
+ suite.FailNow(err.Error())
+ }
+
+ suite.checkStatuses(s, id.Highest, id.Lowest, 5)
+ suite.Equal("01G36SF3V6Y6V5BF9P4R7PQG7G", s[0].ID)
+ suite.Equal("01FCQSQ667XHJ9AV9T27SJJSX5", s[len(s)-1].ID)
+}
+
+func (suite *TimelineTestSuite) TestGetListTimelineMinID() {
+ var (
+ ctx = context.Background()
+ list = suite.testLists["local_account_1_list_1"]
+ )
+
+ s, err := suite.db.GetListTimeline(ctx, list.ID, "", "", id.Lowest, 5)
+ if err != nil {
+ suite.FailNow(err.Error())
}
+
+ suite.checkStatuses(s, id.Highest, id.Lowest, 5)
+ suite.Equal("01F8MHC8VWDRBQR0N1BATDDEM5", s[0].ID)
+ suite.Equal("01F8MH75CBF9JFX4ZAD54N0W0R", s[len(s)-1].ID)
+}
+
+func (suite *TimelineTestSuite) TestGetListTimelineMinIDPagingUp() {
+ var (
+ ctx = context.Background()
+ list = suite.testLists["local_account_1_list_1"]
+ )
+
+ s, err := suite.db.GetListTimeline(ctx, list.ID, "", "", "01F8MHC8VWDRBQR0N1BATDDEM5", 5)
+ if err != nil {
+ suite.FailNow(err.Error())
+ }
+
+ suite.checkStatuses(s, id.Highest, "01F8MHC8VWDRBQR0N1BATDDEM5", 5)
+ suite.Equal("01G20ZM733MGN8J344T4ZDDFY1", s[0].ID)
+ suite.Equal("01F8MHCP5P2NWYQ416SBA0XSEV", s[len(s)-1].ID)
}
func TestTimelineTestSuite(t *testing.T) {
diff --git a/internal/db/db.go b/internal/db/db.go
index 7b25b3dae..f47a35bb3 100644
--- a/internal/db/db.go
+++ b/internal/db/db.go
@@ -36,6 +36,7 @@ type DB interface {
Domain
Emoji
Instance
+ List
Media
Mention
Notification
diff --git a/internal/db/list.go b/internal/db/list.go
new file mode 100644
index 000000000..4472589dc
--- /dev/null
+++ b/internal/db/list.go
@@ -0,0 +1,67 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+package db
+
+import (
+ "context"
+
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+)
+
+type List interface {
+ // GetListByID gets one list with the given id.
+ GetListByID(ctx context.Context, id string) (*gtsmodel.List, error)
+
+ // GetListsForAccountID gets all lists owned by the given accountID.
+ GetListsForAccountID(ctx context.Context, accountID string) ([]*gtsmodel.List, error)
+
+ // PopulateList ensures that the list's struct fields are populated.
+ PopulateList(ctx context.Context, list *gtsmodel.List) error
+
+ // PutList puts a new list in the database.
+ PutList(ctx context.Context, list *gtsmodel.List) error
+
+ // UpdateList updates the given list.
+ // Columns is optional, if not specified all will be updated.
+ UpdateList(ctx context.Context, list *gtsmodel.List, columns ...string) error
+
+ // DeleteListByID deletes one list with the given ID.
+ DeleteListByID(ctx context.Context, id string) error
+
+ // GetListEntryByID gets one list entry with the given ID.
+ GetListEntryByID(ctx context.Context, id string) (*gtsmodel.ListEntry, error)
+
+ // GetListEntries gets list entries from the given listID, using the given parameters.
+ GetListEntries(ctx context.Context, listID string, maxID string, sinceID string, minID string, limit int) ([]*gtsmodel.ListEntry, error)
+
+ // GetListEntriesForFollowID returns all listEntries that pertain to the given followID.
+ GetListEntriesForFollowID(ctx context.Context, followID string) ([]*gtsmodel.ListEntry, error)
+
+ // PopulateListEntry ensures that the listEntry's struct fields are populated.
+ PopulateListEntry(ctx context.Context, listEntry *gtsmodel.ListEntry) error
+
+ // PutListEntries inserts a slice of listEntries into the database.
+ // It uses a transaction to ensure no partial updates.
+ PutListEntries(ctx context.Context, listEntries []*gtsmodel.ListEntry) error
+
+ // DeleteListEntry deletes one list entry with the given id.
+ DeleteListEntry(ctx context.Context, id string) error
+
+ // DeleteListEntryForFollowID deletes all list entries with the given followID.
+ DeleteListEntriesForFollowID(ctx context.Context, followID string) error
+}
diff --git a/internal/db/relationship.go b/internal/db/relationship.go
index ae879b5d2..99093591c 100644
--- a/internal/db/relationship.go
+++ b/internal/db/relationship.go
@@ -64,6 +64,9 @@ type Relationship interface {
// GetFollow retrieves a follow if it exists between source and target accounts.
GetFollow(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.Follow, error)
+ // PopulateFollow populates the struct pointers on the given follow.
+ PopulateFollow(ctx context.Context, follow *gtsmodel.Follow) error
+
// GetFollowRequestByID fetches follow request with given ID from the database.
GetFollowRequestByID(ctx context.Context, id string) (*gtsmodel.FollowRequest, error)
diff --git a/internal/db/timeline.go b/internal/db/timeline.go
index 10149cc09..2635bece2 100644
--- a/internal/db/timeline.go
+++ b/internal/db/timeline.go
@@ -44,4 +44,8 @@ type Timeline interface {
//
// Also note the extra return values, which correspond to the nextMaxID and prevMinID for building Link headers.
GetFavedTimeline(ctx context.Context, accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, Error)
+
+ // GetListTimeline returns a slice of statuses from followed accounts collected within the list with the given listID.
+ // Statuses should be returned in descending order of when they were created (newest first).
+ GetListTimeline(ctx context.Context, listID string, maxID string, sinceID string, minID string, limit int) ([]*gtsmodel.Status, error)
}