diff options
Diffstat (limited to 'vendor/github.com/uptrace/bun/query_insert.go')
-rw-r--r-- | vendor/github.com/uptrace/bun/query_insert.go | 59 |
1 files changed, 54 insertions, 5 deletions
diff --git a/vendor/github.com/uptrace/bun/query_insert.go b/vendor/github.com/uptrace/bun/query_insert.go index e56089d24..42ee49962 100644 --- a/vendor/github.com/uptrace/bun/query_insert.go +++ b/vendor/github.com/uptrace/bun/query_insert.go @@ -84,6 +84,11 @@ func (q *InsertQuery) Column(columns ...string) *InsertQuery { return q } +func (q *InsertQuery) ColumnExpr(query string, args ...interface{}) *InsertQuery { + q.addColumn(schema.SafeQuery(query, args)) + return q +} + func (q *InsertQuery) ExcludeColumn(columns ...string) *InsertQuery { q.excludeColumn(columns) return q @@ -113,7 +118,7 @@ func (q *InsertQuery) WhereOr(query string, args ...interface{}) *InsertQuery { // Returning adds a RETURNING clause to the query. // -// To suppress the auto-generated RETURNING clause, use `Returning("NULL")`. +// To suppress the auto-generated RETURNING clause, use `Returning("")`. func (q *InsertQuery) Returning(query string, args ...interface{}) *InsertQuery { q.addReturning(schema.SafeQuery(query, args)) return q @@ -147,9 +152,6 @@ func (q *InsertQuery) Operation() string { } func (q *InsertQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) { - if q.err != nil { - return nil, q.err - } fmter = formatterWithModel(fmter, q) b, err = q.appendWith(fmter, b) @@ -209,7 +211,18 @@ func (q *InsertQuery) appendColumnsValues( b = append(b, ")"...) } - b = append(b, " SELECT * FROM "...) + b = append(b, " SELECT "...) + + if q.columns != nil { + b, err = q.appendColumns(fmter, b) + if err != nil { + return nil, err + } + } else { + b = append(b, "*"...) + } + + b = append(b, " FROM "...) b, err = q.appendOtherTables(fmter, b) if err != nil { return nil, err @@ -429,6 +442,17 @@ func (q *InsertQuery) appendOn(fmter schema.Formatter, b []byte) (_ []byte, err } b = q.appendSetExcluded(b, fields) + } else if q.onDuplicateKeyUpdate() { + fields, err := q.getDataFields() + if err != nil { + return nil, err + } + + if len(fields) == 0 { + fields = q.tableModel.Table().DataFields + } + + b = q.appendSetValues(b, fields) } if len(q.where) > 0 { @@ -447,6 +471,10 @@ func (q *InsertQuery) onConflictDoUpdate() bool { return strings.HasSuffix(strings.ToUpper(q.on.Query), " DO UPDATE") } +func (q *InsertQuery) onDuplicateKeyUpdate() bool { + return strings.ToUpper(q.on.Query) == "DUPLICATE KEY UPDATE" +} + func (q *InsertQuery) appendSetExcluded(b []byte, fields []*schema.Field) []byte { b = append(b, " SET "...) for i, f := range fields { @@ -460,6 +488,20 @@ func (q *InsertQuery) appendSetExcluded(b []byte, fields []*schema.Field) []byte return b } +func (q *InsertQuery) appendSetValues(b []byte, fields []*schema.Field) []byte { + b = append(b, " "...) + for i, f := range fields { + if i > 0 { + b = append(b, ", "...) + } + b = append(b, f.SQLName...) + b = append(b, " = VALUES("...) + b = append(b, f.SQLName...) + b = append(b, ")"...) + } + return b +} + //------------------------------------------------------------------------------ func (q *InsertQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) { @@ -469,6 +511,13 @@ func (q *InsertQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result } } + if q.err != nil { + return nil, q.err + } + if err := q.beforeAppendModel(ctx, q); err != nil { + return nil, err + } + queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes()) if err != nil { return nil, err |