summaryrefslogtreecommitdiff
path: root/vendor/github.com/uptrace/bun/query_base.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/uptrace/bun/query_base.go')
-rw-r--r--vendor/github.com/uptrace/bun/query_base.go95
1 files changed, 41 insertions, 54 deletions
diff --git a/vendor/github.com/uptrace/bun/query_base.go b/vendor/github.com/uptrace/bun/query_base.go
index 4cf31d04e..8b78d25e1 100644
--- a/vendor/github.com/uptrace/bun/query_base.go
+++ b/vendor/github.com/uptrace/bun/query_base.go
@@ -43,6 +43,7 @@ var (
// IDB is a common interface for *bun.DB, bun.Conn, and bun.Tx.
type IDB interface {
IConn
+ Dialect() schema.Dialect
NewValues(model interface{}) *ValuesQuery
NewSelect() *SelectQuery
@@ -59,9 +60,9 @@ type IDB interface {
}
var (
- _ IConn = (*DB)(nil)
- _ IConn = (*Conn)(nil)
- _ IConn = (*Tx)(nil)
+ _ IDB = (*DB)(nil)
+ _ IDB = (*Conn)(nil)
+ _ IDB = (*Tx)(nil)
)
type baseQuery struct {
@@ -74,10 +75,10 @@ type baseQuery struct {
tableModel TableModel
table *schema.Table
- with []withQuery
- modelTable schema.QueryWithArgs
- tables []schema.QueryWithArgs
- columns []schema.QueryWithArgs
+ with []withQuery
+ modelTableName schema.QueryWithArgs
+ tables []schema.QueryWithArgs
+ columns []schema.QueryWithArgs
flags internal.Flag
}
@@ -86,13 +87,6 @@ func (q *baseQuery) DB() *DB {
return q.db
}
-type query interface {
- GetModel() Model
- GetTableName() string
-}
-
-var _ query = (*baseQuery)(nil)
-
func (q *baseQuery) GetModel() Model {
return q.model
}
@@ -103,15 +97,16 @@ func (q *baseQuery) GetTableName() string {
}
for _, wq := range q.with {
- if v, ok := wq.query.(query); ok {
+ if v, ok := wq.query.(Query); ok {
if model := v.GetModel(); model != nil {
return v.GetTableName()
}
}
}
- if q.modelTable.Query != "" {
- return q.modelTable.Query
+ if q.modelTableName.Query != "" {
+ b, _ := q.modelTableName.AppendQuery(q.db.fmter, nil)
+ return string(b)
}
if len(q.tables) > 0 {
return q.tables[0].Query
@@ -304,8 +299,8 @@ func (q *baseQuery) _excludeColumn(column string) bool {
//------------------------------------------------------------------------------
func (q *baseQuery) modelHasTableName() bool {
- if !q.modelTable.IsZero() {
- return q.modelTable.Query != ""
+ if !q.modelTableName.IsZero() {
+ return q.modelTableName.Query != ""
}
return q.table != nil
}
@@ -332,8 +327,8 @@ func (q *baseQuery) _appendTables(
startLen := len(b)
if q.modelHasTableName() {
- if !q.modelTable.IsZero() {
- b, err = q.modelTable.AppendQuery(fmter, b)
+ if !q.modelTableName.IsZero() {
+ b, err = q.modelTableName.AppendQuery(fmter, b)
if err != nil {
return nil, err
}
@@ -372,8 +367,8 @@ func (q *baseQuery) appendFirstTableWithAlias(
func (q *baseQuery) _appendFirstTable(
fmter schema.Formatter, b []byte, withAlias bool,
) ([]byte, error) {
- if !q.modelTable.IsZero() {
- return q.modelTable.AppendQuery(fmter, b)
+ if !q.modelTableName.IsZero() {
+ return q.modelTableName.AppendQuery(fmter, b)
}
if q.table != nil {
@@ -473,7 +468,7 @@ func (q *baseQuery) scan(
model Model,
hasDest bool,
) (sql.Result, error) {
- ctx, event := q.db.beforeQuery(ctx, iquery, query, nil, q.model)
+ ctx, event := q.db.beforeQuery(ctx, iquery, query, nil, query, q.model)
rows, err := q.conn.QueryContext(ctx, query)
if err != nil {
@@ -503,16 +498,10 @@ func (q *baseQuery) exec(
iquery Query,
query string,
) (sql.Result, error) {
- ctx, event := q.db.beforeQuery(ctx, iquery, query, nil, q.model)
-
+ ctx, event := q.db.beforeQuery(ctx, iquery, query, nil, query, q.model)
res, err := q.conn.ExecContext(ctx, query)
- if err != nil {
- q.db.afterQuery(ctx, event, nil, err)
- return res, err
- }
-
- q.db.afterQuery(ctx, event, res, err)
- return res, nil
+ q.db.afterQuery(ctx, event, nil, err)
+ return res, err
}
//------------------------------------------------------------------------------
@@ -607,34 +596,30 @@ func (q *whereBaseQuery) addWhereCols(cols []string) {
q.setErr(err)
return
}
-
- var fields []*schema.Field
+ if q.whereFields != nil {
+ err := errors.New("bun: WherePK can only be called once")
+ q.setErr(err)
+ return
+ }
if cols == nil {
if err := q.table.CheckPKs(); err != nil {
q.setErr(err)
return
}
- fields = q.table.PKs
- } else {
- fields = make([]*schema.Field, len(cols))
- for i, col := range cols {
- field, err := q.table.Field(col)
- if err != nil {
- q.setErr(err)
- return
- }
- fields[i] = field
- }
- }
-
- if q.whereFields != nil {
- err := errors.New("bun: WherePK can only be called once")
- q.setErr(err)
+ q.whereFields = q.table.PKs
return
}
- q.whereFields = fields
+ q.whereFields = make([]*schema.Field, len(cols))
+ for i, col := range cols {
+ field, err := q.table.Field(col)
+ if err != nil {
+ q.setErr(err)
+ return
+ }
+ q.whereFields[i] = field
+ }
}
func (q *whereBaseQuery) mustAppendWhere(
@@ -951,6 +936,7 @@ func (q setQuery) appendSet(fmter schema.Formatter, b []byte) (_ []byte, err err
//------------------------------------------------------------------------------
type cascadeQuery struct {
+ cascade bool
restrict bool
}
@@ -958,10 +944,11 @@ func (q cascadeQuery) appendCascade(fmter schema.Formatter, b []byte) []byte {
if !fmter.HasFeature(feature.TableCascade) {
return b
}
+ if q.cascade {
+ b = append(b, " CASCADE"...)
+ }
if q.restrict {
b = append(b, " RESTRICT"...)
- } else {
- b = append(b, " CASCADE"...)
}
return b
}