summaryrefslogtreecommitdiff
path: root/internal/processing/search/get.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/processing/search/get.go')
-rw-r--r--internal/processing/search/get.go98
1 files changed, 89 insertions, 9 deletions
diff --git a/internal/processing/search/get.go b/internal/processing/search/get.go
index e4839a179..d1462cf53 100644
--- a/internal/processing/search/get.go
+++ b/internal/processing/search/get.go
@@ -62,14 +62,15 @@ func (p *Processor) Get(
req *apimodel.SearchRequest,
) (*apimodel.SearchResult, gtserror.WithCode) {
var (
- maxID = req.MaxID
- minID = req.MinID
- limit = req.Limit
- offset = req.Offset
- query = strings.TrimSpace(req.Query) // Trim trailing/leading whitespace.
- queryType = strings.TrimSpace(strings.ToLower(req.QueryType)) // Trim trailing/leading whitespace; convert to lowercase.
- resolve = req.Resolve
- following = req.Following
+ maxID = req.MaxID
+ minID = req.MinID
+ limit = req.Limit
+ offset = req.Offset
+ query = strings.TrimSpace(req.Query) // Trim trailing/leading whitespace.
+ queryType = strings.TrimSpace(strings.ToLower(req.QueryType)) // Trim trailing/leading whitespace; convert to lowercase.
+ resolve = req.Resolve
+ following = req.Following
+ fromAccountID = req.AccountID
// Include instance accounts in the first
// parts of this search. This will be
@@ -114,6 +115,7 @@ func (p *Processor) Get(
{"queryType", queryType},
{"resolve", resolve},
{"following", following},
+ {"fromAccountID", fromAccountID},
}...).
Debugf("beginning search")
@@ -309,6 +311,7 @@ func (p *Processor) Get(
query,
queryType,
following,
+ fromAccountID,
appendAccount,
appendStatus,
); err != nil && !errors.Is(err, db.ErrNoEntries) {
@@ -743,6 +746,7 @@ func (p *Processor) byText(
query string,
queryType string,
following bool,
+ fromAccountID string,
appendAccount func(*gtsmodel.Account),
appendStatus func(*gtsmodel.Status),
) error {
@@ -779,6 +783,7 @@ func (p *Processor) byText(
limit,
offset,
query,
+ fromAccountID,
appendStatus,
); err != nil {
return err
@@ -826,12 +831,30 @@ func (p *Processor) statusesByText(
limit int,
offset int,
query string,
+ fromAccountID string,
appendStatus func(*gtsmodel.Status),
) error {
+ parsed, err := p.parseQuery(ctx, query)
+ if err != nil {
+ return err
+ }
+ query = parsed.query
+ // If the owning account for statuses was not provided as the account_id query parameter,
+ // it may still have been provided as a search operator in the query string.
+ if fromAccountID == "" {
+ fromAccountID = parsed.fromAccountID
+ }
+
statuses, err := p.state.DB.SearchForStatuses(
ctx,
requestingAccountID,
- query, maxID, minID, limit, offset)
+ query,
+ fromAccountID,
+ maxID,
+ minID,
+ limit,
+ offset,
+ )
if err != nil && !errors.Is(err, db.ErrNoEntries) {
return gtserror.Newf("error checking database for statuses using text %s: %w", query, err)
}
@@ -842,3 +865,60 @@ func (p *Processor) statusesByText(
return nil
}
+
+// parsedQuery represents the results of parsing the search operator terms within a query.
+type parsedQuery struct {
+ // query is the original search query text with operator terms removed.
+ query string
+ // fromAccountID is the account from a successfully resolved `from:` operator, if present.
+ fromAccountID string
+}
+
+// parseQuery parses query text and handles any search operator terms present.
+func (p *Processor) parseQuery(ctx context.Context, query string) (parsed parsedQuery, err error) {
+ queryPartSeparator := " "
+ queryParts := strings.Split(query, queryPartSeparator)
+ nonOperatorQueryParts := make([]string, 0, len(queryParts))
+ for _, queryPart := range queryParts {
+ if arg, hasPrefix := strings.CutPrefix(queryPart, "from:"); hasPrefix {
+ parsed.fromAccountID, err = p.parseFromOperatorArg(ctx, arg)
+ if err != nil {
+ return
+ }
+ } else {
+ nonOperatorQueryParts = append(nonOperatorQueryParts, queryPart)
+ }
+ }
+ parsed.query = strings.Join(nonOperatorQueryParts, queryPartSeparator)
+ return
+}
+
+// parseFromOperatorArg attempts to parse the from: operator's argument as an account name,
+// and returns the account ID if possible. Allows specifying an account name with or without a leading @.
+func (p *Processor) parseFromOperatorArg(ctx context.Context, namestring string) (string, error) {
+ if namestring == "" {
+ return "", gtserror.New(
+ "the 'from:' search operator requires an account name, but it wasn't provided",
+ )
+ }
+ if namestring[0] != '@' {
+ namestring = "@" + namestring
+ }
+
+ username, domain, err := util.ExtractNamestringParts(namestring)
+ if err != nil {
+ return "", gtserror.Newf(
+ "the 'from:' search operator couldn't parse its argument as an account name: %w",
+ err,
+ )
+ }
+ account, err := p.state.DB.GetAccountByUsernameDomain(gtscontext.SetBarebones(ctx), username, domain)
+ if err != nil {
+ return "", gtserror.Newf(
+ "the 'from:' search operator couldn't find the requested account name: %w",
+ err,
+ )
+ }
+
+ return account.ID, nil
+}