diff options
Diffstat (limited to 'internal/db/bundb/basic.go')
-rw-r--r-- | internal/db/bundb/basic.go | 45 |
1 files changed, 16 insertions, 29 deletions
diff --git a/internal/db/bundb/basic.go b/internal/db/bundb/basic.go index 983b6b810..a3a8d0ae9 100644 --- a/internal/db/bundb/basic.go +++ b/internal/db/bundb/basic.go @@ -21,9 +21,7 @@ package bundb import ( "context" "errors" - "strings" - "github.com/sirupsen/logrus" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/uptrace/bun" @@ -31,16 +29,12 @@ import ( type basicDB struct { config *config.Config - conn *bun.DB - log *logrus.Logger + conn *DBConn } func (b *basicDB) Put(ctx context.Context, i interface{}) db.Error { _, err := b.conn.NewInsert().Model(i).Exec(ctx) - if err != nil && strings.Contains(err.Error(), "duplicate key value violates unique constraint") { - return db.ErrAlreadyExists - } - return err + return b.conn.ProcessError(err) } func (b *basicDB) GetByID(ctx context.Context, id string, i interface{}) db.Error { @@ -49,7 +43,8 @@ func (b *basicDB) GetByID(ctx context.Context, id string, i interface{}) db.Erro Model(i). Where("id = ?", id) - return processErrorResponse(q.Scan(ctx)) + err := q.Scan(ctx) + return b.conn.ProcessError(err) } func (b *basicDB) GetWhere(ctx context.Context, where []db.Where, i interface{}) db.Error { @@ -59,7 +54,6 @@ func (b *basicDB) GetWhere(ctx context.Context, where []db.Where, i interface{}) q := b.conn.NewSelect().Model(i) for _, w := range where { - if w.Value == nil { q = q.Where("? IS NULL", bun.Ident(w.Key)) } else { @@ -71,7 +65,8 @@ func (b *basicDB) GetWhere(ctx context.Context, where []db.Where, i interface{}) } } - return processErrorResponse(q.Scan(ctx)) + err := q.Scan(ctx) + return b.conn.ProcessError(err) } func (b *basicDB) GetAll(ctx context.Context, i interface{}) db.Error { @@ -79,7 +74,8 @@ func (b *basicDB) GetAll(ctx context.Context, i interface{}) db.Error { NewSelect(). Model(i) - return processErrorResponse(q.Scan(ctx)) + err := q.Scan(ctx) + return b.conn.ProcessError(err) } func (b *basicDB) DeleteByID(ctx context.Context, id string, i interface{}) db.Error { @@ -89,8 +85,7 @@ func (b *basicDB) DeleteByID(ctx context.Context, id string, i interface{}) db.E Where("id = ?", id) _, err := q.Exec(ctx) - - return processErrorResponse(err) + return b.conn.ProcessError(err) } func (b *basicDB) DeleteWhere(ctx context.Context, where []db.Where, i interface{}) db.Error { @@ -107,8 +102,7 @@ func (b *basicDB) DeleteWhere(ctx context.Context, where []db.Where, i interface } _, err := q.Exec(ctx) - - return processErrorResponse(err) + return b.conn.ProcessError(err) } func (b *basicDB) UpdateByID(ctx context.Context, id string, i interface{}) db.Error { @@ -118,8 +112,7 @@ func (b *basicDB) UpdateByID(ctx context.Context, id string, i interface{}) db.E WherePK() _, err := q.Exec(ctx) - - return processErrorResponse(err) + return b.conn.ProcessError(err) } func (b *basicDB) UpdateOneByID(ctx context.Context, id string, key string, value interface{}, i interface{}) db.Error { @@ -129,8 +122,7 @@ func (b *basicDB) UpdateOneByID(ctx context.Context, id string, key string, valu WherePK() _, err := q.Exec(ctx) - - return processErrorResponse(err) + return b.conn.ProcessError(err) } func (b *basicDB) UpdateWhere(ctx context.Context, where []db.Where, key string, value interface{}, i interface{}) db.Error { @@ -151,8 +143,7 @@ func (b *basicDB) UpdateWhere(ctx context.Context, where []db.Where, key string, q = q.Set("? = ?", bun.Safe(key), value) _, err := q.Exec(ctx) - - return processErrorResponse(err) + return b.conn.ProcessError(err) } func (b *basicDB) CreateTable(ctx context.Context, i interface{}) db.Error { @@ -162,7 +153,7 @@ func (b *basicDB) CreateTable(ctx context.Context, i interface{}) db.Error { func (b *basicDB) DropTable(ctx context.Context, i interface{}) db.Error { _, err := b.conn.NewDropTable().Model(i).IfExists().Exec(ctx) - return processErrorResponse(err) + return b.conn.ProcessError(err) } func (b *basicDB) IsHealthy(ctx context.Context) db.Error { @@ -170,10 +161,6 @@ func (b *basicDB) IsHealthy(ctx context.Context) db.Error { } func (b *basicDB) Stop(ctx context.Context) db.Error { - b.log.Info("closing db connection") - if err := b.conn.Close(); err != nil { - // only cancel if there's a problem closing the db - return err - } - return nil + b.conn.log.Info("closing db connection") + return b.conn.Close() } |