diff options
Diffstat (limited to 'vendor/github.com/uptrace/bun/db.go')
-rw-r--r-- | vendor/github.com/uptrace/bun/db.go | 35 |
1 files changed, 31 insertions, 4 deletions
diff --git a/vendor/github.com/uptrace/bun/db.go b/vendor/github.com/uptrace/bun/db.go index 72fe118a9..d42669493 100644 --- a/vendor/github.com/uptrace/bun/db.go +++ b/vendor/github.com/uptrace/bun/db.go @@ -356,7 +356,8 @@ func (db *DB) PrepareContext(ctx context.Context, query string) (Stmt, error) { //------------------------------------------------------------------------------ type Tx struct { - db *DB + ctx context.Context + db *DB *sql.Tx } @@ -369,11 +370,20 @@ func (db *DB) RunInTx( if err != nil { return err } - defer tx.Rollback() //nolint:errcheck + + var done bool + + defer func() { + if !done { + _ = tx.Rollback() + } + }() if err := fn(ctx, tx); err != nil { return err } + + done = true return tx.Commit() } @@ -382,16 +392,33 @@ func (db *DB) Begin() (Tx, error) { } func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error) { + ctx, event := db.beforeQuery(ctx, nil, "BEGIN", nil, nil) tx, err := db.DB.BeginTx(ctx, opts) + db.afterQuery(ctx, event, nil, err) if err != nil { return Tx{}, err } return Tx{ - db: db, - Tx: tx, + ctx: ctx, + db: db, + Tx: tx, }, nil } +func (tx Tx) Commit() error { + ctx, event := tx.db.beforeQuery(tx.ctx, nil, "COMMIT", nil, nil) + err := tx.Tx.Commit() + tx.db.afterQuery(ctx, event, nil, err) + return err +} + +func (tx Tx) Rollback() error { + ctx, event := tx.db.beforeQuery(tx.ctx, nil, "ROLLBACK", nil, nil) + err := tx.Tx.Rollback() + tx.db.afterQuery(ctx, event, nil, err) + return err +} + func (tx Tx) Exec(query string, args ...interface{}) (sql.Result, error) { return tx.ExecContext(context.TODO(), query, args...) } |