summaryrefslogtreecommitdiff
path: root/vendor/github.com/uptrace/bun/query_base.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/uptrace/bun/query_base.go')
-rw-r--r--vendor/github.com/uptrace/bun/query_base.go99
1 files changed, 70 insertions, 29 deletions
diff --git a/vendor/github.com/uptrace/bun/query_base.go b/vendor/github.com/uptrace/bun/query_base.go
index 9739d93e2..4cf31d04e 100644
--- a/vendor/github.com/uptrace/bun/query_base.go
+++ b/vendor/github.com/uptrace/bun/query_base.go
@@ -14,8 +14,7 @@ import (
)
const (
- wherePKFlag internal.Flag = 1 << iota
- forceDeleteFlag
+ forceDeleteFlag internal.Flag = 1 << iota
deletedFlag
allWithDeletedFlag
)
@@ -580,7 +579,8 @@ func formatterWithModel(fmter schema.Formatter, model schema.NamedArgAppender) s
type whereBaseQuery struct {
baseQuery
- where []schema.QueryWithSep
+ where []schema.QueryWithSep
+ whereFields []*schema.Field
}
func (q *whereBaseQuery) addWhere(where schema.QueryWithSep) {
@@ -601,10 +601,46 @@ func (q *whereBaseQuery) addWhereGroup(sep string, where []schema.QueryWithSep)
q.addWhere(schema.SafeQueryWithSep("", nil, ")"))
}
+func (q *whereBaseQuery) addWhereCols(cols []string) {
+ if q.table == nil {
+ err := fmt.Errorf("bun: got %T, but WherePK requires a struct or slice-based model", q.model)
+ q.setErr(err)
+ return
+ }
+
+ var fields []*schema.Field
+
+ if cols == nil {
+ if err := q.table.CheckPKs(); err != nil {
+ q.setErr(err)
+ return
+ }
+ fields = q.table.PKs
+ } else {
+ fields = make([]*schema.Field, len(cols))
+ for i, col := range cols {
+ field, err := q.table.Field(col)
+ if err != nil {
+ q.setErr(err)
+ return
+ }
+ fields[i] = field
+ }
+ }
+
+ if q.whereFields != nil {
+ err := errors.New("bun: WherePK can only be called once")
+ q.setErr(err)
+ return
+ }
+
+ q.whereFields = fields
+}
+
func (q *whereBaseQuery) mustAppendWhere(
fmter schema.Formatter, b []byte, withAlias bool,
) ([]byte, error) {
- if len(q.where) == 0 && !q.flags.Has(wherePKFlag) {
+ if len(q.where) == 0 && q.whereFields == nil {
err := errors.New("bun: Update and Delete queries require at least one Where")
return nil, err
}
@@ -614,7 +650,7 @@ func (q *whereBaseQuery) mustAppendWhere(
func (q *whereBaseQuery) appendWhere(
fmter schema.Formatter, b []byte, withAlias bool,
) (_ []byte, err error) {
- if len(q.where) == 0 && !q.isSoftDelete() && !q.flags.Has(wherePKFlag) {
+ if len(q.where) == 0 && q.whereFields == nil && !q.isSoftDelete() {
return b, nil
}
@@ -656,11 +692,11 @@ func (q *whereBaseQuery) appendWhere(
}
}
- if q.flags.Has(wherePKFlag) {
+ if q.whereFields != nil {
if len(b) > startLen {
b = append(b, " AND "...)
}
- b, err = q.appendWherePK(fmter, b, withAlias)
+ b, err = q.appendWhereFields(fmter, b, q.whereFields, withAlias)
if err != nil {
return nil, err
}
@@ -691,29 +727,30 @@ func appendWhere(
return b, nil
}
-func (q *whereBaseQuery) appendWherePK(
- fmter schema.Formatter, b []byte, withAlias bool,
+func (q *whereBaseQuery) appendWhereFields(
+ fmter schema.Formatter, b []byte, fields []*schema.Field, withAlias bool,
) (_ []byte, err error) {
if q.table == nil {
- err := fmt.Errorf("bun: got %T, but WherePK requires a struct or slice-based model", q.model)
- return nil, err
- }
- if err := q.table.CheckPKs(); err != nil {
+ err := fmt.Errorf("bun: got %T, but WherePK requires struct or slice-based model", q.model)
return nil, err
}
switch model := q.tableModel.(type) {
case *structTableModel:
- return q.appendWherePKStruct(fmter, b, model, withAlias)
+ return q.appendWhereStructFields(fmter, b, model, fields, withAlias)
case *sliceTableModel:
- return q.appendWherePKSlice(fmter, b, model, withAlias)
+ return q.appendWhereSliceFields(fmter, b, model, fields, withAlias)
+ default:
+ return nil, fmt.Errorf("bun: WhereColumn does not support %T", q.tableModel)
}
-
- return nil, fmt.Errorf("bun: WherePK does not support %T", q.tableModel)
}
-func (q *whereBaseQuery) appendWherePKStruct(
- fmter schema.Formatter, b []byte, model *structTableModel, withAlias bool,
+func (q *whereBaseQuery) appendWhereStructFields(
+ fmter schema.Formatter,
+ b []byte,
+ model *structTableModel,
+ fields []*schema.Field,
+ withAlias bool,
) (_ []byte, err error) {
if !model.strct.IsValid() {
return nil, errNilModel
@@ -721,7 +758,7 @@ func (q *whereBaseQuery) appendWherePKStruct(
isTemplate := fmter.IsNop()
b = append(b, '(')
- for i, f := range q.table.PKs {
+ for i, f := range fields {
if i > 0 {
b = append(b, " AND "...)
}
@@ -741,18 +778,22 @@ func (q *whereBaseQuery) appendWherePKStruct(
return b, nil
}
-func (q *whereBaseQuery) appendWherePKSlice(
- fmter schema.Formatter, b []byte, model *sliceTableModel, withAlias bool,
+func (q *whereBaseQuery) appendWhereSliceFields(
+ fmter schema.Formatter,
+ b []byte,
+ model *sliceTableModel,
+ fields []*schema.Field,
+ withAlias bool,
) (_ []byte, err error) {
- if len(q.table.PKs) > 1 {
+ if len(fields) > 1 {
b = append(b, '(')
}
if withAlias {
- b = appendColumns(b, q.table.SQLAlias, q.table.PKs)
+ b = appendColumns(b, q.table.SQLAlias, fields)
} else {
- b = appendColumns(b, "", q.table.PKs)
+ b = appendColumns(b, "", fields)
}
- if len(q.table.PKs) > 1 {
+ if len(fields) > 1 {
b = append(b, ')')
}
@@ -771,10 +812,10 @@ func (q *whereBaseQuery) appendWherePKSlice(
el := indirect(slice.Index(i))
- if len(q.table.PKs) > 1 {
+ if len(fields) > 1 {
b = append(b, '(')
}
- for i, f := range q.table.PKs {
+ for i, f := range fields {
if i > 0 {
b = append(b, ", "...)
}
@@ -784,7 +825,7 @@ func (q *whereBaseQuery) appendWherePKSlice(
b = f.AppendValue(fmter, b, el)
}
}
- if len(q.table.PKs) > 1 {
+ if len(fields) > 1 {
b = append(b, ')')
}
}