diff options
Diffstat (limited to 'vendor/github.com/uptrace/bun/query_insert.go')
-rw-r--r-- | vendor/github.com/uptrace/bun/query_insert.go | 23 |
1 files changed, 14 insertions, 9 deletions
diff --git a/vendor/github.com/uptrace/bun/query_insert.go b/vendor/github.com/uptrace/bun/query_insert.go index efddee407..d02633bf6 100644 --- a/vendor/github.com/uptrace/bun/query_insert.go +++ b/vendor/github.com/uptrace/bun/query_insert.go @@ -5,6 +5,7 @@ import ( "database/sql" "fmt" "reflect" + "strings" "github.com/uptrace/bun/dialect/feature" "github.com/uptrace/bun/internal" @@ -16,7 +17,7 @@ type InsertQuery struct { returningQuery customValueQuery - onConflict schema.QueryWithArgs + on schema.QueryWithArgs setQuery ignore bool @@ -88,13 +89,13 @@ func (q *InsertQuery) ExcludeColumn(columns ...string) *InsertQuery { return q } -// Value overwrites model value for the column in INSERT and UPDATE queries. -func (q *InsertQuery) Value(column string, value string, args ...interface{}) *InsertQuery { +// Value overwrites model value for the column. +func (q *InsertQuery) Value(column string, expr string, args ...interface{}) *InsertQuery { if q.table == nil { q.err = errNilModel return q } - q.addValue(q.table, column, value, args) + q.addValue(q.table, column, expr, args) return q } @@ -162,7 +163,7 @@ func (q *InsertQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, e } b = append(b, "INTO "...) - if q.db.features.Has(feature.InsertTableAlias) && !q.onConflict.IsZero() { + if q.db.features.Has(feature.InsertTableAlias) && !q.on.IsZero() { b, err = q.appendFirstTableWithAlias(fmter, b) } else { b, err = q.appendFirstTable(fmter, b) @@ -382,7 +383,7 @@ func (q *InsertQuery) appendFields( //------------------------------------------------------------------------------ func (q *InsertQuery) On(s string, args ...interface{}) *InsertQuery { - q.onConflict = schema.SafeQuery(s, args) + q.on = schema.SafeQuery(s, args) return q } @@ -392,12 +393,12 @@ func (q *InsertQuery) Set(query string, args ...interface{}) *InsertQuery { } func (q *InsertQuery) appendOn(fmter schema.Formatter, b []byte) (_ []byte, err error) { - if q.onConflict.IsZero() { + if q.on.IsZero() { return b, nil } b = append(b, " ON "...) - b, err = q.onConflict.AppendQuery(fmter, b) + b, err = q.on.AppendQuery(fmter, b) if err != nil { return nil, err } @@ -413,7 +414,7 @@ func (q *InsertQuery) appendOn(fmter schema.Formatter, b []byte) (_ []byte, err if err != nil { return nil, err } - } else if len(q.columns) > 0 { + } else if q.onConflictDoUpdate() { fields, err := q.getDataFields() if err != nil { return nil, err @@ -434,6 +435,10 @@ func (q *InsertQuery) appendOn(fmter schema.Formatter, b []byte) (_ []byte, err return b, nil } +func (q *InsertQuery) onConflictDoUpdate() bool { + return strings.HasSuffix(strings.ToUpper(q.on.Query), " DO UPDATE") +} + func (q *InsertQuery) appendSetExcluded(b []byte, fields []*schema.Field) []byte { b = append(b, " SET "...) for i, f := range fields { |