diff options
Diffstat (limited to 'internal/db/bundb/account.go')
| -rw-r--r-- | internal/db/bundb/account.go | 148 |
1 files changed, 85 insertions, 63 deletions
diff --git a/internal/db/bundb/account.go b/internal/db/bundb/account.go index aacfcd247..88a923ecf 100644 --- a/internal/db/bundb/account.go +++ b/internal/db/bundb/account.go @@ -121,18 +121,46 @@ func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel. ) } -func (a *accountDB) GetAccountByURL(ctx context.Context, url string) (*gtsmodel.Account, error) { - return a.getAccount( - ctx, - "URL", - func(account *gtsmodel.Account) error { - return a.db.NewSelect(). - Model(account). - Where("? = ?", bun.Ident("account.url"), url). - Scan(ctx) - }, - url, - ) +func (a *accountDB) GetOneAccountByURL(ctx context.Context, url string) (*gtsmodel.Account, error) { + // Select IDs of all + // accounts with this url. + var ids []string + if err := a.db.NewSelect(). + TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")). + Column("account.id"). + Where("? = ?", bun.Ident("account.url"), url). + Scan(ctx, &ids); err != nil { + return nil, err + } + + // Ensure exactly one account. + if len(ids) == 0 { + return nil, db.ErrNoEntries + } + if len(ids) > 1 { + return nil, db.ErrMultipleEntries + } + + return a.GetAccountByID(ctx, ids[0]) +} + +func (a *accountDB) GetAccountsByURL(ctx context.Context, url string) ([]*gtsmodel.Account, error) { + // Select IDs of all + // accounts with this url. + var ids []string + if err := a.db.NewSelect(). + TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")). + Column("account.id"). + Where("? = ?", bun.Ident("account.url"), url). + Scan(ctx, &ids); err != nil { + return nil, err + } + + if len(ids) == 0 { + return nil, db.ErrNoEntries + } + + return a.GetAccountsByIDs(ctx, ids) } func (a *accountDB) GetAccountByUsernameDomain(ctx context.Context, username string, domain string) (*gtsmodel.Account, error) { @@ -184,60 +212,50 @@ func (a *accountDB) GetAccountByPubkeyID(ctx context.Context, id string) (*gtsmo ) } -func (a *accountDB) GetAccountByInboxURI(ctx context.Context, uri string) (*gtsmodel.Account, error) { - return a.getAccount( - ctx, - "InboxURI", - func(account *gtsmodel.Account) error { - return a.db.NewSelect(). - Model(account). - Where("? = ?", bun.Ident("account.inbox_uri"), uri). - Scan(ctx) - }, - uri, - ) -} +func (a *accountDB) GetOneAccountByInboxURI(ctx context.Context, inboxURI string) (*gtsmodel.Account, error) { + // Select IDs of all accounts + // with this inbox_uri. + var ids []string + if err := a.db.NewSelect(). + TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")). + Column("account.id"). + Where("? = ?", bun.Ident("account.inbox_uri"), inboxURI). + Scan(ctx, &ids); err != nil { + return nil, err + } -func (a *accountDB) GetAccountByOutboxURI(ctx context.Context, uri string) (*gtsmodel.Account, error) { - return a.getAccount( - ctx, - "OutboxURI", - func(account *gtsmodel.Account) error { - return a.db.NewSelect(). - Model(account). - Where("? = ?", bun.Ident("account.outbox_uri"), uri). - Scan(ctx) - }, - uri, - ) -} + // Ensure exactly one account. + if len(ids) == 0 { + return nil, db.ErrNoEntries + } + if len(ids) > 1 { + return nil, db.ErrMultipleEntries + } -func (a *accountDB) GetAccountByFollowersURI(ctx context.Context, uri string) (*gtsmodel.Account, error) { - return a.getAccount( - ctx, - "FollowersURI", - func(account *gtsmodel.Account) error { - return a.db.NewSelect(). - Model(account). - Where("? = ?", bun.Ident("account.followers_uri"), uri). - Scan(ctx) - }, - uri, - ) + return a.GetAccountByID(ctx, ids[0]) } -func (a *accountDB) GetAccountByFollowingURI(ctx context.Context, uri string) (*gtsmodel.Account, error) { - return a.getAccount( - ctx, - "FollowingURI", - func(account *gtsmodel.Account) error { - return a.db.NewSelect(). - Model(account). - Where("? = ?", bun.Ident("account.following_uri"), uri). - Scan(ctx) - }, - uri, - ) +func (a *accountDB) GetOneAccountByOutboxURI(ctx context.Context, outboxURI string) (*gtsmodel.Account, error) { + // Select IDs of all accounts + // with this outbox_uri. + var ids []string + if err := a.db.NewSelect(). + TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")). + Column("account.id"). + Where("? = ?", bun.Ident("account.outbox_uri"), outboxURI). + Scan(ctx, &ids); err != nil { + return nil, err + } + + // Ensure exactly one account. + if len(ids) == 0 { + return nil, db.ErrNoEntries + } + if len(ids) > 1 { + return nil, db.ErrMultipleEntries + } + + return a.GetAccountByID(ctx, ids[0]) } func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, error) { @@ -587,7 +605,11 @@ func (a *accountDB) GetAccounts( return a.state.DB.GetAccountsByIDs(ctx, accountIDs) } -func (a *accountDB) getAccount(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Account) error, keyParts ...any) (*gtsmodel.Account, error) { +func (a *accountDB) getAccount( + ctx context.Context, + lookup string, + dbQuery func(*gtsmodel.Account) error, keyParts ...any, +) (*gtsmodel.Account, error) { // Fetch account from database cache with loader callback account, err := a.state.Caches.DB.Account.LoadOne(lookup, func() (*gtsmodel.Account, error) { var account gtsmodel.Account |
