summaryrefslogtreecommitdiff
path: root/internal/db/bundb/account.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/db/bundb/account.go')
-rw-r--r--internal/db/bundb/account.go148
1 files changed, 85 insertions, 63 deletions
diff --git a/internal/db/bundb/account.go b/internal/db/bundb/account.go
index aacfcd247..88a923ecf 100644
--- a/internal/db/bundb/account.go
+++ b/internal/db/bundb/account.go
@@ -121,18 +121,46 @@ func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.
)
}
-func (a *accountDB) GetAccountByURL(ctx context.Context, url string) (*gtsmodel.Account, error) {
- return a.getAccount(
- ctx,
- "URL",
- func(account *gtsmodel.Account) error {
- return a.db.NewSelect().
- Model(account).
- Where("? = ?", bun.Ident("account.url"), url).
- Scan(ctx)
- },
- url,
- )
+func (a *accountDB) GetOneAccountByURL(ctx context.Context, url string) (*gtsmodel.Account, error) {
+ // Select IDs of all
+ // accounts with this url.
+ var ids []string
+ if err := a.db.NewSelect().
+ TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")).
+ Column("account.id").
+ Where("? = ?", bun.Ident("account.url"), url).
+ Scan(ctx, &ids); err != nil {
+ return nil, err
+ }
+
+ // Ensure exactly one account.
+ if len(ids) == 0 {
+ return nil, db.ErrNoEntries
+ }
+ if len(ids) > 1 {
+ return nil, db.ErrMultipleEntries
+ }
+
+ return a.GetAccountByID(ctx, ids[0])
+}
+
+func (a *accountDB) GetAccountsByURL(ctx context.Context, url string) ([]*gtsmodel.Account, error) {
+ // Select IDs of all
+ // accounts with this url.
+ var ids []string
+ if err := a.db.NewSelect().
+ TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")).
+ Column("account.id").
+ Where("? = ?", bun.Ident("account.url"), url).
+ Scan(ctx, &ids); err != nil {
+ return nil, err
+ }
+
+ if len(ids) == 0 {
+ return nil, db.ErrNoEntries
+ }
+
+ return a.GetAccountsByIDs(ctx, ids)
}
func (a *accountDB) GetAccountByUsernameDomain(ctx context.Context, username string, domain string) (*gtsmodel.Account, error) {
@@ -184,60 +212,50 @@ func (a *accountDB) GetAccountByPubkeyID(ctx context.Context, id string) (*gtsmo
)
}
-func (a *accountDB) GetAccountByInboxURI(ctx context.Context, uri string) (*gtsmodel.Account, error) {
- return a.getAccount(
- ctx,
- "InboxURI",
- func(account *gtsmodel.Account) error {
- return a.db.NewSelect().
- Model(account).
- Where("? = ?", bun.Ident("account.inbox_uri"), uri).
- Scan(ctx)
- },
- uri,
- )
-}
+func (a *accountDB) GetOneAccountByInboxURI(ctx context.Context, inboxURI string) (*gtsmodel.Account, error) {
+ // Select IDs of all accounts
+ // with this inbox_uri.
+ var ids []string
+ if err := a.db.NewSelect().
+ TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")).
+ Column("account.id").
+ Where("? = ?", bun.Ident("account.inbox_uri"), inboxURI).
+ Scan(ctx, &ids); err != nil {
+ return nil, err
+ }
-func (a *accountDB) GetAccountByOutboxURI(ctx context.Context, uri string) (*gtsmodel.Account, error) {
- return a.getAccount(
- ctx,
- "OutboxURI",
- func(account *gtsmodel.Account) error {
- return a.db.NewSelect().
- Model(account).
- Where("? = ?", bun.Ident("account.outbox_uri"), uri).
- Scan(ctx)
- },
- uri,
- )
-}
+ // Ensure exactly one account.
+ if len(ids) == 0 {
+ return nil, db.ErrNoEntries
+ }
+ if len(ids) > 1 {
+ return nil, db.ErrMultipleEntries
+ }
-func (a *accountDB) GetAccountByFollowersURI(ctx context.Context, uri string) (*gtsmodel.Account, error) {
- return a.getAccount(
- ctx,
- "FollowersURI",
- func(account *gtsmodel.Account) error {
- return a.db.NewSelect().
- Model(account).
- Where("? = ?", bun.Ident("account.followers_uri"), uri).
- Scan(ctx)
- },
- uri,
- )
+ return a.GetAccountByID(ctx, ids[0])
}
-func (a *accountDB) GetAccountByFollowingURI(ctx context.Context, uri string) (*gtsmodel.Account, error) {
- return a.getAccount(
- ctx,
- "FollowingURI",
- func(account *gtsmodel.Account) error {
- return a.db.NewSelect().
- Model(account).
- Where("? = ?", bun.Ident("account.following_uri"), uri).
- Scan(ctx)
- },
- uri,
- )
+func (a *accountDB) GetOneAccountByOutboxURI(ctx context.Context, outboxURI string) (*gtsmodel.Account, error) {
+ // Select IDs of all accounts
+ // with this outbox_uri.
+ var ids []string
+ if err := a.db.NewSelect().
+ TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")).
+ Column("account.id").
+ Where("? = ?", bun.Ident("account.outbox_uri"), outboxURI).
+ Scan(ctx, &ids); err != nil {
+ return nil, err
+ }
+
+ // Ensure exactly one account.
+ if len(ids) == 0 {
+ return nil, db.ErrNoEntries
+ }
+ if len(ids) > 1 {
+ return nil, db.ErrMultipleEntries
+ }
+
+ return a.GetAccountByID(ctx, ids[0])
}
func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, error) {
@@ -587,7 +605,11 @@ func (a *accountDB) GetAccounts(
return a.state.DB.GetAccountsByIDs(ctx, accountIDs)
}
-func (a *accountDB) getAccount(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Account) error, keyParts ...any) (*gtsmodel.Account, error) {
+func (a *accountDB) getAccount(
+ ctx context.Context,
+ lookup string,
+ dbQuery func(*gtsmodel.Account) error, keyParts ...any,
+) (*gtsmodel.Account, error) {
// Fetch account from database cache with loader callback
account, err := a.state.Caches.DB.Account.LoadOne(lookup, func() (*gtsmodel.Account, error) {
var account gtsmodel.Account