summaryrefslogtreecommitdiff
path: root/vendor/github.com/uptrace/bun/query_update.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/uptrace/bun/query_update.go')
-rw-r--r--vendor/github.com/uptrace/bun/query_update.go90
1 files changed, 77 insertions, 13 deletions
diff --git a/vendor/github.com/uptrace/bun/query_update.go b/vendor/github.com/uptrace/bun/query_update.go
index f4554e4d3..dbe06799d 100644
--- a/vendor/github.com/uptrace/bun/query_update.go
+++ b/vendor/github.com/uptrace/bun/query_update.go
@@ -92,13 +92,21 @@ func (q *UpdateQuery) Set(query string, args ...interface{}) *UpdateQuery {
return q
}
+func (q *UpdateQuery) SetColumn(column string, query string, args ...interface{}) *UpdateQuery {
+ if q.db.HasFeature(feature.UpdateMultiTable) {
+ column = q.table.Alias + "." + column
+ }
+ q.addSet(schema.SafeQuery(column+" = "+query, args))
+ return q
+}
+
// Value overwrites model value for the column.
-func (q *UpdateQuery) Value(column string, expr string, args ...interface{}) *UpdateQuery {
+func (q *UpdateQuery) Value(column string, query string, args ...interface{}) *UpdateQuery {
if q.table == nil {
q.err = errNilModel
return q
}
- q.addValue(q.table, column, expr, args)
+ q.addValue(q.table, column, query, args)
return q
}
@@ -187,8 +195,10 @@ func (q *UpdateQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, e
if fmter.HasFeature(feature.UpdateMultiTable) {
b, err = q.appendTablesWithAlias(fmter, b)
- } else {
+ } else if fmter.HasFeature(feature.UpdateTableAlias) {
b, err = q.appendFirstTableWithAlias(fmter, b)
+ } else {
+ b, err = q.appendFirstTable(fmter, b)
}
if err != nil {
return nil, err
@@ -206,12 +216,13 @@ func (q *UpdateQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, e
}
}
- b, err = q.mustAppendWhere(fmter, b, true)
+ b, err = q.mustAppendWhere(fmter, b, q.hasTableAlias(fmter))
if err != nil {
return nil, err
}
- if len(q.returning) > 0 {
+ if q.hasFeature(feature.Returning) && q.hasReturning() {
+ b = append(b, " RETURNING "...)
b, err = q.appendReturning(fmter, b)
if err != nil {
return nil, err
@@ -344,7 +355,7 @@ func (q *UpdateQuery) Bulk() *UpdateQuery {
Model(model).
TableExpr("_data").
Set(set).
- Where(q.updateSliceWhere(model))
+ Where(q.updateSliceWhere(q.db.fmter, model))
}
func (q *UpdateQuery) updateSliceSet(
@@ -371,13 +382,17 @@ func (q *UpdateQuery) updateSliceSet(
return internal.String(b), nil
}
-func (db *UpdateQuery) updateSliceWhere(model *sliceTableModel) string {
+func (q *UpdateQuery) updateSliceWhere(fmter schema.Formatter, model *sliceTableModel) string {
var b []byte
for i, pk := range model.table.PKs {
if i > 0 {
b = append(b, " AND "...)
}
- b = append(b, model.table.SQLAlias...)
+ if q.hasTableAlias(fmter) {
+ b = append(b, model.table.SQLAlias...)
+ } else {
+ b = append(b, model.table.SQLName...)
+ }
b = append(b, '.')
b = append(b, pk.SQLName...)
b = append(b, " = _data."...)
@@ -456,14 +471,63 @@ func (q *UpdateQuery) afterUpdateHook(ctx context.Context) error {
return nil
}
-// FQN returns a fully qualified column name. For MySQL, it returns the column name with
-// the table alias. For other RDBMS, it returns just the column name.
+// FQN returns a fully qualified column name, for example, table_name.column_name or
+// table_alias.column_alias.
func (q *UpdateQuery) FQN(column string) Ident {
if q.table == nil {
- panic("UpdateQuery.FQN requires a model")
+ panic("UpdateQuery.SetName requires a model")
}
- if q.db.HasFeature(feature.UpdateMultiTable) {
+ if q.hasTableAlias(q.db.fmter) {
return Ident(q.table.Alias + "." + column)
}
- return Ident(column)
+ return Ident(q.table.Name + "." + column)
+}
+
+func (q *UpdateQuery) hasTableAlias(fmter schema.Formatter) bool {
+ return fmter.HasFeature(feature.UpdateMultiTable | feature.UpdateTableAlias)
+}
+
+//------------------------------------------------------------------------------
+type updateQueryBuilder struct {
+ *UpdateQuery
+}
+
+func (q *updateQueryBuilder) WhereGroup(sep string, fn func(QueryBuilder) QueryBuilder) QueryBuilder {
+ q.UpdateQuery = q.UpdateQuery.WhereGroup(sep, func(qs *UpdateQuery) *UpdateQuery {
+ return fn(q).(*updateQueryBuilder).UpdateQuery
+ })
+ return q
+}
+
+func (q *updateQueryBuilder) Where(query string, args ...interface{}) QueryBuilder {
+ q.UpdateQuery.Where(query, args...)
+ return q
+}
+
+func (q *updateQueryBuilder) WhereOr(query string, args ...interface{}) QueryBuilder {
+ q.UpdateQuery.WhereOr(query, args...)
+ return q
+}
+
+func (q *updateQueryBuilder) WhereDeleted() QueryBuilder {
+ q.UpdateQuery.WhereDeleted()
+ return q
+}
+
+func (q *updateQueryBuilder) WhereAllWithDeleted() QueryBuilder {
+ q.UpdateQuery.WhereAllWithDeleted()
+ return q
+}
+
+func (q *updateQueryBuilder) WherePK(cols ...string) QueryBuilder {
+ q.UpdateQuery.WherePK(cols...)
+ return q
+}
+
+func (q *updateQueryBuilder) Unwrap() interface{} {
+ return q.UpdateQuery
+}
+
+func (q *UpdateQuery) Query() QueryBuilder {
+ return &updateQueryBuilder{q}
}