diff options
author | 2021-11-13 12:29:08 +0100 | |
---|---|---|
committer | 2021-11-13 12:29:08 +0100 | |
commit | 829a934d23ab221049b4d54926305d8d5d64c9ad (patch) | |
tree | f4e382b289c113d3ba8a3c7a183507a5609c46c0 /vendor/github.com/uptrace/bun/query_base.go | |
parent | smtp + email confirmation (#285) (diff) | |
download | gotosocial-829a934d23ab221049b4d54926305d8d5d64c9ad.tar.xz |
update dependencies (#296)
Diffstat (limited to 'vendor/github.com/uptrace/bun/query_base.go')
-rw-r--r-- | vendor/github.com/uptrace/bun/query_base.go | 99 |
1 files changed, 70 insertions, 29 deletions
diff --git a/vendor/github.com/uptrace/bun/query_base.go b/vendor/github.com/uptrace/bun/query_base.go index 9739d93e2..4cf31d04e 100644 --- a/vendor/github.com/uptrace/bun/query_base.go +++ b/vendor/github.com/uptrace/bun/query_base.go @@ -14,8 +14,7 @@ import ( ) const ( - wherePKFlag internal.Flag = 1 << iota - forceDeleteFlag + forceDeleteFlag internal.Flag = 1 << iota deletedFlag allWithDeletedFlag ) @@ -580,7 +579,8 @@ func formatterWithModel(fmter schema.Formatter, model schema.NamedArgAppender) s type whereBaseQuery struct { baseQuery - where []schema.QueryWithSep + where []schema.QueryWithSep + whereFields []*schema.Field } func (q *whereBaseQuery) addWhere(where schema.QueryWithSep) { @@ -601,10 +601,46 @@ func (q *whereBaseQuery) addWhereGroup(sep string, where []schema.QueryWithSep) q.addWhere(schema.SafeQueryWithSep("", nil, ")")) } +func (q *whereBaseQuery) addWhereCols(cols []string) { + if q.table == nil { + err := fmt.Errorf("bun: got %T, but WherePK requires a struct or slice-based model", q.model) + q.setErr(err) + return + } + + var fields []*schema.Field + + if cols == nil { + if err := q.table.CheckPKs(); err != nil { + q.setErr(err) + return + } + fields = q.table.PKs + } else { + fields = make([]*schema.Field, len(cols)) + for i, col := range cols { + field, err := q.table.Field(col) + if err != nil { + q.setErr(err) + return + } + fields[i] = field + } + } + + if q.whereFields != nil { + err := errors.New("bun: WherePK can only be called once") + q.setErr(err) + return + } + + q.whereFields = fields +} + func (q *whereBaseQuery) mustAppendWhere( fmter schema.Formatter, b []byte, withAlias bool, ) ([]byte, error) { - if len(q.where) == 0 && !q.flags.Has(wherePKFlag) { + if len(q.where) == 0 && q.whereFields == nil { err := errors.New("bun: Update and Delete queries require at least one Where") return nil, err } @@ -614,7 +650,7 @@ func (q *whereBaseQuery) mustAppendWhere( func (q *whereBaseQuery) appendWhere( fmter schema.Formatter, b []byte, withAlias bool, ) (_ []byte, err error) { - if len(q.where) == 0 && !q.isSoftDelete() && !q.flags.Has(wherePKFlag) { + if len(q.where) == 0 && q.whereFields == nil && !q.isSoftDelete() { return b, nil } @@ -656,11 +692,11 @@ func (q *whereBaseQuery) appendWhere( } } - if q.flags.Has(wherePKFlag) { + if q.whereFields != nil { if len(b) > startLen { b = append(b, " AND "...) } - b, err = q.appendWherePK(fmter, b, withAlias) + b, err = q.appendWhereFields(fmter, b, q.whereFields, withAlias) if err != nil { return nil, err } @@ -691,29 +727,30 @@ func appendWhere( return b, nil } -func (q *whereBaseQuery) appendWherePK( - fmter schema.Formatter, b []byte, withAlias bool, +func (q *whereBaseQuery) appendWhereFields( + fmter schema.Formatter, b []byte, fields []*schema.Field, withAlias bool, ) (_ []byte, err error) { if q.table == nil { - err := fmt.Errorf("bun: got %T, but WherePK requires a struct or slice-based model", q.model) - return nil, err - } - if err := q.table.CheckPKs(); err != nil { + err := fmt.Errorf("bun: got %T, but WherePK requires struct or slice-based model", q.model) return nil, err } switch model := q.tableModel.(type) { case *structTableModel: - return q.appendWherePKStruct(fmter, b, model, withAlias) + return q.appendWhereStructFields(fmter, b, model, fields, withAlias) case *sliceTableModel: - return q.appendWherePKSlice(fmter, b, model, withAlias) + return q.appendWhereSliceFields(fmter, b, model, fields, withAlias) + default: + return nil, fmt.Errorf("bun: WhereColumn does not support %T", q.tableModel) } - - return nil, fmt.Errorf("bun: WherePK does not support %T", q.tableModel) } -func (q *whereBaseQuery) appendWherePKStruct( - fmter schema.Formatter, b []byte, model *structTableModel, withAlias bool, +func (q *whereBaseQuery) appendWhereStructFields( + fmter schema.Formatter, + b []byte, + model *structTableModel, + fields []*schema.Field, + withAlias bool, ) (_ []byte, err error) { if !model.strct.IsValid() { return nil, errNilModel @@ -721,7 +758,7 @@ func (q *whereBaseQuery) appendWherePKStruct( isTemplate := fmter.IsNop() b = append(b, '(') - for i, f := range q.table.PKs { + for i, f := range fields { if i > 0 { b = append(b, " AND "...) } @@ -741,18 +778,22 @@ func (q *whereBaseQuery) appendWherePKStruct( return b, nil } -func (q *whereBaseQuery) appendWherePKSlice( - fmter schema.Formatter, b []byte, model *sliceTableModel, withAlias bool, +func (q *whereBaseQuery) appendWhereSliceFields( + fmter schema.Formatter, + b []byte, + model *sliceTableModel, + fields []*schema.Field, + withAlias bool, ) (_ []byte, err error) { - if len(q.table.PKs) > 1 { + if len(fields) > 1 { b = append(b, '(') } if withAlias { - b = appendColumns(b, q.table.SQLAlias, q.table.PKs) + b = appendColumns(b, q.table.SQLAlias, fields) } else { - b = appendColumns(b, "", q.table.PKs) + b = appendColumns(b, "", fields) } - if len(q.table.PKs) > 1 { + if len(fields) > 1 { b = append(b, ')') } @@ -771,10 +812,10 @@ func (q *whereBaseQuery) appendWherePKSlice( el := indirect(slice.Index(i)) - if len(q.table.PKs) > 1 { + if len(fields) > 1 { b = append(b, '(') } - for i, f := range q.table.PKs { + for i, f := range fields { if i > 0 { b = append(b, ", "...) } @@ -784,7 +825,7 @@ func (q *whereBaseQuery) appendWherePKSlice( b = f.AppendValue(fmter, b, el) } } - if len(q.table.PKs) > 1 { + if len(fields) > 1 { b = append(b, ')') } } |