diff options
Diffstat (limited to 'internal/db')
| -rw-r--r-- | internal/db/bundb/search.go | 10 | ||||
| -rw-r--r-- | internal/db/bundb/search_test.go | 13 | ||||
| -rw-r--r-- | internal/db/search.go | 5 | 
3 files changed, 22 insertions, 6 deletions
| diff --git a/internal/db/bundb/search.go b/internal/db/bundb/search.go index f8ae529f7..e54cb78e7 100644 --- a/internal/db/bundb/search.go +++ b/internal/db/bundb/search.go @@ -266,8 +266,9 @@ func (s *searchDB) accountText(following bool) *bun.SelectQuery {  //	ORDER BY "status"."id" DESC LIMIT 10  func (s *searchDB) SearchForStatuses(  	ctx context.Context, -	accountID string, +	requestingAccountID string,  	query string, +	fromAccountID string,  	maxID string,  	minID string,  	limit int, @@ -295,9 +296,12 @@ func (s *searchDB) SearchForStatuses(  		// accountID or replying to accountID.  		WhereGroup(" AND ", func(q *bun.SelectQuery) *bun.SelectQuery {  			return q. -				Where("? = ?", bun.Ident("status.account_id"), accountID). -				WhereOr("? = ?", bun.Ident("status.in_reply_to_account_id"), accountID) +				Where("? = ?", bun.Ident("status.account_id"), requestingAccountID). +				WhereOr("? = ?", bun.Ident("status.in_reply_to_account_id"), requestingAccountID)  		}) +	if fromAccountID != "" { +		q = q.Where("? = ?", bun.Ident("status.account_id"), fromAccountID) +	}  	// Return only items with a LOWER id than maxID.  	if maxID == "" { diff --git a/internal/db/bundb/search_test.go b/internal/db/bundb/search_test.go index 75a2d8c8e..cf24b2881 100644 --- a/internal/db/bundb/search_test.go +++ b/internal/db/bundb/search_test.go @@ -107,11 +107,22 @@ func (suite *SearchTestSuite) TestSearchAccountsFossAny() {  func (suite *SearchTestSuite) TestSearchStatuses() {  	testAccount := suite.testAccounts["local_account_1"] -	statuses, err := suite.db.SearchForStatuses(context.Background(), testAccount.ID, "hello", "", "", 10, 0) +	statuses, err := suite.db.SearchForStatuses(context.Background(), testAccount.ID, "hello", "", "", "", 10, 0)  	suite.NoError(err)  	suite.Len(statuses, 1)  } +func (suite *SearchTestSuite) TestSearchStatusesFromAccount() { +	testAccount := suite.testAccounts["local_account_1"] +	fromAccount := suite.testAccounts["local_account_2"] + +	statuses, err := suite.db.SearchForStatuses(context.Background(), testAccount.ID, "hi", fromAccount.ID, "", "", 10, 0) +	suite.NoError(err) +	if suite.Len(statuses, 1) { +		suite.Equal(fromAccount.ID, statuses[0].AccountID) +	} +} +  func (suite *SearchTestSuite) TestSearchTags() {  	// Search with full tag string.  	tags, err := suite.db.SearchForTags(context.Background(), "welcome", "", "", 10, 0) diff --git a/internal/db/search.go b/internal/db/search.go index d2ffe4ad5..bdfd3a8e6 100644 --- a/internal/db/search.go +++ b/internal/db/search.go @@ -27,8 +27,9 @@ type Search interface {  	// SearchForAccounts uses the given query text to search for accounts that accountID follows.  	SearchForAccounts(ctx context.Context, accountID string, query string, maxID string, minID string, limit int, following bool, offset int) ([]*gtsmodel.Account, error) -	// SearchForStatuses uses the given query text to search for statuses created by accountID, or in reply to accountID. -	SearchForStatuses(ctx context.Context, accountID string, query string, maxID string, minID string, limit int, offset int) ([]*gtsmodel.Status, error) +	// SearchForStatuses uses the given query text to search for statuses created by requestingAccountID, or in reply to requestingAccountID. +	// If fromAccountID is used, the results are restricted to statuses created by fromAccountID. +	SearchForStatuses(ctx context.Context, requestingAccountID string, query string, fromAccountID string, maxID string, minID string, limit int, offset int) ([]*gtsmodel.Status, error)  	// SearchForTags searches for tags that start with the given query text (case insensitive).  	SearchForTags(ctx context.Context, query string, maxID string, minID string, limit int, offset int) ([]*gtsmodel.Tag, error) | 
