summaryrefslogtreecommitdiff
path: root/vendor/github.com/uptrace/bun/db.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/uptrace/bun/db.go')
-rw-r--r--vendor/github.com/uptrace/bun/db.go35
1 files changed, 31 insertions, 4 deletions
diff --git a/vendor/github.com/uptrace/bun/db.go b/vendor/github.com/uptrace/bun/db.go
index 72fe118a9..d42669493 100644
--- a/vendor/github.com/uptrace/bun/db.go
+++ b/vendor/github.com/uptrace/bun/db.go
@@ -356,7 +356,8 @@ func (db *DB) PrepareContext(ctx context.Context, query string) (Stmt, error) {
//------------------------------------------------------------------------------
type Tx struct {
- db *DB
+ ctx context.Context
+ db *DB
*sql.Tx
}
@@ -369,11 +370,20 @@ func (db *DB) RunInTx(
if err != nil {
return err
}
- defer tx.Rollback() //nolint:errcheck
+
+ var done bool
+
+ defer func() {
+ if !done {
+ _ = tx.Rollback()
+ }
+ }()
if err := fn(ctx, tx); err != nil {
return err
}
+
+ done = true
return tx.Commit()
}
@@ -382,16 +392,33 @@ func (db *DB) Begin() (Tx, error) {
}
func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error) {
+ ctx, event := db.beforeQuery(ctx, nil, "BEGIN", nil, nil)
tx, err := db.DB.BeginTx(ctx, opts)
+ db.afterQuery(ctx, event, nil, err)
if err != nil {
return Tx{}, err
}
return Tx{
- db: db,
- Tx: tx,
+ ctx: ctx,
+ db: db,
+ Tx: tx,
}, nil
}
+func (tx Tx) Commit() error {
+ ctx, event := tx.db.beforeQuery(tx.ctx, nil, "COMMIT", nil, nil)
+ err := tx.Tx.Commit()
+ tx.db.afterQuery(ctx, event, nil, err)
+ return err
+}
+
+func (tx Tx) Rollback() error {
+ ctx, event := tx.db.beforeQuery(tx.ctx, nil, "ROLLBACK", nil, nil)
+ err := tx.Tx.Rollback()
+ tx.db.afterQuery(ctx, event, nil, err)
+ return err
+}
+
func (tx Tx) Exec(query string, args ...interface{}) (sql.Result, error) {
return tx.ExecContext(context.TODO(), query, args...)
}