summaryrefslogtreecommitdiff
path: root/vendor/github.com/uptrace/bun/query_update.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/uptrace/bun/query_update.go')
-rw-r--r--vendor/github.com/uptrace/bun/query_update.go34
1 files changed, 25 insertions, 9 deletions
diff --git a/vendor/github.com/uptrace/bun/query_update.go b/vendor/github.com/uptrace/bun/query_update.go
index ea74e1419..1e032c548 100644
--- a/vendor/github.com/uptrace/bun/query_update.go
+++ b/vendor/github.com/uptrace/bun/query_update.go
@@ -90,13 +90,13 @@ func (q *UpdateQuery) Set(query string, args ...interface{}) *UpdateQuery {
return q
}
-// Value overwrites model value for the column in INSERT and UPDATE queries.
-func (q *UpdateQuery) Value(column string, value string, args ...interface{}) *UpdateQuery {
+// Value overwrites model value for the column.
+func (q *UpdateQuery) Value(column string, expr string, args ...interface{}) *UpdateQuery {
if q.table == nil {
q.err = errNilModel
return q
}
- q.addValue(q.table, column, value, args)
+ q.addValue(q.table, column, expr, args)
return q
}
@@ -321,20 +321,36 @@ func (q *UpdateQuery) Bulk() *UpdateQuery {
return q
}
- return q.With("_data", q.db.NewValues(model)).
+ set, err := q.updateSliceSet(q.db.fmter, model)
+ if err != nil {
+ q.setErr(err)
+ return q
+ }
+
+ values := q.db.NewValues(model)
+ values.customValueQuery = q.customValueQuery
+
+ return q.With("_data", values).
Model(model).
TableExpr("_data").
- Set(q.updateSliceSet(model)).
+ Set(set).
Where(q.updateSliceWhere(model))
}
-func (q *UpdateQuery) updateSliceSet(model *sliceTableModel) string {
+func (q *UpdateQuery) updateSliceSet(
+ fmter schema.Formatter, model *sliceTableModel,
+) (string, error) {
+ fields, err := q.getDataFields()
+ if err != nil {
+ return "", err
+ }
+
var b []byte
- for i, field := range model.table.DataFields {
+ for i, field := range fields {
if i > 0 {
b = append(b, ", "...)
}
- if q.db.fmter.HasFeature(feature.UpdateMultiTable) {
+ if fmter.HasFeature(feature.UpdateMultiTable) {
b = append(b, model.table.SQLAlias...)
b = append(b, '.')
}
@@ -342,7 +358,7 @@ func (q *UpdateQuery) updateSliceSet(model *sliceTableModel) string {
b = append(b, " = _data."...)
b = append(b, field.SQLName...)
}
- return internal.String(b)
+ return internal.String(b), nil
}
func (db *UpdateQuery) updateSliceWhere(model *sliceTableModel) string {