summaryrefslogtreecommitdiff
path: root/vendor/github.com/uptrace/bun/query_insert.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/uptrace/bun/query_insert.go')
-rw-r--r--vendor/github.com/uptrace/bun/query_insert.go23
1 files changed, 14 insertions, 9 deletions
diff --git a/vendor/github.com/uptrace/bun/query_insert.go b/vendor/github.com/uptrace/bun/query_insert.go
index efddee407..d02633bf6 100644
--- a/vendor/github.com/uptrace/bun/query_insert.go
+++ b/vendor/github.com/uptrace/bun/query_insert.go
@@ -5,6 +5,7 @@ import (
"database/sql"
"fmt"
"reflect"
+ "strings"
"github.com/uptrace/bun/dialect/feature"
"github.com/uptrace/bun/internal"
@@ -16,7 +17,7 @@ type InsertQuery struct {
returningQuery
customValueQuery
- onConflict schema.QueryWithArgs
+ on schema.QueryWithArgs
setQuery
ignore bool
@@ -88,13 +89,13 @@ func (q *InsertQuery) ExcludeColumn(columns ...string) *InsertQuery {
return q
}
-// Value overwrites model value for the column in INSERT and UPDATE queries.
-func (q *InsertQuery) Value(column string, value string, args ...interface{}) *InsertQuery {
+// Value overwrites model value for the column.
+func (q *InsertQuery) Value(column string, expr string, args ...interface{}) *InsertQuery {
if q.table == nil {
q.err = errNilModel
return q
}
- q.addValue(q.table, column, value, args)
+ q.addValue(q.table, column, expr, args)
return q
}
@@ -162,7 +163,7 @@ func (q *InsertQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, e
}
b = append(b, "INTO "...)
- if q.db.features.Has(feature.InsertTableAlias) && !q.onConflict.IsZero() {
+ if q.db.features.Has(feature.InsertTableAlias) && !q.on.IsZero() {
b, err = q.appendFirstTableWithAlias(fmter, b)
} else {
b, err = q.appendFirstTable(fmter, b)
@@ -382,7 +383,7 @@ func (q *InsertQuery) appendFields(
//------------------------------------------------------------------------------
func (q *InsertQuery) On(s string, args ...interface{}) *InsertQuery {
- q.onConflict = schema.SafeQuery(s, args)
+ q.on = schema.SafeQuery(s, args)
return q
}
@@ -392,12 +393,12 @@ func (q *InsertQuery) Set(query string, args ...interface{}) *InsertQuery {
}
func (q *InsertQuery) appendOn(fmter schema.Formatter, b []byte) (_ []byte, err error) {
- if q.onConflict.IsZero() {
+ if q.on.IsZero() {
return b, nil
}
b = append(b, " ON "...)
- b, err = q.onConflict.AppendQuery(fmter, b)
+ b, err = q.on.AppendQuery(fmter, b)
if err != nil {
return nil, err
}
@@ -413,7 +414,7 @@ func (q *InsertQuery) appendOn(fmter schema.Formatter, b []byte) (_ []byte, err
if err != nil {
return nil, err
}
- } else if len(q.columns) > 0 {
+ } else if q.onConflictDoUpdate() {
fields, err := q.getDataFields()
if err != nil {
return nil, err
@@ -434,6 +435,10 @@ func (q *InsertQuery) appendOn(fmter schema.Formatter, b []byte) (_ []byte, err
return b, nil
}
+func (q *InsertQuery) onConflictDoUpdate() bool {
+ return strings.HasSuffix(strings.ToUpper(q.on.Query), " DO UPDATE")
+}
+
func (q *InsertQuery) appendSetExcluded(b []byte, fields []*schema.Field) []byte {
b = append(b, " SET "...)
for i, f := range fields {