summaryrefslogtreecommitdiff
path: root/internal/db/bundb/account.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/db/bundb/account.go')
-rw-r--r--internal/db/bundb/account.go120
1 files changed, 60 insertions, 60 deletions
diff --git a/internal/db/bundb/account.go b/internal/db/bundb/account.go
index 179db6bb3..2ef1618db 100644
--- a/internal/db/bundb/account.go
+++ b/internal/db/bundb/account.go
@@ -38,16 +38,16 @@ import (
)
type accountDB struct {
- conn *DBConn
+ db *WrappedDB
state *state.State
}
-func (a *accountDB) GetAccountByID(ctx context.Context, id string) (*gtsmodel.Account, db.Error) {
+func (a *accountDB) GetAccountByID(ctx context.Context, id string) (*gtsmodel.Account, error) {
return a.getAccount(
ctx,
"ID",
func(account *gtsmodel.Account) error {
- return a.conn.NewSelect().
+ return a.db.NewSelect().
Model(account).
Where("? = ?", bun.Ident("account.id"), id).
Scan(ctx)
@@ -77,12 +77,12 @@ func (a *accountDB) GetAccountsByIDs(ctx context.Context, ids []string) ([]*gtsm
return accounts, nil
}
-func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) {
+func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.Account, error) {
return a.getAccount(
ctx,
"URI",
func(account *gtsmodel.Account) error {
- return a.conn.NewSelect().
+ return a.db.NewSelect().
Model(account).
Where("? = ?", bun.Ident("account.uri"), uri).
Scan(ctx)
@@ -91,12 +91,12 @@ func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.
)
}
-func (a *accountDB) GetAccountByURL(ctx context.Context, url string) (*gtsmodel.Account, db.Error) {
+func (a *accountDB) GetAccountByURL(ctx context.Context, url string) (*gtsmodel.Account, error) {
return a.getAccount(
ctx,
"URL",
func(account *gtsmodel.Account) error {
- return a.conn.NewSelect().
+ return a.db.NewSelect().
Model(account).
Where("? = ?", bun.Ident("account.url"), url).
Scan(ctx)
@@ -105,7 +105,7 @@ func (a *accountDB) GetAccountByURL(ctx context.Context, url string) (*gtsmodel.
)
}
-func (a *accountDB) GetAccountByUsernameDomain(ctx context.Context, username string, domain string) (*gtsmodel.Account, db.Error) {
+func (a *accountDB) GetAccountByUsernameDomain(ctx context.Context, username string, domain string) (*gtsmodel.Account, error) {
if domain != "" {
// Normalize the domain as punycode
var err error
@@ -119,7 +119,7 @@ func (a *accountDB) GetAccountByUsernameDomain(ctx context.Context, username str
ctx,
"Username.Domain",
func(account *gtsmodel.Account) error {
- q := a.conn.NewSelect().
+ q := a.db.NewSelect().
Model(account)
if domain != "" {
@@ -139,12 +139,12 @@ func (a *accountDB) GetAccountByUsernameDomain(ctx context.Context, username str
)
}
-func (a *accountDB) GetAccountByPubkeyID(ctx context.Context, id string) (*gtsmodel.Account, db.Error) {
+func (a *accountDB) GetAccountByPubkeyID(ctx context.Context, id string) (*gtsmodel.Account, error) {
return a.getAccount(
ctx,
"PublicKeyURI",
func(account *gtsmodel.Account) error {
- return a.conn.NewSelect().
+ return a.db.NewSelect().
Model(account).
Where("? = ?", bun.Ident("account.public_key_uri"), id).
Scan(ctx)
@@ -153,12 +153,12 @@ func (a *accountDB) GetAccountByPubkeyID(ctx context.Context, id string) (*gtsmo
)
}
-func (a *accountDB) GetAccountByInboxURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) {
+func (a *accountDB) GetAccountByInboxURI(ctx context.Context, uri string) (*gtsmodel.Account, error) {
return a.getAccount(
ctx,
"InboxURI",
func(account *gtsmodel.Account) error {
- return a.conn.NewSelect().
+ return a.db.NewSelect().
Model(account).
Where("? = ?", bun.Ident("account.inbox_uri"), uri).
Scan(ctx)
@@ -167,12 +167,12 @@ func (a *accountDB) GetAccountByInboxURI(ctx context.Context, uri string) (*gtsm
)
}
-func (a *accountDB) GetAccountByOutboxURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) {
+func (a *accountDB) GetAccountByOutboxURI(ctx context.Context, uri string) (*gtsmodel.Account, error) {
return a.getAccount(
ctx,
"OutboxURI",
func(account *gtsmodel.Account) error {
- return a.conn.NewSelect().
+ return a.db.NewSelect().
Model(account).
Where("? = ?", bun.Ident("account.outbox_uri"), uri).
Scan(ctx)
@@ -181,12 +181,12 @@ func (a *accountDB) GetAccountByOutboxURI(ctx context.Context, uri string) (*gts
)
}
-func (a *accountDB) GetAccountByFollowersURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) {
+func (a *accountDB) GetAccountByFollowersURI(ctx context.Context, uri string) (*gtsmodel.Account, error) {
return a.getAccount(
ctx,
"FollowersURI",
func(account *gtsmodel.Account) error {
- return a.conn.NewSelect().
+ return a.db.NewSelect().
Model(account).
Where("? = ?", bun.Ident("account.followers_uri"), uri).
Scan(ctx)
@@ -195,12 +195,12 @@ func (a *accountDB) GetAccountByFollowersURI(ctx context.Context, uri string) (*
)
}
-func (a *accountDB) GetAccountByFollowingURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) {
+func (a *accountDB) GetAccountByFollowingURI(ctx context.Context, uri string) (*gtsmodel.Account, error) {
return a.getAccount(
ctx,
"FollowingURI",
func(account *gtsmodel.Account) error {
- return a.conn.NewSelect().
+ return a.db.NewSelect().
Model(account).
Where("? = ?", bun.Ident("account.following_uri"), uri).
Scan(ctx)
@@ -209,7 +209,7 @@ func (a *accountDB) GetAccountByFollowingURI(ctx context.Context, uri string) (*
)
}
-func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, db.Error) {
+func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, error) {
var username string
if domain == "" {
@@ -223,14 +223,14 @@ func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gts
return a.GetAccountByUsernameDomain(ctx, username, domain)
}
-func (a *accountDB) getAccount(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Account) error, keyParts ...any) (*gtsmodel.Account, db.Error) {
+func (a *accountDB) getAccount(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Account) error, keyParts ...any) (*gtsmodel.Account, error) {
// Fetch account from database cache with loader callback
account, err := a.state.Caches.GTS.Account().Load(lookup, func() (*gtsmodel.Account, error) {
var account gtsmodel.Account
// Not cached! Perform database query
if err := dbQuery(&account); err != nil {
- return nil, a.conn.ProcessError(err)
+ return nil, a.db.ProcessError(err)
}
return &account, nil
@@ -294,12 +294,12 @@ func (a *accountDB) PopulateAccount(ctx context.Context, account *gtsmodel.Accou
return errs.Combine()
}
-func (a *accountDB) PutAccount(ctx context.Context, account *gtsmodel.Account) db.Error {
+func (a *accountDB) PutAccount(ctx context.Context, account *gtsmodel.Account) error {
return a.state.Caches.GTS.Account().Store(account, func() error {
// It is safe to run this database transaction within cache.Store
// as the cache does not attempt a mutex lock until AFTER hook.
//
- return a.conn.RunInTx(ctx, func(tx bun.Tx) error {
+ return a.db.RunInTx(ctx, func(tx bun.Tx) error {
// create links between this account and any emojis it uses
for _, i := range account.EmojiIDs {
if _, err := tx.NewInsert().Model(&gtsmodel.AccountToEmoji{
@@ -317,7 +317,7 @@ func (a *accountDB) PutAccount(ctx context.Context, account *gtsmodel.Account) d
})
}
-func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account, columns ...string) db.Error {
+func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account, columns ...string) error {
account.UpdatedAt = time.Now()
if len(columns) > 0 {
// If we're updating by column, ensure "updated_at" is included.
@@ -328,7 +328,7 @@ func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account
// It is safe to run this database transaction within cache.Store
// as the cache does not attempt a mutex lock until AFTER hook.
//
- return a.conn.RunInTx(ctx, func(tx bun.Tx) error {
+ return a.db.RunInTx(ctx, func(tx bun.Tx) error {
// create links between this account and any emojis it uses
// first clear out any old emoji links
if _, err := tx.
@@ -362,7 +362,7 @@ func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account
})
}
-func (a *accountDB) DeleteAccount(ctx context.Context, id string) db.Error {
+func (a *accountDB) DeleteAccount(ctx context.Context, id string) error {
defer a.state.Caches.GTS.Account().Invalidate("ID", id)
// Load account into cache before attempting a delete,
@@ -376,7 +376,7 @@ func (a *accountDB) DeleteAccount(ctx context.Context, id string) db.Error {
return err
}
- return a.conn.RunInTx(ctx, func(tx bun.Tx) error {
+ return a.db.RunInTx(ctx, func(tx bun.Tx) error {
// clear out any emoji links
if _, err := tx.
NewDelete().
@@ -396,10 +396,10 @@ func (a *accountDB) DeleteAccount(ctx context.Context, id string) db.Error {
})
}
-func (a *accountDB) GetAccountLastPosted(ctx context.Context, accountID string, webOnly bool) (time.Time, db.Error) {
+func (a *accountDB) GetAccountLastPosted(ctx context.Context, accountID string, webOnly bool) (time.Time, error) {
createdAt := time.Time{}
- q := a.conn.
+ q := a.db.
NewSelect().
TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
Column("status.created_at").
@@ -416,12 +416,12 @@ func (a *accountDB) GetAccountLastPosted(ctx context.Context, accountID string,
}
if err := q.Scan(ctx, &createdAt); err != nil {
- return time.Time{}, a.conn.ProcessError(err)
+ return time.Time{}, a.db.ProcessError(err)
}
return createdAt, nil
}
-func (a *accountDB) SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachment *gtsmodel.MediaAttachment, accountID string) db.Error {
+func (a *accountDB) SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachment *gtsmodel.MediaAttachment, accountID string) error {
if *mediaAttachment.Avatar && *mediaAttachment.Header {
return errors.New("one media attachment cannot be both header and avatar")
}
@@ -437,26 +437,26 @@ func (a *accountDB) SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachmen
}
// TODO: there are probably more side effects here that need to be handled
- if _, err := a.conn.
+ if _, err := a.db.
NewInsert().
Model(mediaAttachment).
Exec(ctx); err != nil {
- return a.conn.ProcessError(err)
+ return a.db.ProcessError(err)
}
- if _, err := a.conn.
+ if _, err := a.db.
NewUpdate().
TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")).
Set("? = ?", column, mediaAttachment.ID).
Where("? = ?", bun.Ident("account.id"), accountID).
Exec(ctx); err != nil {
- return a.conn.ProcessError(err)
+ return a.db.ProcessError(err)
}
return nil
}
-func (a *accountDB) GetAccountCustomCSSByUsername(ctx context.Context, username string) (string, db.Error) {
+func (a *accountDB) GetAccountCustomCSSByUsername(ctx context.Context, username string) (string, error) {
account, err := a.GetAccountByUsernameDomain(ctx, username, "")
if err != nil {
return "", err
@@ -469,7 +469,7 @@ func (a *accountDB) GetAccountsUsingEmoji(ctx context.Context, emojiID string) (
var accountIDs []string
// Create SELECT account query.
- q := a.conn.NewSelect().
+ q := a.db.NewSelect().
Table("accounts").
Column("id")
@@ -486,37 +486,37 @@ func (a *accountDB) GetAccountsUsingEmoji(ctx context.Context, emojiID string) (
// Execute the query, scanning destination into accountIDs.
if _, err := q.Exec(ctx, &accountIDs); err != nil {
- return nil, a.conn.ProcessError(err)
+ return nil, a.db.ProcessError(err)
}
// Convert account IDs into account objects.
return a.GetAccountsByIDs(ctx, accountIDs)
}
-func (a *accountDB) GetAccountFaves(ctx context.Context, accountID string) ([]*gtsmodel.StatusFave, db.Error) {
+func (a *accountDB) GetAccountFaves(ctx context.Context, accountID string) ([]*gtsmodel.StatusFave, error) {
faves := new([]*gtsmodel.StatusFave)
- if err := a.conn.
+ if err := a.db.
NewSelect().
Model(faves).
Where("? = ?", bun.Ident("status_fave.account_id"), accountID).
Scan(ctx); err != nil {
- return nil, a.conn.ProcessError(err)
+ return nil, a.db.ProcessError(err)
}
return *faves, nil
}
-func (a *accountDB) CountAccountStatuses(ctx context.Context, accountID string) (int, db.Error) {
- return a.conn.
+func (a *accountDB) CountAccountStatuses(ctx context.Context, accountID string) (int, error) {
+ return a.db.
NewSelect().
TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
Where("? = ?", bun.Ident("status.account_id"), accountID).
Count(ctx)
}
-func (a *accountDB) CountAccountPinned(ctx context.Context, accountID string) (int, db.Error) {
- return a.conn.
+func (a *accountDB) CountAccountPinned(ctx context.Context, accountID string) (int, error) {
+ return a.db.
NewSelect().
TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
Where("? = ?", bun.Ident("status.account_id"), accountID).
@@ -524,7 +524,7 @@ func (a *accountDB) CountAccountPinned(ctx context.Context, accountID string) (i
Count(ctx)
}
-func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, limit int, excludeReplies bool, excludeReblogs bool, maxID string, minID string, mediaOnly bool, publicOnly bool) ([]*gtsmodel.Status, db.Error) {
+func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, limit int, excludeReplies bool, excludeReblogs bool, maxID string, minID string, mediaOnly bool, publicOnly bool) ([]*gtsmodel.Status, error) {
// Ensure reasonable
if limit < 0 {
limit = 0
@@ -536,7 +536,7 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li
frontToBack = true
)
- q := a.conn.
+ q := a.db.
NewSelect().
TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
// Select only IDs from table
@@ -562,7 +562,7 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li
// implementation differs between SQLite and Postgres,
// so we have to be thorough to cover all eventualities
q = q.WhereGroup(" AND ", func(q *bun.SelectQuery) *bun.SelectQuery {
- switch a.conn.Dialect().Name() {
+ switch a.db.Dialect().Name() {
case dialect.PG:
return q.
Where("? IS NOT NULL", bun.Ident("status.attachments")).
@@ -613,7 +613,7 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li
}
if err := q.Scan(ctx, &statusIDs); err != nil {
- return nil, a.conn.ProcessError(err)
+ return nil, a.db.ProcessError(err)
}
// If we're paging up, we still want statuses
@@ -628,10 +628,10 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li
return a.statusesFromIDs(ctx, statusIDs)
}
-func (a *accountDB) GetAccountPinnedStatuses(ctx context.Context, accountID string) ([]*gtsmodel.Status, db.Error) {
+func (a *accountDB) GetAccountPinnedStatuses(ctx context.Context, accountID string) ([]*gtsmodel.Status, error) {
statusIDs := []string{}
- q := a.conn.
+ q := a.db.
NewSelect().
TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
Column("status.id").
@@ -640,13 +640,13 @@ func (a *accountDB) GetAccountPinnedStatuses(ctx context.Context, accountID stri
Order("status.pinned_at DESC")
if err := q.Scan(ctx, &statusIDs); err != nil {
- return nil, a.conn.ProcessError(err)
+ return nil, a.db.ProcessError(err)
}
return a.statusesFromIDs(ctx, statusIDs)
}
-func (a *accountDB) GetAccountWebStatuses(ctx context.Context, accountID string, limit int, maxID string) ([]*gtsmodel.Status, db.Error) {
+func (a *accountDB) GetAccountWebStatuses(ctx context.Context, accountID string, limit int, maxID string) ([]*gtsmodel.Status, error) {
// Ensure reasonable
if limit < 0 {
limit = 0
@@ -655,7 +655,7 @@ func (a *accountDB) GetAccountWebStatuses(ctx context.Context, accountID string,
// Make educated guess for slice size
statusIDs := make([]string, 0, limit)
- q := a.conn.
+ q := a.db.
NewSelect().
TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
// Select only IDs from table
@@ -688,16 +688,16 @@ func (a *accountDB) GetAccountWebStatuses(ctx context.Context, accountID string,
q = q.Order("status.id DESC")
if err := q.Scan(ctx, &statusIDs); err != nil {
- return nil, a.conn.ProcessError(err)
+ return nil, a.db.ProcessError(err)
}
return a.statusesFromIDs(ctx, statusIDs)
}
-func (a *accountDB) GetAccountBlocks(ctx context.Context, accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, db.Error) {
+func (a *accountDB) GetAccountBlocks(ctx context.Context, accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, error) {
blocks := []*gtsmodel.Block{}
- fq := a.conn.
+ fq := a.db.
NewSelect().
Model(&blocks).
Where("? = ?", bun.Ident("block.account_id"), accountID).
@@ -717,7 +717,7 @@ func (a *accountDB) GetAccountBlocks(ctx context.Context, accountID string, maxI
}
if err := fq.Scan(ctx); err != nil {
- return nil, "", "", a.conn.ProcessError(err)
+ return nil, "", "", a.db.ProcessError(err)
}
if len(blocks) == 0 {
@@ -734,7 +734,7 @@ func (a *accountDB) GetAccountBlocks(ctx context.Context, accountID string, maxI
return accounts, nextMaxID, prevMinID, nil
}
-func (a *accountDB) statusesFromIDs(ctx context.Context, statusIDs []string) ([]*gtsmodel.Status, db.Error) {
+func (a *accountDB) statusesFromIDs(ctx context.Context, statusIDs []string) ([]*gtsmodel.Status, error) {
// Catch case of no statuses early
if len(statusIDs) == 0 {
return nil, db.ErrNoEntries