summaryrefslogtreecommitdiff
path: root/vendor/github.com/uptrace/bun/query_insert.go
diff options
context:
space:
mode:
authorLibravatar tobi <31960611+tsmethurst@users.noreply.github.com>2022-04-24 12:26:22 +0200
committerLibravatar GitHub <noreply@github.com>2022-04-24 12:26:22 +0200
commit88979b35d462516e1765524d70a41c0d26eec911 (patch)
treefd37cb19317217e226ee7717824f24031f53b031 /vendor/github.com/uptrace/bun/query_insert.go
parentRevert "[chore] Tidy up federating db locks a tiny bit (#472)" (#479) (diff)
downloadgotosocial-88979b35d462516e1765524d70a41c0d26eec911.tar.xz
[chore] Update bun and sqlite dependencies (#478)
* update bun + sqlite versions * step bun to v1.1.3
Diffstat (limited to 'vendor/github.com/uptrace/bun/query_insert.go')
-rw-r--r--vendor/github.com/uptrace/bun/query_insert.go69
1 files changed, 43 insertions, 26 deletions
diff --git a/vendor/github.com/uptrace/bun/query_insert.go b/vendor/github.com/uptrace/bun/query_insert.go
index 6300ab815..fdbe0c275 100644
--- a/vendor/github.com/uptrace/bun/query_insert.go
+++ b/vendor/github.com/uptrace/bun/query_insert.go
@@ -126,13 +126,6 @@ func (q *InsertQuery) Returning(query string, args ...interface{}) *InsertQuery
return q
}
-func (q *InsertQuery) hasReturning() bool {
- if !q.db.features.Has(feature.Returning) {
- return false
- }
- return q.returningQuery.hasReturning()
-}
-
//------------------------------------------------------------------------------
// Ignore generates different queries depending on the DBMS:
@@ -148,7 +141,7 @@ func (q *InsertQuery) Ignore() *InsertQuery {
return q
}
-// Replaces generates a `REPLACE INTO` query (MySQL).
+// Replaces generates a `REPLACE INTO` query (MySQL and MariaDB).
func (q *InsertQuery) Replace() *InsertQuery {
q.replace = true
return q
@@ -201,7 +194,8 @@ func (q *InsertQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, e
return nil, err
}
- if q.hasReturning() {
+ if q.hasFeature(feature.InsertReturning) && q.hasReturning() {
+ b = append(b, " RETURNING "...)
b, err = q.appendReturning(fmter, b)
if err != nil {
return nil, err
@@ -224,6 +218,14 @@ func (q *InsertQuery) appendColumnsValues(
b = append(b, ")"...)
}
+ if q.hasFeature(feature.Output) && q.hasReturning() {
+ b = append(b, " OUTPUT "...)
+ b, err = q.appendOutput(fmter, b)
+ if err != nil {
+ return nil, err
+ }
+ }
+
b = append(b, " SELECT "...)
if q.columns != nil {
@@ -255,6 +257,7 @@ func (q *InsertQuery) appendColumnsValues(
return nil, errNilModel
}
+ // Build fields to populate RETURNING clause.
fields, err := q.getFields()
if err != nil {
return nil, err
@@ -262,7 +265,17 @@ func (q *InsertQuery) appendColumnsValues(
b = append(b, " ("...)
b = q.appendFields(fmter, b, fields)
- b = append(b, ") VALUES ("...)
+ b = append(b, ")"...)
+
+ if q.hasFeature(feature.Output) && q.hasReturning() {
+ b = append(b, " OUTPUT "...)
+ b, err = q.appendOutput(fmter, b)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ b = append(b, " VALUES ("...)
switch model := q.tableModel.(type) {
case *structTableModel:
@@ -306,7 +319,7 @@ func (q *InsertQuery) appendStructValues(
switch {
case isTemplate:
b = append(b, '?')
- case f.NullZero && f.HasZeroValue(strct):
+ case (f.IsPtr && f.HasNilValue(strct)) || (f.NullZero && f.HasZeroValue(strct)):
if q.db.features.Has(feature.DefaultPlaceholder) {
b = append(b, "DEFAULT"...)
} else if f.SQLDefault != "" {
@@ -353,22 +366,13 @@ func (q *InsertQuery) appendSliceValues(
}
}
- for i, v := range q.extraValues {
- if i > 0 || len(fields) > 0 {
- b = append(b, ", "...)
- }
-
- b, err = v.value.AppendQuery(fmter, b)
- if err != nil {
- return nil, err
- }
- }
-
return b, nil
}
func (q *InsertQuery) getFields() ([]*schema.Field, error) {
- if q.db.features.Has(feature.DefaultPlaceholder) || len(q.columns) > 0 {
+ hasIdentity := q.db.features.Has(feature.Identity)
+
+ if len(q.columns) > 0 || q.db.features.Has(feature.DefaultPlaceholder) && !hasIdentity {
return q.baseQuery.getFields()
}
@@ -382,15 +386,23 @@ func (q *InsertQuery) getFields() ([]*schema.Field, error) {
return nil, fmt.Errorf("bun: Insert(empty %T)", model.slice.Type())
}
strct = indirect(model.slice.Index(0))
+ default:
+ return nil, errNilModel
}
fields := make([]*schema.Field, 0, len(q.table.Fields))
for _, f := range q.table.Fields {
- if f.NotNull && f.NullZero && f.SQLDefault == "" && f.HasZeroValue(strct) {
+ if hasIdentity && f.AutoIncrement {
q.addReturningField(f)
continue
}
+ if f.NotNull && f.SQLDefault == "" {
+ if (f.IsPtr && f.HasNilValue(strct)) || (f.NullZero && f.HasZeroValue(strct)) {
+ q.addReturningField(f)
+ continue
+ }
+ }
fields = append(fields, f)
}
@@ -539,7 +551,8 @@ func (q *InsertQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result
query := internal.String(queryBytes)
var res sql.Result
- if hasDest := len(dest) > 0; hasDest || q.hasReturning() {
+ if hasDest := len(dest) > 0; hasDest ||
+ (q.hasReturning() && q.hasFeature(feature.InsertReturning|feature.Output)) {
model, err := q.getModel(dest)
if err != nil {
return nil, err
@@ -588,7 +601,11 @@ func (q *InsertQuery) afterInsertHook(ctx context.Context) error {
}
func (q *InsertQuery) tryLastInsertID(res sql.Result, dest []interface{}) error {
- if q.db.features.Has(feature.Returning) || q.table == nil || len(q.table.PKs) != 1 {
+ if q.db.features.Has(feature.Returning) ||
+ q.db.features.Has(feature.Output) ||
+ q.table == nil ||
+ len(q.table.PKs) != 1 ||
+ !q.table.PKs[0].AutoIncrement {
return nil
}