summaryrefslogtreecommitdiff
path: root/vendor/github.com/uptrace/bun/db.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/uptrace/bun/db.go')
-rw-r--r--vendor/github.com/uptrace/bun/db.go62
1 files changed, 40 insertions, 22 deletions
diff --git a/vendor/github.com/uptrace/bun/db.go b/vendor/github.com/uptrace/bun/db.go
index a83b07d35..78969c019 100644
--- a/vendor/github.com/uptrace/bun/db.go
+++ b/vendor/github.com/uptrace/bun/db.go
@@ -32,6 +32,7 @@ func WithDiscardUnknownColumns() DBOption {
type DB struct {
*sql.DB
+
dialect schema.Dialect
features feature.Feature
@@ -125,7 +126,7 @@ func (db *DB) NewDropColumn() *DropColumnQuery {
func (db *DB) ResetModel(ctx context.Context, models ...interface{}) error {
for _, model := range models {
- if _, err := db.NewDropTable().Model(model).IfExists().Exec(ctx); err != nil {
+ if _, err := db.NewDropTable().Model(model).IfExists().Cascade().Exec(ctx); err != nil {
return err
}
if _, err := db.NewCreateTable().Model(model).Exec(ctx); err != nil {
@@ -226,8 +227,9 @@ func (db *DB) Exec(query string, args ...interface{}) (sql.Result, error) {
func (db *DB) ExecContext(
ctx context.Context, query string, args ...interface{},
) (sql.Result, error) {
- ctx, event := db.beforeQuery(ctx, nil, query, args, nil)
- res, err := db.DB.ExecContext(ctx, db.format(query, args))
+ formattedQuery := db.format(query, args)
+ ctx, event := db.beforeQuery(ctx, nil, query, args, formattedQuery, nil)
+ res, err := db.DB.ExecContext(ctx, formattedQuery)
db.afterQuery(ctx, event, res, err)
return res, err
}
@@ -239,8 +241,9 @@ func (db *DB) Query(query string, args ...interface{}) (*sql.Rows, error) {
func (db *DB) QueryContext(
ctx context.Context, query string, args ...interface{},
) (*sql.Rows, error) {
- ctx, event := db.beforeQuery(ctx, nil, query, args, nil)
- rows, err := db.DB.QueryContext(ctx, db.format(query, args))
+ formattedQuery := db.format(query, args)
+ ctx, event := db.beforeQuery(ctx, nil, query, args, formattedQuery, nil)
+ rows, err := db.DB.QueryContext(ctx, formattedQuery)
db.afterQuery(ctx, event, nil, err)
return rows, err
}
@@ -250,8 +253,9 @@ func (db *DB) QueryRow(query string, args ...interface{}) *sql.Row {
}
func (db *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
- ctx, event := db.beforeQuery(ctx, nil, query, args, nil)
- row := db.DB.QueryRowContext(ctx, db.format(query, args))
+ formattedQuery := db.format(query, args)
+ ctx, event := db.beforeQuery(ctx, nil, query, args, formattedQuery, nil)
+ row := db.DB.QueryRowContext(ctx, formattedQuery)
db.afterQuery(ctx, event, nil, row.Err())
return row
}
@@ -281,8 +285,9 @@ func (db *DB) Conn(ctx context.Context) (Conn, error) {
func (c Conn) ExecContext(
ctx context.Context, query string, args ...interface{},
) (sql.Result, error) {
- ctx, event := c.db.beforeQuery(ctx, nil, query, args, nil)
- res, err := c.Conn.ExecContext(ctx, c.db.format(query, args))
+ formattedQuery := c.db.format(query, args)
+ ctx, event := c.db.beforeQuery(ctx, nil, query, args, formattedQuery, nil)
+ res, err := c.Conn.ExecContext(ctx, formattedQuery)
c.db.afterQuery(ctx, event, res, err)
return res, err
}
@@ -290,19 +295,25 @@ func (c Conn) ExecContext(
func (c Conn) QueryContext(
ctx context.Context, query string, args ...interface{},
) (*sql.Rows, error) {
- ctx, event := c.db.beforeQuery(ctx, nil, query, args, nil)
- rows, err := c.Conn.QueryContext(ctx, c.db.format(query, args))
+ formattedQuery := c.db.format(query, args)
+ ctx, event := c.db.beforeQuery(ctx, nil, query, args, formattedQuery, nil)
+ rows, err := c.Conn.QueryContext(ctx, formattedQuery)
c.db.afterQuery(ctx, event, nil, err)
return rows, err
}
func (c Conn) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
- ctx, event := c.db.beforeQuery(ctx, nil, query, args, nil)
- row := c.Conn.QueryRowContext(ctx, c.db.format(query, args))
+ formattedQuery := c.db.format(query, args)
+ ctx, event := c.db.beforeQuery(ctx, nil, query, args, formattedQuery, nil)
+ row := c.Conn.QueryRowContext(ctx, formattedQuery)
c.db.afterQuery(ctx, event, nil, row.Err())
return row
}
+func (c Conn) Dialect() schema.Dialect {
+ return c.db.Dialect()
+}
+
func (c Conn) NewValues(model interface{}) *ValuesQuery {
return NewValuesQuery(c.db, model).Conn(c)
}
@@ -408,7 +419,7 @@ func (db *DB) Begin() (Tx, error) {
}
func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error) {
- ctx, event := db.beforeQuery(ctx, nil, "BEGIN", nil, nil)
+ ctx, event := db.beforeQuery(ctx, nil, "BEGIN", nil, "BEGIN", nil)
tx, err := db.DB.BeginTx(ctx, opts)
db.afterQuery(ctx, event, nil, err)
if err != nil {
@@ -422,14 +433,14 @@ func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error) {
}
func (tx Tx) Commit() error {
- ctx, event := tx.db.beforeQuery(tx.ctx, nil, "COMMIT", nil, nil)
+ ctx, event := tx.db.beforeQuery(tx.ctx, nil, "COMMIT", nil, "COMMIT", nil)
err := tx.Tx.Commit()
tx.db.afterQuery(ctx, event, nil, err)
return err
}
func (tx Tx) Rollback() error {
- ctx, event := tx.db.beforeQuery(tx.ctx, nil, "ROLLBACK", nil, nil)
+ ctx, event := tx.db.beforeQuery(tx.ctx, nil, "ROLLBACK", nil, "ROLLBACK", nil)
err := tx.Tx.Rollback()
tx.db.afterQuery(ctx, event, nil, err)
return err
@@ -442,8 +453,9 @@ func (tx Tx) Exec(query string, args ...interface{}) (sql.Result, error) {
func (tx Tx) ExecContext(
ctx context.Context, query string, args ...interface{},
) (sql.Result, error) {
- ctx, event := tx.db.beforeQuery(ctx, nil, query, args, nil)
- res, err := tx.Tx.ExecContext(ctx, tx.db.format(query, args))
+ formattedQuery := tx.db.format(query, args)
+ ctx, event := tx.db.beforeQuery(ctx, nil, query, args, formattedQuery, nil)
+ res, err := tx.Tx.ExecContext(ctx, formattedQuery)
tx.db.afterQuery(ctx, event, res, err)
return res, err
}
@@ -455,8 +467,9 @@ func (tx Tx) Query(query string, args ...interface{}) (*sql.Rows, error) {
func (tx Tx) QueryContext(
ctx context.Context, query string, args ...interface{},
) (*sql.Rows, error) {
- ctx, event := tx.db.beforeQuery(ctx, nil, query, args, nil)
- rows, err := tx.Tx.QueryContext(ctx, tx.db.format(query, args))
+ formattedQuery := tx.db.format(query, args)
+ ctx, event := tx.db.beforeQuery(ctx, nil, query, args, formattedQuery, nil)
+ rows, err := tx.Tx.QueryContext(ctx, formattedQuery)
tx.db.afterQuery(ctx, event, nil, err)
return rows, err
}
@@ -466,14 +479,19 @@ func (tx Tx) QueryRow(query string, args ...interface{}) *sql.Row {
}
func (tx Tx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
- ctx, event := tx.db.beforeQuery(ctx, nil, query, args, nil)
- row := tx.Tx.QueryRowContext(ctx, tx.db.format(query, args))
+ formattedQuery := tx.db.format(query, args)
+ ctx, event := tx.db.beforeQuery(ctx, nil, query, args, formattedQuery, nil)
+ row := tx.Tx.QueryRowContext(ctx, formattedQuery)
tx.db.afterQuery(ctx, event, nil, row.Err())
return row
}
//------------------------------------------------------------------------------
+func (tx Tx) Dialect() schema.Dialect {
+ return tx.db.Dialect()
+}
+
func (tx Tx) NewValues(model interface{}) *ValuesQuery {
return NewValuesQuery(tx.db, model).Conn(tx)
}