diff options
| author | 2022-08-20 22:47:19 +0200 | |
|---|---|---|
| committer | 2022-08-20 21:47:19 +0100 | |
| commit | 570fa7c3598118ded6df7ced0a5326f54e7a43e2 (patch) | |
| tree | 9575a6f3016c73b7109c88f68a2a512981cf19e4 /internal | |
| parent | [docs] Textual updates on markdown files (#756) (diff) | |
| download | gotosocial-570fa7c3598118ded6df7ced0a5326f54e7a43e2.tar.xz | |
[bugfix] Fix potential dereference of accounts on own instance (#757)
* add GetAccountByUsernameDomain
* simplify search
* add escape to not deref accounts on own domain
* check if local + we have account by ap uri
Diffstat (limited to 'internal')
| -rw-r--r-- | internal/cache/account.go | 15 | ||||
| -rw-r--r-- | internal/cache/account_test.go | 4 | ||||
| -rw-r--r-- | internal/db/account.go | 3 | ||||
| -rw-r--r-- | internal/db/bundb/account.go | 20 | ||||
| -rw-r--r-- | internal/db/bundb/account_test.go | 12 | ||||
| -rw-r--r-- | internal/federation/dereferencing/account.go | 96 | ||||
| -rw-r--r-- | internal/federation/dereferencing/account_test.go | 102 | ||||
| -rw-r--r-- | internal/processing/search.go | 83 | 
8 files changed, 243 insertions, 92 deletions
diff --git a/internal/cache/account.go b/internal/cache/account.go index ac67b5d07..1f958ebb8 100644 --- a/internal/cache/account.go +++ b/internal/cache/account.go @@ -37,6 +37,7 @@ func NewAccountCache() *AccountCache {  		RegisterLookups: func(lm *cache.LookupMap[string, string]) {  			lm.RegisterLookup("uri")  			lm.RegisterLookup("url") +			lm.RegisterLookup("usernamedomain")  		},  		AddLookups: func(lm *cache.LookupMap[string, string], acc *gtsmodel.Account) { @@ -46,6 +47,7 @@ func NewAccountCache() *AccountCache {  			if url := acc.URL; url != "" {  				lm.Set("url", url, acc.ID)  			} +			lm.Set("usernamedomain", usernameDomainKey(acc.Username, acc.Domain), acc.ID)  		},  		DeleteLookups: func(lm *cache.LookupMap[string, string], acc *gtsmodel.Account) { @@ -55,6 +57,7 @@ func NewAccountCache() *AccountCache {  			if url := acc.URL; url != "" {  				lm.Delete("url", url)  			} +			lm.Delete("usernamedomain", usernameDomainKey(acc.Username, acc.Domain))  		},  	})  	c.cache.SetTTL(time.Minute*5, false) @@ -77,6 +80,10 @@ func (c *AccountCache) GetByURI(uri string) (*gtsmodel.Account, bool) {  	return c.cache.GetBy("uri", uri)  } +func (c *AccountCache) GetByUsernameDomain(username string, domain string) (*gtsmodel.Account, bool) { +	return c.cache.GetBy("usernamedomain", usernameDomainKey(username, domain)) +} +  // Put places a account in the cache, ensuring that the object place is a copy for thread-safety  func (c *AccountCache) Put(account *gtsmodel.Account) {  	if account == nil || account.ID == "" { @@ -135,3 +142,11 @@ func copyAccount(account *gtsmodel.Account) *gtsmodel.Account {  		SuspensionOrigin:        account.SuspensionOrigin,  	}  } + +func usernameDomainKey(username string, domain string) string { +	u := "@" + username +	if domain != "" { +		return u + "@" + domain +	} +	return u +} diff --git a/internal/cache/account_test.go b/internal/cache/account_test.go index ff882cc3d..a6d3c6b7d 100644 --- a/internal/cache/account_test.go +++ b/internal/cache/account_test.go @@ -69,6 +69,10 @@ func (suite *AccountCacheTestSuite) TestAccountCache() {  		if account.URL != "" && !ok && !accountIs(account, check) {  			suite.Fail("Failed to fetch expected account with URL: %s", account.URL)  		} +		check, ok = suite.cache.GetByUsernameDomain(account.Username, account.Domain) +		if !ok && !accountIs(account, check) { +			suite.Fail("Failed to fetch expected account with username/domain: %s/%s", account.Username, account.Domain) +		}  	}  } diff --git a/internal/db/account.go b/internal/db/account.go index 79e7c01a5..155bd666c 100644 --- a/internal/db/account.go +++ b/internal/db/account.go @@ -36,6 +36,9 @@ type Account interface {  	// GetAccountByURL returns one account with the given URL, or an error if something goes wrong.  	GetAccountByURL(ctx context.Context, uri string) (*gtsmodel.Account, Error) +	// GetAccountByUsernameDomain returns one account with the given username and domain, or an error if something goes wrong. +	GetAccountByUsernameDomain(ctx context.Context, username string, domain string) (*gtsmodel.Account, Error) +  	// UpdateAccount updates one account by ID.  	UpdateAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, Error) diff --git a/internal/db/bundb/account.go b/internal/db/bundb/account.go index 201de6f02..95c3d80d8 100644 --- a/internal/db/bundb/account.go +++ b/internal/db/bundb/account.go @@ -84,6 +84,26 @@ func (a *accountDB) GetAccountByURL(ctx context.Context, url string) (*gtsmodel.  	)  } +func (a *accountDB) GetAccountByUsernameDomain(ctx context.Context, username string, domain string) (*gtsmodel.Account, db.Error) { +	return a.getAccount( +		ctx, +		func() (*gtsmodel.Account, bool) { +			return a.cache.GetByUsernameDomain(username, domain) +		}, +		func(account *gtsmodel.Account) error { +			q := a.newAccountQ(account).Where("account.username = ?", username) + +			if domain != "" { +				q = q.Where("account.domain = ?", domain) +			} else { +				q = q.Where("account.domain IS NULL") +			} + +			return q.Scan(ctx) +		}, +	) +} +  func (a *accountDB) getAccount(ctx context.Context, cacheGet func() (*gtsmodel.Account, bool), dbQuery func(*gtsmodel.Account) error) (*gtsmodel.Account, db.Error) {  	// Attempt to fetch cached account  	account, cached := cacheGet() diff --git a/internal/db/bundb/account_test.go b/internal/db/bundb/account_test.go index 59b51386d..3c19e84d9 100644 --- a/internal/db/bundb/account_test.go +++ b/internal/db/bundb/account_test.go @@ -58,6 +58,18 @@ func (suite *AccountTestSuite) TestGetAccountByIDWithExtras() {  	suite.NotEmpty(account.HeaderMediaAttachment.URL)  } +func (suite *AccountTestSuite) TestGetAccountByUsernameDomain() { +	testAccount1 := suite.testAccounts["local_account_1"] +	account1, err := suite.db.GetAccountByUsernameDomain(context.Background(), testAccount1.Username, testAccount1.Domain) +	suite.NoError(err) +	suite.NotNil(account1) + +	testAccount2 := suite.testAccounts["remote_account_1"] +	account2, err := suite.db.GetAccountByUsernameDomain(context.Background(), testAccount2.Username, testAccount2.Domain) +	suite.NoError(err) +	suite.NotNil(account2) +} +  func (suite *AccountTestSuite) TestUpdateAccount() {  	testAccount := suite.testAccounts["local_account_1"] diff --git a/internal/federation/dereferencing/account.go b/internal/federation/dereferencing/account.go index a0e2b87ae..cbb9466ff 100644 --- a/internal/federation/dereferencing/account.go +++ b/internal/federation/dereferencing/account.go @@ -32,6 +32,7 @@ import (  	"github.com/superseriousbusiness/activity/streams"  	"github.com/superseriousbusiness/activity/streams/vocab"  	"github.com/superseriousbusiness/gotosocial/internal/ap" +	"github.com/superseriousbusiness/gotosocial/internal/config"  	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"  	"github.com/superseriousbusiness/gotosocial/internal/id" @@ -78,7 +79,10 @@ type GetRemoteAccountParams struct {  // GetRemoteAccount completely dereferences a remote account, converts it to a GtS model account,  // puts or updates it in the database (if necessary), and returns it to a caller. -func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountParams) (remoteAccount *gtsmodel.Account, err error) { +// +// If a local account is passed into this function for whatever reason (hey, it happens!), then it +// will be returned from the database without making any remote calls. +func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountParams) (foundAccount *gtsmodel.Account, err error) {  	/*  		In this function we want to retrieve a gtsmodel representation of a remote account, with its proper  		accountDomain set, while making as few calls to remote instances as possible to save time and bandwidth. @@ -99,23 +103,40 @@ func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountPar  					from that.  	*/ -	// first check if we can retrieve the account locally just with what we've been given +	skipResolve := params.SkipResolve + +	// this first step checks if we have the +	// account in the database somewhere already  	switch {  	case params.RemoteAccountID != nil: -		// try with uri -		if a, dbErr := d.db.GetAccountByURI(ctx, params.RemoteAccountID.String()); dbErr == nil { -			remoteAccount = a +		uri := params.RemoteAccountID +		host := uri.Host +		if host == config.GetHost() || host == config.GetAccountDomain() { +			// this is actually a local account, +			// make sure we don't try to resolve +			skipResolve = true +		} + +		if a, dbErr := d.db.GetAccountByURI(ctx, uri.String()); dbErr == nil { +			foundAccount = a +		} else if dbErr != db.ErrNoEntries { +			err = fmt.Errorf("GetRemoteAccount: database error looking for account with uri %s: %s", uri, err) +		} +	case params.RemoteAccountUsername != "" && (params.RemoteAccountHost == "" || params.RemoteAccountHost == config.GetHost() || params.RemoteAccountHost == config.GetAccountDomain()): +		// either no domain is provided or this seems +		// to be a local account, so don't resolve +		skipResolve = true + +		if a, dbErr := d.db.GetLocalAccountByUsername(ctx, params.RemoteAccountUsername); dbErr == nil { +			foundAccount = a  		} else if dbErr != db.ErrNoEntries { -			err = fmt.Errorf("GetRemoteAccount: database error looking for account %s: %s", params.RemoteAccountID, err) +			err = fmt.Errorf("GetRemoteAccount: database error looking for local account with username %s: %s", params.RemoteAccountUsername, err)  		}  	case params.RemoteAccountUsername != "" && params.RemoteAccountHost != "": -		// try with username/host -		a := >smodel.Account{} -		where := []db.Where{{Key: "username", Value: params.RemoteAccountUsername}, {Key: "domain", Value: params.RemoteAccountHost}} -		if dbErr := d.db.GetWhere(ctx, where, a); dbErr == nil { -			remoteAccount = a +		if a, dbErr := d.db.GetAccountByUsernameDomain(ctx, params.RemoteAccountUsername, params.RemoteAccountHost); dbErr == nil { +			foundAccount = a  		} else if dbErr != db.ErrNoEntries { -			err = fmt.Errorf("GetRemoteAccount: database error looking for account with username %s and host %s: %s", params.RemoteAccountUsername, params.RemoteAccountHost, err) +			err = fmt.Errorf("GetRemoteAccount: database error looking for account with username %s and domain %s: %s", params.RemoteAccountUsername, params.RemoteAccountHost, err)  		}  	default:  		err = errors.New("GetRemoteAccount: no identifying parameters were set so we cannot get account") @@ -125,10 +146,11 @@ func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountPar  		return  	} -	if params.SkipResolve { -		// if we can't resolve, return already since there's nothing more we can do -		if remoteAccount == nil { -			err = errors.New("GetRemoteAccount: error retrieving account with skipResolve set true") +	if skipResolve { +		// if we can't resolve, return already +		// since there's nothing more we can do +		if foundAccount == nil { +			err = errors.New("GetRemoteAccount: couldn't retrieve account locally and won't try to resolve it")  		}  		return  	} @@ -141,8 +163,8 @@ func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountPar  		// ... but we still need the username so we can do a finger for the accountDomain  		// check if we had the account stored already and got it earlier -		if remoteAccount != nil { -			params.RemoteAccountUsername = remoteAccount.Username +		if foundAccount != nil { +			params.RemoteAccountUsername = foundAccount.Username  		} else {  			// if we didn't already have it, we have dereference it from remote and just...  			accountable, err = d.dereferenceAccountable(ctx, params.RequestingUsername, params.RemoteAccountID) @@ -167,8 +189,8 @@ func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountPar  	// already about what the account domain might be; this var will be overwritten later if necessary  	var accountDomain string  	switch { -	case remoteAccount != nil: -		accountDomain = remoteAccount.Domain +	case foundAccount != nil: +		accountDomain = foundAccount.Domain  	case params.RemoteAccountID != nil:  		accountDomain = params.RemoteAccountID.Host  	default: @@ -178,7 +200,7 @@ func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountPar  	// to save on remote calls: only webfinger if we don't have a remoteAccount yet, or if we haven't  	// fingered the remote account for at least 2 days; don't finger instance accounts  	var fingered time.Time -	if remoteAccount == nil || (remoteAccount.LastWebfingeredAt.Before(time.Now().Add(webfingerInterval)) && !instanceAccount(remoteAccount)) { +	if foundAccount == nil || (foundAccount.LastWebfingeredAt.Before(time.Now().Add(webfingerInterval)) && !instanceAccount(foundAccount)) {  		accountDomain, params.RemoteAccountID, err = d.fingerRemoteAccount(ctx, params.RequestingUsername, params.RemoteAccountUsername, params.RemoteAccountHost)  		if err != nil {  			err = fmt.Errorf("GetRemoteAccount: error while fingering: %s", err) @@ -187,14 +209,14 @@ func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountPar  		fingered = time.Now()  	} -	if !fingered.IsZero() && remoteAccount == nil { +	if !fingered.IsZero() && foundAccount == nil {  		// if we just fingered and now have a discovered account domain but still no account,  		// we should do a final lookup in the database with the discovered username + accountDomain  		// to make absolutely sure we don't already have this account  		a := >smodel.Account{}  		where := []db.Where{{Key: "username", Value: params.RemoteAccountUsername}, {Key: "domain", Value: accountDomain}}  		if dbErr := d.db.GetWhere(ctx, where, a); dbErr == nil { -			remoteAccount = a +			foundAccount = a  		} else if dbErr != db.ErrNoEntries {  			err = fmt.Errorf("GetRemoteAccount: database error looking for account with username %s and host %s: %s", params.RemoteAccountUsername, params.RemoteAccountHost, err)  			return @@ -203,7 +225,7 @@ func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountPar  	// we may also have some extra information already, like the account we had in the db, or the  	// accountable representation that we dereferenced from remote -	if remoteAccount == nil { +	if foundAccount == nil {  		// we still don't have the account, so deference it if we didn't earlier  		if accountable == nil {  			accountable, err = d.dereferenceAccountable(ctx, params.RequestingUsername, params.RemoteAccountID) @@ -214,7 +236,7 @@ func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountPar  		}  		// then convert -		remoteAccount, err = d.typeConverter.ASRepresentationToAccount(ctx, accountable, accountDomain, false) +		foundAccount, err = d.typeConverter.ASRepresentationToAccount(ctx, accountable, accountDomain, false)  		if err != nil {  			err = fmt.Errorf("GetRemoteAccount: error converting accountable to account: %s", err)  			return @@ -227,18 +249,18 @@ func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountPar  			err = fmt.Errorf("GetRemoteAccount: error generating new id for account: %s", err)  			return  		} -		remoteAccount.ID = ulid +		foundAccount.ID = ulid -		_, err = d.populateAccountFields(ctx, remoteAccount, params.RequestingUsername, params.Blocking) +		_, err = d.populateAccountFields(ctx, foundAccount, params.RequestingUsername, params.Blocking)  		if err != nil {  			err = fmt.Errorf("GetRemoteAccount: error populating further account fields: %s", err)  			return  		} -		remoteAccount.LastWebfingeredAt = fingered -		remoteAccount.UpdatedAt = time.Now() +		foundAccount.LastWebfingeredAt = fingered +		foundAccount.UpdatedAt = time.Now() -		err = d.db.Put(ctx, remoteAccount) +		err = d.db.Put(ctx, foundAccount)  		if err != nil {  			err = fmt.Errorf("GetRemoteAccount: error putting new account: %s", err)  			return @@ -248,9 +270,9 @@ func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountPar  	}  	// we had the account already, but now we know the account domain, so update it if it's different -	if !strings.EqualFold(remoteAccount.Domain, accountDomain) { -		remoteAccount.Domain = accountDomain -		remoteAccount, err = d.db.UpdateAccount(ctx, remoteAccount) +	if !strings.EqualFold(foundAccount.Domain, accountDomain) { +		foundAccount.Domain = accountDomain +		foundAccount, err = d.db.UpdateAccount(ctx, foundAccount)  		if err != nil {  			err = fmt.Errorf("GetRemoteAccount: error updating account: %s", err)  			return @@ -260,7 +282,7 @@ func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountPar  	// make sure the account fields are populated before returning:  	// the caller might want to block until everything is loaded  	var fieldsChanged bool -	fieldsChanged, err = d.populateAccountFields(ctx, remoteAccount, params.RequestingUsername, params.Blocking) +	fieldsChanged, err = d.populateAccountFields(ctx, foundAccount, params.RequestingUsername, params.Blocking)  	if err != nil {  		return nil, fmt.Errorf("GetRemoteAccount: error populating remoteAccount fields: %s", err)  	} @@ -268,12 +290,12 @@ func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountPar  	var fingeredChanged bool  	if !fingered.IsZero() {  		fingeredChanged = true -		remoteAccount.LastWebfingeredAt = fingered +		foundAccount.LastWebfingeredAt = fingered  	}  	if fieldsChanged || fingeredChanged { -		remoteAccount.UpdatedAt = time.Now() -		remoteAccount, err = d.db.UpdateAccount(ctx, remoteAccount) +		foundAccount.UpdatedAt = time.Now() +		foundAccount, err = d.db.UpdateAccount(ctx, foundAccount)  		if err != nil {  			return nil, fmt.Errorf("GetRemoteAccount: error updating remoteAccount: %s", err)  		} diff --git a/internal/federation/dereferencing/account_test.go b/internal/federation/dereferencing/account_test.go index 72092951b..77ebb7cac 100644 --- a/internal/federation/dereferencing/account_test.go +++ b/internal/federation/dereferencing/account_test.go @@ -21,9 +21,11 @@ package dereferencing_test  import (  	"context"  	"testing" +	"time"  	"github.com/stretchr/testify/suite"  	"github.com/superseriousbusiness/gotosocial/internal/ap" +	"github.com/superseriousbusiness/gotosocial/internal/config"  	"github.com/superseriousbusiness/gotosocial/internal/federation/dereferencing"  	"github.com/superseriousbusiness/gotosocial/testrig"  ) @@ -42,11 +44,11 @@ func (suite *AccountTestSuite) TestDereferenceGroup() {  	})  	suite.NoError(err)  	suite.NotNil(group) -	suite.NotNil(group)  	// group values should be set  	suite.Equal("https://unknown-instance.com/groups/some_group", group.URI)  	suite.Equal("https://unknown-instance.com/@some_group", group.URL) +	suite.WithinDuration(time.Now(), group.LastWebfingeredAt, 5*time.Second)  	// group should be in the database  	dbGroup, err := suite.db.GetAccountByURI(context.Background(), group.URI) @@ -65,11 +67,11 @@ func (suite *AccountTestSuite) TestDereferenceService() {  	})  	suite.NoError(err)  	suite.NotNil(service) -	suite.NotNil(service)  	// service values should be set  	suite.Equal("https://owncast.example.org/federation/user/rgh", service.URI)  	suite.Equal("https://owncast.example.org/federation/user/rgh", service.URL) +	suite.WithinDuration(time.Now(), service.LastWebfingeredAt, 5*time.Second)  	// service should be in the database  	dbService, err := suite.db.GetAccountByURI(context.Background(), service.URI) @@ -79,6 +81,102 @@ func (suite *AccountTestSuite) TestDereferenceService() {  	suite.Equal("example.org", dbService.Domain)  } +/* +	We shouldn't try webfingering or making http calls to dereference local accounts +	that might be passed into GetRemoteAccount for whatever reason, so these tests are +	here to make sure that such cases are (basically) short-circuit evaluated and given +	back as-is without trying to make any calls to one's own instance. +*/ + +func (suite *AccountTestSuite) TestDereferenceLocalAccountAsRemoteURL() { +	fetchingAccount := suite.testAccounts["local_account_1"] +	targetAccount := suite.testAccounts["local_account_2"] + +	fetchedAccount, err := suite.dereferencer.GetRemoteAccount(context.Background(), dereferencing.GetRemoteAccountParams{ +		RequestingUsername: fetchingAccount.Username, +		RemoteAccountID:    testrig.URLMustParse(targetAccount.URI), +	}) +	suite.NoError(err) +	suite.NotNil(fetchedAccount) +	suite.Empty(fetchedAccount.Domain) +} + +func (suite *AccountTestSuite) TestDereferenceLocalAccountAsUsername() { +	fetchingAccount := suite.testAccounts["local_account_1"] +	targetAccount := suite.testAccounts["local_account_2"] + +	fetchedAccount, err := suite.dereferencer.GetRemoteAccount(context.Background(), dereferencing.GetRemoteAccountParams{ +		RequestingUsername:    fetchingAccount.Username, +		RemoteAccountUsername: targetAccount.Username, +	}) +	suite.NoError(err) +	suite.NotNil(fetchedAccount) +	suite.Empty(fetchedAccount.Domain) +} + +func (suite *AccountTestSuite) TestDereferenceLocalAccountAsUsernameDomain() { +	fetchingAccount := suite.testAccounts["local_account_1"] +	targetAccount := suite.testAccounts["local_account_2"] + +	fetchedAccount, err := suite.dereferencer.GetRemoteAccount(context.Background(), dereferencing.GetRemoteAccountParams{ +		RequestingUsername:    fetchingAccount.Username, +		RemoteAccountUsername: targetAccount.Username, +		RemoteAccountHost:     config.GetHost(), +	}) +	suite.NoError(err) +	suite.NotNil(fetchedAccount) +	suite.Empty(fetchedAccount.Domain) +} + +func (suite *AccountTestSuite) TestDereferenceLocalAccountAsUsernameDomainAndURL() { +	fetchingAccount := suite.testAccounts["local_account_1"] +	targetAccount := suite.testAccounts["local_account_2"] + +	fetchedAccount, err := suite.dereferencer.GetRemoteAccount(context.Background(), dereferencing.GetRemoteAccountParams{ +		RequestingUsername:    fetchingAccount.Username, +		RemoteAccountID:       testrig.URLMustParse(targetAccount.URI), +		RemoteAccountUsername: targetAccount.Username, +		RemoteAccountHost:     config.GetHost(), +	}) +	suite.NoError(err) +	suite.NotNil(fetchedAccount) +	suite.Empty(fetchedAccount.Domain) +} + +func (suite *AccountTestSuite) TestDereferenceLocalAccountWithUnknownUsername() { +	fetchingAccount := suite.testAccounts["local_account_1"] + +	fetchedAccount, err := suite.dereferencer.GetRemoteAccount(context.Background(), dereferencing.GetRemoteAccountParams{ +		RequestingUsername:    fetchingAccount.Username, +		RemoteAccountUsername: "thisaccountdoesnotexist", +	}) +	suite.EqualError(err, "GetRemoteAccount: couldn't retrieve account locally and won't try to resolve it") +	suite.Nil(fetchedAccount) +} + +func (suite *AccountTestSuite) TestDereferenceLocalAccountWithUnknownUsernameDomain() { +	fetchingAccount := suite.testAccounts["local_account_1"] + +	fetchedAccount, err := suite.dereferencer.GetRemoteAccount(context.Background(), dereferencing.GetRemoteAccountParams{ +		RequestingUsername:    fetchingAccount.Username, +		RemoteAccountUsername: "thisaccountdoesnotexist", +		RemoteAccountHost:     "localhost:8080", +	}) +	suite.EqualError(err, "GetRemoteAccount: couldn't retrieve account locally and won't try to resolve it") +	suite.Nil(fetchedAccount) +} + +func (suite *AccountTestSuite) TestDereferenceLocalAccountWithUnknownUserURI() { +	fetchingAccount := suite.testAccounts["local_account_1"] + +	fetchedAccount, err := suite.dereferencer.GetRemoteAccount(context.Background(), dereferencing.GetRemoteAccountParams{ +		RequestingUsername: fetchingAccount.Username, +		RemoteAccountID:    testrig.URLMustParse("http://localhost:8080/users/thisaccountdoesnotexist"), +	}) +	suite.EqualError(err, "GetRemoteAccount: couldn't retrieve account locally and won't try to resolve it") +	suite.Nil(fetchedAccount) +} +  func TestAccountTestSuite(t *testing.T) {  	suite.Run(t, new(AccountTestSuite))  } diff --git a/internal/processing/search.go b/internal/processing/search.go index d25bee2ae..b766b4ba2 100644 --- a/internal/processing/search.go +++ b/internal/processing/search.go @@ -39,7 +39,6 @@ import (  func (p *processor) SearchGet(ctx context.Context, authed *oauth.Auth, search *apimodel.SearchQuery) (*apimodel.SearchResult, gtserror.WithCode) {  	l := log.WithFields(kv.Fields{ -  		{"query", search.Query},  	}...) @@ -62,7 +61,7 @@ func (p *processor) SearchGet(ctx context.Context, authed *oauth.Auth, search *a  	/*  		SEARCH BY MENTION -		check if the query is something like @whatever_username@example.org -- this means it's a remote account +		check if the query is something like @whatever_username@example.org -- this means it's likely a remote account  	*/  	maybeNamestring := query  	if maybeNamestring[0] != '@' { @@ -135,7 +134,6 @@ func (p *processor) SearchGet(ctx context.Context, authed *oauth.Auth, search *a  func (p *processor) searchStatusByURI(ctx context.Context, authed *oauth.Auth, uri *url.URL, resolve bool) (*gtsmodel.Status, error) {  	l := log.WithFields(kv.Fields{ -  		{"uri", uri.String()},  		{"resolve", resolve},  	}...) @@ -161,67 +159,46 @@ func (p *processor) searchStatusByURI(ctx context.Context, authed *oauth.Auth, u  }  func (p *processor) searchAccountByURI(ctx context.Context, authed *oauth.Auth, uri *url.URL, resolve bool) (*gtsmodel.Account, error) { -	if maybeAccount, err := p.db.GetAccountByURI(ctx, uri.String()); err == nil { -		return maybeAccount, nil -	} else if maybeAccount, err := p.db.GetAccountByURL(ctx, uri.String()); err == nil { +	// it might be a web url like http://example.org/@user instead +	// of an AP uri like http://example.org/users/user, check first +	if maybeAccount, err := p.db.GetAccountByURL(ctx, uri.String()); err == nil {  		return maybeAccount, nil  	} -	if resolve { -		// we don't have it locally so try and dereference it -		account, err := p.federator.GetRemoteAccount(ctx, dereferencing.GetRemoteAccountParams{ -			RequestingUsername: authed.Account.Username, -			RemoteAccountID:    uri, -		}) -		if err != nil { -			return nil, fmt.Errorf("searchAccountByURI: error dereferencing account with uri %s: %s", uri.String(), err) +	if uri.Host == config.GetHost() || uri.Host == config.GetAccountDomain() { +		// this is a local account; if we don't have it now then +		// we should just bail instead of trying to get it remote +		if maybeAccount, err := p.db.GetAccountByURI(ctx, uri.String()); err == nil { +			return maybeAccount, nil  		} -		return account, nil +		return nil, nil  	} -	return nil, nil + +	// we don't have it yet, try to find it remotely +	return p.federator.GetRemoteAccount(ctx, dereferencing.GetRemoteAccountParams{ +		RequestingUsername: authed.Account.Username, +		RemoteAccountID:    uri, +		Blocking:           true, +		SkipResolve:        !resolve, +	})  }  func (p *processor) searchAccountByMention(ctx context.Context, authed *oauth.Auth, username string, domain string, resolve bool) (*gtsmodel.Account, error) { -	maybeAcct := >smodel.Account{} -	var err error -  	// if it's a local account we can skip a whole bunch of stuff  	if domain == config.GetHost() || domain == config.GetAccountDomain() || domain == "" { -		maybeAcct, err = p.db.GetLocalAccountByUsername(ctx, username) -		if err != nil && err != db.ErrNoEntries { -			return nil, fmt.Errorf("searchAccountByMention: error getting local account by username: %s", err) +		maybeAcct, err := p.db.GetLocalAccountByUsername(ctx, username) +		if err == nil || err == db.ErrNoEntries { +			return maybeAcct, nil  		} -		return maybeAcct, nil +		return nil, fmt.Errorf("searchAccountByMention: error getting local account by username: %s", err)  	} -	// it's not a local account so first we'll check if it's in the database already... -	where := []db.Where{ -		{Key: "username", Value: username, CaseInsensitive: true}, -		{Key: "domain", Value: domain, CaseInsensitive: true}, -	} -	err = p.db.GetWhere(ctx, where, maybeAcct) -	if err == nil { -		// we've got it stored locally already! -		return maybeAcct, nil -	} - -	if err != db.ErrNoEntries { -		// if it's  not errNoEntries there's been a real database error so bail at this point -		return nil, fmt.Errorf("searchAccountByMention: database error: %s", err) -	} - -	// we got a db.ErrNoEntries, so we just don't have the account locally stored -- check if we can dereference it -	if resolve { -		maybeAcct, err = p.federator.GetRemoteAccount(ctx, dereferencing.GetRemoteAccountParams{ -			RequestingUsername:    authed.Account.Username, -			RemoteAccountUsername: username, -			RemoteAccountHost:     domain, -		}) -		if err != nil { -			return nil, fmt.Errorf("searchAccountByMention: error getting remote account: %s", err) -		} -		return maybeAcct, nil -	} - -	return nil, nil +	// we don't have it yet, try to find it remotely +	return p.federator.GetRemoteAccount(ctx, dereferencing.GetRemoteAccountParams{ +		RequestingUsername:    authed.Account.Username, +		RemoteAccountUsername: username, +		RemoteAccountHost:     domain, +		Blocking:              true, +		SkipResolve:           !resolve, +	})  }  | 
