summaryrefslogtreecommitdiff
path: root/vendor/github.com/uptrace/bun/db.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/uptrace/bun/db.go')
-rw-r--r--vendor/github.com/uptrace/bun/db.go138
1 files changed, 137 insertions, 1 deletions
diff --git a/vendor/github.com/uptrace/bun/db.go b/vendor/github.com/uptrace/bun/db.go
index 78969c019..47e654655 100644
--- a/vendor/github.com/uptrace/bun/db.go
+++ b/vendor/github.com/uptrace/bun/db.go
@@ -2,7 +2,9 @@ package bun
import (
"context"
+ "crypto/rand"
"database/sql"
+ "encoding/hex"
"fmt"
"reflect"
"strings"
@@ -141,13 +143,19 @@ func (db *DB) Dialect() schema.Dialect {
}
func (db *DB) ScanRows(ctx context.Context, rows *sql.Rows, dest ...interface{}) error {
+ defer rows.Close()
+
model, err := newModel(db, dest)
if err != nil {
return err
}
_, err = model.ScanRows(ctx, rows)
- return err
+ if err != nil {
+ return err
+ }
+
+ return rows.Err()
}
func (db *DB) ScanRow(ctx context.Context, rows *sql.Rows, dest ...interface{}) error {
@@ -362,6 +370,46 @@ func (c Conn) NewDropColumn() *DropColumnQuery {
return NewDropColumnQuery(c.db).Conn(c)
}
+// RunInTx runs the function in a transaction. If the function returns an error,
+// the transaction is rolled back. Otherwise, the transaction is committed.
+func (c Conn) RunInTx(
+ ctx context.Context, opts *sql.TxOptions, fn func(ctx context.Context, tx Tx) error,
+) error {
+ tx, err := c.BeginTx(ctx, opts)
+ if err != nil {
+ return err
+ }
+
+ var done bool
+
+ defer func() {
+ if !done {
+ _ = tx.Rollback()
+ }
+ }()
+
+ if err := fn(ctx, tx); err != nil {
+ return err
+ }
+
+ done = true
+ return tx.Commit()
+}
+
+func (c Conn) BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error) {
+ ctx, event := c.db.beforeQuery(ctx, nil, "BEGIN", nil, "BEGIN", nil)
+ tx, err := c.Conn.BeginTx(ctx, opts)
+ c.db.afterQuery(ctx, event, nil, err)
+ if err != nil {
+ return Tx{}, err
+ }
+ return Tx{
+ ctx: ctx,
+ db: c.db,
+ Tx: tx,
+ }, nil
+}
+
//------------------------------------------------------------------------------
type Stmt struct {
@@ -385,6 +433,8 @@ func (db *DB) PrepareContext(ctx context.Context, query string) (Stmt, error) {
type Tx struct {
ctx context.Context
db *DB
+ // name is the name of a savepoint
+ name string
*sql.Tx
}
@@ -433,19 +483,51 @@ func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error) {
}
func (tx Tx) Commit() error {
+ if tx.name == "" {
+ return tx.commitTX()
+ }
+ return tx.commitSP()
+}
+
+func (tx Tx) commitTX() error {
ctx, event := tx.db.beforeQuery(tx.ctx, nil, "COMMIT", nil, "COMMIT", nil)
err := tx.Tx.Commit()
tx.db.afterQuery(ctx, event, nil, err)
return err
}
+func (tx Tx) commitSP() error {
+ if tx.Dialect().Features().Has(feature.MSSavepoint) {
+ return nil
+ }
+ query := "RELEASE SAVEPOINT " + tx.name
+ _, err := tx.ExecContext(tx.ctx, query)
+ return err
+}
+
func (tx Tx) Rollback() error {
+ if tx.name == "" {
+ return tx.rollbackTX()
+ }
+ return tx.rollbackSP()
+}
+
+func (tx Tx) rollbackTX() error {
ctx, event := tx.db.beforeQuery(tx.ctx, nil, "ROLLBACK", nil, "ROLLBACK", nil)
err := tx.Tx.Rollback()
tx.db.afterQuery(ctx, event, nil, err)
return err
}
+func (tx Tx) rollbackSP() error {
+ query := "ROLLBACK TO SAVEPOINT " + tx.name
+ if tx.Dialect().Features().Has(feature.MSSavepoint) {
+ query = "ROLLBACK TRANSACTION " + tx.name
+ }
+ _, err := tx.ExecContext(tx.ctx, query)
+ return err
+}
+
func (tx Tx) Exec(query string, args ...interface{}) (sql.Result, error) {
return tx.ExecContext(context.TODO(), query, args...)
}
@@ -488,6 +570,60 @@ func (tx Tx) QueryRowContext(ctx context.Context, query string, args ...interfac
//------------------------------------------------------------------------------
+func (tx Tx) Begin() (Tx, error) {
+ return tx.BeginTx(tx.ctx, nil)
+}
+
+// BeginTx will save a point in the running transaction.
+func (tx Tx) BeginTx(ctx context.Context, _ *sql.TxOptions) (Tx, error) {
+ // mssql savepoint names are limited to 32 characters
+ sp := make([]byte, 14)
+ _, err := rand.Read(sp)
+ if err != nil {
+ return Tx{}, err
+ }
+
+ qName := "SP_" + hex.EncodeToString(sp)
+ query := "SAVEPOINT " + qName
+ if tx.Dialect().Features().Has(feature.MSSavepoint) {
+ query = "SAVE TRANSACTION " + qName
+ }
+ _, err = tx.ExecContext(ctx, query)
+ if err != nil {
+ return Tx{}, err
+ }
+ return Tx{
+ ctx: ctx,
+ db: tx.db,
+ Tx: tx.Tx,
+ name: qName,
+ }, nil
+}
+
+func (tx Tx) RunInTx(
+ ctx context.Context, _ *sql.TxOptions, fn func(ctx context.Context, tx Tx) error,
+) error {
+ sp, err := tx.BeginTx(ctx, nil)
+ if err != nil {
+ return err
+ }
+
+ var done bool
+
+ defer func() {
+ if !done {
+ _ = sp.Rollback()
+ }
+ }()
+
+ if err := fn(ctx, sp); err != nil {
+ return err
+ }
+
+ done = true
+ return sp.Commit()
+}
+
func (tx Tx) Dialect() schema.Dialect {
return tx.db.Dialect()
}