diff options
Diffstat (limited to 'vendor/github.com/uptrace/bun/db.go')
-rw-r--r-- | vendor/github.com/uptrace/bun/db.go | 62 |
1 files changed, 40 insertions, 22 deletions
diff --git a/vendor/github.com/uptrace/bun/db.go b/vendor/github.com/uptrace/bun/db.go index a83b07d35..78969c019 100644 --- a/vendor/github.com/uptrace/bun/db.go +++ b/vendor/github.com/uptrace/bun/db.go @@ -32,6 +32,7 @@ func WithDiscardUnknownColumns() DBOption { type DB struct { *sql.DB + dialect schema.Dialect features feature.Feature @@ -125,7 +126,7 @@ func (db *DB) NewDropColumn() *DropColumnQuery { func (db *DB) ResetModel(ctx context.Context, models ...interface{}) error { for _, model := range models { - if _, err := db.NewDropTable().Model(model).IfExists().Exec(ctx); err != nil { + if _, err := db.NewDropTable().Model(model).IfExists().Cascade().Exec(ctx); err != nil { return err } if _, err := db.NewCreateTable().Model(model).Exec(ctx); err != nil { @@ -226,8 +227,9 @@ func (db *DB) Exec(query string, args ...interface{}) (sql.Result, error) { func (db *DB) ExecContext( ctx context.Context, query string, args ...interface{}, ) (sql.Result, error) { - ctx, event := db.beforeQuery(ctx, nil, query, args, nil) - res, err := db.DB.ExecContext(ctx, db.format(query, args)) + formattedQuery := db.format(query, args) + ctx, event := db.beforeQuery(ctx, nil, query, args, formattedQuery, nil) + res, err := db.DB.ExecContext(ctx, formattedQuery) db.afterQuery(ctx, event, res, err) return res, err } @@ -239,8 +241,9 @@ func (db *DB) Query(query string, args ...interface{}) (*sql.Rows, error) { func (db *DB) QueryContext( ctx context.Context, query string, args ...interface{}, ) (*sql.Rows, error) { - ctx, event := db.beforeQuery(ctx, nil, query, args, nil) - rows, err := db.DB.QueryContext(ctx, db.format(query, args)) + formattedQuery := db.format(query, args) + ctx, event := db.beforeQuery(ctx, nil, query, args, formattedQuery, nil) + rows, err := db.DB.QueryContext(ctx, formattedQuery) db.afterQuery(ctx, event, nil, err) return rows, err } @@ -250,8 +253,9 @@ func (db *DB) QueryRow(query string, args ...interface{}) *sql.Row { } func (db *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { - ctx, event := db.beforeQuery(ctx, nil, query, args, nil) - row := db.DB.QueryRowContext(ctx, db.format(query, args)) + formattedQuery := db.format(query, args) + ctx, event := db.beforeQuery(ctx, nil, query, args, formattedQuery, nil) + row := db.DB.QueryRowContext(ctx, formattedQuery) db.afterQuery(ctx, event, nil, row.Err()) return row } @@ -281,8 +285,9 @@ func (db *DB) Conn(ctx context.Context) (Conn, error) { func (c Conn) ExecContext( ctx context.Context, query string, args ...interface{}, ) (sql.Result, error) { - ctx, event := c.db.beforeQuery(ctx, nil, query, args, nil) - res, err := c.Conn.ExecContext(ctx, c.db.format(query, args)) + formattedQuery := c.db.format(query, args) + ctx, event := c.db.beforeQuery(ctx, nil, query, args, formattedQuery, nil) + res, err := c.Conn.ExecContext(ctx, formattedQuery) c.db.afterQuery(ctx, event, res, err) return res, err } @@ -290,19 +295,25 @@ func (c Conn) ExecContext( func (c Conn) QueryContext( ctx context.Context, query string, args ...interface{}, ) (*sql.Rows, error) { - ctx, event := c.db.beforeQuery(ctx, nil, query, args, nil) - rows, err := c.Conn.QueryContext(ctx, c.db.format(query, args)) + formattedQuery := c.db.format(query, args) + ctx, event := c.db.beforeQuery(ctx, nil, query, args, formattedQuery, nil) + rows, err := c.Conn.QueryContext(ctx, formattedQuery) c.db.afterQuery(ctx, event, nil, err) return rows, err } func (c Conn) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { - ctx, event := c.db.beforeQuery(ctx, nil, query, args, nil) - row := c.Conn.QueryRowContext(ctx, c.db.format(query, args)) + formattedQuery := c.db.format(query, args) + ctx, event := c.db.beforeQuery(ctx, nil, query, args, formattedQuery, nil) + row := c.Conn.QueryRowContext(ctx, formattedQuery) c.db.afterQuery(ctx, event, nil, row.Err()) return row } +func (c Conn) Dialect() schema.Dialect { + return c.db.Dialect() +} + func (c Conn) NewValues(model interface{}) *ValuesQuery { return NewValuesQuery(c.db, model).Conn(c) } @@ -408,7 +419,7 @@ func (db *DB) Begin() (Tx, error) { } func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error) { - ctx, event := db.beforeQuery(ctx, nil, "BEGIN", nil, nil) + ctx, event := db.beforeQuery(ctx, nil, "BEGIN", nil, "BEGIN", nil) tx, err := db.DB.BeginTx(ctx, opts) db.afterQuery(ctx, event, nil, err) if err != nil { @@ -422,14 +433,14 @@ func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error) { } func (tx Tx) Commit() error { - ctx, event := tx.db.beforeQuery(tx.ctx, nil, "COMMIT", nil, nil) + ctx, event := tx.db.beforeQuery(tx.ctx, nil, "COMMIT", nil, "COMMIT", nil) err := tx.Tx.Commit() tx.db.afterQuery(ctx, event, nil, err) return err } func (tx Tx) Rollback() error { - ctx, event := tx.db.beforeQuery(tx.ctx, nil, "ROLLBACK", nil, nil) + ctx, event := tx.db.beforeQuery(tx.ctx, nil, "ROLLBACK", nil, "ROLLBACK", nil) err := tx.Tx.Rollback() tx.db.afterQuery(ctx, event, nil, err) return err @@ -442,8 +453,9 @@ func (tx Tx) Exec(query string, args ...interface{}) (sql.Result, error) { func (tx Tx) ExecContext( ctx context.Context, query string, args ...interface{}, ) (sql.Result, error) { - ctx, event := tx.db.beforeQuery(ctx, nil, query, args, nil) - res, err := tx.Tx.ExecContext(ctx, tx.db.format(query, args)) + formattedQuery := tx.db.format(query, args) + ctx, event := tx.db.beforeQuery(ctx, nil, query, args, formattedQuery, nil) + res, err := tx.Tx.ExecContext(ctx, formattedQuery) tx.db.afterQuery(ctx, event, res, err) return res, err } @@ -455,8 +467,9 @@ func (tx Tx) Query(query string, args ...interface{}) (*sql.Rows, error) { func (tx Tx) QueryContext( ctx context.Context, query string, args ...interface{}, ) (*sql.Rows, error) { - ctx, event := tx.db.beforeQuery(ctx, nil, query, args, nil) - rows, err := tx.Tx.QueryContext(ctx, tx.db.format(query, args)) + formattedQuery := tx.db.format(query, args) + ctx, event := tx.db.beforeQuery(ctx, nil, query, args, formattedQuery, nil) + rows, err := tx.Tx.QueryContext(ctx, formattedQuery) tx.db.afterQuery(ctx, event, nil, err) return rows, err } @@ -466,14 +479,19 @@ func (tx Tx) QueryRow(query string, args ...interface{}) *sql.Row { } func (tx Tx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { - ctx, event := tx.db.beforeQuery(ctx, nil, query, args, nil) - row := tx.Tx.QueryRowContext(ctx, tx.db.format(query, args)) + formattedQuery := tx.db.format(query, args) + ctx, event := tx.db.beforeQuery(ctx, nil, query, args, formattedQuery, nil) + row := tx.Tx.QueryRowContext(ctx, formattedQuery) tx.db.afterQuery(ctx, event, nil, row.Err()) return row } //------------------------------------------------------------------------------ +func (tx Tx) Dialect() schema.Dialect { + return tx.db.Dialect() +} + func (tx Tx) NewValues(model interface{}) *ValuesQuery { return NewValuesQuery(tx.db, model).Conn(tx) } |