diff options
| author | 2021-11-13 12:29:08 +0100 | |
|---|---|---|
| committer | 2021-11-13 12:29:08 +0100 | |
| commit | 829a934d23ab221049b4d54926305d8d5d64c9ad (patch) | |
| tree | f4e382b289c113d3ba8a3c7a183507a5609c46c0 /vendor/github.com/uptrace/bun/db.go | |
| parent | smtp + email confirmation (#285) (diff) | |
| download | gotosocial-829a934d23ab221049b4d54926305d8d5d64c9ad.tar.xz | |
update dependencies (#296)
Diffstat (limited to 'vendor/github.com/uptrace/bun/db.go')
| -rw-r--r-- | vendor/github.com/uptrace/bun/db.go | 35 |
1 files changed, 31 insertions, 4 deletions
diff --git a/vendor/github.com/uptrace/bun/db.go b/vendor/github.com/uptrace/bun/db.go index 72fe118a9..d42669493 100644 --- a/vendor/github.com/uptrace/bun/db.go +++ b/vendor/github.com/uptrace/bun/db.go @@ -356,7 +356,8 @@ func (db *DB) PrepareContext(ctx context.Context, query string) (Stmt, error) { //------------------------------------------------------------------------------ type Tx struct { - db *DB + ctx context.Context + db *DB *sql.Tx } @@ -369,11 +370,20 @@ func (db *DB) RunInTx( if err != nil { return err } - defer tx.Rollback() //nolint:errcheck + + var done bool + + defer func() { + if !done { + _ = tx.Rollback() + } + }() if err := fn(ctx, tx); err != nil { return err } + + done = true return tx.Commit() } @@ -382,16 +392,33 @@ func (db *DB) Begin() (Tx, error) { } func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error) { + ctx, event := db.beforeQuery(ctx, nil, "BEGIN", nil, nil) tx, err := db.DB.BeginTx(ctx, opts) + db.afterQuery(ctx, event, nil, err) if err != nil { return Tx{}, err } return Tx{ - db: db, - Tx: tx, + ctx: ctx, + db: db, + Tx: tx, }, nil } +func (tx Tx) Commit() error { + ctx, event := tx.db.beforeQuery(tx.ctx, nil, "COMMIT", nil, nil) + err := tx.Tx.Commit() + tx.db.afterQuery(ctx, event, nil, err) + return err +} + +func (tx Tx) Rollback() error { + ctx, event := tx.db.beforeQuery(tx.ctx, nil, "ROLLBACK", nil, nil) + err := tx.Tx.Rollback() + tx.db.afterQuery(ctx, event, nil, err) + return err +} + func (tx Tx) Exec(query string, args ...interface{}) (sql.Result, error) { return tx.ExecContext(context.TODO(), query, args...) } |
