diff options
author | 2022-04-24 12:26:22 +0200 | |
---|---|---|
committer | 2022-04-24 12:26:22 +0200 | |
commit | 88979b35d462516e1765524d70a41c0d26eec911 (patch) | |
tree | fd37cb19317217e226ee7717824f24031f53b031 /vendor/github.com/uptrace/bun/query_insert.go | |
parent | Revert "[chore] Tidy up federating db locks a tiny bit (#472)" (#479) (diff) | |
download | gotosocial-88979b35d462516e1765524d70a41c0d26eec911.tar.xz |
[chore] Update bun and sqlite dependencies (#478)
* update bun + sqlite versions
* step bun to v1.1.3
Diffstat (limited to 'vendor/github.com/uptrace/bun/query_insert.go')
-rw-r--r-- | vendor/github.com/uptrace/bun/query_insert.go | 69 |
1 files changed, 43 insertions, 26 deletions
diff --git a/vendor/github.com/uptrace/bun/query_insert.go b/vendor/github.com/uptrace/bun/query_insert.go index 6300ab815..fdbe0c275 100644 --- a/vendor/github.com/uptrace/bun/query_insert.go +++ b/vendor/github.com/uptrace/bun/query_insert.go @@ -126,13 +126,6 @@ func (q *InsertQuery) Returning(query string, args ...interface{}) *InsertQuery return q } -func (q *InsertQuery) hasReturning() bool { - if !q.db.features.Has(feature.Returning) { - return false - } - return q.returningQuery.hasReturning() -} - //------------------------------------------------------------------------------ // Ignore generates different queries depending on the DBMS: @@ -148,7 +141,7 @@ func (q *InsertQuery) Ignore() *InsertQuery { return q } -// Replaces generates a `REPLACE INTO` query (MySQL). +// Replaces generates a `REPLACE INTO` query (MySQL and MariaDB). func (q *InsertQuery) Replace() *InsertQuery { q.replace = true return q @@ -201,7 +194,8 @@ func (q *InsertQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, e return nil, err } - if q.hasReturning() { + if q.hasFeature(feature.InsertReturning) && q.hasReturning() { + b = append(b, " RETURNING "...) b, err = q.appendReturning(fmter, b) if err != nil { return nil, err @@ -224,6 +218,14 @@ func (q *InsertQuery) appendColumnsValues( b = append(b, ")"...) } + if q.hasFeature(feature.Output) && q.hasReturning() { + b = append(b, " OUTPUT "...) + b, err = q.appendOutput(fmter, b) + if err != nil { + return nil, err + } + } + b = append(b, " SELECT "...) if q.columns != nil { @@ -255,6 +257,7 @@ func (q *InsertQuery) appendColumnsValues( return nil, errNilModel } + // Build fields to populate RETURNING clause. fields, err := q.getFields() if err != nil { return nil, err @@ -262,7 +265,17 @@ func (q *InsertQuery) appendColumnsValues( b = append(b, " ("...) b = q.appendFields(fmter, b, fields) - b = append(b, ") VALUES ("...) + b = append(b, ")"...) + + if q.hasFeature(feature.Output) && q.hasReturning() { + b = append(b, " OUTPUT "...) + b, err = q.appendOutput(fmter, b) + if err != nil { + return nil, err + } + } + + b = append(b, " VALUES ("...) switch model := q.tableModel.(type) { case *structTableModel: @@ -306,7 +319,7 @@ func (q *InsertQuery) appendStructValues( switch { case isTemplate: b = append(b, '?') - case f.NullZero && f.HasZeroValue(strct): + case (f.IsPtr && f.HasNilValue(strct)) || (f.NullZero && f.HasZeroValue(strct)): if q.db.features.Has(feature.DefaultPlaceholder) { b = append(b, "DEFAULT"...) } else if f.SQLDefault != "" { @@ -353,22 +366,13 @@ func (q *InsertQuery) appendSliceValues( } } - for i, v := range q.extraValues { - if i > 0 || len(fields) > 0 { - b = append(b, ", "...) - } - - b, err = v.value.AppendQuery(fmter, b) - if err != nil { - return nil, err - } - } - return b, nil } func (q *InsertQuery) getFields() ([]*schema.Field, error) { - if q.db.features.Has(feature.DefaultPlaceholder) || len(q.columns) > 0 { + hasIdentity := q.db.features.Has(feature.Identity) + + if len(q.columns) > 0 || q.db.features.Has(feature.DefaultPlaceholder) && !hasIdentity { return q.baseQuery.getFields() } @@ -382,15 +386,23 @@ func (q *InsertQuery) getFields() ([]*schema.Field, error) { return nil, fmt.Errorf("bun: Insert(empty %T)", model.slice.Type()) } strct = indirect(model.slice.Index(0)) + default: + return nil, errNilModel } fields := make([]*schema.Field, 0, len(q.table.Fields)) for _, f := range q.table.Fields { - if f.NotNull && f.NullZero && f.SQLDefault == "" && f.HasZeroValue(strct) { + if hasIdentity && f.AutoIncrement { q.addReturningField(f) continue } + if f.NotNull && f.SQLDefault == "" { + if (f.IsPtr && f.HasNilValue(strct)) || (f.NullZero && f.HasZeroValue(strct)) { + q.addReturningField(f) + continue + } + } fields = append(fields, f) } @@ -539,7 +551,8 @@ func (q *InsertQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result query := internal.String(queryBytes) var res sql.Result - if hasDest := len(dest) > 0; hasDest || q.hasReturning() { + if hasDest := len(dest) > 0; hasDest || + (q.hasReturning() && q.hasFeature(feature.InsertReturning|feature.Output)) { model, err := q.getModel(dest) if err != nil { return nil, err @@ -588,7 +601,11 @@ func (q *InsertQuery) afterInsertHook(ctx context.Context) error { } func (q *InsertQuery) tryLastInsertID(res sql.Result, dest []interface{}) error { - if q.db.features.Has(feature.Returning) || q.table == nil || len(q.table.PKs) != 1 { + if q.db.features.Has(feature.Returning) || + q.db.features.Has(feature.Output) || + q.table == nil || + len(q.table.PKs) != 1 || + !q.table.PKs[0].AutoIncrement { return nil } |