diff options
Diffstat (limited to 'internal/db/bundb/account.go')
-rw-r--r-- | internal/db/bundb/account.go | 120 |
1 files changed, 60 insertions, 60 deletions
diff --git a/internal/db/bundb/account.go b/internal/db/bundb/account.go index 179db6bb3..2ef1618db 100644 --- a/internal/db/bundb/account.go +++ b/internal/db/bundb/account.go @@ -38,16 +38,16 @@ import ( ) type accountDB struct { - conn *DBConn + db *WrappedDB state *state.State } -func (a *accountDB) GetAccountByID(ctx context.Context, id string) (*gtsmodel.Account, db.Error) { +func (a *accountDB) GetAccountByID(ctx context.Context, id string) (*gtsmodel.Account, error) { return a.getAccount( ctx, "ID", func(account *gtsmodel.Account) error { - return a.conn.NewSelect(). + return a.db.NewSelect(). Model(account). Where("? = ?", bun.Ident("account.id"), id). Scan(ctx) @@ -77,12 +77,12 @@ func (a *accountDB) GetAccountsByIDs(ctx context.Context, ids []string) ([]*gtsm return accounts, nil } -func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) { +func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.Account, error) { return a.getAccount( ctx, "URI", func(account *gtsmodel.Account) error { - return a.conn.NewSelect(). + return a.db.NewSelect(). Model(account). Where("? = ?", bun.Ident("account.uri"), uri). Scan(ctx) @@ -91,12 +91,12 @@ func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel. ) } -func (a *accountDB) GetAccountByURL(ctx context.Context, url string) (*gtsmodel.Account, db.Error) { +func (a *accountDB) GetAccountByURL(ctx context.Context, url string) (*gtsmodel.Account, error) { return a.getAccount( ctx, "URL", func(account *gtsmodel.Account) error { - return a.conn.NewSelect(). + return a.db.NewSelect(). Model(account). Where("? = ?", bun.Ident("account.url"), url). Scan(ctx) @@ -105,7 +105,7 @@ func (a *accountDB) GetAccountByURL(ctx context.Context, url string) (*gtsmodel. ) } -func (a *accountDB) GetAccountByUsernameDomain(ctx context.Context, username string, domain string) (*gtsmodel.Account, db.Error) { +func (a *accountDB) GetAccountByUsernameDomain(ctx context.Context, username string, domain string) (*gtsmodel.Account, error) { if domain != "" { // Normalize the domain as punycode var err error @@ -119,7 +119,7 @@ func (a *accountDB) GetAccountByUsernameDomain(ctx context.Context, username str ctx, "Username.Domain", func(account *gtsmodel.Account) error { - q := a.conn.NewSelect(). + q := a.db.NewSelect(). Model(account) if domain != "" { @@ -139,12 +139,12 @@ func (a *accountDB) GetAccountByUsernameDomain(ctx context.Context, username str ) } -func (a *accountDB) GetAccountByPubkeyID(ctx context.Context, id string) (*gtsmodel.Account, db.Error) { +func (a *accountDB) GetAccountByPubkeyID(ctx context.Context, id string) (*gtsmodel.Account, error) { return a.getAccount( ctx, "PublicKeyURI", func(account *gtsmodel.Account) error { - return a.conn.NewSelect(). + return a.db.NewSelect(). Model(account). Where("? = ?", bun.Ident("account.public_key_uri"), id). Scan(ctx) @@ -153,12 +153,12 @@ func (a *accountDB) GetAccountByPubkeyID(ctx context.Context, id string) (*gtsmo ) } -func (a *accountDB) GetAccountByInboxURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) { +func (a *accountDB) GetAccountByInboxURI(ctx context.Context, uri string) (*gtsmodel.Account, error) { return a.getAccount( ctx, "InboxURI", func(account *gtsmodel.Account) error { - return a.conn.NewSelect(). + return a.db.NewSelect(). Model(account). Where("? = ?", bun.Ident("account.inbox_uri"), uri). Scan(ctx) @@ -167,12 +167,12 @@ func (a *accountDB) GetAccountByInboxURI(ctx context.Context, uri string) (*gtsm ) } -func (a *accountDB) GetAccountByOutboxURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) { +func (a *accountDB) GetAccountByOutboxURI(ctx context.Context, uri string) (*gtsmodel.Account, error) { return a.getAccount( ctx, "OutboxURI", func(account *gtsmodel.Account) error { - return a.conn.NewSelect(). + return a.db.NewSelect(). Model(account). Where("? = ?", bun.Ident("account.outbox_uri"), uri). Scan(ctx) @@ -181,12 +181,12 @@ func (a *accountDB) GetAccountByOutboxURI(ctx context.Context, uri string) (*gts ) } -func (a *accountDB) GetAccountByFollowersURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) { +func (a *accountDB) GetAccountByFollowersURI(ctx context.Context, uri string) (*gtsmodel.Account, error) { return a.getAccount( ctx, "FollowersURI", func(account *gtsmodel.Account) error { - return a.conn.NewSelect(). + return a.db.NewSelect(). Model(account). Where("? = ?", bun.Ident("account.followers_uri"), uri). Scan(ctx) @@ -195,12 +195,12 @@ func (a *accountDB) GetAccountByFollowersURI(ctx context.Context, uri string) (* ) } -func (a *accountDB) GetAccountByFollowingURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) { +func (a *accountDB) GetAccountByFollowingURI(ctx context.Context, uri string) (*gtsmodel.Account, error) { return a.getAccount( ctx, "FollowingURI", func(account *gtsmodel.Account) error { - return a.conn.NewSelect(). + return a.db.NewSelect(). Model(account). Where("? = ?", bun.Ident("account.following_uri"), uri). Scan(ctx) @@ -209,7 +209,7 @@ func (a *accountDB) GetAccountByFollowingURI(ctx context.Context, uri string) (* ) } -func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, db.Error) { +func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, error) { var username string if domain == "" { @@ -223,14 +223,14 @@ func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gts return a.GetAccountByUsernameDomain(ctx, username, domain) } -func (a *accountDB) getAccount(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Account) error, keyParts ...any) (*gtsmodel.Account, db.Error) { +func (a *accountDB) getAccount(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Account) error, keyParts ...any) (*gtsmodel.Account, error) { // Fetch account from database cache with loader callback account, err := a.state.Caches.GTS.Account().Load(lookup, func() (*gtsmodel.Account, error) { var account gtsmodel.Account // Not cached! Perform database query if err := dbQuery(&account); err != nil { - return nil, a.conn.ProcessError(err) + return nil, a.db.ProcessError(err) } return &account, nil @@ -294,12 +294,12 @@ func (a *accountDB) PopulateAccount(ctx context.Context, account *gtsmodel.Accou return errs.Combine() } -func (a *accountDB) PutAccount(ctx context.Context, account *gtsmodel.Account) db.Error { +func (a *accountDB) PutAccount(ctx context.Context, account *gtsmodel.Account) error { return a.state.Caches.GTS.Account().Store(account, func() error { // It is safe to run this database transaction within cache.Store // as the cache does not attempt a mutex lock until AFTER hook. // - return a.conn.RunInTx(ctx, func(tx bun.Tx) error { + return a.db.RunInTx(ctx, func(tx bun.Tx) error { // create links between this account and any emojis it uses for _, i := range account.EmojiIDs { if _, err := tx.NewInsert().Model(>smodel.AccountToEmoji{ @@ -317,7 +317,7 @@ func (a *accountDB) PutAccount(ctx context.Context, account *gtsmodel.Account) d }) } -func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account, columns ...string) db.Error { +func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account, columns ...string) error { account.UpdatedAt = time.Now() if len(columns) > 0 { // If we're updating by column, ensure "updated_at" is included. @@ -328,7 +328,7 @@ func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account // It is safe to run this database transaction within cache.Store // as the cache does not attempt a mutex lock until AFTER hook. // - return a.conn.RunInTx(ctx, func(tx bun.Tx) error { + return a.db.RunInTx(ctx, func(tx bun.Tx) error { // create links between this account and any emojis it uses // first clear out any old emoji links if _, err := tx. @@ -362,7 +362,7 @@ func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account }) } -func (a *accountDB) DeleteAccount(ctx context.Context, id string) db.Error { +func (a *accountDB) DeleteAccount(ctx context.Context, id string) error { defer a.state.Caches.GTS.Account().Invalidate("ID", id) // Load account into cache before attempting a delete, @@ -376,7 +376,7 @@ func (a *accountDB) DeleteAccount(ctx context.Context, id string) db.Error { return err } - return a.conn.RunInTx(ctx, func(tx bun.Tx) error { + return a.db.RunInTx(ctx, func(tx bun.Tx) error { // clear out any emoji links if _, err := tx. NewDelete(). @@ -396,10 +396,10 @@ func (a *accountDB) DeleteAccount(ctx context.Context, id string) db.Error { }) } -func (a *accountDB) GetAccountLastPosted(ctx context.Context, accountID string, webOnly bool) (time.Time, db.Error) { +func (a *accountDB) GetAccountLastPosted(ctx context.Context, accountID string, webOnly bool) (time.Time, error) { createdAt := time.Time{} - q := a.conn. + q := a.db. NewSelect(). TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). Column("status.created_at"). @@ -416,12 +416,12 @@ func (a *accountDB) GetAccountLastPosted(ctx context.Context, accountID string, } if err := q.Scan(ctx, &createdAt); err != nil { - return time.Time{}, a.conn.ProcessError(err) + return time.Time{}, a.db.ProcessError(err) } return createdAt, nil } -func (a *accountDB) SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachment *gtsmodel.MediaAttachment, accountID string) db.Error { +func (a *accountDB) SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachment *gtsmodel.MediaAttachment, accountID string) error { if *mediaAttachment.Avatar && *mediaAttachment.Header { return errors.New("one media attachment cannot be both header and avatar") } @@ -437,26 +437,26 @@ func (a *accountDB) SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachmen } // TODO: there are probably more side effects here that need to be handled - if _, err := a.conn. + if _, err := a.db. NewInsert(). Model(mediaAttachment). Exec(ctx); err != nil { - return a.conn.ProcessError(err) + return a.db.ProcessError(err) } - if _, err := a.conn. + if _, err := a.db. NewUpdate(). TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")). Set("? = ?", column, mediaAttachment.ID). Where("? = ?", bun.Ident("account.id"), accountID). Exec(ctx); err != nil { - return a.conn.ProcessError(err) + return a.db.ProcessError(err) } return nil } -func (a *accountDB) GetAccountCustomCSSByUsername(ctx context.Context, username string) (string, db.Error) { +func (a *accountDB) GetAccountCustomCSSByUsername(ctx context.Context, username string) (string, error) { account, err := a.GetAccountByUsernameDomain(ctx, username, "") if err != nil { return "", err @@ -469,7 +469,7 @@ func (a *accountDB) GetAccountsUsingEmoji(ctx context.Context, emojiID string) ( var accountIDs []string // Create SELECT account query. - q := a.conn.NewSelect(). + q := a.db.NewSelect(). Table("accounts"). Column("id") @@ -486,37 +486,37 @@ func (a *accountDB) GetAccountsUsingEmoji(ctx context.Context, emojiID string) ( // Execute the query, scanning destination into accountIDs. if _, err := q.Exec(ctx, &accountIDs); err != nil { - return nil, a.conn.ProcessError(err) + return nil, a.db.ProcessError(err) } // Convert account IDs into account objects. return a.GetAccountsByIDs(ctx, accountIDs) } -func (a *accountDB) GetAccountFaves(ctx context.Context, accountID string) ([]*gtsmodel.StatusFave, db.Error) { +func (a *accountDB) GetAccountFaves(ctx context.Context, accountID string) ([]*gtsmodel.StatusFave, error) { faves := new([]*gtsmodel.StatusFave) - if err := a.conn. + if err := a.db. NewSelect(). Model(faves). Where("? = ?", bun.Ident("status_fave.account_id"), accountID). Scan(ctx); err != nil { - return nil, a.conn.ProcessError(err) + return nil, a.db.ProcessError(err) } return *faves, nil } -func (a *accountDB) CountAccountStatuses(ctx context.Context, accountID string) (int, db.Error) { - return a.conn. +func (a *accountDB) CountAccountStatuses(ctx context.Context, accountID string) (int, error) { + return a.db. NewSelect(). TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). Where("? = ?", bun.Ident("status.account_id"), accountID). Count(ctx) } -func (a *accountDB) CountAccountPinned(ctx context.Context, accountID string) (int, db.Error) { - return a.conn. +func (a *accountDB) CountAccountPinned(ctx context.Context, accountID string) (int, error) { + return a.db. NewSelect(). TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). Where("? = ?", bun.Ident("status.account_id"), accountID). @@ -524,7 +524,7 @@ func (a *accountDB) CountAccountPinned(ctx context.Context, accountID string) (i Count(ctx) } -func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, limit int, excludeReplies bool, excludeReblogs bool, maxID string, minID string, mediaOnly bool, publicOnly bool) ([]*gtsmodel.Status, db.Error) { +func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, limit int, excludeReplies bool, excludeReblogs bool, maxID string, minID string, mediaOnly bool, publicOnly bool) ([]*gtsmodel.Status, error) { // Ensure reasonable if limit < 0 { limit = 0 @@ -536,7 +536,7 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li frontToBack = true ) - q := a.conn. + q := a.db. NewSelect(). TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). // Select only IDs from table @@ -562,7 +562,7 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li // implementation differs between SQLite and Postgres, // so we have to be thorough to cover all eventualities q = q.WhereGroup(" AND ", func(q *bun.SelectQuery) *bun.SelectQuery { - switch a.conn.Dialect().Name() { + switch a.db.Dialect().Name() { case dialect.PG: return q. Where("? IS NOT NULL", bun.Ident("status.attachments")). @@ -613,7 +613,7 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li } if err := q.Scan(ctx, &statusIDs); err != nil { - return nil, a.conn.ProcessError(err) + return nil, a.db.ProcessError(err) } // If we're paging up, we still want statuses @@ -628,10 +628,10 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li return a.statusesFromIDs(ctx, statusIDs) } -func (a *accountDB) GetAccountPinnedStatuses(ctx context.Context, accountID string) ([]*gtsmodel.Status, db.Error) { +func (a *accountDB) GetAccountPinnedStatuses(ctx context.Context, accountID string) ([]*gtsmodel.Status, error) { statusIDs := []string{} - q := a.conn. + q := a.db. NewSelect(). TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). Column("status.id"). @@ -640,13 +640,13 @@ func (a *accountDB) GetAccountPinnedStatuses(ctx context.Context, accountID stri Order("status.pinned_at DESC") if err := q.Scan(ctx, &statusIDs); err != nil { - return nil, a.conn.ProcessError(err) + return nil, a.db.ProcessError(err) } return a.statusesFromIDs(ctx, statusIDs) } -func (a *accountDB) GetAccountWebStatuses(ctx context.Context, accountID string, limit int, maxID string) ([]*gtsmodel.Status, db.Error) { +func (a *accountDB) GetAccountWebStatuses(ctx context.Context, accountID string, limit int, maxID string) ([]*gtsmodel.Status, error) { // Ensure reasonable if limit < 0 { limit = 0 @@ -655,7 +655,7 @@ func (a *accountDB) GetAccountWebStatuses(ctx context.Context, accountID string, // Make educated guess for slice size statusIDs := make([]string, 0, limit) - q := a.conn. + q := a.db. NewSelect(). TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). // Select only IDs from table @@ -688,16 +688,16 @@ func (a *accountDB) GetAccountWebStatuses(ctx context.Context, accountID string, q = q.Order("status.id DESC") if err := q.Scan(ctx, &statusIDs); err != nil { - return nil, a.conn.ProcessError(err) + return nil, a.db.ProcessError(err) } return a.statusesFromIDs(ctx, statusIDs) } -func (a *accountDB) GetAccountBlocks(ctx context.Context, accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, db.Error) { +func (a *accountDB) GetAccountBlocks(ctx context.Context, accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, error) { blocks := []*gtsmodel.Block{} - fq := a.conn. + fq := a.db. NewSelect(). Model(&blocks). Where("? = ?", bun.Ident("block.account_id"), accountID). @@ -717,7 +717,7 @@ func (a *accountDB) GetAccountBlocks(ctx context.Context, accountID string, maxI } if err := fq.Scan(ctx); err != nil { - return nil, "", "", a.conn.ProcessError(err) + return nil, "", "", a.db.ProcessError(err) } if len(blocks) == 0 { @@ -734,7 +734,7 @@ func (a *accountDB) GetAccountBlocks(ctx context.Context, accountID string, maxI return accounts, nextMaxID, prevMinID, nil } -func (a *accountDB) statusesFromIDs(ctx context.Context, statusIDs []string) ([]*gtsmodel.Status, db.Error) { +func (a *accountDB) statusesFromIDs(ctx context.Context, statusIDs []string) ([]*gtsmodel.Status, error) { // Catch case of no statuses early if len(statusIDs) == 0 { return nil, db.ErrNoEntries |