summaryrefslogtreecommitdiff
path: root/vendor/github.com/uptrace/bun/query_update.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/uptrace/bun/query_update.go')
-rw-r--r--vendor/github.com/uptrace/bun/query_update.go64
1 files changed, 58 insertions, 6 deletions
diff --git a/vendor/github.com/uptrace/bun/query_update.go b/vendor/github.com/uptrace/bun/query_update.go
index dbe06799d..b415ff201 100644
--- a/vendor/github.com/uptrace/bun/query_update.go
+++ b/vendor/github.com/uptrace/bun/query_update.go
@@ -6,6 +6,8 @@ import (
"errors"
"fmt"
+ "github.com/uptrace/bun/dialect"
+
"github.com/uptrace/bun/dialect/feature"
"github.com/uptrace/bun/internal"
"github.com/uptrace/bun/schema"
@@ -16,6 +18,7 @@ type UpdateQuery struct {
returningQuery
customValueQuery
setQuery
+ idxHintsQuery
omitZero bool
}
@@ -204,6 +207,11 @@ func (q *UpdateQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, e
return nil, err
}
+ b, err = q.appendIndexHints(fmter, b)
+ if err != nil {
+ return nil, err
+ }
+
b, err = q.mustAppendSet(fmter, b)
if err != nil {
return nil, err
@@ -273,7 +281,13 @@ func (q *UpdateQuery) appendSetStruct(
isTemplate := fmter.IsNop()
pos := len(b)
for _, f := range fields {
- if q.omitZero && f.HasZeroValue(model.strct) {
+ if f.SkipUpdate() {
+ continue
+ }
+
+ app, hasValue := q.modelValues[f.Name]
+
+ if !hasValue && q.omitZero && f.HasZeroValue(model.strct) {
continue
}
@@ -290,8 +304,7 @@ func (q *UpdateQuery) appendSetStruct(
continue
}
- app, ok := q.modelValues[f.Name]
- if ok {
+ if hasValue {
b, err = app.AppendQuery(fmter, b)
if err != nil {
return nil, err
@@ -487,12 +500,32 @@ func (q *UpdateQuery) hasTableAlias(fmter schema.Formatter) bool {
return fmter.HasFeature(feature.UpdateMultiTable | feature.UpdateTableAlias)
}
+func (q *UpdateQuery) String() string {
+ buf, err := q.AppendQuery(q.db.Formatter(), nil)
+ if err != nil {
+ panic(err)
+ }
+
+ return string(buf)
+}
+
//------------------------------------------------------------------------------
+
+func (q *UpdateQuery) QueryBuilder() QueryBuilder {
+ return &updateQueryBuilder{q}
+}
+
+func (q *UpdateQuery) ApplyQueryBuilder(fn func(QueryBuilder) QueryBuilder) *UpdateQuery {
+ return fn(q.QueryBuilder()).Unwrap().(*UpdateQuery)
+}
+
type updateQueryBuilder struct {
*UpdateQuery
}
-func (q *updateQueryBuilder) WhereGroup(sep string, fn func(QueryBuilder) QueryBuilder) QueryBuilder {
+func (q *updateQueryBuilder) WhereGroup(
+ sep string, fn func(QueryBuilder) QueryBuilder,
+) QueryBuilder {
q.UpdateQuery = q.UpdateQuery.WhereGroup(sep, func(qs *UpdateQuery) *UpdateQuery {
return fn(q).(*updateQueryBuilder).UpdateQuery
})
@@ -528,6 +561,25 @@ func (q *updateQueryBuilder) Unwrap() interface{} {
return q.UpdateQuery
}
-func (q *UpdateQuery) Query() QueryBuilder {
- return &updateQueryBuilder{q}
+//------------------------------------------------------------------------------
+
+func (q *UpdateQuery) UseIndex(indexes ...string) *UpdateQuery {
+ if q.db.dialect.Name() == dialect.MySQL {
+ q.addUseIndex(indexes...)
+ }
+ return q
+}
+
+func (q *UpdateQuery) IgnoreIndex(indexes ...string) *UpdateQuery {
+ if q.db.dialect.Name() == dialect.MySQL {
+ q.addIgnoreIndex(indexes...)
+ }
+ return q
+}
+
+func (q *UpdateQuery) ForceIndex(indexes ...string) *UpdateQuery {
+ if q.db.dialect.Name() == dialect.MySQL {
+ q.addForceIndex(indexes...)
+ }
+ return q
}