summaryrefslogtreecommitdiff
path: root/internal/db/bundb/account.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/db/bundb/account.go')
-rw-r--r--internal/db/bundb/account.go75
1 files changed, 42 insertions, 33 deletions
diff --git a/internal/db/bundb/account.go b/internal/db/bundb/account.go
index aef1f3281..d7d45a739 100644
--- a/internal/db/bundb/account.go
+++ b/internal/db/bundb/account.go
@@ -25,7 +25,6 @@ import (
"strings"
"time"
- "github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
@@ -34,8 +33,7 @@ import (
type accountDB struct {
config *config.Config
- conn *bun.DB
- log *logrus.Logger
+ conn *DBConn
}
func (a *accountDB) newAccountQ(account *gtsmodel.Account) *bun.SelectQuery {
@@ -52,9 +50,11 @@ func (a *accountDB) GetAccountByID(ctx context.Context, id string) (*gtsmodel.Ac
q := a.newAccountQ(account).
Where("account.id = ?", id)
- err := processErrorResponse(q.Scan(ctx))
-
- return account, err
+ err := q.Scan(ctx)
+ if err != nil {
+ return nil, a.conn.ProcessError(err)
+ }
+ return account, nil
}
func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) {
@@ -63,9 +63,11 @@ func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.
q := a.newAccountQ(account).
Where("account.uri = ?", uri)
- err := processErrorResponse(q.Scan(ctx))
-
- return account, err
+ err := q.Scan(ctx)
+ if err != nil {
+ return nil, a.conn.ProcessError(err)
+ }
+ return account, nil
}
func (a *accountDB) GetAccountByURL(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) {
@@ -74,9 +76,11 @@ func (a *accountDB) GetAccountByURL(ctx context.Context, uri string) (*gtsmodel.
q := a.newAccountQ(account).
Where("account.url = ?", uri)
- err := processErrorResponse(q.Scan(ctx))
-
- return account, err
+ err := q.Scan(ctx)
+ if err != nil {
+ return nil, a.conn.ProcessError(err)
+ }
+ return account, nil
}
func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, db.Error) {
@@ -92,10 +96,10 @@ func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account
WherePK()
_, err := q.Exec(ctx)
-
- err = processErrorResponse(err)
-
- return account, err
+ if err != nil {
+ return nil, a.conn.ProcessError(err)
+ }
+ return account, nil
}
func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, db.Error) {
@@ -113,9 +117,11 @@ func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gts
WhereGroup(" AND ", whereEmptyOrNull("domain"))
}
- err := processErrorResponse(q.Scan(ctx))
-
- return account, err
+ err := q.Scan(ctx)
+ if err != nil {
+ return nil, a.conn.ProcessError(err)
+ }
+ return account, nil
}
func (a *accountDB) GetAccountLastPosted(ctx context.Context, accountID string) (time.Time, db.Error) {
@@ -129,9 +135,11 @@ func (a *accountDB) GetAccountLastPosted(ctx context.Context, accountID string)
Where("account_id = ?", accountID).
Column("created_at")
- err := processErrorResponse(q.Scan(ctx))
-
- return status.CreatedAt, err
+ err := q.Scan(ctx)
+ if err != nil {
+ return time.Time{}, a.conn.ProcessError(err)
+ }
+ return status.CreatedAt, nil
}
func (a *accountDB) SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachment *gtsmodel.MediaAttachment, accountID string) db.Error {
@@ -153,17 +161,17 @@ func (a *accountDB) SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachmen
NewInsert().
Model(mediaAttachment).
Exec(ctx); err != nil {
- return err
+ return a.conn.ProcessError(err)
}
-
if _, err := a.conn.
NewUpdate().
Model(&gtsmodel.Account{}).
Set(fmt.Sprintf("%s_media_attachment_id = ?", headerOrAVI), mediaAttachment.ID).
Where("id = ?", accountID).
Exec(ctx); err != nil {
- return err
+ return a.conn.ProcessError(err)
}
+
return nil
}
@@ -174,9 +182,11 @@ func (a *accountDB) GetLocalAccountByUsername(ctx context.Context, username stri
Where("username = ?", username).
WhereGroup(" AND ", whereEmptyOrNull("domain"))
- err := processErrorResponse(q.Scan(ctx))
-
- return account, err
+ err := q.Scan(ctx)
+ if err != nil {
+ return nil, a.conn.ProcessError(err)
+ }
+ return account, nil
}
func (a *accountDB) GetAccountFaves(ctx context.Context, accountID string) ([]*gtsmodel.StatusFave, db.Error) {
@@ -187,8 +197,9 @@ func (a *accountDB) GetAccountFaves(ctx context.Context, accountID string) ([]*g
Model(faves).
Where("account_id = ?", accountID).
Scan(ctx); err != nil {
- return nil, err
+ return nil, a.conn.ProcessError(err)
}
+
return *faves, nil
}
@@ -201,7 +212,6 @@ func (a *accountDB) CountAccountStatuses(ctx context.Context, accountID string)
}
func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, db.Error) {
- a.log.Debugf("getting statuses for account %s", accountID)
statuses := []*gtsmodel.Status{}
q := a.conn.
@@ -238,14 +248,13 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li
}
if err := q.Scan(ctx); err != nil {
- return nil, err
+ return nil, a.conn.ProcessError(err)
}
if len(statuses) == 0 {
return nil, db.ErrNoEntries
}
- a.log.Debugf("returning statuses for account %s", accountID)
return statuses, nil
}
@@ -273,7 +282,7 @@ func (a *accountDB) GetAccountBlocks(ctx context.Context, accountID string, maxI
err := fq.Scan(ctx)
if err != nil {
- return nil, "", "", err
+ return nil, "", "", a.conn.ProcessError(err)
}
if len(blocks) == 0 {