diff options
Diffstat (limited to 'internal/processing/search.go')
-rw-r--r-- | internal/processing/search.go | 192 |
1 files changed, 104 insertions, 88 deletions
diff --git a/internal/processing/search.go b/internal/processing/search.go index bc2bc93d4..ca6cc42ce 100644 --- a/internal/processing/search.go +++ b/internal/processing/search.go @@ -27,8 +27,6 @@ import ( "codeberg.org/gruf/go-kv" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" - "github.com/superseriousbusiness/gotosocial/internal/config" - "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/federation/dereferencing" "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" @@ -38,11 +36,18 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/util" ) +// Implementation note: in this function, we tend to log errors +// at debug level rather than return them. This is because the +// search has a sort of fallthrough logic: if we can't get a result +// with x search, we should try with y search rather than returning. +// +// If we get to the end and still haven't found anything, even then +// we shouldn't return an error, just return an empty search result. +// +// The only exception to this is when we get a malformed query, in +// which case we return a bad request error so the user knows they +// did something funky. func (p *processor) SearchGet(ctx context.Context, authed *oauth.Auth, search *apimodel.SearchQuery) (*apimodel.SearchResult, gtserror.WithCode) { - l := log.WithFields(kv.Fields{ - {"query", search.Query}, - }...) - // tidy up the query and make sure it wasn't just spaces query := strings.TrimSpace(search.Query) if query == "" { @@ -50,6 +55,8 @@ func (p *processor) SearchGet(ctx context.Context, authed *oauth.Auth, search *a return nil, gtserror.NewErrorBadRequest(err, err.Error()) } + l := log.WithFields(kv.Fields{{"query", query}}...) + searchResult := &apimodel.SearchResult{ Accounts: []apimodel.Account{}, Statuses: []apimodel.Status{}, @@ -77,14 +84,20 @@ func (p *processor) SearchGet(ctx context.Context, authed *oauth.Auth, search *a } if username, domain, err := util.ExtractNamestringParts(maybeNamestring); err == nil { - l.Debugf("search term %s is a mention, looking it up...", maybeNamestring) - if foundAccount, err := p.searchAccountByMention(ctx, authed, username, domain, search.Resolve); err == nil && foundAccount != nil { - foundAccounts = append(foundAccounts, foundAccount) - foundOne = true - l.Debug("got an account by searching by mention") - } else if err != nil { - l.Debugf("error looking up account %s: %s", maybeNamestring, err) + l.Trace("search term is a mention, looking it up...") + foundAccount, err := p.searchAccountByMention(ctx, authed, username, domain, search.Resolve) + if err != nil { + var errNotRetrievable *dereferencing.ErrNotRetrievable + if !errors.As(err, &errNotRetrievable) { + // return a proper error only if it wasn't just not retrievable + return nil, gtserror.NewErrorInternalError(fmt.Errorf("error looking up account: %w", err)) + } + return searchResult, nil } + + foundAccounts = append(foundAccounts, foundAccount) + foundOne = true + l.Trace("got an account by searching by mention") } /* @@ -92,46 +105,95 @@ func (p *processor) SearchGet(ctx context.Context, authed *oauth.Auth, search *a check if the query is a URI with a recognizable scheme and dereference it */ if !foundOne { - if uri, err := url.Parse(query); err == nil && (uri.Scheme == "https" || uri.Scheme == "http") { - // don't attempt to resolve (ie., dereference) local accounts/statuses - resolve := search.Resolve - if uri.Host == config.GetHost() || uri.Host == config.GetAccountDomain() { - resolve = false - } - - // check if it's a status or an account - if foundStatus, err := p.searchStatusByURI(ctx, authed, uri, resolve); err == nil && foundStatus != nil { - foundStatuses = append(foundStatuses, foundStatus) - l.Debug("got a status by searching by URI") - } else if foundAccount, err := p.searchAccountByURI(ctx, authed, uri, resolve); err == nil && foundAccount != nil { - foundAccounts = append(foundAccounts, foundAccount) - l.Debug("got an account by searching by URI") + if uri, err := url.Parse(query); err == nil { + if uri.Scheme == "https" || uri.Scheme == "http" { + l.Trace("search term is a uri, looking it up...") + // check if it's a status... + foundStatus, err := p.searchStatusByURI(ctx, authed, uri) + if err != nil { + var ( + errNotRetrievable *dereferencing.ErrNotRetrievable + errWrongType *dereferencing.ErrWrongType + ) + if !errors.As(err, &errNotRetrievable) && !errors.As(err, &errWrongType) { + return nil, gtserror.NewErrorInternalError(fmt.Errorf("error looking up status: %w", err)) + } + } else { + foundStatuses = append(foundStatuses, foundStatus) + foundOne = true + l.Trace("got a status by searching by URI") + } + + // ... or an account + if !foundOne { + foundAccount, err := p.searchAccountByURI(ctx, authed, uri, search.Resolve) + if err != nil { + var ( + errNotRetrievable *dereferencing.ErrNotRetrievable + errWrongType *dereferencing.ErrWrongType + ) + if !errors.As(err, &errNotRetrievable) && !errors.As(err, &errWrongType) { + return nil, gtserror.NewErrorInternalError(fmt.Errorf("error looking up account: %w", err)) + } + } else { + foundAccounts = append(foundAccounts, foundAccount) + foundOne = true + l.Trace("got an account by searching by URI") + } + } } } } + if !foundOne { + // we got nothing, we can return early + l.Trace("found nothing, returning") + return searchResult, nil + } + /* FROM HERE ON we have our search results, it's just a matter of filtering them according to what this user is allowed to see, and then converting them into our frontend format. */ for _, foundAccount := range foundAccounts { // make sure there's no block in either direction between the account and the requester - if blocked, err := p.db.IsBlocked(ctx, authed.Account.ID, foundAccount.ID, true); err == nil && !blocked { - // all good, convert it and add it to the results - if apiAcct, err := p.tc.AccountToAPIAccountPublic(ctx, foundAccount); err == nil && apiAcct != nil { - searchResult.Accounts = append(searchResult.Accounts, *apiAcct) - } + blocked, err := p.db.IsBlocked(ctx, authed.Account.ID, foundAccount.ID, true) + if err != nil { + err = fmt.Errorf("SearchGet: error checking block between %s and %s: %s", authed.Account.ID, foundAccount.ID, err) + return nil, gtserror.NewErrorInternalError(err) + } + + if blocked { + l.Tracef("block exists between %s and %s, skipping this result", authed.Account.ID, foundAccount.ID) + continue + } + + apiAcct, err := p.tc.AccountToAPIAccountPublic(ctx, foundAccount) + if err != nil { + err = fmt.Errorf("SearchGet: error converting account %s to api account: %s", foundAccount.ID, err) + return nil, gtserror.NewErrorInternalError(err) } + + searchResult.Accounts = append(searchResult.Accounts, *apiAcct) } for _, foundStatus := range foundStatuses { - if visible, err := p.filter.StatusVisible(ctx, foundStatus, authed.Account); !visible || err != nil { + // make sure each found status is visible to the requester + visible, err := p.filter.StatusVisible(ctx, foundStatus, authed.Account) + if err != nil { + err = fmt.Errorf("SearchGet: error checking visibility of status %s for account %s: %s", foundStatus.ID, authed.Account.ID, err) + return nil, gtserror.NewErrorInternalError(err) + } + + if !visible { + l.Tracef("status %s is not visible to account %s, skipping this result", foundStatus.ID, authed.Account.ID) continue } apiStatus, err := p.tc.StatusToAPIStatus(ctx, foundStatus, authed.Account) if err != nil { - continue + err = fmt.Errorf("SearchGet: error converting status %s to api status: %s", foundStatus.ID, err) + return nil, gtserror.NewErrorInternalError(err) } searchResult.Statuses = append(searchResult.Statuses, *apiStatus) @@ -140,58 +202,22 @@ func (p *processor) SearchGet(ctx context.Context, authed *oauth.Auth, search *a return searchResult, nil } -func (p *processor) searchStatusByURI(ctx context.Context, authed *oauth.Auth, uri *url.URL, resolve bool) (*gtsmodel.Status, error) { - // Calculate URI string once - uriStr := uri.String() - - // Look for status locally (by URI), we only accept "not found" errors. - status, err := p.db.GetStatusByURI(ctx, uriStr) - if err != nil && !errors.Is(err, db.ErrNoEntries) { - return nil, fmt.Errorf("searchStatusByURI: error fetching status %q: %v", uriStr, err) - } else if err == nil { - return status, nil - } - - // Again, look for status locally (by URL), we only accept "not found" errors. - status, err = p.db.GetStatusByURL(ctx, uriStr) - if err != nil && !errors.Is(err, db.ErrNoEntries) { - return nil, fmt.Errorf("searchStatusByURI: error fetching status %q: %v", uriStr, err) - } else if err == nil { - return status, nil +func (p *processor) searchStatusByURI(ctx context.Context, authed *oauth.Auth, uri *url.URL) (*gtsmodel.Status, error) { + status, statusable, err := p.federator.GetStatus(transport.WithFastfail(ctx), authed.Account.Username, uri, true, true) + if err != nil { + return nil, err } - if resolve { - // This is a non-local status and we're allowed to resolve, so dereference it - status, statusable, err := p.federator.GetRemoteStatus(transport.WithFastfail(ctx), authed.Account.Username, uri, true, true) - if err != nil { - return nil, fmt.Errorf("searchStatusByURI: error fetching remote status %q: %v", uriStr, err) - } - + if !*status.Local && statusable != nil { // Attempt to dereference the status thread while we are here p.federator.DereferenceRemoteThread(transport.WithFastfail(ctx), authed.Account.Username, uri, status, statusable) } - return nil, nil + return status, nil } func (p *processor) searchAccountByURI(ctx context.Context, authed *oauth.Auth, uri *url.URL, resolve bool) (*gtsmodel.Account, error) { - // it might be a web url like http://example.org/@user instead - // of an AP uri like http://example.org/users/user, check first - if maybeAccount, err := p.db.GetAccountByURL(ctx, uri.String()); err == nil { - return maybeAccount, nil - } - - if uri.Host == config.GetHost() || uri.Host == config.GetAccountDomain() { - // this is a local account; if we don't have it now then - // we should just bail instead of trying to get it remote - if maybeAccount, err := p.db.GetAccountByURI(ctx, uri.String()); err == nil { - return maybeAccount, nil - } - return nil, nil - } - - // we don't have it yet, try to find it remotely - return p.federator.GetRemoteAccount(transport.WithFastfail(ctx), dereferencing.GetRemoteAccountParams{ + return p.federator.GetAccount(transport.WithFastfail(ctx), dereferencing.GetAccountParams{ RequestingUsername: authed.Account.Username, RemoteAccountID: uri, Blocking: true, @@ -200,17 +226,7 @@ func (p *processor) searchAccountByURI(ctx context.Context, authed *oauth.Auth, } func (p *processor) searchAccountByMention(ctx context.Context, authed *oauth.Auth, username string, domain string, resolve bool) (*gtsmodel.Account, error) { - // if it's a local account we can skip a whole bunch of stuff - if domain == config.GetHost() || domain == config.GetAccountDomain() || domain == "" { - maybeAcct, err := p.db.GetAccountByUsernameDomain(ctx, username, "") - if err == nil || err == db.ErrNoEntries { - return maybeAcct, nil - } - return nil, fmt.Errorf("searchAccountByMention: error getting local account by username: %s", err) - } - - // we don't have it yet, try to find it remotely - return p.federator.GetRemoteAccount(transport.WithFastfail(ctx), dereferencing.GetRemoteAccountParams{ + return p.federator.GetAccount(transport.WithFastfail(ctx), dereferencing.GetAccountParams{ RequestingUsername: authed.Account.Username, RemoteAccountUsername: username, RemoteAccountHost: domain, |