diff options
Diffstat (limited to 'internal/db')
-rw-r--r-- | internal/db/bundb/search.go | 10 | ||||
-rw-r--r-- | internal/db/bundb/search_test.go | 13 | ||||
-rw-r--r-- | internal/db/search.go | 5 |
3 files changed, 22 insertions, 6 deletions
diff --git a/internal/db/bundb/search.go b/internal/db/bundb/search.go index f8ae529f7..e54cb78e7 100644 --- a/internal/db/bundb/search.go +++ b/internal/db/bundb/search.go @@ -266,8 +266,9 @@ func (s *searchDB) accountText(following bool) *bun.SelectQuery { // ORDER BY "status"."id" DESC LIMIT 10 func (s *searchDB) SearchForStatuses( ctx context.Context, - accountID string, + requestingAccountID string, query string, + fromAccountID string, maxID string, minID string, limit int, @@ -295,9 +296,12 @@ func (s *searchDB) SearchForStatuses( // accountID or replying to accountID. WhereGroup(" AND ", func(q *bun.SelectQuery) *bun.SelectQuery { return q. - Where("? = ?", bun.Ident("status.account_id"), accountID). - WhereOr("? = ?", bun.Ident("status.in_reply_to_account_id"), accountID) + Where("? = ?", bun.Ident("status.account_id"), requestingAccountID). + WhereOr("? = ?", bun.Ident("status.in_reply_to_account_id"), requestingAccountID) }) + if fromAccountID != "" { + q = q.Where("? = ?", bun.Ident("status.account_id"), fromAccountID) + } // Return only items with a LOWER id than maxID. if maxID == "" { diff --git a/internal/db/bundb/search_test.go b/internal/db/bundb/search_test.go index 75a2d8c8e..cf24b2881 100644 --- a/internal/db/bundb/search_test.go +++ b/internal/db/bundb/search_test.go @@ -107,11 +107,22 @@ func (suite *SearchTestSuite) TestSearchAccountsFossAny() { func (suite *SearchTestSuite) TestSearchStatuses() { testAccount := suite.testAccounts["local_account_1"] - statuses, err := suite.db.SearchForStatuses(context.Background(), testAccount.ID, "hello", "", "", 10, 0) + statuses, err := suite.db.SearchForStatuses(context.Background(), testAccount.ID, "hello", "", "", "", 10, 0) suite.NoError(err) suite.Len(statuses, 1) } +func (suite *SearchTestSuite) TestSearchStatusesFromAccount() { + testAccount := suite.testAccounts["local_account_1"] + fromAccount := suite.testAccounts["local_account_2"] + + statuses, err := suite.db.SearchForStatuses(context.Background(), testAccount.ID, "hi", fromAccount.ID, "", "", 10, 0) + suite.NoError(err) + if suite.Len(statuses, 1) { + suite.Equal(fromAccount.ID, statuses[0].AccountID) + } +} + func (suite *SearchTestSuite) TestSearchTags() { // Search with full tag string. tags, err := suite.db.SearchForTags(context.Background(), "welcome", "", "", 10, 0) diff --git a/internal/db/search.go b/internal/db/search.go index d2ffe4ad5..bdfd3a8e6 100644 --- a/internal/db/search.go +++ b/internal/db/search.go @@ -27,8 +27,9 @@ type Search interface { // SearchForAccounts uses the given query text to search for accounts that accountID follows. SearchForAccounts(ctx context.Context, accountID string, query string, maxID string, minID string, limit int, following bool, offset int) ([]*gtsmodel.Account, error) - // SearchForStatuses uses the given query text to search for statuses created by accountID, or in reply to accountID. - SearchForStatuses(ctx context.Context, accountID string, query string, maxID string, minID string, limit int, offset int) ([]*gtsmodel.Status, error) + // SearchForStatuses uses the given query text to search for statuses created by requestingAccountID, or in reply to requestingAccountID. + // If fromAccountID is used, the results are restricted to statuses created by fromAccountID. + SearchForStatuses(ctx context.Context, requestingAccountID string, query string, fromAccountID string, maxID string, minID string, limit int, offset int) ([]*gtsmodel.Status, error) // SearchForTags searches for tags that start with the given query text (case insensitive). SearchForTags(ctx context.Context, query string, maxID string, minID string, limit int, offset int) ([]*gtsmodel.Tag, error) |