summaryrefslogtreecommitdiff
path: root/internal/db/bundb/conn.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/db/bundb/conn.go')
-rw-r--r--internal/db/bundb/conn.go41
1 files changed, 24 insertions, 17 deletions
diff --git a/internal/db/bundb/conn.go b/internal/db/bundb/conn.go
index baa0baeae..1c85f6f6f 100644
--- a/internal/db/bundb/conn.go
+++ b/internal/db/bundb/conn.go
@@ -11,13 +11,11 @@ import (
// DBConn wrapps a bun.DB conn to provide SQL-type specific additional functionality
type DBConn struct {
- // TODO: move *Config here, no need to be in each struct type
-
errProc func(error) db.Error // errProc is the SQL-type specific error processor
*bun.DB // DB is the underlying bun.DB connection
}
-// WrapDBConn @TODO
+// WrapDBConn wraps a bun DB connection to provide our own error processing dependent on DB dialect.
func WrapDBConn(dbConn *bun.DB) *DBConn {
var errProc func(error) db.Error
switch dbConn.Dialect().Name() {
@@ -36,21 +34,31 @@ func WrapDBConn(dbConn *bun.DB) *DBConn {
// RunInTx wraps execution of the supplied transaction function.
func (conn *DBConn) RunInTx(ctx context.Context, fn func(bun.Tx) error) db.Error {
- // Acquire a new transaction
- tx, err := conn.BeginTx(ctx, nil)
- if err != nil {
- return conn.ProcessError(err)
- }
+ return conn.ProcessError(func() error {
+ // Acquire a new transaction
+ tx, err := conn.BeginTx(ctx, nil)
+ if err != nil {
+ return err
+ }
- // Perform supplied transaction
- if err = fn(tx); err != nil {
- tx.Rollback() //nolint
- return conn.ProcessError(err)
- }
+ var done bool
+
+ defer func() {
+ if !done {
+ _ = tx.Rollback()
+ }
+ }()
+
+ // Perform supplied transaction
+ if err := fn(tx); err != nil {
+ return err
+ }
- // Finally, commit transaction
- err = tx.Commit()
- return conn.ProcessError(err)
+ // Finally, commit
+ err = tx.Commit()
+ done = true
+ return err
+ }())
}
// ProcessError processes an error to replace any known values with our own db.Error types,
@@ -83,7 +91,6 @@ func (conn *DBConn) Exists(ctx context.Context, query *bun.SelectQuery) (bool, d
// NotExists is the functional opposite of conn.Exists()
func (conn *DBConn) NotExists(ctx context.Context, query *bun.SelectQuery) (bool, db.Error) {
- // Simply inverse of conn.exists()
exists, err := conn.Exists(ctx, query)
return !exists, err
}