diff options
Diffstat (limited to 'internal/db/bundb/conn.go')
-rw-r--r-- | internal/db/bundb/conn.go | 41 |
1 files changed, 24 insertions, 17 deletions
diff --git a/internal/db/bundb/conn.go b/internal/db/bundb/conn.go index baa0baeae..1c85f6f6f 100644 --- a/internal/db/bundb/conn.go +++ b/internal/db/bundb/conn.go @@ -11,13 +11,11 @@ import ( // DBConn wrapps a bun.DB conn to provide SQL-type specific additional functionality type DBConn struct { - // TODO: move *Config here, no need to be in each struct type - errProc func(error) db.Error // errProc is the SQL-type specific error processor *bun.DB // DB is the underlying bun.DB connection } -// WrapDBConn @TODO +// WrapDBConn wraps a bun DB connection to provide our own error processing dependent on DB dialect. func WrapDBConn(dbConn *bun.DB) *DBConn { var errProc func(error) db.Error switch dbConn.Dialect().Name() { @@ -36,21 +34,31 @@ func WrapDBConn(dbConn *bun.DB) *DBConn { // RunInTx wraps execution of the supplied transaction function. func (conn *DBConn) RunInTx(ctx context.Context, fn func(bun.Tx) error) db.Error { - // Acquire a new transaction - tx, err := conn.BeginTx(ctx, nil) - if err != nil { - return conn.ProcessError(err) - } + return conn.ProcessError(func() error { + // Acquire a new transaction + tx, err := conn.BeginTx(ctx, nil) + if err != nil { + return err + } - // Perform supplied transaction - if err = fn(tx); err != nil { - tx.Rollback() //nolint - return conn.ProcessError(err) - } + var done bool + + defer func() { + if !done { + _ = tx.Rollback() + } + }() + + // Perform supplied transaction + if err := fn(tx); err != nil { + return err + } - // Finally, commit transaction - err = tx.Commit() - return conn.ProcessError(err) + // Finally, commit + err = tx.Commit() + done = true + return err + }()) } // ProcessError processes an error to replace any known values with our own db.Error types, @@ -83,7 +91,6 @@ func (conn *DBConn) Exists(ctx context.Context, query *bun.SelectQuery) (bool, d // NotExists is the functional opposite of conn.Exists() func (conn *DBConn) NotExists(ctx context.Context, query *bun.SelectQuery) (bool, db.Error) { - // Simply inverse of conn.exists() exists, err := conn.Exists(ctx, query) return !exists, err } |