diff options
Diffstat (limited to 'vendor/github.com/uptrace/bun/db.go')
-rw-r--r-- | vendor/github.com/uptrace/bun/db.go | 138 |
1 files changed, 137 insertions, 1 deletions
diff --git a/vendor/github.com/uptrace/bun/db.go b/vendor/github.com/uptrace/bun/db.go index 78969c019..47e654655 100644 --- a/vendor/github.com/uptrace/bun/db.go +++ b/vendor/github.com/uptrace/bun/db.go @@ -2,7 +2,9 @@ package bun import ( "context" + "crypto/rand" "database/sql" + "encoding/hex" "fmt" "reflect" "strings" @@ -141,13 +143,19 @@ func (db *DB) Dialect() schema.Dialect { } func (db *DB) ScanRows(ctx context.Context, rows *sql.Rows, dest ...interface{}) error { + defer rows.Close() + model, err := newModel(db, dest) if err != nil { return err } _, err = model.ScanRows(ctx, rows) - return err + if err != nil { + return err + } + + return rows.Err() } func (db *DB) ScanRow(ctx context.Context, rows *sql.Rows, dest ...interface{}) error { @@ -362,6 +370,46 @@ func (c Conn) NewDropColumn() *DropColumnQuery { return NewDropColumnQuery(c.db).Conn(c) } +// RunInTx runs the function in a transaction. If the function returns an error, +// the transaction is rolled back. Otherwise, the transaction is committed. +func (c Conn) RunInTx( + ctx context.Context, opts *sql.TxOptions, fn func(ctx context.Context, tx Tx) error, +) error { + tx, err := c.BeginTx(ctx, opts) + if err != nil { + return err + } + + var done bool + + defer func() { + if !done { + _ = tx.Rollback() + } + }() + + if err := fn(ctx, tx); err != nil { + return err + } + + done = true + return tx.Commit() +} + +func (c Conn) BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error) { + ctx, event := c.db.beforeQuery(ctx, nil, "BEGIN", nil, "BEGIN", nil) + tx, err := c.Conn.BeginTx(ctx, opts) + c.db.afterQuery(ctx, event, nil, err) + if err != nil { + return Tx{}, err + } + return Tx{ + ctx: ctx, + db: c.db, + Tx: tx, + }, nil +} + //------------------------------------------------------------------------------ type Stmt struct { @@ -385,6 +433,8 @@ func (db *DB) PrepareContext(ctx context.Context, query string) (Stmt, error) { type Tx struct { ctx context.Context db *DB + // name is the name of a savepoint + name string *sql.Tx } @@ -433,19 +483,51 @@ func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error) { } func (tx Tx) Commit() error { + if tx.name == "" { + return tx.commitTX() + } + return tx.commitSP() +} + +func (tx Tx) commitTX() error { ctx, event := tx.db.beforeQuery(tx.ctx, nil, "COMMIT", nil, "COMMIT", nil) err := tx.Tx.Commit() tx.db.afterQuery(ctx, event, nil, err) return err } +func (tx Tx) commitSP() error { + if tx.Dialect().Features().Has(feature.MSSavepoint) { + return nil + } + query := "RELEASE SAVEPOINT " + tx.name + _, err := tx.ExecContext(tx.ctx, query) + return err +} + func (tx Tx) Rollback() error { + if tx.name == "" { + return tx.rollbackTX() + } + return tx.rollbackSP() +} + +func (tx Tx) rollbackTX() error { ctx, event := tx.db.beforeQuery(tx.ctx, nil, "ROLLBACK", nil, "ROLLBACK", nil) err := tx.Tx.Rollback() tx.db.afterQuery(ctx, event, nil, err) return err } +func (tx Tx) rollbackSP() error { + query := "ROLLBACK TO SAVEPOINT " + tx.name + if tx.Dialect().Features().Has(feature.MSSavepoint) { + query = "ROLLBACK TRANSACTION " + tx.name + } + _, err := tx.ExecContext(tx.ctx, query) + return err +} + func (tx Tx) Exec(query string, args ...interface{}) (sql.Result, error) { return tx.ExecContext(context.TODO(), query, args...) } @@ -488,6 +570,60 @@ func (tx Tx) QueryRowContext(ctx context.Context, query string, args ...interfac //------------------------------------------------------------------------------ +func (tx Tx) Begin() (Tx, error) { + return tx.BeginTx(tx.ctx, nil) +} + +// BeginTx will save a point in the running transaction. +func (tx Tx) BeginTx(ctx context.Context, _ *sql.TxOptions) (Tx, error) { + // mssql savepoint names are limited to 32 characters + sp := make([]byte, 14) + _, err := rand.Read(sp) + if err != nil { + return Tx{}, err + } + + qName := "SP_" + hex.EncodeToString(sp) + query := "SAVEPOINT " + qName + if tx.Dialect().Features().Has(feature.MSSavepoint) { + query = "SAVE TRANSACTION " + qName + } + _, err = tx.ExecContext(ctx, query) + if err != nil { + return Tx{}, err + } + return Tx{ + ctx: ctx, + db: tx.db, + Tx: tx.Tx, + name: qName, + }, nil +} + +func (tx Tx) RunInTx( + ctx context.Context, _ *sql.TxOptions, fn func(ctx context.Context, tx Tx) error, +) error { + sp, err := tx.BeginTx(ctx, nil) + if err != nil { + return err + } + + var done bool + + defer func() { + if !done { + _ = sp.Rollback() + } + }() + + if err := fn(ctx, sp); err != nil { + return err + } + + done = true + return sp.Commit() +} + func (tx Tx) Dialect() schema.Dialect { return tx.db.Dialect() } |