diff options
Diffstat (limited to 'internal/db')
-rw-r--r-- | internal/db/bundb/bundb.go | 5 | ||||
-rw-r--r-- | internal/db/bundb/bundb_test.go | 6 | ||||
-rw-r--r-- | internal/db/bundb/list.go | 467 | ||||
-rw-r--r-- | internal/db/bundb/list_test.go | 315 | ||||
-rw-r--r-- | internal/db/bundb/migrations/20230515173919_lists.go | 92 | ||||
-rw-r--r-- | internal/db/bundb/relationship_follow.go | 101 | ||||
-rw-r--r-- | internal/db/bundb/relationship_test.go | 15 | ||||
-rw-r--r-- | internal/db/bundb/timeline.go | 129 | ||||
-rw-r--r-- | internal/db/bundb/timeline_test.go | 253 | ||||
-rw-r--r-- | internal/db/db.go | 1 | ||||
-rw-r--r-- | internal/db/list.go | 67 | ||||
-rw-r--r-- | internal/db/relationship.go | 3 | ||||
-rw-r--r-- | internal/db/timeline.go | 4 |
13 files changed, 1347 insertions, 111 deletions
diff --git a/internal/db/bundb/bundb.go b/internal/db/bundb/bundb.go index f095d1728..f0329e898 100644 --- a/internal/db/bundb/bundb.go +++ b/internal/db/bundb/bundb.go @@ -65,6 +65,7 @@ type DBService struct { db.Domain db.Emoji db.Instance + db.List db.Media db.Mention db.Notification @@ -179,6 +180,10 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) { Instance: &instanceDB{ conn: conn, }, + List: &listDB{ + conn: conn, + state: state, + }, Media: &mediaDB{ conn: conn, state: state, diff --git a/internal/db/bundb/bundb_test.go b/internal/db/bundb/bundb_test.go index 2566be2ba..84e11447a 100644 --- a/internal/db/bundb/bundb_test.go +++ b/internal/db/bundb/bundb_test.go @@ -22,6 +22,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/state" + "github.com/superseriousbusiness/gotosocial/internal/visibility" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -46,6 +47,8 @@ type BunDBStandardTestSuite struct { testReports map[string]*gtsmodel.Report testBookmarks map[string]*gtsmodel.StatusBookmark testFaves map[string]*gtsmodel.StatusFave + testLists map[string]*gtsmodel.List + testListEntries map[string]*gtsmodel.ListEntry } func (suite *BunDBStandardTestSuite) SetupSuite() { @@ -63,6 +66,8 @@ func (suite *BunDBStandardTestSuite) SetupSuite() { suite.testReports = testrig.NewTestReports() suite.testBookmarks = testrig.NewTestBookmarks() suite.testFaves = testrig.NewTestFaves() + suite.testLists = testrig.NewTestLists() + suite.testListEntries = testrig.NewTestListEntries() } func (suite *BunDBStandardTestSuite) SetupTest() { @@ -70,6 +75,7 @@ func (suite *BunDBStandardTestSuite) SetupTest() { testrig.InitTestLog() suite.state.Caches.Init() suite.db = testrig.NewTestDB(&suite.state) + testrig.StartTimelines(&suite.state, visibility.NewFilter(&suite.state), testrig.NewTestTypeConverter(suite.db)) testrig.StandardDBSetup(suite.db, suite.testAccounts) } diff --git a/internal/db/bundb/list.go b/internal/db/bundb/list.go new file mode 100644 index 000000000..38701cc07 --- /dev/null +++ b/internal/db/bundb/list.go @@ -0,0 +1,467 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package bundb + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtscontext" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/log" + "github.com/superseriousbusiness/gotosocial/internal/state" + "github.com/uptrace/bun" +) + +type listDB struct { + conn *DBConn + state *state.State +} + +/* + LIST FUNCTIONS +*/ + +func (l *listDB) getList(ctx context.Context, lookup string, dbQuery func(*gtsmodel.List) error, keyParts ...any) (*gtsmodel.List, error) { + list, err := l.state.Caches.GTS.List().Load(lookup, func() (*gtsmodel.List, error) { + var list gtsmodel.List + + // Not cached! Perform database query. + if err := dbQuery(&list); err != nil { + return nil, l.conn.ProcessError(err) + } + + return &list, nil + }, keyParts...) + if err != nil { + return nil, err // already processed + } + + if gtscontext.Barebones(ctx) { + // Only a barebones model was requested. + return list, nil + } + + if err := l.state.DB.PopulateList(ctx, list); err != nil { + return nil, err + } + + return list, nil +} + +func (l *listDB) GetListByID(ctx context.Context, id string) (*gtsmodel.List, error) { + return l.getList( + ctx, + "ID", + func(list *gtsmodel.List) error { + return l.conn.NewSelect(). + Model(list). + Where("? = ?", bun.Ident("list.id"), id). + Scan(ctx) + }, + id, + ) +} + +func (l *listDB) GetListsForAccountID(ctx context.Context, accountID string) ([]*gtsmodel.List, error) { + // Fetch IDs of all lists owned by this account. + var listIDs []string + if err := l.conn. + NewSelect(). + TableExpr("? AS ?", bun.Ident("lists"), bun.Ident("list")). + Column("list.id"). + Where("? = ?", bun.Ident("list.account_id"), accountID). + Order("list.id DESC"). + Scan(ctx, &listIDs); err != nil { + return nil, l.conn.ProcessError(err) + } + + if len(listIDs) == 0 { + return nil, nil + } + + // Select each list using its ID to ensure cache used. + lists := make([]*gtsmodel.List, 0, len(listIDs)) + for _, id := range listIDs { + list, err := l.state.DB.GetListByID(ctx, id) + if err != nil { + log.Errorf(ctx, "error fetching list %q: %v", id, err) + continue + } + + // Append list. + lists = append(lists, list) + } + + return lists, nil +} + +func (l *listDB) PopulateList(ctx context.Context, list *gtsmodel.List) error { + var ( + err error + errs = make(gtserror.MultiError, 0, 2) + ) + + if list.Account == nil { + // List account is not set, fetch from the database. + list.Account, err = l.state.DB.GetAccountByID( + gtscontext.SetBarebones(ctx), + list.AccountID, + ) + if err != nil { + errs.Append(fmt.Errorf("error populating list account: %w", err)) + } + } + + if list.ListEntries == nil { + // List entries are not set, fetch from the database. + list.ListEntries, err = l.state.DB.GetListEntries( + gtscontext.SetBarebones(ctx), + list.ID, + "", "", "", 0, + ) + if err != nil { + errs.Append(fmt.Errorf("error populating list entries: %w", err)) + } + } + + return errs.Combine() +} + +func (l *listDB) PutList(ctx context.Context, list *gtsmodel.List) error { + return l.state.Caches.GTS.List().Store(list, func() error { + _, err := l.conn.NewInsert().Model(list).Exec(ctx) + return l.conn.ProcessError(err) + }) +} + +func (l *listDB) UpdateList(ctx context.Context, list *gtsmodel.List, columns ...string) error { + list.UpdatedAt = time.Now() + if len(columns) > 0 { + // If we're updating by column, ensure "updated_at" is included. + columns = append(columns, "updated_at") + } + + return l.state.Caches.GTS.List().Store(list, func() error { + if _, err := l.conn.NewUpdate(). + Model(list). + Where("? = ?", bun.Ident("list.id"), list.ID). + Column(columns...). + Exec(ctx); err != nil { + return l.conn.ProcessError(err) + } + + return nil + }) +} + +func (l *listDB) DeleteListByID(ctx context.Context, id string) error { + defer l.state.Caches.GTS.List().Invalidate("ID", id) + + // Select all entries that belong to this list. + listEntries, err := l.state.DB.GetListEntries(ctx, id, "", "", "", 0) + if err != nil { + return fmt.Errorf("error selecting entries from list %q: %w", id, err) + } + + // Delete each list entry. This will + // invalidate the list timeline too. + for _, listEntry := range listEntries { + err := l.state.DB.DeleteListEntry(ctx, listEntry.ID) + if err != nil && !errors.Is(err, db.ErrNoEntries) { + return err + } + } + + // Finally delete list itself from DB. + _, err = l.conn.NewDelete(). + Table("lists"). + Where("? = ?", bun.Ident("id"), id). + Exec(ctx) + return l.conn.ProcessError(err) +} + +/* + LIST ENTRY functions +*/ + +func (l *listDB) getListEntry(ctx context.Context, lookup string, dbQuery func(*gtsmodel.ListEntry) error, keyParts ...any) (*gtsmodel.ListEntry, error) { + listEntry, err := l.state.Caches.GTS.ListEntry().Load(lookup, func() (*gtsmodel.ListEntry, error) { + var listEntry gtsmodel.ListEntry + + // Not cached! Perform database query. + if err := dbQuery(&listEntry); err != nil { + return nil, l.conn.ProcessError(err) + } + + return &listEntry, nil + }, keyParts...) + if err != nil { + return nil, err // already processed + } + + if gtscontext.Barebones(ctx) { + // Only a barebones model was requested. + return listEntry, nil + } + + // Further populate the list entry fields where applicable. + if err := l.state.DB.PopulateListEntry(ctx, listEntry); err != nil { + return nil, err + } + + return listEntry, nil +} + +func (l *listDB) GetListEntryByID(ctx context.Context, id string) (*gtsmodel.ListEntry, error) { + return l.getListEntry( + ctx, + "ID", + func(listEntry *gtsmodel.ListEntry) error { + return l.conn.NewSelect(). + Model(listEntry). + Where("? = ?", bun.Ident("list_entry.id"), id). + Scan(ctx) + }, + id, + ) +} + +func (l *listDB) GetListEntries(ctx context.Context, + listID string, + maxID string, + sinceID string, + minID string, + limit int, +) ([]*gtsmodel.ListEntry, error) { + // Ensure reasonable + if limit < 0 { + limit = 0 + } + + // Make educated guess for slice size + var ( + entryIDs = make([]string, 0, limit) + frontToBack = true + ) + + q := l.conn. + NewSelect(). + TableExpr("? AS ?", bun.Ident("list_entries"), bun.Ident("entry")). + // Select only IDs from table + Column("entry.id"). + // Select only entries belonging to listID. + Where("? = ?", bun.Ident("entry.list_id"), listID) + + if maxID != "" { + // return only entries LOWER (ie., older) than maxID + q = q.Where("? < ?", bun.Ident("entry.id"), maxID) + } + + if sinceID != "" { + // return only entries HIGHER (ie., newer) than sinceID + q = q.Where("? > ?", bun.Ident("entry.id"), sinceID) + } + + if minID != "" { + // return only entries HIGHER (ie., newer) than minID + q = q.Where("? > ?", bun.Ident("entry.id"), minID) + + // page up + frontToBack = false + } + + if limit > 0 { + // limit amount of entries returned + q = q.Limit(limit) + } + + if frontToBack { + // Page down. + q = q.Order("entry.id DESC") + } else { + // Page up. + q = q.Order("entry.id ASC") + } + + if err := q.Scan(ctx, &entryIDs); err != nil { + return nil, l.conn.ProcessError(err) + } + + if len(entryIDs) == 0 { + return nil, nil + } + + // If we're paging up, we still want entries + // to be sorted by ID desc, so reverse ids slice. + // https://zchee.github.io/golang-wiki/SliceTricks/#reversing + if !frontToBack { + for l, r := 0, len(entryIDs)-1; l < r; l, r = l+1, r-1 { + entryIDs[l], entryIDs[r] = entryIDs[r], entryIDs[l] + } + } + + // Select each list entry using its ID to ensure cache used. + listEntries := make([]*gtsmodel.ListEntry, 0, len(entryIDs)) + for _, id := range entryIDs { + listEntry, err := l.state.DB.GetListEntryByID(ctx, id) + if err != nil { + log.Errorf(ctx, "error fetching list entry %q: %v", id, err) + continue + } + + // Append list entries. + listEntries = append(listEntries, listEntry) + } + + return listEntries, nil +} + +func (l *listDB) GetListEntriesForFollowID(ctx context.Context, followID string) ([]*gtsmodel.ListEntry, error) { + entryIDs := []string{} + + if err := l.conn. + NewSelect(). + TableExpr("? AS ?", bun.Ident("list_entries"), bun.Ident("entry")). + // Select only IDs from table + Column("entry.id"). + // Select only entries belonging with given followID. + Where("? = ?", bun.Ident("entry.follow_id"), followID). + Scan(ctx, &entryIDs); err != nil { + return nil, l.conn.ProcessError(err) + } + + if len(entryIDs) == 0 { + return nil, nil + } + + // Select each list entry using its ID to ensure cache used. + listEntries := make([]*gtsmodel.ListEntry, 0, len(entryIDs)) + for _, id := range entryIDs { + listEntry, err := l.state.DB.GetListEntryByID(ctx, id) + if err != nil { + log.Errorf(ctx, "error fetching list entry %q: %v", id, err) + continue + } + + // Append list entries. + listEntries = append(listEntries, listEntry) + } + + return listEntries, nil +} + +func (l *listDB) PopulateListEntry(ctx context.Context, listEntry *gtsmodel.ListEntry) error { + var err error + + if listEntry.Follow == nil { + // ListEntry follow is not set, fetch from the database. + listEntry.Follow, err = l.state.DB.GetFollowByID( + gtscontext.SetBarebones(ctx), + listEntry.FollowID, + ) + if err != nil { + return fmt.Errorf("error populating listEntry follow: %w", err) + } + } + + return nil +} + +func (l *listDB) PutListEntries(ctx context.Context, listEntries []*gtsmodel.ListEntry) error { + return l.conn.RunInTx(ctx, func(tx bun.Tx) error { + for _, listEntry := range listEntries { + if _, err := tx. + NewInsert(). + Model(listEntry). + Exec(ctx); err != nil { + return err + } + + // Invalidate the timeline for the list this entry belongs to. + if err := l.state.Timelines.List.RemoveTimeline(ctx, listEntry.ListID); err != nil { + log.Errorf(ctx, "PutListEntries: error invalidating list timeline: %q", err) + } + } + + return nil + }) +} + +func (l *listDB) DeleteListEntry(ctx context.Context, id string) error { + defer l.state.Caches.GTS.ListEntry().Invalidate("ID", id) + + // Load list entry into cache before attempting a delete, + // as we need the followID from it in order to trigger + // timeline invalidation. + listEntry, err := l.GetListEntryByID( + // Don't populate the entry; + // we only want the list ID. + gtscontext.SetBarebones(ctx), + id, + ) + if err != nil { + if errors.Is(err, db.ErrNoEntries) { + // Already gone. + return nil + } + return err + } + + defer func() { + // Invalidate the timeline for the list this entry belongs to. + if err := l.state.Timelines.List.RemoveTimeline(ctx, listEntry.ListID); err != nil { + log.Errorf(ctx, "DeleteListEntry: error invalidating list timeline: %q", err) + } + }() + + if _, err := l.conn.NewDelete(). + Table("list_entries"). + Where("? = ?", bun.Ident("id"), listEntry.ID). + Exec(ctx); err != nil { + return l.conn.ProcessError(err) + } + + return nil +} + +func (l *listDB) DeleteListEntriesForFollowID(ctx context.Context, followID string) error { + // Fetch IDs of all entries that pertain to this follow. + var listEntryIDs []string + if err := l.conn. + NewSelect(). + TableExpr("? AS ?", bun.Ident("list_entries"), bun.Ident("list_entry")). + Column("list_entry.id"). + Where("? = ?", bun.Ident("list_entry.follow_id"), followID). + Order("list_entry.id DESC"). + Scan(ctx, &listEntryIDs); err != nil { + return l.conn.ProcessError(err) + } + + for _, id := range listEntryIDs { + if err := l.DeleteListEntry(ctx, id); err != nil { + return err + } + } + + return nil +} diff --git a/internal/db/bundb/list_test.go b/internal/db/bundb/list_test.go new file mode 100644 index 000000000..296ab7c1a --- /dev/null +++ b/internal/db/bundb/list_test.go @@ -0,0 +1,315 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package bundb_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/suite" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtscontext" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "golang.org/x/exp/slices" +) + +type ListTestSuite struct { + BunDBStandardTestSuite +} + +func (suite *ListTestSuite) testStructs() (*gtsmodel.List, *gtsmodel.Account) { + testList := >smodel.List{} + *testList = *suite.testLists["local_account_1_list_1"] + + // Populate entries on this list as we'd expect them back from the db. + entries := make([]*gtsmodel.ListEntry, 0, len(suite.testListEntries)) + for _, entry := range suite.testListEntries { + entries = append(entries, entry) + } + + // Sort by ID descending (again, as we'd expect from the db). + slices.SortFunc(entries, func(a, b *gtsmodel.ListEntry) bool { + return b.ID < a.ID + }) + + testList.ListEntries = entries + + testAccount := >smodel.Account{} + *testAccount = *suite.testAccounts["local_account_1"] + + return testList, testAccount +} + +func (suite *ListTestSuite) checkList(expected *gtsmodel.List, actual *gtsmodel.List) { + suite.Equal(expected.ID, actual.ID) + suite.Equal(expected.Title, actual.Title) + suite.Equal(expected.AccountID, actual.AccountID) + suite.Equal(expected.RepliesPolicy, actual.RepliesPolicy) + suite.NotNil(actual.Account) +} + +func (suite *ListTestSuite) checkListEntry(expected *gtsmodel.ListEntry, actual *gtsmodel.ListEntry) { + suite.Equal(expected.ID, actual.ID) + suite.Equal(expected.ListID, actual.ListID) + suite.Equal(expected.FollowID, actual.FollowID) +} + +func (suite *ListTestSuite) checkListEntries(expected []*gtsmodel.ListEntry, actual []*gtsmodel.ListEntry) { + var ( + lExpected = len(expected) + lActual = len(actual) + ) + + if lExpected != lActual { + suite.FailNow("", "expected %d list entries, got %d", lExpected, lActual) + } + + var topID string + for i, expectedEntry := range expected { + actualEntry := actual[i] + + // Ensure ID descending. + if topID == "" { + topID = actualEntry.ID + } else { + suite.Less(actualEntry.ID, topID) + } + + suite.checkListEntry(expectedEntry, actualEntry) + } +} + +func (suite *ListTestSuite) TestGetListByID() { + testList, _ := suite.testStructs() + + dbList, err := suite.db.GetListByID(context.Background(), testList.ID) + if err != nil { + suite.FailNow(err.Error()) + } + + suite.checkList(testList, dbList) + suite.checkListEntries(testList.ListEntries, dbList.ListEntries) +} + +func (suite *ListTestSuite) TestGetListsForAccountID() { + testList, testAccount := suite.testStructs() + + dbLists, err := suite.db.GetListsForAccountID(context.Background(), testAccount.ID) + if err != nil { + suite.FailNow(err.Error()) + } + + if l := len(dbLists); l != 1 { + suite.FailNow("", "expected %d lists, got %d", 1, l) + } + + suite.checkList(testList, dbLists[0]) +} + +func (suite *ListTestSuite) TestGetListEntries() { + testList, _ := suite.testStructs() + + dbListEntries, err := suite.db.GetListEntries(context.Background(), testList.ID, "", "", "", 0) + if err != nil { + suite.FailNow(err.Error()) + } + + suite.checkListEntries(testList.ListEntries, dbListEntries) +} + +func (suite *ListTestSuite) TestPutList() { + ctx := context.Background() + _, testAccount := suite.testStructs() + + testList := >smodel.List{ + ID: "01H0J2PMYM54618VCV8Y8QYAT4", + Title: "Test List!", + AccountID: testAccount.ID, + } + + if err := suite.db.PutList(ctx, testList); err != nil { + suite.FailNow(err.Error()) + } + + dbList, err := suite.db.GetListByID(ctx, testList.ID) + if err != nil { + suite.FailNow(err.Error()) + } + + // Bodge testlist as though default had been set. + testList.RepliesPolicy = gtsmodel.RepliesPolicyFollowed + suite.checkList(testList, dbList) +} + +func (suite *ListTestSuite) TestUpdateList() { + ctx := context.Background() + testList, _ := suite.testStructs() + + // Get List in the cache first. + dbList, err := suite.db.GetListByID(ctx, testList.ID) + if err != nil { + suite.FailNow(err.Error()) + } + + // Now do the update. + testList.Title = "New Title!" + if err := suite.db.UpdateList(ctx, testList, "title"); err != nil { + suite.FailNow(err.Error()) + } + + // Cache should be invalidated + // + we should have updated list. + dbList, err = suite.db.GetListByID(ctx, testList.ID) + if err != nil { + suite.FailNow(err.Error()) + } + + suite.checkList(testList, dbList) +} + +func (suite *ListTestSuite) TestDeleteList() { + ctx := context.Background() + testList, _ := suite.testStructs() + + // Get List in the cache first. + if _, err := suite.db.GetListByID(ctx, testList.ID); err != nil { + suite.FailNow(err.Error()) + } + + // Now do the delete. + if err := suite.db.DeleteListByID(ctx, testList.ID); err != nil { + suite.FailNow(err.Error()) + } + + // Cache should be invalidated + // + we should have no list. + _, err := suite.db.GetListByID(ctx, testList.ID) + suite.ErrorIs(err, db.ErrNoEntries) + + // All entries belonging to this + // list should now be deleted. + listEntries, err := suite.db.GetListEntries(ctx, testList.ID, "", "", "", 0) + if err != nil { + suite.FailNow(err.Error()) + } + suite.Empty(listEntries) +} + +func (suite *ListTestSuite) TestPutListEntries() { + ctx := context.Background() + testList, _ := suite.testStructs() + + listEntries := []*gtsmodel.ListEntry{ + { + ID: "01H0MKMQY69HWDSDR2SWGA17R4", + ListID: testList.ID, + FollowID: "01H0MKNFRFZS8R9WV6DBX31Y03", // random id, doesn't exist + }, + { + ID: "01H0MKPGQF0E7QAVW5BKTHZ630", + ListID: testList.ID, + FollowID: "01H0MKP6RR8VEHN3GVWFBP2H30", // random id, doesn't exist + }, + { + ID: "01H0MKPPP2DT68FRBMR1FJM32T", + ListID: testList.ID, + FollowID: "01H0MKQ0KA29C6NFJ27GTZD16J", // random id, doesn't exist + }, + } + + if err := suite.db.PutListEntries(ctx, listEntries); err != nil { + suite.FailNow(err.Error()) + } + + // Add these entries to the test list, sort it again + // to reflect what we'd expect to get from the db. + testList.ListEntries = append(testList.ListEntries, listEntries...) + slices.SortFunc(testList.ListEntries, func(a, b *gtsmodel.ListEntry) bool { + return b.ID < a.ID + }) + + // Now get all list entries from the db. + // Use barebones for this because the ones + // we just added will fail if we try to get + // the nonexistent follows. + dbListEntries, err := suite.db.GetListEntries( + gtscontext.SetBarebones(ctx), + testList.ID, + "", "", "", 0) + if err != nil { + suite.FailNow(err.Error()) + } + + suite.checkListEntries(testList.ListEntries, dbListEntries) +} + +func (suite *ListTestSuite) TestDeleteListEntry() { + ctx := context.Background() + testList, _ := suite.testStructs() + + // Get List in the cache first. + if _, err := suite.db.GetListByID(ctx, testList.ID); err != nil { + suite.FailNow(err.Error()) + } + + // Delete the first entry. + if err := suite.db.DeleteListEntry(ctx, testList.ListEntries[0].ID); err != nil { + suite.FailNow(err.Error()) + } + + // Get list from the db again. + dbList, err := suite.db.GetListByID(ctx, testList.ID) + if err != nil { + suite.FailNow(err.Error()) + } + + // Bodge the testlist as though + // we'd removed the first entry. + testList.ListEntries = testList.ListEntries[1:] + suite.checkList(testList, dbList) +} + +func (suite *ListTestSuite) TestDeleteListEntriesForFollowID() { + ctx := context.Background() + testList, _ := suite.testStructs() + + // Get List in the cache first. + if _, err := suite.db.GetListByID(ctx, testList.ID); err != nil { + suite.FailNow(err.Error()) + } + + // Delete the first entry. + if err := suite.db.DeleteListEntriesForFollowID(ctx, testList.ListEntries[0].FollowID); err != nil { + suite.FailNow(err.Error()) + } + + // Get list from the db again. + dbList, err := suite.db.GetListByID(ctx, testList.ID) + if err != nil { + suite.FailNow(err.Error()) + } + + // Bodge the testlist as though + // we'd removed the first entry. + testList.ListEntries = testList.ListEntries[1:] + suite.checkList(testList, dbList) +} + +func TestListTestSuite(t *testing.T) { + suite.Run(t, new(ListTestSuite)) +} diff --git a/internal/db/bundb/migrations/20230515173919_lists.go b/internal/db/bundb/migrations/20230515173919_lists.go new file mode 100644 index 000000000..e0ea5c7b6 --- /dev/null +++ b/internal/db/bundb/migrations/20230515173919_lists.go @@ -0,0 +1,92 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package migrations + +import ( + "context" + + gtsmodel "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/uptrace/bun" +) + +func init() { + up := func(ctx context.Context, db *bun.DB) error { + return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + // List table. + if _, err := tx. + NewCreateTable(). + Model(>smodel.List{}). + IfNotExists(). + Exec(ctx); err != nil { + return err + } + + // Add indexes to the List table. + for index, columns := range map[string][]string{ + "lists_id_idx": {"id"}, + "lists_account_id_idx": {"account_id"}, + } { + if _, err := tx. + NewCreateIndex(). + Table("lists"). + Index(index). + Column(columns...). + Exec(ctx); err != nil { + return err + } + } + + // List entry table. + if _, err := tx. + NewCreateTable(). + Model(>smodel.ListEntry{}). + IfNotExists(). + Exec(ctx); err != nil { + return err + } + + // Add indexes to the List entry table. + for index, columns := range map[string][]string{ + "list_entries_id_idx": {"id"}, + "list_entries_list_id_idx": {"list_id"}, + "list_entries_follow_id_idx": {"follow_id"}, + } { + if _, err := tx. + NewCreateIndex(). + Table("list_entries"). + Index(index). + Column(columns...). + Exec(ctx); err != nil { + return err + } + } + + return nil + }) + } + + down := func(ctx context.Context, db *bun.DB) error { + return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + return nil + }) + } + + if err := Migrations.Register(up, down); err != nil { + panic(err) + } +} diff --git a/internal/db/bundb/relationship_follow.go b/internal/db/bundb/relationship_follow.go index fe1f26bf1..39b85075c 100644 --- a/internal/db/bundb/relationship_follow.go +++ b/internal/db/bundb/relationship_follow.go @@ -25,6 +25,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtscontext" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/log" "github.com/uptrace/bun" @@ -149,25 +150,42 @@ func (r *relationshipDB) getFollow(ctx context.Context, lookup string, dbQuery f return follow, nil } - // Set the follow source account - follow.Account, err = r.state.DB.GetAccountByID( - gtscontext.SetBarebones(ctx), - follow.AccountID, - ) - if err != nil { - return nil, fmt.Errorf("error getting follow source account: %w", err) + if err := r.state.DB.PopulateFollow(ctx, follow); err != nil { + return nil, err } - // Set the follow target account - follow.TargetAccount, err = r.state.DB.GetAccountByID( - gtscontext.SetBarebones(ctx), - follow.TargetAccountID, + return follow, nil +} + +func (r *relationshipDB) PopulateFollow(ctx context.Context, follow *gtsmodel.Follow) error { + var ( + err error + errs = make(gtserror.MultiError, 0, 2) ) - if err != nil { - return nil, fmt.Errorf("error getting follow target account: %w", err) + + if follow.Account == nil { + // Follow account is not set, fetch from the database. + follow.Account, err = r.state.DB.GetAccountByID( + gtscontext.SetBarebones(ctx), + follow.AccountID, + ) + if err != nil { + errs.Append(fmt.Errorf("error populating follow account: %w", err)) + } } - return follow, nil + if follow.TargetAccount == nil { + // Follow target account is not set, fetch from the database. + follow.TargetAccount, err = r.state.DB.GetAccountByID( + gtscontext.SetBarebones(ctx), + follow.TargetAccountID, + ) + if err != nil { + errs.Append(fmt.Errorf("error populating follow target account: %w", err)) + } + } + + return errs.Combine() } func (r *relationshipDB) PutFollow(ctx context.Context, follow *gtsmodel.Follow) error { @@ -197,27 +215,40 @@ func (r *relationshipDB) UpdateFollow(ctx context.Context, follow *gtsmodel.Foll }) } +func (r *relationshipDB) deleteFollow(ctx context.Context, id string) error { + // Delete the follow itself using the given ID. + if _, err := r.conn.NewDelete(). + Table("follows"). + Where("? = ?", bun.Ident("id"), id). + Exec(ctx); err != nil { + return r.conn.ProcessError(err) + } + + // Delete every list entry that used this followID. + if err := r.state.DB.DeleteListEntriesForFollowID(ctx, id); err != nil { + return fmt.Errorf("deleteFollow: error deleting list entries: %w", err) + } + + return nil +} + func (r *relationshipDB) DeleteFollowByID(ctx context.Context, id string) error { defer r.state.Caches.GTS.Follow().Invalidate("ID", id) // Load follow into cache before attempting a delete, // as we need it cached in order to trigger the invalidate // callback. This in turn invalidates others. - _, err := r.GetFollowByID(gtscontext.SetBarebones(ctx), id) + follow, err := r.GetFollowByID(gtscontext.SetBarebones(ctx), id) if err != nil { if errors.Is(err, db.ErrNoEntries) { - // not an issue. - err = nil + // Already gone. + return nil } return err } // Finally delete follow from DB. - _, err = r.conn.NewDelete(). - Table("follows"). - Where("? = ?", bun.Ident("id"), id). - Exec(ctx) - return r.conn.ProcessError(err) + return r.deleteFollow(ctx, follow.ID) } func (r *relationshipDB) DeleteFollowByURI(ctx context.Context, uri string) error { @@ -226,21 +257,17 @@ func (r *relationshipDB) DeleteFollowByURI(ctx context.Context, uri string) erro // Load follow into cache before attempting a delete, // as we need it cached in order to trigger the invalidate // callback. This in turn invalidates others. - _, err := r.GetFollowByURI(gtscontext.SetBarebones(ctx), uri) + follow, err := r.GetFollowByURI(gtscontext.SetBarebones(ctx), uri) if err != nil { if errors.Is(err, db.ErrNoEntries) { - // not an issue. - err = nil + // Already gone. + return nil } return err } // Finally delete follow from DB. - _, err = r.conn.NewDelete(). - Table("follows"). - Where("? = ?", bun.Ident("uri"), uri). - Exec(ctx) - return r.conn.ProcessError(err) + return r.deleteFollow(ctx, follow.ID) } func (r *relationshipDB) DeleteAccountFollows(ctx context.Context, accountID string) error { @@ -272,16 +299,16 @@ func (r *relationshipDB) DeleteAccountFollows(ctx context.Context, accountID str // but it is the only way we can ensure we invalidate all // related caches correctly (e.g. visibility). for _, id := range followIDs { - _, err := r.GetFollowByID(ctx, id) + follow, err := r.GetFollowByID(ctx, id) if err != nil && !errors.Is(err, db.ErrNoEntries) { return err } + + // Delete each follow from DB. + if err := r.deleteFollow(ctx, follow.ID); err != nil && !errors.Is(err, db.ErrNoEntries) { + return err + } } - // Finally delete all from DB. - _, err := r.conn.NewDelete(). - Table("follows"). - Where("? IN (?)", bun.Ident("id"), bun.In(followIDs)). - Exec(ctx) - return r.conn.ProcessError(err) + return nil } diff --git a/internal/db/bundb/relationship_test.go b/internal/db/bundb/relationship_test.go index 0e38d19fe..63fdb9632 100644 --- a/internal/db/bundb/relationship_test.go +++ b/internal/db/bundb/relationship_test.go @@ -807,16 +807,27 @@ func (suite *RelationshipTestSuite) TestUnfollowExisting() { follow, err := suite.db.GetFollow(context.Background(), originAccount.ID, targetAccount.ID) suite.NoError(err) suite.NotNil(follow) + followID := follow.ID - err = suite.db.DeleteFollowByID(context.Background(), follow.ID) + // We should have list entries for this follow. + listEntries, err := suite.db.GetListEntriesForFollowID(context.Background(), followID) + suite.NoError(err) + suite.NotEmpty(listEntries) + + err = suite.db.DeleteFollowByID(context.Background(), followID) suite.NoError(err) follow, err = suite.db.GetFollow(context.Background(), originAccount.ID, targetAccount.ID) suite.EqualError(err, db.ErrNoEntries.Error()) suite.Nil(follow) + + // ListEntries pertaining to this follow should be deleted too. + listEntries, err = suite.db.GetListEntriesForFollowID(context.Background(), followID) + suite.NoError(err) + suite.Empty(listEntries) } -func (suite *RelationshipTestSuite) TestUnfollowNotExisting() { +func (suite *RelationshipTestSuite) TestGetFollowNotExisting() { originAccount := suite.testAccounts["local_account_1"] targetAccountID := "01GTVD9N484CZ6AM90PGGNY7GQ" diff --git a/internal/db/bundb/timeline.go b/internal/db/bundb/timeline.go index 87e7751d2..d33840a7b 100644 --- a/internal/db/bundb/timeline.go +++ b/internal/db/bundb/timeline.go @@ -19,9 +19,11 @@ package bundb import ( "context" + "fmt" "time" "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtscontext" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/id" "github.com/superseriousbusiness/gotosocial/internal/log" @@ -281,3 +283,130 @@ func (t *timelineDB) GetFavedTimeline(ctx context.Context, accountID string, max prevMinID := faves[0].ID return statuses, nextMaxID, prevMinID, nil } + +func (t *timelineDB) GetListTimeline( + ctx context.Context, + listID string, + maxID string, + sinceID string, + minID string, + limit int, +) ([]*gtsmodel.Status, error) { + // Ensure reasonable + if limit < 0 { + limit = 0 + } + + // Make educated guess for slice size + var ( + statusIDs = make([]string, 0, limit) + frontToBack = true + ) + + // Fetch all listEntries entries from the database. + listEntries, err := t.state.DB.GetListEntries( + // Don't need actual follows + // for this, just the IDs. + gtscontext.SetBarebones(ctx), + listID, + "", "", "", 0, + ) + if err != nil { + return nil, fmt.Errorf("error getting entries for list %s: %w", listID, err) + } + + // Extract just the IDs of each follow. + followIDs := make([]string, 0, len(listEntries)) + for _, listEntry := range listEntries { + followIDs = append(followIDs, listEntry.FollowID) + } + + // Select target account IDs from follows. + subQ := t.conn. + NewSelect(). + TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")). + Column("follow.target_account_id"). + Where("? IN (?)", bun.Ident("follow.id"), bun.In(followIDs)) + + // Select only status IDs created + // by one of the followed accounts. + q := t.conn. + NewSelect(). + TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). + // Select only IDs from table + Column("status.id"). + Where("? IN (?)", bun.Ident("status.account_id"), subQ) + + if maxID == "" || maxID >= id.Highest { + const future = 24 * time.Hour + + var err error + + // don't return statuses more than 24hr in the future + maxID, err = id.NewULIDFromTime(time.Now().Add(future)) + if err != nil { + return nil, err + } + } + + // return only statuses LOWER (ie., older) than maxID + q = q.Where("? < ?", bun.Ident("status.id"), maxID) + + if sinceID != "" { + // return only statuses HIGHER (ie., newer) than sinceID + q = q.Where("? > ?", bun.Ident("status.id"), sinceID) + } + + if minID != "" { + // return only statuses HIGHER (ie., newer) than minID + q = q.Where("? > ?", bun.Ident("status.id"), minID) + + // page up + frontToBack = false + } + + if limit > 0 { + // limit amount of statuses returned + q = q.Limit(limit) + } + + if frontToBack { + // Page down. + q = q.Order("status.id DESC") + } else { + // Page up. + q = q.Order("status.id ASC") + } + + if err := q.Scan(ctx, &statusIDs); err != nil { + return nil, t.conn.ProcessError(err) + } + + if len(statusIDs) == 0 { + return nil, nil + } + + // If we're paging up, we still want statuses + // to be sorted by ID desc, so reverse ids slice. + // https://zchee.github.io/golang-wiki/SliceTricks/#reversing + if !frontToBack { + for l, r := 0, len(statusIDs)-1; l < r; l, r = l+1, r-1 { + statusIDs[l], statusIDs[r] = statusIDs[r], statusIDs[l] + } + } + + statuses := make([]*gtsmodel.Status, 0, len(statusIDs)) + for _, id := range statusIDs { + // Fetch status from db for ID + status, err := t.state.DB.GetStatusByID(ctx, id) + if err != nil { + log.Errorf(ctx, "error fetching status %q: %v", id, err) + continue + } + + // Append status to slice + statuses = append(statuses, status) + } + + return statuses, nil +} diff --git a/internal/db/bundb/timeline_test.go b/internal/db/bundb/timeline_test.go index f954c78dd..7e8fd0838 100644 --- a/internal/db/bundb/timeline_test.go +++ b/internal/db/bundb/timeline_test.go @@ -33,134 +33,243 @@ type TimelineTestSuite struct { BunDBStandardTestSuite } -func (suite *TimelineTestSuite) TestGetPublicTimeline() { - var count int +func getFutureStatus() *gtsmodel.Status { + theDistantFuture := time.Now().Add(876600 * time.Hour) + id, err := id.NewULIDFromTime(theDistantFuture) + if err != nil { + panic(err) + } + + return >smodel.Status{ + ID: id, + URI: "http://localhost:8080/users/admin/statuses/" + id, + URL: "http://localhost:8080/@admin/statuses/" + id, + Content: "it's the future, wooooooooooooooooooooooooooooooooo", + Text: "it's the future, wooooooooooooooooooooooooooooooooo", + AttachmentIDs: []string{}, + TagIDs: []string{}, + MentionIDs: []string{}, + EmojiIDs: []string{}, + CreatedAt: theDistantFuture, + UpdatedAt: theDistantFuture, + Local: testrig.TrueBool(), + AccountURI: "http://localhost:8080/users/admin", + AccountID: "01F8MH17FWEB39HZJ76B6VXSKF", + InReplyToID: "", + BoostOfID: "", + ContentWarning: "", + Visibility: gtsmodel.VisibilityPublic, + Sensitive: testrig.FalseBool(), + Language: "en", + CreatedWithApplicationID: "01F8MGXQRHYF5QPMTMXP78QC2F", + Federated: testrig.TrueBool(), + Boostable: testrig.TrueBool(), + Replyable: testrig.TrueBool(), + Likeable: testrig.TrueBool(), + ActivityStreamsType: ap.ObjectNote, + } +} + +func (suite *TimelineTestSuite) publicCount() int { + var publicCount int for _, status := range suite.testStatuses { if status.Visibility == gtsmodel.VisibilityPublic && status.BoostOfID == "" { - count++ + publicCount++ } } - ctx := context.Background() - s, err := suite.db.GetPublicTimeline(ctx, "", "", "", 20, false) - suite.NoError(err) - - suite.Len(s, count) + return publicCount } -func (suite *TimelineTestSuite) TestGetPublicTimelineWithFutureStatus() { - var count int +func (suite *TimelineTestSuite) checkStatuses(statuses []*gtsmodel.Status, maxID string, minID string, expectedLength int) { + if l := len(statuses); l != expectedLength { + suite.FailNow("", "expected %d statuses in slice, got %d", expectedLength, l) + } else if l == 0 { + // Can't test empty slice. + return + } - for _, status := range suite.testStatuses { - if status.Visibility == gtsmodel.VisibilityPublic && - status.BoostOfID == "" { - count++ + // Check ordering + bounds of statuses. + highest := statuses[0].ID + for _, status := range statuses { + id := status.ID + + if id >= maxID { + suite.FailNow("", "%s greater than maxID %s", id, maxID) + } + + if id <= minID { + suite.FailNow("", "%s smaller than minID %s", id, minID) + } + + if id > highest { + suite.FailNow("", "statuses in slice were not ordered highest -> lowest ID") } + + highest = id } +} +func (suite *TimelineTestSuite) TestGetPublicTimeline() { ctx := context.Background() + s, err := suite.db.GetPublicTimeline(ctx, "", "", "", 20, false) + if err != nil { + suite.FailNow(err.Error()) + } + + suite.checkStatuses(s, id.Highest, id.Lowest, suite.publicCount()) +} + +func (suite *TimelineTestSuite) TestGetPublicTimelineWithFutureStatus() { + ctx := context.Background() + + // Insert a status set far in the + // future, it shouldn't be retrieved. futureStatus := getFutureStatus() - err := suite.db.PutStatus(ctx, futureStatus) - suite.NoError(err) + if err := suite.db.PutStatus(ctx, futureStatus); err != nil { + suite.FailNow(err.Error()) + } s, err := suite.db.GetPublicTimeline(ctx, "", "", "", 20, false) - suite.NoError(err) + if err != nil { + suite.FailNow(err.Error()) + } suite.NotContains(s, futureStatus) - suite.Len(s, count) + suite.checkStatuses(s, id.Highest, id.Lowest, suite.publicCount()) } func (suite *TimelineTestSuite) TestGetHomeTimeline() { - ctx := context.Background() - - viewingAccount := suite.testAccounts["local_account_1"] + var ( + ctx = context.Background() + viewingAccount = suite.testAccounts["local_account_1"] + ) s, err := suite.db.GetHomeTimeline(ctx, viewingAccount.ID, "", "", "", 20, false) - suite.NoError(err) + if err != nil { + suite.FailNow(err.Error()) + } - suite.Len(s, 16) + suite.checkStatuses(s, id.Highest, id.Lowest, 16) } func (suite *TimelineTestSuite) TestGetHomeTimelineWithFutureStatus() { - ctx := context.Background() - - viewingAccount := suite.testAccounts["local_account_1"] + var ( + ctx = context.Background() + viewingAccount = suite.testAccounts["local_account_1"] + ) + // Insert a status set far in the + // future, it shouldn't be retrieved. futureStatus := getFutureStatus() - err := suite.db.PutStatus(ctx, futureStatus) - suite.NoError(err) + if err := suite.db.PutStatus(ctx, futureStatus); err != nil { + suite.FailNow(err.Error()) + } - s, err := suite.db.GetHomeTimeline(context.Background(), viewingAccount.ID, "", "", "", 20, false) - suite.NoError(err) + s, err := suite.db.GetHomeTimeline(ctx, viewingAccount.ID, "", "", "", 20, false) + if err != nil { + suite.FailNow(err.Error()) + } suite.NotContains(s, futureStatus) - suite.Len(s, 16) + suite.checkStatuses(s, id.Highest, id.Lowest, 16) } func (suite *TimelineTestSuite) TestGetHomeTimelineBackToFront() { - ctx := context.Background() - - viewingAccount := suite.testAccounts["local_account_1"] + var ( + ctx = context.Background() + viewingAccount = suite.testAccounts["local_account_1"] + ) s, err := suite.db.GetHomeTimeline(ctx, viewingAccount.ID, "", "", id.Lowest, 5, false) - suite.NoError(err) + if err != nil { + suite.FailNow(err.Error()) + } - suite.Len(s, 5) + suite.checkStatuses(s, id.Highest, id.Lowest, 5) suite.Equal("01F8MHAYFKS4KMXF8K5Y1C0KRN", s[0].ID) suite.Equal("01F8MH75CBF9JFX4ZAD54N0W0R", s[len(s)-1].ID) } func (suite *TimelineTestSuite) TestGetHomeTimelineFromHighest() { - ctx := context.Background() - - viewingAccount := suite.testAccounts["local_account_1"] + var ( + ctx = context.Background() + viewingAccount = suite.testAccounts["local_account_1"] + ) s, err := suite.db.GetHomeTimeline(ctx, viewingAccount.ID, id.Highest, "", "", 5, false) - suite.NoError(err) + if err != nil { + suite.FailNow(err.Error()) + } - suite.Len(s, 5) + suite.checkStatuses(s, id.Highest, id.Lowest, 5) suite.Equal("01G36SF3V6Y6V5BF9P4R7PQG7G", s[0].ID) suite.Equal("01FCTA44PW9H1TB328S9AQXKDS", s[len(s)-1].ID) } -func getFutureStatus() *gtsmodel.Status { - theDistantFuture := time.Now().Add(876600 * time.Hour) - id, err := id.NewULIDFromTime(theDistantFuture) +func (suite *TimelineTestSuite) TestGetListTimelineNoParams() { + var ( + ctx = context.Background() + list = suite.testLists["local_account_1_list_1"] + ) + + s, err := suite.db.GetListTimeline(ctx, list.ID, "", "", "", 20) if err != nil { - panic(err) + suite.FailNow(err.Error()) } - return >smodel.Status{ - ID: id, - URI: "http://localhost:8080/users/admin/statuses/" + id, - URL: "http://localhost:8080/@admin/statuses/" + id, - Content: "it's the future, wooooooooooooooooooooooooooooooooo", - Text: "it's the future, wooooooooooooooooooooooooooooooooo", - AttachmentIDs: []string{}, - TagIDs: []string{}, - MentionIDs: []string{}, - EmojiIDs: []string{}, - CreatedAt: theDistantFuture, - UpdatedAt: theDistantFuture, - Local: testrig.TrueBool(), - AccountURI: "http://localhost:8080/users/admin", - AccountID: "01F8MH17FWEB39HZJ76B6VXSKF", - InReplyToID: "", - BoostOfID: "", - ContentWarning: "", - Visibility: gtsmodel.VisibilityPublic, - Sensitive: testrig.FalseBool(), - Language: "en", - CreatedWithApplicationID: "01F8MGXQRHYF5QPMTMXP78QC2F", - Federated: testrig.TrueBool(), - Boostable: testrig.TrueBool(), - Replyable: testrig.TrueBool(), - Likeable: testrig.TrueBool(), - ActivityStreamsType: ap.ObjectNote, + suite.checkStatuses(s, id.Highest, id.Lowest, 11) +} + +func (suite *TimelineTestSuite) TestGetListTimelineMaxID() { + var ( + ctx = context.Background() + list = suite.testLists["local_account_1_list_1"] + ) + + s, err := suite.db.GetListTimeline(ctx, list.ID, id.Highest, "", "", 5) + if err != nil { + suite.FailNow(err.Error()) + } + + suite.checkStatuses(s, id.Highest, id.Lowest, 5) + suite.Equal("01G36SF3V6Y6V5BF9P4R7PQG7G", s[0].ID) + suite.Equal("01FCQSQ667XHJ9AV9T27SJJSX5", s[len(s)-1].ID) +} + +func (suite *TimelineTestSuite) TestGetListTimelineMinID() { + var ( + ctx = context.Background() + list = suite.testLists["local_account_1_list_1"] + ) + + s, err := suite.db.GetListTimeline(ctx, list.ID, "", "", id.Lowest, 5) + if err != nil { + suite.FailNow(err.Error()) } + + suite.checkStatuses(s, id.Highest, id.Lowest, 5) + suite.Equal("01F8MHC8VWDRBQR0N1BATDDEM5", s[0].ID) + suite.Equal("01F8MH75CBF9JFX4ZAD54N0W0R", s[len(s)-1].ID) +} + +func (suite *TimelineTestSuite) TestGetListTimelineMinIDPagingUp() { + var ( + ctx = context.Background() + list = suite.testLists["local_account_1_list_1"] + ) + + s, err := suite.db.GetListTimeline(ctx, list.ID, "", "", "01F8MHC8VWDRBQR0N1BATDDEM5", 5) + if err != nil { + suite.FailNow(err.Error()) + } + + suite.checkStatuses(s, id.Highest, "01F8MHC8VWDRBQR0N1BATDDEM5", 5) + suite.Equal("01G20ZM733MGN8J344T4ZDDFY1", s[0].ID) + suite.Equal("01F8MHCP5P2NWYQ416SBA0XSEV", s[len(s)-1].ID) } func TestTimelineTestSuite(t *testing.T) { diff --git a/internal/db/db.go b/internal/db/db.go index 7b25b3dae..f47a35bb3 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -36,6 +36,7 @@ type DB interface { Domain Emoji Instance + List Media Mention Notification diff --git a/internal/db/list.go b/internal/db/list.go new file mode 100644 index 000000000..4472589dc --- /dev/null +++ b/internal/db/list.go @@ -0,0 +1,67 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package db + +import ( + "context" + + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) + +type List interface { + // GetListByID gets one list with the given id. + GetListByID(ctx context.Context, id string) (*gtsmodel.List, error) + + // GetListsForAccountID gets all lists owned by the given accountID. + GetListsForAccountID(ctx context.Context, accountID string) ([]*gtsmodel.List, error) + + // PopulateList ensures that the list's struct fields are populated. + PopulateList(ctx context.Context, list *gtsmodel.List) error + + // PutList puts a new list in the database. + PutList(ctx context.Context, list *gtsmodel.List) error + + // UpdateList updates the given list. + // Columns is optional, if not specified all will be updated. + UpdateList(ctx context.Context, list *gtsmodel.List, columns ...string) error + + // DeleteListByID deletes one list with the given ID. + DeleteListByID(ctx context.Context, id string) error + + // GetListEntryByID gets one list entry with the given ID. + GetListEntryByID(ctx context.Context, id string) (*gtsmodel.ListEntry, error) + + // GetListEntries gets list entries from the given listID, using the given parameters. + GetListEntries(ctx context.Context, listID string, maxID string, sinceID string, minID string, limit int) ([]*gtsmodel.ListEntry, error) + + // GetListEntriesForFollowID returns all listEntries that pertain to the given followID. + GetListEntriesForFollowID(ctx context.Context, followID string) ([]*gtsmodel.ListEntry, error) + + // PopulateListEntry ensures that the listEntry's struct fields are populated. + PopulateListEntry(ctx context.Context, listEntry *gtsmodel.ListEntry) error + + // PutListEntries inserts a slice of listEntries into the database. + // It uses a transaction to ensure no partial updates. + PutListEntries(ctx context.Context, listEntries []*gtsmodel.ListEntry) error + + // DeleteListEntry deletes one list entry with the given id. + DeleteListEntry(ctx context.Context, id string) error + + // DeleteListEntryForFollowID deletes all list entries with the given followID. + DeleteListEntriesForFollowID(ctx context.Context, followID string) error +} diff --git a/internal/db/relationship.go b/internal/db/relationship.go index ae879b5d2..99093591c 100644 --- a/internal/db/relationship.go +++ b/internal/db/relationship.go @@ -64,6 +64,9 @@ type Relationship interface { // GetFollow retrieves a follow if it exists between source and target accounts. GetFollow(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.Follow, error) + // PopulateFollow populates the struct pointers on the given follow. + PopulateFollow(ctx context.Context, follow *gtsmodel.Follow) error + // GetFollowRequestByID fetches follow request with given ID from the database. GetFollowRequestByID(ctx context.Context, id string) (*gtsmodel.FollowRequest, error) diff --git a/internal/db/timeline.go b/internal/db/timeline.go index 10149cc09..2635bece2 100644 --- a/internal/db/timeline.go +++ b/internal/db/timeline.go @@ -44,4 +44,8 @@ type Timeline interface { // // Also note the extra return values, which correspond to the nextMaxID and prevMinID for building Link headers. GetFavedTimeline(ctx context.Context, accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, Error) + + // GetListTimeline returns a slice of statuses from followed accounts collected within the list with the given listID. + // Statuses should be returned in descending order of when they were created (newest first). + GetListTimeline(ctx context.Context, listID string, maxID string, sinceID string, minID string, limit int) ([]*gtsmodel.Status, error) } |