diff options
Diffstat (limited to 'internal/db')
-rw-r--r-- | internal/db/account.go | 3 | ||||
-rw-r--r-- | internal/db/bundb/account.go | 20 | ||||
-rw-r--r-- | internal/db/bundb/account_test.go | 12 |
3 files changed, 35 insertions, 0 deletions
diff --git a/internal/db/account.go b/internal/db/account.go index 79e7c01a5..155bd666c 100644 --- a/internal/db/account.go +++ b/internal/db/account.go @@ -36,6 +36,9 @@ type Account interface { // GetAccountByURL returns one account with the given URL, or an error if something goes wrong. GetAccountByURL(ctx context.Context, uri string) (*gtsmodel.Account, Error) + // GetAccountByUsernameDomain returns one account with the given username and domain, or an error if something goes wrong. + GetAccountByUsernameDomain(ctx context.Context, username string, domain string) (*gtsmodel.Account, Error) + // UpdateAccount updates one account by ID. UpdateAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, Error) diff --git a/internal/db/bundb/account.go b/internal/db/bundb/account.go index 201de6f02..95c3d80d8 100644 --- a/internal/db/bundb/account.go +++ b/internal/db/bundb/account.go @@ -84,6 +84,26 @@ func (a *accountDB) GetAccountByURL(ctx context.Context, url string) (*gtsmodel. ) } +func (a *accountDB) GetAccountByUsernameDomain(ctx context.Context, username string, domain string) (*gtsmodel.Account, db.Error) { + return a.getAccount( + ctx, + func() (*gtsmodel.Account, bool) { + return a.cache.GetByUsernameDomain(username, domain) + }, + func(account *gtsmodel.Account) error { + q := a.newAccountQ(account).Where("account.username = ?", username) + + if domain != "" { + q = q.Where("account.domain = ?", domain) + } else { + q = q.Where("account.domain IS NULL") + } + + return q.Scan(ctx) + }, + ) +} + func (a *accountDB) getAccount(ctx context.Context, cacheGet func() (*gtsmodel.Account, bool), dbQuery func(*gtsmodel.Account) error) (*gtsmodel.Account, db.Error) { // Attempt to fetch cached account account, cached := cacheGet() diff --git a/internal/db/bundb/account_test.go b/internal/db/bundb/account_test.go index 59b51386d..3c19e84d9 100644 --- a/internal/db/bundb/account_test.go +++ b/internal/db/bundb/account_test.go @@ -58,6 +58,18 @@ func (suite *AccountTestSuite) TestGetAccountByIDWithExtras() { suite.NotEmpty(account.HeaderMediaAttachment.URL) } +func (suite *AccountTestSuite) TestGetAccountByUsernameDomain() { + testAccount1 := suite.testAccounts["local_account_1"] + account1, err := suite.db.GetAccountByUsernameDomain(context.Background(), testAccount1.Username, testAccount1.Domain) + suite.NoError(err) + suite.NotNil(account1) + + testAccount2 := suite.testAccounts["remote_account_1"] + account2, err := suite.db.GetAccountByUsernameDomain(context.Background(), testAccount2.Username, testAccount2.Domain) + suite.NoError(err) + suite.NotNil(account2) +} + func (suite *AccountTestSuite) TestUpdateAccount() { testAccount := suite.testAccounts["local_account_1"] |