diff options
Diffstat (limited to 'internal/processing')
| -rw-r--r-- | internal/processing/search/get.go | 98 | 
1 files changed, 89 insertions, 9 deletions
| diff --git a/internal/processing/search/get.go b/internal/processing/search/get.go index e4839a179..d1462cf53 100644 --- a/internal/processing/search/get.go +++ b/internal/processing/search/get.go @@ -62,14 +62,15 @@ func (p *Processor) Get(  	req *apimodel.SearchRequest,  ) (*apimodel.SearchResult, gtserror.WithCode) {  	var ( -		maxID     = req.MaxID -		minID     = req.MinID -		limit     = req.Limit -		offset    = req.Offset -		query     = strings.TrimSpace(req.Query)                      // Trim trailing/leading whitespace. -		queryType = strings.TrimSpace(strings.ToLower(req.QueryType)) // Trim trailing/leading whitespace; convert to lowercase. -		resolve   = req.Resolve -		following = req.Following +		maxID         = req.MaxID +		minID         = req.MinID +		limit         = req.Limit +		offset        = req.Offset +		query         = strings.TrimSpace(req.Query)                      // Trim trailing/leading whitespace. +		queryType     = strings.TrimSpace(strings.ToLower(req.QueryType)) // Trim trailing/leading whitespace; convert to lowercase. +		resolve       = req.Resolve +		following     = req.Following +		fromAccountID = req.AccountID  		// Include instance accounts in the first  		// parts of this search. This will be @@ -114,6 +115,7 @@ func (p *Processor) Get(  			{"queryType", queryType},  			{"resolve", resolve},  			{"following", following}, +			{"fromAccountID", fromAccountID},  		}...).  		Debugf("beginning search") @@ -309,6 +311,7 @@ func (p *Processor) Get(  		query,  		queryType,  		following, +		fromAccountID,  		appendAccount,  		appendStatus,  	); err != nil && !errors.Is(err, db.ErrNoEntries) { @@ -743,6 +746,7 @@ func (p *Processor) byText(  	query string,  	queryType string,  	following bool, +	fromAccountID string,  	appendAccount func(*gtsmodel.Account),  	appendStatus func(*gtsmodel.Status),  ) error { @@ -779,6 +783,7 @@ func (p *Processor) byText(  			limit,  			offset,  			query, +			fromAccountID,  			appendStatus,  		); err != nil {  			return err @@ -826,12 +831,30 @@ func (p *Processor) statusesByText(  	limit int,  	offset int,  	query string, +	fromAccountID string,  	appendStatus func(*gtsmodel.Status),  ) error { +	parsed, err := p.parseQuery(ctx, query) +	if err != nil { +		return err +	} +	query = parsed.query +	// If the owning account for statuses was not provided as the account_id query parameter, +	// it may still have been provided as a search operator in the query string. +	if fromAccountID == "" { +		fromAccountID = parsed.fromAccountID +	} +  	statuses, err := p.state.DB.SearchForStatuses(  		ctx,  		requestingAccountID, -		query, maxID, minID, limit, offset) +		query, +		fromAccountID, +		maxID, +		minID, +		limit, +		offset, +	)  	if err != nil && !errors.Is(err, db.ErrNoEntries) {  		return gtserror.Newf("error checking database for statuses using text %s: %w", query, err)  	} @@ -842,3 +865,60 @@ func (p *Processor) statusesByText(  	return nil  } + +// parsedQuery represents the results of parsing the search operator terms within a query. +type parsedQuery struct { +	// query is the original search query text with operator terms removed. +	query string +	// fromAccountID is the account from a successfully resolved `from:` operator, if present. +	fromAccountID string +} + +// parseQuery parses query text and handles any search operator terms present. +func (p *Processor) parseQuery(ctx context.Context, query string) (parsed parsedQuery, err error) { +	queryPartSeparator := " " +	queryParts := strings.Split(query, queryPartSeparator) +	nonOperatorQueryParts := make([]string, 0, len(queryParts)) +	for _, queryPart := range queryParts { +		if arg, hasPrefix := strings.CutPrefix(queryPart, "from:"); hasPrefix { +			parsed.fromAccountID, err = p.parseFromOperatorArg(ctx, arg) +			if err != nil { +				return +			} +		} else { +			nonOperatorQueryParts = append(nonOperatorQueryParts, queryPart) +		} +	} +	parsed.query = strings.Join(nonOperatorQueryParts, queryPartSeparator) +	return +} + +// parseFromOperatorArg attempts to parse the from: operator's argument as an account name, +// and returns the account ID if possible. Allows specifying an account name with or without a leading @. +func (p *Processor) parseFromOperatorArg(ctx context.Context, namestring string) (string, error) { +	if namestring == "" { +		return "", gtserror.New( +			"the 'from:' search operator requires an account name, but it wasn't provided", +		) +	} +	if namestring[0] != '@' { +		namestring = "@" + namestring +	} + +	username, domain, err := util.ExtractNamestringParts(namestring) +	if err != nil { +		return "", gtserror.Newf( +			"the 'from:' search operator couldn't parse its argument as an account name: %w", +			err, +		) +	} +	account, err := p.state.DB.GetAccountByUsernameDomain(gtscontext.SetBarebones(ctx), username, domain) +	if err != nil { +		return "", gtserror.Newf( +			"the 'from:' search operator couldn't find the requested account name: %w", +			err, +		) +	} + +	return account.ID, nil +} | 
