diff options
Diffstat (limited to 'vendor/github.com/uptrace/bun/query_update.go')
-rw-r--r-- | vendor/github.com/uptrace/bun/query_update.go | 64 |
1 files changed, 58 insertions, 6 deletions
diff --git a/vendor/github.com/uptrace/bun/query_update.go b/vendor/github.com/uptrace/bun/query_update.go index dbe06799d..b415ff201 100644 --- a/vendor/github.com/uptrace/bun/query_update.go +++ b/vendor/github.com/uptrace/bun/query_update.go @@ -6,6 +6,8 @@ import ( "errors" "fmt" + "github.com/uptrace/bun/dialect" + "github.com/uptrace/bun/dialect/feature" "github.com/uptrace/bun/internal" "github.com/uptrace/bun/schema" @@ -16,6 +18,7 @@ type UpdateQuery struct { returningQuery customValueQuery setQuery + idxHintsQuery omitZero bool } @@ -204,6 +207,11 @@ func (q *UpdateQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, e return nil, err } + b, err = q.appendIndexHints(fmter, b) + if err != nil { + return nil, err + } + b, err = q.mustAppendSet(fmter, b) if err != nil { return nil, err @@ -273,7 +281,13 @@ func (q *UpdateQuery) appendSetStruct( isTemplate := fmter.IsNop() pos := len(b) for _, f := range fields { - if q.omitZero && f.HasZeroValue(model.strct) { + if f.SkipUpdate() { + continue + } + + app, hasValue := q.modelValues[f.Name] + + if !hasValue && q.omitZero && f.HasZeroValue(model.strct) { continue } @@ -290,8 +304,7 @@ func (q *UpdateQuery) appendSetStruct( continue } - app, ok := q.modelValues[f.Name] - if ok { + if hasValue { b, err = app.AppendQuery(fmter, b) if err != nil { return nil, err @@ -487,12 +500,32 @@ func (q *UpdateQuery) hasTableAlias(fmter schema.Formatter) bool { return fmter.HasFeature(feature.UpdateMultiTable | feature.UpdateTableAlias) } +func (q *UpdateQuery) String() string { + buf, err := q.AppendQuery(q.db.Formatter(), nil) + if err != nil { + panic(err) + } + + return string(buf) +} + //------------------------------------------------------------------------------ + +func (q *UpdateQuery) QueryBuilder() QueryBuilder { + return &updateQueryBuilder{q} +} + +func (q *UpdateQuery) ApplyQueryBuilder(fn func(QueryBuilder) QueryBuilder) *UpdateQuery { + return fn(q.QueryBuilder()).Unwrap().(*UpdateQuery) +} + type updateQueryBuilder struct { *UpdateQuery } -func (q *updateQueryBuilder) WhereGroup(sep string, fn func(QueryBuilder) QueryBuilder) QueryBuilder { +func (q *updateQueryBuilder) WhereGroup( + sep string, fn func(QueryBuilder) QueryBuilder, +) QueryBuilder { q.UpdateQuery = q.UpdateQuery.WhereGroup(sep, func(qs *UpdateQuery) *UpdateQuery { return fn(q).(*updateQueryBuilder).UpdateQuery }) @@ -528,6 +561,25 @@ func (q *updateQueryBuilder) Unwrap() interface{} { return q.UpdateQuery } -func (q *UpdateQuery) Query() QueryBuilder { - return &updateQueryBuilder{q} +//------------------------------------------------------------------------------ + +func (q *UpdateQuery) UseIndex(indexes ...string) *UpdateQuery { + if q.db.dialect.Name() == dialect.MySQL { + q.addUseIndex(indexes...) + } + return q +} + +func (q *UpdateQuery) IgnoreIndex(indexes ...string) *UpdateQuery { + if q.db.dialect.Name() == dialect.MySQL { + q.addIgnoreIndex(indexes...) + } + return q +} + +func (q *UpdateQuery) ForceIndex(indexes ...string) *UpdateQuery { + if q.db.dialect.Name() == dialect.MySQL { + q.addForceIndex(indexes...) + } + return q } |