summaryrefslogtreecommitdiff
path: root/internal/db/bundb/list_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/db/bundb/list_test.go')
-rw-r--r--internal/db/bundb/list_test.go145
1 files changed, 49 insertions, 96 deletions
diff --git a/internal/db/bundb/list_test.go b/internal/db/bundb/list_test.go
index 9c5fb2c76..3952a87c0 100644
--- a/internal/db/bundb/list_test.go
+++ b/internal/db/bundb/list_test.go
@@ -24,7 +24,6 @@ import (
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/db"
- "github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
@@ -32,7 +31,7 @@ type ListTestSuite struct {
BunDBStandardTestSuite
}
-func (suite *ListTestSuite) testStructs() (*gtsmodel.List, *gtsmodel.Account) {
+func (suite *ListTestSuite) testStructs() (*gtsmodel.List, []*gtsmodel.ListEntry, *gtsmodel.Account) {
testList := &gtsmodel.List{}
*testList = *suite.testLists["local_account_1_list_1"]
@@ -55,12 +54,10 @@ func (suite *ListTestSuite) testStructs() (*gtsmodel.List, *gtsmodel.Account) {
}
})
- testList.ListEntries = entries
-
testAccount := &gtsmodel.Account{}
*testAccount = *suite.testAccounts["local_account_1"]
- return testList, testAccount
+ return testList, entries, testAccount
}
func (suite *ListTestSuite) checkList(expected *gtsmodel.List, actual *gtsmodel.List) {
@@ -103,7 +100,7 @@ func (suite *ListTestSuite) checkListEntries(expected []*gtsmodel.ListEntry, act
}
func (suite *ListTestSuite) TestGetListByID() {
- testList, _ := suite.testStructs()
+ testList, _, _ := suite.testStructs()
dbList, err := suite.db.GetListByID(context.Background(), testList.ID)
if err != nil {
@@ -111,13 +108,12 @@ func (suite *ListTestSuite) TestGetListByID() {
}
suite.checkList(testList, dbList)
- suite.checkListEntries(testList.ListEntries, dbList.ListEntries)
}
func (suite *ListTestSuite) TestGetListsForAccountID() {
- testList, testAccount := suite.testStructs()
+ testList, _, testAccount := suite.testStructs()
- dbLists, err := suite.db.GetListsForAccountID(context.Background(), testAccount.ID)
+ dbLists, err := suite.db.GetListsByAccountID(context.Background(), testAccount.ID)
if err != nil {
suite.FailNow(err.Error())
}
@@ -129,20 +125,9 @@ func (suite *ListTestSuite) TestGetListsForAccountID() {
suite.checkList(testList, dbLists[0])
}
-func (suite *ListTestSuite) TestGetListEntries() {
- testList, _ := suite.testStructs()
-
- dbListEntries, err := suite.db.GetListEntries(context.Background(), testList.ID, "", "", "", 0)
- if err != nil {
- suite.FailNow(err.Error())
- }
-
- suite.checkListEntries(testList.ListEntries, dbListEntries)
-}
-
func (suite *ListTestSuite) TestPutList() {
ctx := context.Background()
- _, testAccount := suite.testStructs()
+ _, _, testAccount := suite.testStructs()
testList := &gtsmodel.List{
ID: "01H0J2PMYM54618VCV8Y8QYAT4",
@@ -166,7 +151,7 @@ func (suite *ListTestSuite) TestPutList() {
func (suite *ListTestSuite) TestUpdateList() {
ctx := context.Background()
- testList, _ := suite.testStructs()
+ testList, _, _ := suite.testStructs()
// Get List in the cache first.
dbList, err := suite.db.GetListByID(ctx, testList.ID)
@@ -192,7 +177,7 @@ func (suite *ListTestSuite) TestUpdateList() {
func (suite *ListTestSuite) TestDeleteList() {
ctx := context.Background()
- testList, _ := suite.testStructs()
+ testList, _, _ := suite.testStructs()
// Get List in the cache first.
if _, err := suite.db.GetListByID(ctx, testList.ID); err != nil {
@@ -209,18 +194,19 @@ func (suite *ListTestSuite) TestDeleteList() {
_, err := suite.db.GetListByID(ctx, testList.ID)
suite.ErrorIs(err, db.ErrNoEntries)
- // All entries belonging to this
- // list should now be deleted.
- listEntries, err := suite.db.GetListEntries(ctx, testList.ID, "", "", "", 0)
- if err != nil {
- suite.FailNow(err.Error())
- }
- suite.Empty(listEntries)
+ // All accounts / follows attached to this
+ // list should now be return empty values.
+ listAccounts, err1 := suite.db.GetAccountsInList(ctx, testList.ID, nil)
+ listFollows, err2 := suite.db.GetFollowsInList(ctx, testList.ID, nil)
+ suite.NoError(err1)
+ suite.NoError(err2)
+ suite.Empty(listAccounts)
+ suite.Empty(listFollows)
}
func (suite *ListTestSuite) TestPutListEntries() {
ctx := context.Background()
- testList, _ := suite.testStructs()
+ testList, testEntries, _ := suite.testStructs()
listEntries := []*gtsmodel.ListEntry{
{
@@ -244,91 +230,58 @@ func (suite *ListTestSuite) TestPutListEntries() {
suite.FailNow(err.Error())
}
- // Add these entries to the test list, sort it again
- // to reflect what we'd expect to get from the db.
- testList.ListEntries = append(testList.ListEntries, listEntries...)
- slices.SortFunc(testList.ListEntries, func(a, b *gtsmodel.ListEntry) int {
- const k = -1
- switch {
- case a.ID > b.ID:
- return +k
- case a.ID < b.ID:
- return -k
- default:
- return 0
- }
- })
-
- // Now get all list entries from the db.
- // Use barebones for this because the ones
- // we just added will fail if we try to get
- // the nonexistent follows.
- dbListEntries, err := suite.db.GetListEntries(
- gtscontext.SetBarebones(ctx),
- testList.ID,
- "", "", "", 0)
- if err != nil {
- suite.FailNow(err.Error())
- }
-
- suite.checkListEntries(testList.ListEntries, dbListEntries)
+ // Get all follows stored under this list ID, to ensure
+ // the newly added list entry follows are among these.
+ followIDs, err := suite.db.GetFollowIDsInList(ctx, testList.ID, nil)
+ suite.NoError(err)
+ suite.Len(followIDs, len(testEntries)+len(listEntries))
+ suite.Contains(followIDs, "01H0MKNFRFZS8R9WV6DBX31Y03")
+ suite.Contains(followIDs, "01H0MKP6RR8VEHN3GVWFBP2H30")
+ suite.Contains(followIDs, "01H0MKQ0KA29C6NFJ27GTZD16J")
}
func (suite *ListTestSuite) TestDeleteListEntry() {
ctx := context.Background()
- testList, _ := suite.testStructs()
-
- // Get List in the cache first.
- if _, err := suite.db.GetListByID(ctx, testList.ID); err != nil {
- suite.FailNow(err.Error())
- }
+ testList, testEntries, _ := suite.testStructs()
// Delete the first entry.
- if err := suite.db.DeleteListEntry(ctx, testList.ListEntries[0].ID); err != nil {
- suite.FailNow(err.Error())
- }
-
- // Get list from the db again.
- dbList, err := suite.db.GetListByID(ctx, testList.ID)
- if err != nil {
+ if err := suite.db.DeleteListEntry(ctx,
+ testEntries[0].ListID,
+ testEntries[0].FollowID,
+ ); err != nil {
suite.FailNow(err.Error())
}
- // Bodge the testlist as though
- // we'd removed the first entry.
- testList.ListEntries = testList.ListEntries[1:]
- suite.checkList(testList, dbList)
+ // Get all follows stored under this list ID, to ensure
+ // the newly removed list entry follow is now missing.
+ followIDs, err := suite.db.GetFollowIDsInList(ctx, testList.ID, nil)
+ suite.NoError(err)
+ suite.Len(followIDs, len(testEntries)-1)
+ suite.NotContains(followIDs, testEntries[0].FollowID)
}
-func (suite *ListTestSuite) TestDeleteListEntriesForFollowID() {
+func (suite *ListTestSuite) TestDeleteAllListEntriesByFollows() {
ctx := context.Background()
- testList, _ := suite.testStructs()
-
- // Get List in the cache first.
- if _, err := suite.db.GetListByID(ctx, testList.ID); err != nil {
- suite.FailNow(err.Error())
- }
+ testList, testEntries, _ := suite.testStructs()
// Delete the first entry.
- if err := suite.db.DeleteListEntriesForFollowID(ctx, testList.ListEntries[0].FollowID); err != nil {
+ if err := suite.db.DeleteAllListEntriesByFollows(ctx,
+ testEntries[0].FollowID,
+ ); err != nil {
suite.FailNow(err.Error())
}
- // Get list from the db again.
- dbList, err := suite.db.GetListByID(ctx, testList.ID)
- if err != nil {
- suite.FailNow(err.Error())
- }
-
- // Bodge the testlist as though
- // we'd removed the first entry.
- testList.ListEntries = testList.ListEntries[1:]
- suite.checkList(testList, dbList)
+ // Get all follows stored under this list ID, to ensure
+ // the newly removed list entry follow is now missing.
+ followIDs, err := suite.db.GetFollowIDsInList(ctx, testList.ID, nil)
+ suite.NoError(err)
+ suite.Len(followIDs, len(testEntries)-1)
+ suite.NotContains(followIDs, testEntries[0].FollowID)
}
func (suite *ListTestSuite) TestListIncludesAccount() {
ctx := context.Background()
- testList, _ := suite.testStructs()
+ testList, _, _ := suite.testStructs()
for accountID, expected := range map[string]bool{
suite.testAccounts["admin_account"].ID: true,
@@ -336,7 +289,7 @@ func (suite *ListTestSuite) TestListIncludesAccount() {
suite.testAccounts["local_account_2"].ID: true,
"01H7074GEZJ56J5C86PFB0V2CT": false,
} {
- includes, err := suite.db.ListIncludesAccount(ctx, testList.ID, accountID)
+ includes, err := suite.db.IsAccountInList(ctx, testList.ID, accountID)
if err != nil {
suite.FailNow(err.Error())
}