diff options
Diffstat (limited to 'vendor/github.com/go-pg/pg/v10/orm/update.go')
-rw-r--r-- | vendor/github.com/go-pg/pg/v10/orm/update.go | 378 |
1 files changed, 378 insertions, 0 deletions
diff --git a/vendor/github.com/go-pg/pg/v10/orm/update.go b/vendor/github.com/go-pg/pg/v10/orm/update.go new file mode 100644 index 000000000..ce6396fd3 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/orm/update.go @@ -0,0 +1,378 @@ +package orm + +import ( + "fmt" + "reflect" + "sort" + + "github.com/go-pg/pg/v10/types" +) + +type UpdateQuery struct { + q *Query + omitZero bool + placeholder bool +} + +var ( + _ QueryAppender = (*UpdateQuery)(nil) + _ QueryCommand = (*UpdateQuery)(nil) +) + +func NewUpdateQuery(q *Query, omitZero bool) *UpdateQuery { + return &UpdateQuery{ + q: q, + omitZero: omitZero, + } +} + +func (q *UpdateQuery) String() string { + b, err := q.AppendQuery(defaultFmter, nil) + if err != nil { + panic(err) + } + return string(b) +} + +func (q *UpdateQuery) Operation() QueryOp { + return UpdateOp +} + +func (q *UpdateQuery) Clone() QueryCommand { + return &UpdateQuery{ + q: q.q.Clone(), + omitZero: q.omitZero, + placeholder: q.placeholder, + } +} + +func (q *UpdateQuery) Query() *Query { + return q.q +} + +func (q *UpdateQuery) AppendTemplate(b []byte) ([]byte, error) { + cp := q.Clone().(*UpdateQuery) + cp.placeholder = true + return cp.AppendQuery(dummyFormatter{}, b) +} + +func (q *UpdateQuery) AppendQuery(fmter QueryFormatter, b []byte) (_ []byte, err error) { + if q.q.stickyErr != nil { + return nil, q.q.stickyErr + } + + if len(q.q.with) > 0 { + b, err = q.q.appendWith(fmter, b) + if err != nil { + return nil, err + } + } + + b = append(b, "UPDATE "...) + + b, err = q.q.appendFirstTableWithAlias(fmter, b) + if err != nil { + return nil, err + } + + b, err = q.mustAppendSet(fmter, b) + if err != nil { + return nil, err + } + + isSliceModelWithData := q.q.isSliceModelWithData() + if isSliceModelWithData || q.q.hasMultiTables() { + b = append(b, " FROM "...) + b, err = q.q.appendOtherTables(fmter, b) + if err != nil { + return nil, err + } + + if isSliceModelWithData { + b, err = q.appendSliceModelData(fmter, b) + if err != nil { + return nil, err + } + } + } + + b, err = q.mustAppendWhere(fmter, b, isSliceModelWithData) + if err != nil { + return nil, err + } + + if len(q.q.returning) > 0 { + b, err = q.q.appendReturning(fmter, b) + if err != nil { + return nil, err + } + } + + return b, q.q.stickyErr +} + +func (q *UpdateQuery) mustAppendWhere( + fmter QueryFormatter, b []byte, isSliceModelWithData bool, +) (_ []byte, err error) { + b = append(b, " WHERE "...) + + if !isSliceModelWithData { + return q.q.mustAppendWhere(fmter, b) + } + + if len(q.q.where) > 0 { + return q.q.appendWhere(fmter, b) + } + + table := q.q.tableModel.Table() + err = table.checkPKs() + if err != nil { + return nil, err + } + + b = appendWhereColumnAndColumn(b, table.Alias, table.PKs) + return b, nil +} + +func (q *UpdateQuery) mustAppendSet(fmter QueryFormatter, b []byte) (_ []byte, err error) { + if len(q.q.set) > 0 { + return q.q.appendSet(fmter, b) + } + + b = append(b, " SET "...) + + if m, ok := q.q.model.(*mapModel); ok { + return q.appendMapSet(b, m.m), nil + } + + if !q.q.hasTableModel() { + return nil, errModelNil + } + + value := q.q.tableModel.Value() + if value.Kind() == reflect.Struct { + b, err = q.appendSetStruct(fmter, b, value) + } else { + if value.Len() > 0 { + b, err = q.appendSetSlice(b) + } else { + err = fmt.Errorf("pg: can't bulk-update empty slice %s", value.Type()) + } + } + if err != nil { + return nil, err + } + + return b, nil +} + +func (q *UpdateQuery) appendMapSet(b []byte, m map[string]interface{}) []byte { + keys := make([]string, 0, len(m)) + + for k := range m { + keys = append(keys, k) + } + sort.Strings(keys) + + for i, k := range keys { + if i > 0 { + b = append(b, ", "...) + } + + b = types.AppendIdent(b, k, 1) + b = append(b, " = "...) + if q.placeholder { + b = append(b, '?') + } else { + b = types.Append(b, m[k], 1) + } + } + + return b +} + +func (q *UpdateQuery) appendSetStruct(fmter QueryFormatter, b []byte, strct reflect.Value) ([]byte, error) { + fields, err := q.q.getFields() + if err != nil { + return nil, err + } + + if len(fields) == 0 { + fields = q.q.tableModel.Table().DataFields + } + + pos := len(b) + for _, f := range fields { + if q.omitZero && f.NullZero() && f.HasZeroValue(strct) { + continue + } + + if len(b) != pos { + b = append(b, ", "...) + pos = len(b) + } + + b = append(b, f.Column...) + b = append(b, " = "...) + + if q.placeholder { + b = append(b, '?') + continue + } + + app, ok := q.q.modelValues[f.SQLName] + if ok { + b, err = app.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } else { + b = f.AppendValue(b, strct, 1) + } + } + + for i, v := range q.q.extraValues { + if i > 0 || len(fields) > 0 { + b = append(b, ", "...) + } + + b = append(b, v.column...) + b = append(b, " = "...) + + b, err = v.value.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + + return b, nil +} + +func (q *UpdateQuery) appendSetSlice(b []byte) ([]byte, error) { + fields, err := q.q.getFields() + if err != nil { + return nil, err + } + + if len(fields) == 0 { + fields = q.q.tableModel.Table().DataFields + } + + var table *Table + if q.omitZero { + table = q.q.tableModel.Table() + } + + for i, f := range fields { + if i > 0 { + b = append(b, ", "...) + } + + b = append(b, f.Column...) + b = append(b, " = "...) + if q.omitZero && table != nil { + b = append(b, "COALESCE("...) + } + b = append(b, "_data."...) + b = append(b, f.Column...) + if q.omitZero && table != nil { + b = append(b, ", "...) + if table.Alias != table.SQLName { + b = append(b, table.Alias...) + b = append(b, '.') + } + b = append(b, f.Column...) + b = append(b, ")"...) + } + } + + return b, nil +} + +func (q *UpdateQuery) appendSliceModelData(fmter QueryFormatter, b []byte) ([]byte, error) { + columns, err := q.q.getDataFields() + if err != nil { + return nil, err + } + + if len(columns) > 0 { + columns = append(columns, q.q.tableModel.Table().PKs...) + } else { + columns = q.q.tableModel.Table().Fields + } + + return q.appendSliceValues(fmter, b, columns, q.q.tableModel.Value()) +} + +func (q *UpdateQuery) appendSliceValues( + fmter QueryFormatter, b []byte, fields []*Field, slice reflect.Value, +) (_ []byte, err error) { + b = append(b, "(VALUES ("...) + + if q.placeholder { + b, err = q.appendValues(fmter, b, fields, reflect.Value{}) + if err != nil { + return nil, err + } + } else { + sliceLen := slice.Len() + for i := 0; i < sliceLen; i++ { + if i > 0 { + b = append(b, "), ("...) + } + b, err = q.appendValues(fmter, b, fields, slice.Index(i)) + if err != nil { + return nil, err + } + } + } + + b = append(b, ")) AS _data("...) + b = appendColumns(b, "", fields) + b = append(b, ")"...) + + return b, nil +} + +func (q *UpdateQuery) appendValues( + fmter QueryFormatter, b []byte, fields []*Field, strct reflect.Value, +) (_ []byte, err error) { + for i, f := range fields { + if i > 0 { + b = append(b, ", "...) + } + + app, ok := q.q.modelValues[f.SQLName] + if ok { + b, err = app.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + continue + } + + if q.placeholder { + b = append(b, '?') + } else { + b = f.AppendValue(b, indirect(strct), 1) + } + + b = append(b, "::"...) + b = append(b, f.SQLType...) + } + return b, nil +} + +func appendWhereColumnAndColumn(b []byte, alias types.Safe, fields []*Field) []byte { + for i, f := range fields { + if i > 0 { + b = append(b, " AND "...) + } + b = append(b, alias...) + b = append(b, '.') + b = append(b, f.Column...) + b = append(b, " = _data."...) + b = append(b, f.Column...) + } + return b +} |