summaryrefslogtreecommitdiff
path: root/internal/db
diff options
context:
space:
mode:
Diffstat (limited to 'internal/db')
-rw-r--r--internal/db/account.go3
-rw-r--r--internal/db/bundb/account.go20
-rw-r--r--internal/db/bundb/account_test.go12
3 files changed, 35 insertions, 0 deletions
diff --git a/internal/db/account.go b/internal/db/account.go
index 79e7c01a5..155bd666c 100644
--- a/internal/db/account.go
+++ b/internal/db/account.go
@@ -36,6 +36,9 @@ type Account interface {
// GetAccountByURL returns one account with the given URL, or an error if something goes wrong.
GetAccountByURL(ctx context.Context, uri string) (*gtsmodel.Account, Error)
+ // GetAccountByUsernameDomain returns one account with the given username and domain, or an error if something goes wrong.
+ GetAccountByUsernameDomain(ctx context.Context, username string, domain string) (*gtsmodel.Account, Error)
+
// UpdateAccount updates one account by ID.
UpdateAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, Error)
diff --git a/internal/db/bundb/account.go b/internal/db/bundb/account.go
index 201de6f02..95c3d80d8 100644
--- a/internal/db/bundb/account.go
+++ b/internal/db/bundb/account.go
@@ -84,6 +84,26 @@ func (a *accountDB) GetAccountByURL(ctx context.Context, url string) (*gtsmodel.
)
}
+func (a *accountDB) GetAccountByUsernameDomain(ctx context.Context, username string, domain string) (*gtsmodel.Account, db.Error) {
+ return a.getAccount(
+ ctx,
+ func() (*gtsmodel.Account, bool) {
+ return a.cache.GetByUsernameDomain(username, domain)
+ },
+ func(account *gtsmodel.Account) error {
+ q := a.newAccountQ(account).Where("account.username = ?", username)
+
+ if domain != "" {
+ q = q.Where("account.domain = ?", domain)
+ } else {
+ q = q.Where("account.domain IS NULL")
+ }
+
+ return q.Scan(ctx)
+ },
+ )
+}
+
func (a *accountDB) getAccount(ctx context.Context, cacheGet func() (*gtsmodel.Account, bool), dbQuery func(*gtsmodel.Account) error) (*gtsmodel.Account, db.Error) {
// Attempt to fetch cached account
account, cached := cacheGet()
diff --git a/internal/db/bundb/account_test.go b/internal/db/bundb/account_test.go
index 59b51386d..3c19e84d9 100644
--- a/internal/db/bundb/account_test.go
+++ b/internal/db/bundb/account_test.go
@@ -58,6 +58,18 @@ func (suite *AccountTestSuite) TestGetAccountByIDWithExtras() {
suite.NotEmpty(account.HeaderMediaAttachment.URL)
}
+func (suite *AccountTestSuite) TestGetAccountByUsernameDomain() {
+ testAccount1 := suite.testAccounts["local_account_1"]
+ account1, err := suite.db.GetAccountByUsernameDomain(context.Background(), testAccount1.Username, testAccount1.Domain)
+ suite.NoError(err)
+ suite.NotNil(account1)
+
+ testAccount2 := suite.testAccounts["remote_account_1"]
+ account2, err := suite.db.GetAccountByUsernameDomain(context.Background(), testAccount2.Username, testAccount2.Domain)
+ suite.NoError(err)
+ suite.NotNil(account2)
+}
+
func (suite *AccountTestSuite) TestUpdateAccount() {
testAccount := suite.testAccounts["local_account_1"]