diff options
Diffstat (limited to 'internal/db/bundb/account.go')
-rw-r--r-- | internal/db/bundb/account.go | 75 |
1 files changed, 42 insertions, 33 deletions
diff --git a/internal/db/bundb/account.go b/internal/db/bundb/account.go index aef1f3281..d7d45a739 100644 --- a/internal/db/bundb/account.go +++ b/internal/db/bundb/account.go @@ -25,7 +25,6 @@ import ( "strings" "time" - "github.com/sirupsen/logrus" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" @@ -34,8 +33,7 @@ import ( type accountDB struct { config *config.Config - conn *bun.DB - log *logrus.Logger + conn *DBConn } func (a *accountDB) newAccountQ(account *gtsmodel.Account) *bun.SelectQuery { @@ -52,9 +50,11 @@ func (a *accountDB) GetAccountByID(ctx context.Context, id string) (*gtsmodel.Ac q := a.newAccountQ(account). Where("account.id = ?", id) - err := processErrorResponse(q.Scan(ctx)) - - return account, err + err := q.Scan(ctx) + if err != nil { + return nil, a.conn.ProcessError(err) + } + return account, nil } func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) { @@ -63,9 +63,11 @@ func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel. q := a.newAccountQ(account). Where("account.uri = ?", uri) - err := processErrorResponse(q.Scan(ctx)) - - return account, err + err := q.Scan(ctx) + if err != nil { + return nil, a.conn.ProcessError(err) + } + return account, nil } func (a *accountDB) GetAccountByURL(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) { @@ -74,9 +76,11 @@ func (a *accountDB) GetAccountByURL(ctx context.Context, uri string) (*gtsmodel. q := a.newAccountQ(account). Where("account.url = ?", uri) - err := processErrorResponse(q.Scan(ctx)) - - return account, err + err := q.Scan(ctx) + if err != nil { + return nil, a.conn.ProcessError(err) + } + return account, nil } func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, db.Error) { @@ -92,10 +96,10 @@ func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account WherePK() _, err := q.Exec(ctx) - - err = processErrorResponse(err) - - return account, err + if err != nil { + return nil, a.conn.ProcessError(err) + } + return account, nil } func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, db.Error) { @@ -113,9 +117,11 @@ func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gts WhereGroup(" AND ", whereEmptyOrNull("domain")) } - err := processErrorResponse(q.Scan(ctx)) - - return account, err + err := q.Scan(ctx) + if err != nil { + return nil, a.conn.ProcessError(err) + } + return account, nil } func (a *accountDB) GetAccountLastPosted(ctx context.Context, accountID string) (time.Time, db.Error) { @@ -129,9 +135,11 @@ func (a *accountDB) GetAccountLastPosted(ctx context.Context, accountID string) Where("account_id = ?", accountID). Column("created_at") - err := processErrorResponse(q.Scan(ctx)) - - return status.CreatedAt, err + err := q.Scan(ctx) + if err != nil { + return time.Time{}, a.conn.ProcessError(err) + } + return status.CreatedAt, nil } func (a *accountDB) SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachment *gtsmodel.MediaAttachment, accountID string) db.Error { @@ -153,17 +161,17 @@ func (a *accountDB) SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachmen NewInsert(). Model(mediaAttachment). Exec(ctx); err != nil { - return err + return a.conn.ProcessError(err) } - if _, err := a.conn. NewUpdate(). Model(>smodel.Account{}). Set(fmt.Sprintf("%s_media_attachment_id = ?", headerOrAVI), mediaAttachment.ID). Where("id = ?", accountID). Exec(ctx); err != nil { - return err + return a.conn.ProcessError(err) } + return nil } @@ -174,9 +182,11 @@ func (a *accountDB) GetLocalAccountByUsername(ctx context.Context, username stri Where("username = ?", username). WhereGroup(" AND ", whereEmptyOrNull("domain")) - err := processErrorResponse(q.Scan(ctx)) - - return account, err + err := q.Scan(ctx) + if err != nil { + return nil, a.conn.ProcessError(err) + } + return account, nil } func (a *accountDB) GetAccountFaves(ctx context.Context, accountID string) ([]*gtsmodel.StatusFave, db.Error) { @@ -187,8 +197,9 @@ func (a *accountDB) GetAccountFaves(ctx context.Context, accountID string) ([]*g Model(faves). Where("account_id = ?", accountID). Scan(ctx); err != nil { - return nil, err + return nil, a.conn.ProcessError(err) } + return *faves, nil } @@ -201,7 +212,6 @@ func (a *accountDB) CountAccountStatuses(ctx context.Context, accountID string) } func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, db.Error) { - a.log.Debugf("getting statuses for account %s", accountID) statuses := []*gtsmodel.Status{} q := a.conn. @@ -238,14 +248,13 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li } if err := q.Scan(ctx); err != nil { - return nil, err + return nil, a.conn.ProcessError(err) } if len(statuses) == 0 { return nil, db.ErrNoEntries } - a.log.Debugf("returning statuses for account %s", accountID) return statuses, nil } @@ -273,7 +282,7 @@ func (a *accountDB) GetAccountBlocks(ctx context.Context, accountID string, maxI err := fq.Scan(ctx) if err != nil { - return nil, "", "", err + return nil, "", "", a.conn.ProcessError(err) } if len(blocks) == 0 { |