diff options
Diffstat (limited to 'internal/db/bundb/list_test.go')
-rw-r--r-- | internal/db/bundb/list_test.go | 145 |
1 files changed, 49 insertions, 96 deletions
diff --git a/internal/db/bundb/list_test.go b/internal/db/bundb/list_test.go index 9c5fb2c76..3952a87c0 100644 --- a/internal/db/bundb/list_test.go +++ b/internal/db/bundb/list_test.go @@ -24,7 +24,6 @@ import ( "github.com/stretchr/testify/suite" "github.com/superseriousbusiness/gotosocial/internal/db" - "github.com/superseriousbusiness/gotosocial/internal/gtscontext" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) @@ -32,7 +31,7 @@ type ListTestSuite struct { BunDBStandardTestSuite } -func (suite *ListTestSuite) testStructs() (*gtsmodel.List, *gtsmodel.Account) { +func (suite *ListTestSuite) testStructs() (*gtsmodel.List, []*gtsmodel.ListEntry, *gtsmodel.Account) { testList := >smodel.List{} *testList = *suite.testLists["local_account_1_list_1"] @@ -55,12 +54,10 @@ func (suite *ListTestSuite) testStructs() (*gtsmodel.List, *gtsmodel.Account) { } }) - testList.ListEntries = entries - testAccount := >smodel.Account{} *testAccount = *suite.testAccounts["local_account_1"] - return testList, testAccount + return testList, entries, testAccount } func (suite *ListTestSuite) checkList(expected *gtsmodel.List, actual *gtsmodel.List) { @@ -103,7 +100,7 @@ func (suite *ListTestSuite) checkListEntries(expected []*gtsmodel.ListEntry, act } func (suite *ListTestSuite) TestGetListByID() { - testList, _ := suite.testStructs() + testList, _, _ := suite.testStructs() dbList, err := suite.db.GetListByID(context.Background(), testList.ID) if err != nil { @@ -111,13 +108,12 @@ func (suite *ListTestSuite) TestGetListByID() { } suite.checkList(testList, dbList) - suite.checkListEntries(testList.ListEntries, dbList.ListEntries) } func (suite *ListTestSuite) TestGetListsForAccountID() { - testList, testAccount := suite.testStructs() + testList, _, testAccount := suite.testStructs() - dbLists, err := suite.db.GetListsForAccountID(context.Background(), testAccount.ID) + dbLists, err := suite.db.GetListsByAccountID(context.Background(), testAccount.ID) if err != nil { suite.FailNow(err.Error()) } @@ -129,20 +125,9 @@ func (suite *ListTestSuite) TestGetListsForAccountID() { suite.checkList(testList, dbLists[0]) } -func (suite *ListTestSuite) TestGetListEntries() { - testList, _ := suite.testStructs() - - dbListEntries, err := suite.db.GetListEntries(context.Background(), testList.ID, "", "", "", 0) - if err != nil { - suite.FailNow(err.Error()) - } - - suite.checkListEntries(testList.ListEntries, dbListEntries) -} - func (suite *ListTestSuite) TestPutList() { ctx := context.Background() - _, testAccount := suite.testStructs() + _, _, testAccount := suite.testStructs() testList := >smodel.List{ ID: "01H0J2PMYM54618VCV8Y8QYAT4", @@ -166,7 +151,7 @@ func (suite *ListTestSuite) TestPutList() { func (suite *ListTestSuite) TestUpdateList() { ctx := context.Background() - testList, _ := suite.testStructs() + testList, _, _ := suite.testStructs() // Get List in the cache first. dbList, err := suite.db.GetListByID(ctx, testList.ID) @@ -192,7 +177,7 @@ func (suite *ListTestSuite) TestUpdateList() { func (suite *ListTestSuite) TestDeleteList() { ctx := context.Background() - testList, _ := suite.testStructs() + testList, _, _ := suite.testStructs() // Get List in the cache first. if _, err := suite.db.GetListByID(ctx, testList.ID); err != nil { @@ -209,18 +194,19 @@ func (suite *ListTestSuite) TestDeleteList() { _, err := suite.db.GetListByID(ctx, testList.ID) suite.ErrorIs(err, db.ErrNoEntries) - // All entries belonging to this - // list should now be deleted. - listEntries, err := suite.db.GetListEntries(ctx, testList.ID, "", "", "", 0) - if err != nil { - suite.FailNow(err.Error()) - } - suite.Empty(listEntries) + // All accounts / follows attached to this + // list should now be return empty values. + listAccounts, err1 := suite.db.GetAccountsInList(ctx, testList.ID, nil) + listFollows, err2 := suite.db.GetFollowsInList(ctx, testList.ID, nil) + suite.NoError(err1) + suite.NoError(err2) + suite.Empty(listAccounts) + suite.Empty(listFollows) } func (suite *ListTestSuite) TestPutListEntries() { ctx := context.Background() - testList, _ := suite.testStructs() + testList, testEntries, _ := suite.testStructs() listEntries := []*gtsmodel.ListEntry{ { @@ -244,91 +230,58 @@ func (suite *ListTestSuite) TestPutListEntries() { suite.FailNow(err.Error()) } - // Add these entries to the test list, sort it again - // to reflect what we'd expect to get from the db. - testList.ListEntries = append(testList.ListEntries, listEntries...) - slices.SortFunc(testList.ListEntries, func(a, b *gtsmodel.ListEntry) int { - const k = -1 - switch { - case a.ID > b.ID: - return +k - case a.ID < b.ID: - return -k - default: - return 0 - } - }) - - // Now get all list entries from the db. - // Use barebones for this because the ones - // we just added will fail if we try to get - // the nonexistent follows. - dbListEntries, err := suite.db.GetListEntries( - gtscontext.SetBarebones(ctx), - testList.ID, - "", "", "", 0) - if err != nil { - suite.FailNow(err.Error()) - } - - suite.checkListEntries(testList.ListEntries, dbListEntries) + // Get all follows stored under this list ID, to ensure + // the newly added list entry follows are among these. + followIDs, err := suite.db.GetFollowIDsInList(ctx, testList.ID, nil) + suite.NoError(err) + suite.Len(followIDs, len(testEntries)+len(listEntries)) + suite.Contains(followIDs, "01H0MKNFRFZS8R9WV6DBX31Y03") + suite.Contains(followIDs, "01H0MKP6RR8VEHN3GVWFBP2H30") + suite.Contains(followIDs, "01H0MKQ0KA29C6NFJ27GTZD16J") } func (suite *ListTestSuite) TestDeleteListEntry() { ctx := context.Background() - testList, _ := suite.testStructs() - - // Get List in the cache first. - if _, err := suite.db.GetListByID(ctx, testList.ID); err != nil { - suite.FailNow(err.Error()) - } + testList, testEntries, _ := suite.testStructs() // Delete the first entry. - if err := suite.db.DeleteListEntry(ctx, testList.ListEntries[0].ID); err != nil { - suite.FailNow(err.Error()) - } - - // Get list from the db again. - dbList, err := suite.db.GetListByID(ctx, testList.ID) - if err != nil { + if err := suite.db.DeleteListEntry(ctx, + testEntries[0].ListID, + testEntries[0].FollowID, + ); err != nil { suite.FailNow(err.Error()) } - // Bodge the testlist as though - // we'd removed the first entry. - testList.ListEntries = testList.ListEntries[1:] - suite.checkList(testList, dbList) + // Get all follows stored under this list ID, to ensure + // the newly removed list entry follow is now missing. + followIDs, err := suite.db.GetFollowIDsInList(ctx, testList.ID, nil) + suite.NoError(err) + suite.Len(followIDs, len(testEntries)-1) + suite.NotContains(followIDs, testEntries[0].FollowID) } -func (suite *ListTestSuite) TestDeleteListEntriesForFollowID() { +func (suite *ListTestSuite) TestDeleteAllListEntriesByFollows() { ctx := context.Background() - testList, _ := suite.testStructs() - - // Get List in the cache first. - if _, err := suite.db.GetListByID(ctx, testList.ID); err != nil { - suite.FailNow(err.Error()) - } + testList, testEntries, _ := suite.testStructs() // Delete the first entry. - if err := suite.db.DeleteListEntriesForFollowID(ctx, testList.ListEntries[0].FollowID); err != nil { + if err := suite.db.DeleteAllListEntriesByFollows(ctx, + testEntries[0].FollowID, + ); err != nil { suite.FailNow(err.Error()) } - // Get list from the db again. - dbList, err := suite.db.GetListByID(ctx, testList.ID) - if err != nil { - suite.FailNow(err.Error()) - } - - // Bodge the testlist as though - // we'd removed the first entry. - testList.ListEntries = testList.ListEntries[1:] - suite.checkList(testList, dbList) + // Get all follows stored under this list ID, to ensure + // the newly removed list entry follow is now missing. + followIDs, err := suite.db.GetFollowIDsInList(ctx, testList.ID, nil) + suite.NoError(err) + suite.Len(followIDs, len(testEntries)-1) + suite.NotContains(followIDs, testEntries[0].FollowID) } func (suite *ListTestSuite) TestListIncludesAccount() { ctx := context.Background() - testList, _ := suite.testStructs() + testList, _, _ := suite.testStructs() for accountID, expected := range map[string]bool{ suite.testAccounts["admin_account"].ID: true, @@ -336,7 +289,7 @@ func (suite *ListTestSuite) TestListIncludesAccount() { suite.testAccounts["local_account_2"].ID: true, "01H7074GEZJ56J5C86PFB0V2CT": false, } { - includes, err := suite.db.ListIncludesAccount(ctx, testList.ID, accountID) + includes, err := suite.db.IsAccountInList(ctx, testList.ID, accountID) if err != nil { suite.FailNow(err.Error()) } |