summaryrefslogtreecommitdiff
path: root/vendor/github.com/uptrace/bun/query_insert.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/uptrace/bun/query_insert.go')
-rw-r--r--vendor/github.com/uptrace/bun/query_insert.go59
1 files changed, 54 insertions, 5 deletions
diff --git a/vendor/github.com/uptrace/bun/query_insert.go b/vendor/github.com/uptrace/bun/query_insert.go
index e56089d24..42ee49962 100644
--- a/vendor/github.com/uptrace/bun/query_insert.go
+++ b/vendor/github.com/uptrace/bun/query_insert.go
@@ -84,6 +84,11 @@ func (q *InsertQuery) Column(columns ...string) *InsertQuery {
return q
}
+func (q *InsertQuery) ColumnExpr(query string, args ...interface{}) *InsertQuery {
+ q.addColumn(schema.SafeQuery(query, args))
+ return q
+}
+
func (q *InsertQuery) ExcludeColumn(columns ...string) *InsertQuery {
q.excludeColumn(columns)
return q
@@ -113,7 +118,7 @@ func (q *InsertQuery) WhereOr(query string, args ...interface{}) *InsertQuery {
// Returning adds a RETURNING clause to the query.
//
-// To suppress the auto-generated RETURNING clause, use `Returning("NULL")`.
+// To suppress the auto-generated RETURNING clause, use `Returning("")`.
func (q *InsertQuery) Returning(query string, args ...interface{}) *InsertQuery {
q.addReturning(schema.SafeQuery(query, args))
return q
@@ -147,9 +152,6 @@ func (q *InsertQuery) Operation() string {
}
func (q *InsertQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) {
- if q.err != nil {
- return nil, q.err
- }
fmter = formatterWithModel(fmter, q)
b, err = q.appendWith(fmter, b)
@@ -209,7 +211,18 @@ func (q *InsertQuery) appendColumnsValues(
b = append(b, ")"...)
}
- b = append(b, " SELECT * FROM "...)
+ b = append(b, " SELECT "...)
+
+ if q.columns != nil {
+ b, err = q.appendColumns(fmter, b)
+ if err != nil {
+ return nil, err
+ }
+ } else {
+ b = append(b, "*"...)
+ }
+
+ b = append(b, " FROM "...)
b, err = q.appendOtherTables(fmter, b)
if err != nil {
return nil, err
@@ -429,6 +442,17 @@ func (q *InsertQuery) appendOn(fmter schema.Formatter, b []byte) (_ []byte, err
}
b = q.appendSetExcluded(b, fields)
+ } else if q.onDuplicateKeyUpdate() {
+ fields, err := q.getDataFields()
+ if err != nil {
+ return nil, err
+ }
+
+ if len(fields) == 0 {
+ fields = q.tableModel.Table().DataFields
+ }
+
+ b = q.appendSetValues(b, fields)
}
if len(q.where) > 0 {
@@ -447,6 +471,10 @@ func (q *InsertQuery) onConflictDoUpdate() bool {
return strings.HasSuffix(strings.ToUpper(q.on.Query), " DO UPDATE")
}
+func (q *InsertQuery) onDuplicateKeyUpdate() bool {
+ return strings.ToUpper(q.on.Query) == "DUPLICATE KEY UPDATE"
+}
+
func (q *InsertQuery) appendSetExcluded(b []byte, fields []*schema.Field) []byte {
b = append(b, " SET "...)
for i, f := range fields {
@@ -460,6 +488,20 @@ func (q *InsertQuery) appendSetExcluded(b []byte, fields []*schema.Field) []byte
return b
}
+func (q *InsertQuery) appendSetValues(b []byte, fields []*schema.Field) []byte {
+ b = append(b, " "...)
+ for i, f := range fields {
+ if i > 0 {
+ b = append(b, ", "...)
+ }
+ b = append(b, f.SQLName...)
+ b = append(b, " = VALUES("...)
+ b = append(b, f.SQLName...)
+ b = append(b, ")"...)
+ }
+ return b
+}
+
//------------------------------------------------------------------------------
func (q *InsertQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) {
@@ -469,6 +511,13 @@ func (q *InsertQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result
}
}
+ if q.err != nil {
+ return nil, q.err
+ }
+ if err := q.beforeAppendModel(ctx, q); err != nil {
+ return nil, err
+ }
+
queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes())
if err != nil {
return nil, err