diff options
Diffstat (limited to 'vendor/github.com/uptrace/bun')
42 files changed, 1320 insertions, 1150 deletions
diff --git a/vendor/github.com/uptrace/bun/CHANGELOG.md b/vendor/github.com/uptrace/bun/CHANGELOG.md index a5ae7761a..ded6f1f40 100644 --- a/vendor/github.com/uptrace/bun/CHANGELOG.md +++ b/vendor/github.com/uptrace/bun/CHANGELOG.md @@ -1,3 +1,70 @@ +## [1.2.5](https://github.com/uptrace/bun/compare/v1.2.3...v1.2.5) (2024-10-26) + + +### Bug Fixes + +* allow Limit() without Order() with MSSQL ([#1009](https://github.com/uptrace/bun/issues/1009)) ([1a46ddc](https://github.com/uptrace/bun/commit/1a46ddc0d3ca0bdc60ca8be5ad1886799d14c8b0)) +* copy bytes in mapModel.Scan ([#1030](https://github.com/uptrace/bun/issues/1030)) ([#1032](https://github.com/uptrace/bun/issues/1032)) ([39fda4e](https://github.com/uptrace/bun/commit/39fda4e3d341e59e4955f751cb354a939e57c1b1)) +* fix issue with has-many join and pointer fields ([#950](https://github.com/uptrace/bun/issues/950)) ([#983](https://github.com/uptrace/bun/issues/983)) ([cbc5177](https://github.com/uptrace/bun/commit/cbc517792ba6cdcef1828f3699d3d4dfe3c5e0eb)) +* restore explicit column: name override ([#984](https://github.com/uptrace/bun/issues/984)) ([169f258](https://github.com/uptrace/bun/commit/169f258a9460cad451f3025d2ef8df1bbd42a003)) +* return column option back ([#1036](https://github.com/uptrace/bun/issues/1036)) ([a3ccbea](https://github.com/uptrace/bun/commit/a3ccbeab39151d3eed6cb245fe15cfb5d71ba557)) +* sql.NullString mistaken as custom struct ([#1019](https://github.com/uptrace/bun/issues/1019)) ([87c77b8](https://github.com/uptrace/bun/commit/87c77b8911f2035b0ee8ea96356a2c7600b5b94d)) +* typos ([#1026](https://github.com/uptrace/bun/issues/1026)) ([760de7d](https://github.com/uptrace/bun/commit/760de7d0fad15dc761475670a4dde056aef9210d)) + + +### Features + +* add transaction isolation level support to pgdriver ([#1034](https://github.com/uptrace/bun/issues/1034)) ([3ef44ce](https://github.com/uptrace/bun/commit/3ef44ce1cdd969a21b76d6c803119cf12c375cb0)) + + +### Performance Improvements + +* refactor SelectQuery.ScanAndCount to optimize performance when there is no limit and offset ([#1035](https://github.com/uptrace/bun/issues/1035)) ([8638613](https://github.com/uptrace/bun/commit/86386135897485bbada6c50ec9a2743626111433)) + + + +## [1.2.4](https://github.com/uptrace/bun/compare/v1.2.3...v1.2.4) (2024-10-26) + + +### Bug Fixes + +* allow Limit() without Order() with MSSQL ([#1009](https://github.com/uptrace/bun/issues/1009)) ([1a46ddc](https://github.com/uptrace/bun/commit/1a46ddc0d3ca0bdc60ca8be5ad1886799d14c8b0)) +* copy bytes in mapModel.Scan ([#1030](https://github.com/uptrace/bun/issues/1030)) ([#1032](https://github.com/uptrace/bun/issues/1032)) ([39fda4e](https://github.com/uptrace/bun/commit/39fda4e3d341e59e4955f751cb354a939e57c1b1)) +* return column option back ([#1036](https://github.com/uptrace/bun/issues/1036)) ([a3ccbea](https://github.com/uptrace/bun/commit/a3ccbeab39151d3eed6cb245fe15cfb5d71ba557)) +* sql.NullString mistaken as custom struct ([#1019](https://github.com/uptrace/bun/issues/1019)) ([87c77b8](https://github.com/uptrace/bun/commit/87c77b8911f2035b0ee8ea96356a2c7600b5b94d)) +* typos ([#1026](https://github.com/uptrace/bun/issues/1026)) ([760de7d](https://github.com/uptrace/bun/commit/760de7d0fad15dc761475670a4dde056aef9210d)) + + +### Features + +* add transaction isolation level support to pgdriver ([#1034](https://github.com/uptrace/bun/issues/1034)) ([3ef44ce](https://github.com/uptrace/bun/commit/3ef44ce1cdd969a21b76d6c803119cf12c375cb0)) + + +### Performance Improvements + +* refactor SelectQuery.ScanAndCount to optimize performance when there is no limit and offset ([#1035](https://github.com/uptrace/bun/issues/1035)) ([8638613](https://github.com/uptrace/bun/commit/86386135897485bbada6c50ec9a2743626111433)) + + + +## [1.2.3](https://github.com/uptrace/bun/compare/v1.2.2...v1.2.3) (2024-08-31) + + + +## [1.2.2](https://github.com/uptrace/bun/compare/v1.2.1...v1.2.2) (2024-08-29) + + +### Bug Fixes + +* gracefully handle empty hstore in pgdialect ([#1010](https://github.com/uptrace/bun/issues/1010)) ([2f73d8a](https://github.com/uptrace/bun/commit/2f73d8a8e16c8718ebfc956036d9c9a01a0888bc)) +* number each unit test ([#974](https://github.com/uptrace/bun/issues/974)) ([b005dc2](https://github.com/uptrace/bun/commit/b005dc2a9034715c6f59dcfc8e76aa3b85df38ab)) + + +### Features + +* add ModelTableExpr to TruncateTableQuery ([#969](https://github.com/uptrace/bun/issues/969)) ([7bc330f](https://github.com/uptrace/bun/commit/7bc330f152cf0d9dc30956478e2731ea5816f012)) + + + ## [1.2.1](https://github.com/uptrace/bun/compare/v1.2.0...v1.2.1) (2024-04-02) @@ -14,7 +81,7 @@ ### Features -* Allow overiding of Warn and Deprecated loggers ([#952](https://github.com/uptrace/bun/issues/952)) ([0e9d737](https://github.com/uptrace/bun/commit/0e9d737e4ca2deb86930237ee32a39cf3f7e8157)) +* Allow overriding of Warn and Deprecated loggers ([#952](https://github.com/uptrace/bun/issues/952)) ([0e9d737](https://github.com/uptrace/bun/commit/0e9d737e4ca2deb86930237ee32a39cf3f7e8157)) * enable SNI ([#953](https://github.com/uptrace/bun/issues/953)) ([4071ffb](https://github.com/uptrace/bun/commit/4071ffb5bcb1b233cda239c92504d8139dcf1d2f)) * **idb:** add NewMerge method to IDB ([#966](https://github.com/uptrace/bun/issues/966)) ([664e2f1](https://github.com/uptrace/bun/commit/664e2f154f1153d2a80cd062a5074f1692edaee7)) @@ -100,7 +167,7 @@ ### Bug Fixes -* add support for inserting values with unicode encoding for mssql dialect ([e98c6c0](https://github.com/uptrace/bun/commit/e98c6c0f033b553bea3bbc783aa56c2eaa17718f)) +* add support for inserting values with Unicode encoding for mssql dialect ([e98c6c0](https://github.com/uptrace/bun/commit/e98c6c0f033b553bea3bbc783aa56c2eaa17718f)) * fix relation tag ([a3eedff](https://github.com/uptrace/bun/commit/a3eedff49700490d4998dcdcdc04f554d8f17166)) @@ -136,7 +203,7 @@ ### Bug Fixes -* addng dialect override for append-bool ([#695](https://github.com/uptrace/bun/issues/695)) ([338f2f0](https://github.com/uptrace/bun/commit/338f2f04105ad89e64530db86aeb387e2ad4789e)) +* adding dialect override for append-bool ([#695](https://github.com/uptrace/bun/issues/695)) ([338f2f0](https://github.com/uptrace/bun/commit/338f2f04105ad89e64530db86aeb387e2ad4789e)) * don't call hooks twice for whereExists ([9057857](https://github.com/uptrace/bun/commit/90578578e717f248e4b6eb114c5b495fd8d4ed41)) * don't lock migrations when running Migrate and Rollback ([69a7354](https://github.com/uptrace/bun/commit/69a7354d987ff2ed5338c9ef5f4ce320724299ab)) * **query:** make WhereDeleted compatible with ForceDelete ([299c3fd](https://github.com/uptrace/bun/commit/299c3fd57866aaecd127a8f219c95332898475db)), closes [#673](https://github.com/uptrace/bun/issues/673) @@ -304,7 +371,7 @@ recommended to upgrade to v1.0.24 before upgrading to v1.1.x. - append slice values ([4a65129](https://github.com/uptrace/bun/commit/4a651294fb0f1e73079553024810c3ead9777311)) -- check for nils when appeding driver.Value +- check for nils when appending driver.Value ([7bb1640](https://github.com/uptrace/bun/commit/7bb1640a00fceca1e1075fe6544b9a4842ab2b26)) - cleanup soft deletes for mssql ([e72e2c5](https://github.com/uptrace/bun/commit/e72e2c5d0a85f3d26c3fa22c7284c2de1dcfda8e)) @@ -323,7 +390,7 @@ recommended to upgrade to v1.0.24 before upgrading to v1.1.x. ### Deprecated -In the comming v1.1.x release, Bun will stop automatically adding `,pk,autoincrement` options on +In the coming v1.1.x release, Bun will stop automatically adding `,pk,autoincrement` options on `ID int64/int32` fields. This version (v1.0.23) only prints a warning when it encounters such fields, but the code will continue working as before. @@ -441,7 +508,7 @@ In v1.1.x, such options as `,nopk` and `,allowzero` will not be necessary and wi ([693f1e1](https://github.com/uptrace/bun/commit/693f1e135999fc31cf83b99a2530a695b20f4e1b)) - add model embedding via embed:prefix\_ ([9a2cedc](https://github.com/uptrace/bun/commit/9a2cedc8b08fa8585d4bfced338bd0a40d736b1d)) -- change the default logoutput to stderr +- change the default log output to stderr ([4bf5773](https://github.com/uptrace/bun/commit/4bf577382f19c64457cbf0d64490401450954654)), closes [#349](https://github.com/uptrace/bun/issues/349) diff --git a/vendor/github.com/uptrace/bun/Makefile b/vendor/github.com/uptrace/bun/Makefile index fc295561c..50a1903e7 100644 --- a/vendor/github.com/uptrace/bun/Makefile +++ b/vendor/github.com/uptrace/bun/Makefile @@ -15,7 +15,7 @@ go_mod_tidy: echo "go mod tidy in $${dir}"; \ (cd "$${dir}" && \ go get -u ./... && \ - go mod tidy -go=1.21); \ + go mod tidy); \ done fmt: diff --git a/vendor/github.com/uptrace/bun/README.md b/vendor/github.com/uptrace/bun/README.md index 07a01aa61..dbe5bc0b4 100644 --- a/vendor/github.com/uptrace/bun/README.md +++ b/vendor/github.com/uptrace/bun/README.md @@ -1,4 +1,4 @@ -# SQL-first Golang ORM for PostgreSQL, MySQL, MSSQL, and SQLite +# SQL-first Golang ORM for PostgreSQL, MySQL, MSSQL, SQLite and Oracle [](https://github.com/uptrace/bun/actions) [](https://pkg.go.dev/github.com/uptrace/bun) @@ -19,6 +19,7 @@ [MySQL](https://bun.uptrace.dev/guide/drivers.html#mysql) (including MariaDB), [MSSQL](https://bun.uptrace.dev/guide/drivers.html#mssql), [SQLite](https://bun.uptrace.dev/guide/drivers.html#sqlite). + [Oracle](https://bun.uptrace.dev/guide/drivers.html#oracle). - [ORM-like](/example/basic/) experience using good old SQL. Bun supports structs, map, scalars, and slices of map/structs/scalars. - [Bulk inserts](https://bun.uptrace.dev/guide/query-insert.html). diff --git a/vendor/github.com/uptrace/bun/bun.go b/vendor/github.com/uptrace/bun/bun.go index 8f71db8fc..626f0bf4b 100644 --- a/vendor/github.com/uptrace/bun/bun.go +++ b/vendor/github.com/uptrace/bun/bun.go @@ -22,6 +22,10 @@ type ( AfterScanRowHook = schema.AfterScanRowHook ) +func SafeQuery(query string, args ...interface{}) schema.QueryWithArgs { + return schema.SafeQuery(query, args) +} + type BeforeSelectHook interface { BeforeSelect(ctx context.Context, query *SelectQuery) error } @@ -70,7 +74,7 @@ type AfterDropTableHook interface { AfterDropTable(ctx context.Context, query *DropTableQuery) error } -// SetLogger overwriters default Bun logger. +// SetLogger overwrites default Bun logger. func SetLogger(logger internal.Logging) { internal.SetLogger(logger) } diff --git a/vendor/github.com/uptrace/bun/dialect/dialect.go b/vendor/github.com/uptrace/bun/dialect/dialect.go index 03b81fbbc..4dde63c92 100644 --- a/vendor/github.com/uptrace/bun/dialect/dialect.go +++ b/vendor/github.com/uptrace/bun/dialect/dialect.go @@ -12,6 +12,8 @@ func (n Name) String() string { return "mysql" case MSSQL: return "mssql" + case Oracle: + return "oracle" default: return "invalid" } @@ -23,4 +25,5 @@ const ( SQLite MySQL MSSQL + Oracle ) diff --git a/vendor/github.com/uptrace/bun/dialect/pgdialect/append.go b/vendor/github.com/uptrace/bun/dialect/pgdialect/append.go index 7e9491abc..c95fa86e7 100644 --- a/vendor/github.com/uptrace/bun/dialect/pgdialect/append.go +++ b/vendor/github.com/uptrace/bun/dialect/pgdialect/append.go @@ -2,12 +2,9 @@ package pgdialect import ( "database/sql/driver" - "encoding/hex" "fmt" "reflect" - "strconv" "time" - "unicode/utf8" "github.com/uptrace/bun/dialect" "github.com/uptrace/bun/schema" @@ -32,316 +29,10 @@ var ( sliceTimeType = reflect.TypeOf([]time.Time(nil)) ) -func arrayAppend(fmter schema.Formatter, b []byte, v interface{}) []byte { - switch v := v.(type) { - case int64: - return strconv.AppendInt(b, v, 10) - case float64: - return dialect.AppendFloat64(b, v) - case bool: - return dialect.AppendBool(b, v) - case []byte: - return arrayAppendBytes(b, v) - case string: - return arrayAppendString(b, v) - case time.Time: - return fmter.Dialect().AppendTime(b, v) - default: - err := fmt.Errorf("pgdialect: can't append %T", v) - return dialect.AppendError(b, err) - } -} - -func arrayAppendStringValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte { - return arrayAppendString(b, v.String()) -} - -func arrayAppendBytesValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte { - return arrayAppendBytes(b, v.Bytes()) -} - -func arrayAppendDriverValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte { - iface, err := v.Interface().(driver.Valuer).Value() - if err != nil { - return dialect.AppendError(b, err) - } - return arrayAppend(fmter, b, iface) -} - -//------------------------------------------------------------------------------ - -func (d *Dialect) arrayAppender(typ reflect.Type) schema.AppenderFunc { - kind := typ.Kind() - - switch kind { - case reflect.Ptr: - if fn := d.arrayAppender(typ.Elem()); fn != nil { - return schema.PtrAppender(fn) - } - case reflect.Slice, reflect.Array: - // ok: - default: - return nil - } - - elemType := typ.Elem() - - if kind == reflect.Slice { - switch elemType { - case stringType: - return appendStringSliceValue - case intType: - return appendIntSliceValue - case int64Type: - return appendInt64SliceValue - case float64Type: - return appendFloat64SliceValue - case timeType: - return appendTimeSliceValue - } - } - - appendElem := d.arrayElemAppender(elemType) - if appendElem == nil { - panic(fmt.Errorf("pgdialect: %s is not supported", typ)) - } - - return func(fmter schema.Formatter, b []byte, v reflect.Value) []byte { - kind := v.Kind() - switch kind { - case reflect.Ptr, reflect.Slice: - if v.IsNil() { - return dialect.AppendNull(b) - } - } - - if kind == reflect.Ptr { - v = v.Elem() - } - - b = append(b, '\'') - - b = append(b, '{') - ln := v.Len() - for i := 0; i < ln; i++ { - elem := v.Index(i) - b = appendElem(fmter, b, elem) - b = append(b, ',') - } - if v.Len() > 0 { - b[len(b)-1] = '}' // Replace trailing comma. - } else { - b = append(b, '}') - } - - b = append(b, '\'') - - return b - } -} - -func (d *Dialect) arrayElemAppender(typ reflect.Type) schema.AppenderFunc { - if typ.Implements(driverValuerType) { - return arrayAppendDriverValue - } - switch typ.Kind() { - case reflect.String: - return arrayAppendStringValue - case reflect.Slice: - if typ.Elem().Kind() == reflect.Uint8 { - return arrayAppendBytesValue - } - } - return schema.Appender(d, typ) -} - -func appendStringSliceValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte { - ss := v.Convert(sliceStringType).Interface().([]string) - return appendStringSlice(b, ss) +func appendTime(buf []byte, tm time.Time) []byte { + return tm.UTC().AppendFormat(buf, "2006-01-02 15:04:05.999999-07:00") } -func appendStringSlice(b []byte, ss []string) []byte { - if ss == nil { - return dialect.AppendNull(b) - } - - b = append(b, '\'') - - b = append(b, '{') - for _, s := range ss { - b = arrayAppendString(b, s) - b = append(b, ',') - } - if len(ss) > 0 { - b[len(b)-1] = '}' // Replace trailing comma. - } else { - b = append(b, '}') - } - - b = append(b, '\'') - - return b -} - -func appendIntSliceValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte { - ints := v.Convert(sliceIntType).Interface().([]int) - return appendIntSlice(b, ints) -} - -func appendIntSlice(b []byte, ints []int) []byte { - if ints == nil { - return dialect.AppendNull(b) - } - - b = append(b, '\'') - - b = append(b, '{') - for _, n := range ints { - b = strconv.AppendInt(b, int64(n), 10) - b = append(b, ',') - } - if len(ints) > 0 { - b[len(b)-1] = '}' // Replace trailing comma. - } else { - b = append(b, '}') - } - - b = append(b, '\'') - - return b -} - -func appendInt64SliceValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte { - ints := v.Convert(sliceInt64Type).Interface().([]int64) - return appendInt64Slice(b, ints) -} - -func appendInt64Slice(b []byte, ints []int64) []byte { - if ints == nil { - return dialect.AppendNull(b) - } - - b = append(b, '\'') - - b = append(b, '{') - for _, n := range ints { - b = strconv.AppendInt(b, n, 10) - b = append(b, ',') - } - if len(ints) > 0 { - b[len(b)-1] = '}' // Replace trailing comma. - } else { - b = append(b, '}') - } - - b = append(b, '\'') - - return b -} - -func appendFloat64SliceValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte { - floats := v.Convert(sliceFloat64Type).Interface().([]float64) - return appendFloat64Slice(b, floats) -} - -func appendFloat64Slice(b []byte, floats []float64) []byte { - if floats == nil { - return dialect.AppendNull(b) - } - - b = append(b, '\'') - - b = append(b, '{') - for _, n := range floats { - b = dialect.AppendFloat64(b, n) - b = append(b, ',') - } - if len(floats) > 0 { - b[len(b)-1] = '}' // Replace trailing comma. - } else { - b = append(b, '}') - } - - b = append(b, '\'') - - return b -} - -//------------------------------------------------------------------------------ - -func arrayAppendBytes(b []byte, bs []byte) []byte { - if bs == nil { - return dialect.AppendNull(b) - } - - b = append(b, `"\\x`...) - - s := len(b) - b = append(b, make([]byte, hex.EncodedLen(len(bs)))...) - hex.Encode(b[s:], bs) - - b = append(b, '"') - - return b -} - -func arrayAppendString(b []byte, s string) []byte { - b = append(b, '"') - for _, r := range s { - switch r { - case 0: - // ignore - case '\'': - b = append(b, "''"...) - case '"': - b = append(b, '\\', '"') - case '\\': - b = append(b, '\\', '\\') - default: - if r < utf8.RuneSelf { - b = append(b, byte(r)) - break - } - l := len(b) - if cap(b)-l < utf8.UTFMax { - b = append(b, make([]byte, utf8.UTFMax)...) - } - n := utf8.EncodeRune(b[l:l+utf8.UTFMax], r) - b = b[:l+n] - } - } - b = append(b, '"') - return b -} - -func appendTimeSliceValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte { - ts := v.Convert(sliceTimeType).Interface().([]time.Time) - return appendTimeSlice(fmter, b, ts) -} - -func appendTimeSlice(fmter schema.Formatter, b []byte, ts []time.Time) []byte { - if ts == nil { - return dialect.AppendNull(b) - } - b = append(b, '\'') - b = append(b, '{') - for _, t := range ts { - b = append(b, '"') - b = t.UTC().AppendFormat(b, "2006-01-02 15:04:05.999999-07:00") - b = append(b, '"') - b = append(b, ',') - } - if len(ts) > 0 { - b[len(b)-1] = '}' // Replace trailing comma. - } else { - b = append(b, '}') - } - b = append(b, '\'') - return b -} - -//------------------------------------------------------------------------------ - var mapStringStringType = reflect.TypeOf(map[string]string(nil)) func (d *Dialect) hstoreAppender(typ reflect.Type) schema.AppenderFunc { diff --git a/vendor/github.com/uptrace/bun/dialect/pgdialect/array.go b/vendor/github.com/uptrace/bun/dialect/pgdialect/array.go index 281cff733..46b55659b 100644 --- a/vendor/github.com/uptrace/bun/dialect/pgdialect/array.go +++ b/vendor/github.com/uptrace/bun/dialect/pgdialect/array.go @@ -2,9 +2,16 @@ package pgdialect import ( "database/sql" + "database/sql/driver" + "encoding/hex" "fmt" "reflect" + "strconv" + "time" + "unicode/utf8" + "github.com/uptrace/bun/dialect" + "github.com/uptrace/bun/internal" "github.com/uptrace/bun/schema" ) @@ -20,7 +27,7 @@ type ArrayValue struct { // // For struct fields you can use array tag: // -// Emails []string `bun:",array"` +// Emails []string `bun:",array"` func Array(vi interface{}) *ArrayValue { v := reflect.ValueOf(vi) if !v.IsValid() { @@ -63,3 +70,576 @@ func (a *ArrayValue) Value() interface{} { } return nil } + +//------------------------------------------------------------------------------ + +func (d *Dialect) arrayAppender(typ reflect.Type) schema.AppenderFunc { + kind := typ.Kind() + + switch kind { + case reflect.Ptr: + if fn := d.arrayAppender(typ.Elem()); fn != nil { + return schema.PtrAppender(fn) + } + case reflect.Slice, reflect.Array: + // continue below + default: + return nil + } + + elemType := typ.Elem() + + if kind == reflect.Slice { + switch elemType { + case stringType: + return appendStringSliceValue + case intType: + return appendIntSliceValue + case int64Type: + return appendInt64SliceValue + case float64Type: + return appendFloat64SliceValue + case timeType: + return appendTimeSliceValue + } + } + + appendElem := d.arrayElemAppender(elemType) + if appendElem == nil { + panic(fmt.Errorf("pgdialect: %s is not supported", typ)) + } + + return func(fmter schema.Formatter, b []byte, v reflect.Value) []byte { + kind := v.Kind() + switch kind { + case reflect.Ptr, reflect.Slice: + if v.IsNil() { + return dialect.AppendNull(b) + } + } + + if kind == reflect.Ptr { + v = v.Elem() + } + + b = append(b, "'{"...) + + ln := v.Len() + for i := 0; i < ln; i++ { + elem := v.Index(i) + if i > 0 { + b = append(b, ',') + } + b = appendElem(fmter, b, elem) + } + + b = append(b, "}'"...) + + return b + } +} + +func (d *Dialect) arrayElemAppender(typ reflect.Type) schema.AppenderFunc { + if typ.Implements(driverValuerType) { + return arrayAppendDriverValue + } + switch typ.Kind() { + case reflect.String: + return arrayAppendStringValue + case reflect.Slice: + if typ.Elem().Kind() == reflect.Uint8 { + return arrayAppendBytesValue + } + } + return schema.Appender(d, typ) +} + +func arrayAppend(fmter schema.Formatter, b []byte, v interface{}) []byte { + switch v := v.(type) { + case int64: + return strconv.AppendInt(b, v, 10) + case float64: + return dialect.AppendFloat64(b, v) + case bool: + return dialect.AppendBool(b, v) + case []byte: + return arrayAppendBytes(b, v) + case string: + return arrayAppendString(b, v) + case time.Time: + return fmter.Dialect().AppendTime(b, v) + default: + err := fmt.Errorf("pgdialect: can't append %T", v) + return dialect.AppendError(b, err) + } +} + +func arrayAppendStringValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte { + return arrayAppendString(b, v.String()) +} + +func arrayAppendBytesValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte { + return arrayAppendBytes(b, v.Bytes()) +} + +func arrayAppendDriverValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte { + iface, err := v.Interface().(driver.Valuer).Value() + if err != nil { + return dialect.AppendError(b, err) + } + return arrayAppend(fmter, b, iface) +} + +func appendStringSliceValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte { + ss := v.Convert(sliceStringType).Interface().([]string) + return appendStringSlice(b, ss) +} + +func appendStringSlice(b []byte, ss []string) []byte { + if ss == nil { + return dialect.AppendNull(b) + } + + b = append(b, '\'') + + b = append(b, '{') + for _, s := range ss { + b = arrayAppendString(b, s) + b = append(b, ',') + } + if len(ss) > 0 { + b[len(b)-1] = '}' // Replace trailing comma. + } else { + b = append(b, '}') + } + + b = append(b, '\'') + + return b +} + +func appendIntSliceValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte { + ints := v.Convert(sliceIntType).Interface().([]int) + return appendIntSlice(b, ints) +} + +func appendIntSlice(b []byte, ints []int) []byte { + if ints == nil { + return dialect.AppendNull(b) + } + + b = append(b, '\'') + + b = append(b, '{') + for _, n := range ints { + b = strconv.AppendInt(b, int64(n), 10) + b = append(b, ',') + } + if len(ints) > 0 { + b[len(b)-1] = '}' // Replace trailing comma. + } else { + b = append(b, '}') + } + + b = append(b, '\'') + + return b +} + +func appendInt64SliceValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte { + ints := v.Convert(sliceInt64Type).Interface().([]int64) + return appendInt64Slice(b, ints) +} + +func appendInt64Slice(b []byte, ints []int64) []byte { + if ints == nil { + return dialect.AppendNull(b) + } + + b = append(b, '\'') + + b = append(b, '{') + for _, n := range ints { + b = strconv.AppendInt(b, n, 10) + b = append(b, ',') + } + if len(ints) > 0 { + b[len(b)-1] = '}' // Replace trailing comma. + } else { + b = append(b, '}') + } + + b = append(b, '\'') + + return b +} + +func appendFloat64SliceValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte { + floats := v.Convert(sliceFloat64Type).Interface().([]float64) + return appendFloat64Slice(b, floats) +} + +func appendFloat64Slice(b []byte, floats []float64) []byte { + if floats == nil { + return dialect.AppendNull(b) + } + + b = append(b, '\'') + + b = append(b, '{') + for _, n := range floats { + b = dialect.AppendFloat64(b, n) + b = append(b, ',') + } + if len(floats) > 0 { + b[len(b)-1] = '}' // Replace trailing comma. + } else { + b = append(b, '}') + } + + b = append(b, '\'') + + return b +} + +func appendTimeSliceValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte { + ts := v.Convert(sliceTimeType).Interface().([]time.Time) + return appendTimeSlice(fmter, b, ts) +} + +func appendTimeSlice(fmter schema.Formatter, b []byte, ts []time.Time) []byte { + if ts == nil { + return dialect.AppendNull(b) + } + b = append(b, '\'') + b = append(b, '{') + for _, t := range ts { + b = append(b, '"') + b = appendTime(b, t) + b = append(b, '"') + b = append(b, ',') + } + if len(ts) > 0 { + b[len(b)-1] = '}' // Replace trailing comma. + } else { + b = append(b, '}') + } + b = append(b, '\'') + return b +} + +//------------------------------------------------------------------------------ + +func arrayScanner(typ reflect.Type) schema.ScannerFunc { + kind := typ.Kind() + + switch kind { + case reflect.Ptr: + if fn := arrayScanner(typ.Elem()); fn != nil { + return schema.PtrScanner(fn) + } + case reflect.Slice, reflect.Array: + // ok: + default: + return nil + } + + elemType := typ.Elem() + + if kind == reflect.Slice { + switch elemType { + case stringType: + return scanStringSliceValue + case intType: + return scanIntSliceValue + case int64Type: + return scanInt64SliceValue + case float64Type: + return scanFloat64SliceValue + } + } + + scanElem := schema.Scanner(elemType) + return func(dest reflect.Value, src interface{}) error { + dest = reflect.Indirect(dest) + if !dest.CanSet() { + return fmt.Errorf("bun: Scan(non-settable %s)", dest.Type()) + } + + kind := dest.Kind() + + if src == nil { + if kind != reflect.Slice || !dest.IsNil() { + dest.Set(reflect.Zero(dest.Type())) + } + return nil + } + + if kind == reflect.Slice { + if dest.IsNil() { + dest.Set(reflect.MakeSlice(dest.Type(), 0, 0)) + } else if dest.Len() > 0 { + dest.Set(dest.Slice(0, 0)) + } + } + + b, err := toBytes(src) + if err != nil { + return err + } + + p := newArrayParser(b) + nextValue := internal.MakeSliceNextElemFunc(dest) + for p.Next() { + elem := p.Elem() + elemValue := nextValue() + if err := scanElem(elemValue, elem); err != nil { + return fmt.Errorf("scanElem failed: %w", err) + } + } + return p.Err() + } +} + +func scanStringSliceValue(dest reflect.Value, src interface{}) error { + dest = reflect.Indirect(dest) + if !dest.CanSet() { + return fmt.Errorf("bun: Scan(non-settable %s)", dest.Type()) + } + + slice, err := decodeStringSlice(src) + if err != nil { + return err + } + + dest.Set(reflect.ValueOf(slice)) + return nil +} + +func decodeStringSlice(src interface{}) ([]string, error) { + if src == nil { + return nil, nil + } + + b, err := toBytes(src) + if err != nil { + return nil, err + } + + slice := make([]string, 0) + + p := newArrayParser(b) + for p.Next() { + elem := p.Elem() + slice = append(slice, string(elem)) + } + if err := p.Err(); err != nil { + return nil, err + } + return slice, nil +} + +func scanIntSliceValue(dest reflect.Value, src interface{}) error { + dest = reflect.Indirect(dest) + if !dest.CanSet() { + return fmt.Errorf("bun: Scan(non-settable %s)", dest.Type()) + } + + slice, err := decodeIntSlice(src) + if err != nil { + return err + } + + dest.Set(reflect.ValueOf(slice)) + return nil +} + +func decodeIntSlice(src interface{}) ([]int, error) { + if src == nil { + return nil, nil + } + + b, err := toBytes(src) + if err != nil { + return nil, err + } + + slice := make([]int, 0) + + p := newArrayParser(b) + for p.Next() { + elem := p.Elem() + + if elem == nil { + slice = append(slice, 0) + continue + } + + n, err := strconv.Atoi(bytesToString(elem)) + if err != nil { + return nil, err + } + + slice = append(slice, n) + } + if err := p.Err(); err != nil { + return nil, err + } + return slice, nil +} + +func scanInt64SliceValue(dest reflect.Value, src interface{}) error { + dest = reflect.Indirect(dest) + if !dest.CanSet() { + return fmt.Errorf("bun: Scan(non-settable %s)", dest.Type()) + } + + slice, err := decodeInt64Slice(src) + if err != nil { + return err + } + + dest.Set(reflect.ValueOf(slice)) + return nil +} + +func decodeInt64Slice(src interface{}) ([]int64, error) { + if src == nil { + return nil, nil + } + + b, err := toBytes(src) + if err != nil { + return nil, err + } + + slice := make([]int64, 0) + + p := newArrayParser(b) + for p.Next() { + elem := p.Elem() + + if elem == nil { + slice = append(slice, 0) + continue + } + + n, err := strconv.ParseInt(bytesToString(elem), 10, 64) + if err != nil { + return nil, err + } + + slice = append(slice, n) + } + if err := p.Err(); err != nil { + return nil, err + } + return slice, nil +} + +func scanFloat64SliceValue(dest reflect.Value, src interface{}) error { + dest = reflect.Indirect(dest) + if !dest.CanSet() { + return fmt.Errorf("bun: Scan(non-settable %s)", dest.Type()) + } + + slice, err := scanFloat64Slice(src) + if err != nil { + return err + } + + dest.Set(reflect.ValueOf(slice)) + return nil +} + +func scanFloat64Slice(src interface{}) ([]float64, error) { + if src == -1 { + return nil, nil + } + + b, err := toBytes(src) + if err != nil { + return nil, err + } + + slice := make([]float64, 0) + + p := newArrayParser(b) + for p.Next() { + elem := p.Elem() + + if elem == nil { + slice = append(slice, 0) + continue + } + + n, err := strconv.ParseFloat(bytesToString(elem), 64) + if err != nil { + return nil, err + } + + slice = append(slice, n) + } + if err := p.Err(); err != nil { + return nil, err + } + return slice, nil +} + +func toBytes(src interface{}) ([]byte, error) { + switch src := src.(type) { + case string: + return stringToBytes(src), nil + case []byte: + return src, nil + default: + return nil, fmt.Errorf("bun: got %T, wanted []byte or string", src) + } +} + +//------------------------------------------------------------------------------ + +func arrayAppendBytes(b []byte, bs []byte) []byte { + if bs == nil { + return dialect.AppendNull(b) + } + + b = append(b, `"\\x`...) + + s := len(b) + b = append(b, make([]byte, hex.EncodedLen(len(bs)))...) + hex.Encode(b[s:], bs) + + b = append(b, '"') + + return b +} + +func arrayAppendString(b []byte, s string) []byte { + b = append(b, '"') + for _, r := range s { + switch r { + case 0: + // ignore + case '\'': + b = append(b, "''"...) + case '"': + b = append(b, '\\', '"') + case '\\': + b = append(b, '\\', '\\') + default: + if r < utf8.RuneSelf { + b = append(b, byte(r)) + break + } + l := len(b) + if cap(b)-l < utf8.UTFMax { + b = append(b, make([]byte, utf8.UTFMax)...) + } + n := utf8.EncodeRune(b[l:l+utf8.UTFMax], r) + b = b[:l+n] + } + } + b = append(b, '"') + return b +} diff --git a/vendor/github.com/uptrace/bun/dialect/pgdialect/array_parser.go b/vendor/github.com/uptrace/bun/dialect/pgdialect/array_parser.go index a8358337e..462f8d91d 100644 --- a/vendor/github.com/uptrace/bun/dialect/pgdialect/array_parser.go +++ b/vendor/github.com/uptrace/bun/dialect/pgdialect/array_parser.go @@ -2,132 +2,92 @@ package pgdialect import ( "bytes" - "encoding/hex" "fmt" "io" ) type arrayParser struct { - *streamParser - err error + p pgparser + + elem []byte + err error } func newArrayParser(b []byte) *arrayParser { - p := &arrayParser{ - streamParser: newStreamParser(b, 1), - } + p := new(arrayParser) + if len(b) < 2 || b[0] != '{' || b[len(b)-1] != '}' { - p.err = fmt.Errorf("bun: can't parse array: %q", b) + p.err = fmt.Errorf("pgdialect: can't parse array: %q", b) + return p } + + p.p.Reset(b[1 : len(b)-1]) return p } -func (p *arrayParser) NextElem() ([]byte, error) { +func (p *arrayParser) Next() bool { if p.err != nil { - return nil, p.err + return false } + p.err = p.readNext() + return p.err == nil +} + +func (p *arrayParser) Err() error { + if p.err != io.EOF { + return p.err + } + return nil +} - c, err := p.readByte() - if err != nil { - return nil, err +func (p *arrayParser) Elem() []byte { + return p.elem +} + +func (p *arrayParser) readNext() error { + ch := p.p.Read() + if ch == 0 { + return io.EOF } - switch c { + switch ch { case '}': - return nil, io.EOF + return io.EOF case '"': - b, err := p.readSubstring() + b, err := p.p.ReadSubstring(ch) if err != nil { - return nil, err - } - - if p.peek() == ',' { - p.skipNext() + return err } - return b, nil - default: - b := p.readSimple() - if bytes.Equal(b, []byte("NULL")) { - b = nil + if p.p.Peek() == ',' { + p.p.Advance() } - if p.peek() == ',' { - p.skipNext() + p.elem = b + return nil + case '[', '(': + rng, err := p.p.ReadRange(ch) + if err != nil { + return err } - return b, nil - } -} - -func (p *arrayParser) readSimple() []byte { - p.unreadByte() - - if i := bytes.IndexByte(p.b[p.i:], ','); i >= 0 { - b := p.b[p.i : p.i+i] - p.i += i - return b - } - - b := p.b[p.i : len(p.b)-1] - p.i = len(p.b) - 1 - return b -} - -func (p *arrayParser) readSubstring() ([]byte, error) { - c, err := p.readByte() - if err != nil { - return nil, err - } - - p.buf = p.buf[:0] - for { - if c == '"' { - break + if p.p.Peek() == ',' { + p.p.Advance() } - next, err := p.readByte() - if err != nil { - return nil, err + p.elem = rng + return nil + default: + lit := p.p.ReadLiteral(ch) + if bytes.Equal(lit, []byte("NULL")) { + lit = nil } - if c == '\\' { - switch next { - case '\\', '"': - p.buf = append(p.buf, next) - - c, err = p.readByte() - if err != nil { - return nil, err - } - default: - p.buf = append(p.buf, '\\') - c = next - } - continue + if p.p.Peek() == ',' { + p.p.Advance() } - if c == '\'' && next == '\'' { - p.buf = append(p.buf, next) - c, err = p.readByte() - if err != nil { - return nil, err - } - continue - } - - p.buf = append(p.buf, c) - c = next - } - if bytes.HasPrefix(p.buf, []byte("\\x")) && len(p.buf)%2 == 0 { - data := p.buf[2:] - buf := make([]byte, hex.DecodedLen(len(data))) - n, err := hex.Decode(buf, data) - if err != nil { - return nil, err - } - return buf[:n], nil + p.elem = lit + return nil } - - return p.buf, nil } diff --git a/vendor/github.com/uptrace/bun/dialect/pgdialect/array_scan.go b/vendor/github.com/uptrace/bun/dialect/pgdialect/array_scan.go index a8ff29715..6b8abda3d 100644 --- a/vendor/github.com/uptrace/bun/dialect/pgdialect/array_scan.go +++ b/vendor/github.com/uptrace/bun/dialect/pgdialect/array_scan.go @@ -1,302 +1 @@ package pgdialect - -import ( - "fmt" - "io" - "reflect" - "strconv" - - "github.com/uptrace/bun/internal" - "github.com/uptrace/bun/schema" -) - -func arrayScanner(typ reflect.Type) schema.ScannerFunc { - kind := typ.Kind() - - switch kind { - case reflect.Ptr: - if fn := arrayScanner(typ.Elem()); fn != nil { - return schema.PtrScanner(fn) - } - case reflect.Slice, reflect.Array: - // ok: - default: - return nil - } - - elemType := typ.Elem() - - if kind == reflect.Slice { - switch elemType { - case stringType: - return scanStringSliceValue - case intType: - return scanIntSliceValue - case int64Type: - return scanInt64SliceValue - case float64Type: - return scanFloat64SliceValue - } - } - - scanElem := schema.Scanner(elemType) - return func(dest reflect.Value, src interface{}) error { - dest = reflect.Indirect(dest) - if !dest.CanSet() { - return fmt.Errorf("bun: Scan(non-settable %s)", dest.Type()) - } - - kind := dest.Kind() - - if src == nil { - if kind != reflect.Slice || !dest.IsNil() { - dest.Set(reflect.Zero(dest.Type())) - } - return nil - } - - if kind == reflect.Slice { - if dest.IsNil() { - dest.Set(reflect.MakeSlice(dest.Type(), 0, 0)) - } else if dest.Len() > 0 { - dest.Set(dest.Slice(0, 0)) - } - } - - b, err := toBytes(src) - if err != nil { - return err - } - - p := newArrayParser(b) - nextValue := internal.MakeSliceNextElemFunc(dest) - for { - elem, err := p.NextElem() - if err != nil { - if err == io.EOF { - break - } - return err - } - - elemValue := nextValue() - if err := scanElem(elemValue, elem); err != nil { - return err - } - } - - return nil - } -} - -func scanStringSliceValue(dest reflect.Value, src interface{}) error { - dest = reflect.Indirect(dest) - if !dest.CanSet() { - return fmt.Errorf("bun: Scan(non-settable %s)", dest.Type()) - } - - slice, err := decodeStringSlice(src) - if err != nil { - return err - } - - dest.Set(reflect.ValueOf(slice)) - return nil -} - -func decodeStringSlice(src interface{}) ([]string, error) { - if src == nil { - return nil, nil - } - - b, err := toBytes(src) - if err != nil { - return nil, err - } - - slice := make([]string, 0) - - p := newArrayParser(b) - for { - elem, err := p.NextElem() - if err != nil { - if err == io.EOF { - break - } - return nil, err - } - slice = append(slice, string(elem)) - } - - return slice, nil -} - -func scanIntSliceValue(dest reflect.Value, src interface{}) error { - dest = reflect.Indirect(dest) - if !dest.CanSet() { - return fmt.Errorf("bun: Scan(non-settable %s)", dest.Type()) - } - - slice, err := decodeIntSlice(src) - if err != nil { - return err - } - - dest.Set(reflect.ValueOf(slice)) - return nil -} - -func decodeIntSlice(src interface{}) ([]int, error) { - if src == nil { - return nil, nil - } - - b, err := toBytes(src) - if err != nil { - return nil, err - } - - slice := make([]int, 0) - - p := newArrayParser(b) - for { - elem, err := p.NextElem() - if err != nil { - if err == io.EOF { - break - } - return nil, err - } - - if elem == nil { - slice = append(slice, 0) - continue - } - - n, err := strconv.Atoi(bytesToString(elem)) - if err != nil { - return nil, err - } - - slice = append(slice, n) - } - - return slice, nil -} - -func scanInt64SliceValue(dest reflect.Value, src interface{}) error { - dest = reflect.Indirect(dest) - if !dest.CanSet() { - return fmt.Errorf("bun: Scan(non-settable %s)", dest.Type()) - } - - slice, err := decodeInt64Slice(src) - if err != nil { - return err - } - - dest.Set(reflect.ValueOf(slice)) - return nil -} - -func decodeInt64Slice(src interface{}) ([]int64, error) { - if src == nil { - return nil, nil - } - - b, err := toBytes(src) - if err != nil { - return nil, err - } - - slice := make([]int64, 0) - - p := newArrayParser(b) - for { - elem, err := p.NextElem() - if err != nil { - if err == io.EOF { - break - } - return nil, err - } - - if elem == nil { - slice = append(slice, 0) - continue - } - - n, err := strconv.ParseInt(bytesToString(elem), 10, 64) - if err != nil { - return nil, err - } - - slice = append(slice, n) - } - - return slice, nil -} - -func scanFloat64SliceValue(dest reflect.Value, src interface{}) error { - dest = reflect.Indirect(dest) - if !dest.CanSet() { - return fmt.Errorf("bun: Scan(non-settable %s)", dest.Type()) - } - - slice, err := scanFloat64Slice(src) - if err != nil { - return err - } - - dest.Set(reflect.ValueOf(slice)) - return nil -} - -func scanFloat64Slice(src interface{}) ([]float64, error) { - if src == -1 { - return nil, nil - } - - b, err := toBytes(src) - if err != nil { - return nil, err - } - - slice := make([]float64, 0) - - p := newArrayParser(b) - for { - elem, err := p.NextElem() - if err != nil { - if err == io.EOF { - break - } - return nil, err - } - - if elem == nil { - slice = append(slice, 0) - continue - } - - n, err := strconv.ParseFloat(bytesToString(elem), 64) - if err != nil { - return nil, err - } - - slice = append(slice, n) - } - - return slice, nil -} - -func toBytes(src interface{}) ([]byte, error) { - switch src := src.(type) { - case string: - return stringToBytes(src), nil - case []byte: - return src, nil - default: - return nil, fmt.Errorf("bun: got %T, wanted []byte or string", src) - } -} diff --git a/vendor/github.com/uptrace/bun/dialect/pgdialect/dialect.go b/vendor/github.com/uptrace/bun/dialect/pgdialect/dialect.go index f100e682c..358971f61 100644 --- a/vendor/github.com/uptrace/bun/dialect/pgdialect/dialect.go +++ b/vendor/github.com/uptrace/bun/dialect/pgdialect/dialect.go @@ -89,9 +89,17 @@ func (d *Dialect) onField(field *schema.Field) { if field.Tag.HasOption("array") || strings.HasSuffix(field.UserSQLType, "[]") { field.Append = d.arrayAppender(field.StructField.Type) field.Scan = arrayScanner(field.StructField.Type) + return } - if field.DiscoveredSQLType == sqltype.HSTORE { + if field.Tag.HasOption("multirange") { + field.Append = d.arrayAppender(field.StructField.Type) + field.Scan = arrayScanner(field.StructField.Type) + return + } + + switch field.DiscoveredSQLType { + case sqltype.HSTORE: field.Append = d.hstoreAppender(field.StructField.Type) field.Scan = hstoreScanner(field.StructField.Type) } diff --git a/vendor/github.com/uptrace/bun/dialect/pgdialect/hstore_parser.go b/vendor/github.com/uptrace/bun/dialect/pgdialect/hstore_parser.go index 7a18b50b1..fec401786 100644 --- a/vendor/github.com/uptrace/bun/dialect/pgdialect/hstore_parser.go +++ b/vendor/github.com/uptrace/bun/dialect/pgdialect/hstore_parser.go @@ -3,140 +3,98 @@ package pgdialect import ( "bytes" "fmt" + "io" ) type hstoreParser struct { - *streamParser - err error + p pgparser + + key string + value string + err error } func newHStoreParser(b []byte) *hstoreParser { - p := &hstoreParser{ - streamParser: newStreamParser(b, 0), - } - if len(b) < 6 || b[0] != '"' { - p.err = fmt.Errorf("bun: can't parse hstore: %q", b) + p := new(hstoreParser) + if len(b) != 0 && (len(b) < 6 || b[0] != '"') { + p.err = fmt.Errorf("pgdialect: can't parse hstore: %q", b) + return p } + p.p.Reset(b) return p } -func (p *hstoreParser) NextKey() (string, error) { +func (p *hstoreParser) Next() bool { if p.err != nil { - return "", p.err + return false } + p.err = p.readNext() + return p.err == nil +} - err := p.skipByte('"') - if err != nil { - return "", err +func (p *hstoreParser) Err() error { + if p.err != io.EOF { + return p.err } + return nil +} - key, err := p.readSubstring() - if err != nil { - return "", err - } +func (p *hstoreParser) Key() string { + return p.key +} - const separator = "=>" +func (p *hstoreParser) Value() string { + return p.value +} - for i := range separator { - err = p.skipByte(separator[i]) - if err != nil { - return "", err - } +func (p *hstoreParser) readNext() error { + if !p.p.Valid() { + return io.EOF } - return string(key), nil -} + if err := p.p.Skip('"'); err != nil { + return err + } -func (p *hstoreParser) NextValue() (string, error) { - if p.err != nil { - return "", p.err + key, err := p.p.ReadUnescapedSubstring('"') + if err != nil { + return err + } + p.key = string(key) + + if err := p.p.SkipPrefix([]byte("=>")); err != nil { + return err } - c, err := p.readByte() + ch, err := p.p.ReadByte() if err != nil { - return "", err + return err } - switch c { + switch ch { case '"': - value, err := p.readSubstring() + value, err := p.p.ReadUnescapedSubstring(ch) if err != nil { - return "", err - } - - if p.peek() == ',' { - p.skipNext() - } - - if p.peek() == ' ' { - p.skipNext() + return err } - - return string(value), nil + p.skipComma() + p.value = string(value) + return nil default: - value := p.readSimple() + value := p.p.ReadLiteral(ch) if bytes.Equal(value, []byte("NULL")) { - value = nil - } - - if p.peek() == ',' { - p.skipNext() + p.value = "" } - - return string(value), nil - } -} - -func (p *hstoreParser) readSimple() []byte { - p.unreadByte() - - if i := bytes.IndexByte(p.b[p.i:], ','); i >= 0 { - b := p.b[p.i : p.i+i] - p.i += i - return b + p.skipComma() + return nil } - - b := p.b[p.i:len(p.b)] - p.i = len(p.b) - return b } -func (p *hstoreParser) readSubstring() ([]byte, error) { - c, err := p.readByte() - if err != nil { - return nil, err +func (p *hstoreParser) skipComma() { + if p.p.Peek() == ',' { + p.p.Advance() } - - p.buf = p.buf[:0] - for { - if c == '"' { - break - } - - next, err := p.readByte() - if err != nil { - return nil, err - } - - if c == '\\' { - switch next { - case '\\', '"': - p.buf = append(p.buf, next) - - c, err = p.readByte() - if err != nil { - return nil, err - } - default: - p.buf = append(p.buf, '\\') - c = next - } - continue - } - - p.buf = append(p.buf, c) - c = next + if p.p.Peek() == ' ' { + p.p.Advance() } - - return p.buf, nil } diff --git a/vendor/github.com/uptrace/bun/dialect/pgdialect/hstore_scan.go b/vendor/github.com/uptrace/bun/dialect/pgdialect/hstore_scan.go index b10b06b8d..62ab89a3a 100644 --- a/vendor/github.com/uptrace/bun/dialect/pgdialect/hstore_scan.go +++ b/vendor/github.com/uptrace/bun/dialect/pgdialect/hstore_scan.go @@ -2,7 +2,6 @@ package pgdialect import ( "fmt" - "io" "reflect" "github.com/uptrace/bun/schema" @@ -58,25 +57,11 @@ func decodeMapStringString(src interface{}) (map[string]string, error) { m := make(map[string]string) p := newHStoreParser(b) - for { - key, err := p.NextKey() - if err != nil { - if err == io.EOF { - break - } - return nil, err - } - - value, err := p.NextValue() - if err != nil { - if err == io.EOF { - break - } - return nil, err - } - - m[key] = value + for p.Next() { + m[p.Key()] = p.Value() + } + if err := p.Err(); err != nil { + return nil, err } - return m, nil } diff --git a/vendor/github.com/uptrace/bun/dialect/pgdialect/range.go b/vendor/github.com/uptrace/bun/dialect/pgdialect/range.go new file mode 100644 index 000000000..b942a068e --- /dev/null +++ b/vendor/github.com/uptrace/bun/dialect/pgdialect/range.go @@ -0,0 +1,240 @@ +package pgdialect + +import ( + "bytes" + "database/sql" + "encoding/hex" + "fmt" + "io" + "time" + + "github.com/uptrace/bun/internal" + "github.com/uptrace/bun/internal/parser" + "github.com/uptrace/bun/schema" +) + +type MultiRange[T any] []Range[T] + +type Range[T any] struct { + Lower, Upper T + LowerBound, UpperBound RangeBound +} + +type RangeBound byte + +const ( + RangeBoundInclusiveLeft RangeBound = '[' + RangeBoundInclusiveRight RangeBound = ']' + RangeBoundExclusiveLeft RangeBound = '(' + RangeBoundExclusiveRight RangeBound = ')' +) + +func NewRange[T any](lower, upper T) Range[T] { + return Range[T]{ + Lower: lower, + Upper: upper, + LowerBound: RangeBoundInclusiveLeft, + UpperBound: RangeBoundExclusiveRight, + } +} + +var _ sql.Scanner = (*Range[any])(nil) + +func (r *Range[T]) Scan(anySrc any) (err error) { + src := anySrc.([]byte) + + if len(src) == 0 { + return io.ErrUnexpectedEOF + } + r.LowerBound = RangeBound(src[0]) + src = src[1:] + + src, err = scanElem(&r.Lower, src) + if err != nil { + return err + } + + if len(src) == 0 { + return io.ErrUnexpectedEOF + } + if ch := src[0]; ch != ',' { + return fmt.Errorf("got %q, wanted %q", ch, ',') + } + src = src[1:] + + src, err = scanElem(&r.Upper, src) + if err != nil { + return err + } + + if len(src) == 0 { + return io.ErrUnexpectedEOF + } + r.UpperBound = RangeBound(src[0]) + src = src[1:] + + if len(src) > 0 { + return fmt.Errorf("unread data: %q", src) + } + return nil +} + +var _ schema.QueryAppender = (*Range[any])(nil) + +func (r *Range[T]) AppendQuery(fmt schema.Formatter, buf []byte) ([]byte, error) { + buf = append(buf, byte(r.LowerBound)) + buf = appendElem(buf, r.Lower) + buf = append(buf, ',') + buf = appendElem(buf, r.Upper) + buf = append(buf, byte(r.UpperBound)) + return buf, nil +} + +func appendElem(buf []byte, val any) []byte { + switch val := val.(type) { + case time.Time: + buf = append(buf, '"') + buf = appendTime(buf, val) + buf = append(buf, '"') + return buf + default: + panic(fmt.Errorf("unsupported range type: %T", val)) + } +} + +func scanElem(ptr any, src []byte) ([]byte, error) { + switch ptr := ptr.(type) { + case *time.Time: + src, str, err := readStringLiteral(src) + if err != nil { + return nil, err + } + + tm, err := internal.ParseTime(internal.String(str)) + if err != nil { + return nil, err + } + *ptr = tm + + return src, nil + default: + panic(fmt.Errorf("unsupported range type: %T", ptr)) + } +} + +func readStringLiteral(src []byte) ([]byte, []byte, error) { + p := newParser(src) + + if err := p.Skip('"'); err != nil { + return nil, nil, err + } + + str, err := p.ReadSubstring('"') + if err != nil { + return nil, nil, err + } + + src = p.Remaining() + return src, str, nil +} + +//------------------------------------------------------------------------------ + +type pgparser struct { + parser.Parser + buf []byte +} + +func newParser(b []byte) *pgparser { + p := new(pgparser) + p.Reset(b) + return p +} + +func (p *pgparser) ReadLiteral(ch byte) []byte { + p.Unread() + lit, _ := p.ReadSep(',') + return lit +} + +func (p *pgparser) ReadUnescapedSubstring(ch byte) ([]byte, error) { + return p.readSubstring(ch, false) +} + +func (p *pgparser) ReadSubstring(ch byte) ([]byte, error) { + return p.readSubstring(ch, true) +} + +func (p *pgparser) readSubstring(ch byte, escaped bool) ([]byte, error) { + ch, err := p.ReadByte() + if err != nil { + return nil, err + } + + p.buf = p.buf[:0] + for { + if ch == '"' { + break + } + + next, err := p.ReadByte() + if err != nil { + return nil, err + } + + if ch == '\\' { + switch next { + case '\\', '"': + p.buf = append(p.buf, next) + + ch, err = p.ReadByte() + if err != nil { + return nil, err + } + default: + p.buf = append(p.buf, '\\') + ch = next + } + continue + } + + if escaped && ch == '\'' && next == '\'' { + p.buf = append(p.buf, next) + ch, err = p.ReadByte() + if err != nil { + return nil, err + } + continue + } + + p.buf = append(p.buf, ch) + ch = next + } + + if bytes.HasPrefix(p.buf, []byte("\\x")) && len(p.buf)%2 == 0 { + data := p.buf[2:] + buf := make([]byte, hex.DecodedLen(len(data))) + n, err := hex.Decode(buf, data) + if err != nil { + return nil, err + } + return buf[:n], nil + } + + return p.buf, nil +} + +func (p *pgparser) ReadRange(ch byte) ([]byte, error) { + p.buf = p.buf[:0] + p.buf = append(p.buf, ch) + + for p.Valid() { + ch = p.Read() + p.buf = append(p.buf, ch) + if ch == ']' || ch == ')' { + break + } + } + + return p.buf, nil +} diff --git a/vendor/github.com/uptrace/bun/dialect/pgdialect/sqltype.go b/vendor/github.com/uptrace/bun/dialect/pgdialect/sqltype.go index dadea5c1c..fad84209d 100644 --- a/vendor/github.com/uptrace/bun/dialect/pgdialect/sqltype.go +++ b/vendor/github.com/uptrace/bun/dialect/pgdialect/sqltype.go @@ -1,6 +1,7 @@ package pgdialect import ( + "database/sql" "encoding/json" "net" "reflect" @@ -27,14 +28,6 @@ const ( pgTypeSerial = "SERIAL" // 4 byte autoincrementing integer pgTypeBigSerial = "BIGSERIAL" // 8 byte autoincrementing integer - // Character Types - pgTypeChar = "CHAR" // fixed length string (blank padded) - pgTypeText = "TEXT" // variable length string without limit - - // JSON Types - pgTypeJSON = "JSON" // text representation of json data - pgTypeJSONB = "JSONB" // binary representation of json data - // Binary Data Types pgTypeBytea = "BYTEA" // binary string ) @@ -43,6 +36,7 @@ var ( ipType = reflect.TypeOf((*net.IP)(nil)).Elem() ipNetType = reflect.TypeOf((*net.IPNet)(nil)).Elem() jsonRawMessageType = reflect.TypeOf((*json.RawMessage)(nil)).Elem() + nullStringType = reflect.TypeOf((*sql.NullString)(nil)).Elem() ) func (d *Dialect) DefaultVarcharLen() int { @@ -78,12 +72,14 @@ func fieldSQLType(field *schema.Field) string { func sqlType(typ reflect.Type) string { switch typ { + case nullStringType: // typ.Kind() == reflect.Struct, test for exact match + return sqltype.VarChar case ipType: return pgTypeInet case ipNetType: return pgTypeCidr case jsonRawMessageType: - return pgTypeJSONB + return sqltype.JSONB } sqlType := schema.DiscoverSQLType(typ) @@ -93,16 +89,16 @@ func sqlType(typ reflect.Type) string { } switch typ.Kind() { - case reflect.Map, reflect.Struct: + case reflect.Map, reflect.Struct: // except typ == nullStringType, see above if sqlType == sqltype.VarChar { - return pgTypeJSONB + return sqltype.JSONB } return sqlType case reflect.Array, reflect.Slice: if typ.Elem().Kind() == reflect.Uint8 { return pgTypeBytea } - return pgTypeJSONB + return sqltype.JSONB } return sqlType diff --git a/vendor/github.com/uptrace/bun/dialect/pgdialect/stream_parser.go b/vendor/github.com/uptrace/bun/dialect/pgdialect/stream_parser.go deleted file mode 100644 index 7b9a15f62..000000000 --- a/vendor/github.com/uptrace/bun/dialect/pgdialect/stream_parser.go +++ /dev/null @@ -1,60 +0,0 @@ -package pgdialect - -import ( - "fmt" - "io" -) - -type streamParser struct { - b []byte - i int - - buf []byte -} - -func newStreamParser(b []byte, start int) *streamParser { - return &streamParser{ - b: b, - i: start, - } -} - -func (p *streamParser) valid() bool { - return p.i < len(p.b) -} - -func (p *streamParser) skipByte(skip byte) error { - c, err := p.readByte() - if err != nil { - return err - } - if c == skip { - return nil - } - p.unreadByte() - return fmt.Errorf("got %q, wanted %q", c, skip) -} - -func (p *streamParser) readByte() (byte, error) { - if p.valid() { - c := p.b[p.i] - p.i++ - return c, nil - } - return 0, io.EOF -} - -func (p *streamParser) unreadByte() { - p.i-- -} - -func (p *streamParser) peek() byte { - if p.valid() { - return p.b[p.i] - } - return 0 -} - -func (p *streamParser) skipNext() { - p.i++ -} diff --git a/vendor/github.com/uptrace/bun/dialect/pgdialect/version.go b/vendor/github.com/uptrace/bun/dialect/pgdialect/version.go index b5e5e3cb0..c06043647 100644 --- a/vendor/github.com/uptrace/bun/dialect/pgdialect/version.go +++ b/vendor/github.com/uptrace/bun/dialect/pgdialect/version.go @@ -2,5 +2,5 @@ package pgdialect // Version is the current release version. func Version() string { - return "1.2.1" + return "1.2.5" } diff --git a/vendor/github.com/uptrace/bun/dialect/sqlitedialect/version.go b/vendor/github.com/uptrace/bun/dialect/sqlitedialect/version.go index 06e5bb1a4..e3cceaa77 100644 --- a/vendor/github.com/uptrace/bun/dialect/sqlitedialect/version.go +++ b/vendor/github.com/uptrace/bun/dialect/sqlitedialect/version.go @@ -2,5 +2,5 @@ package sqlitedialect // Version is the current release version. func Version() string { - return "1.2.1" + return "1.2.5" } diff --git a/vendor/github.com/uptrace/bun/extra/bunotel/README.md b/vendor/github.com/uptrace/bun/extra/bunotel/README.md index 50b3e6c48..1773ecf02 100644 --- a/vendor/github.com/uptrace/bun/extra/bunotel/README.md +++ b/vendor/github.com/uptrace/bun/extra/bunotel/README.md @@ -1,3 +1,3 @@ # OpenTelemetry instrumentation for Bun -See [example](../example/opentelemetry) for details. +See [example](../../example/opentelemetry) for details. diff --git a/vendor/github.com/uptrace/bun/extra/bunotel/unsafe.go b/vendor/github.com/uptrace/bun/extra/bunotel/unsafe.go index 23accd40e..67b687cbe 100644 --- a/vendor/github.com/uptrace/bun/extra/bunotel/unsafe.go +++ b/vendor/github.com/uptrace/bun/extra/bunotel/unsafe.go @@ -1,3 +1,4 @@ +//go:build !appengine // +build !appengine package bunotel @@ -5,14 +6,15 @@ package bunotel import "unsafe" func bytesToString(b []byte) string { - return *(*string)(unsafe.Pointer(&b)) + if len(b) == 0 { + return "" + } + return unsafe.String(&b[0], len(b)) } func stringToBytes(s string) []byte { - return *(*[]byte)(unsafe.Pointer( - &struct { - string - Cap int - }{s, len(s)}, - )) + if s == "" { + return []byte{} + } + return unsafe.Slice(unsafe.StringData(s), len(s)) } diff --git a/vendor/github.com/uptrace/bun/internal/parser/parser.go b/vendor/github.com/uptrace/bun/internal/parser/parser.go index cdfc0be16..1f2704478 100644 --- a/vendor/github.com/uptrace/bun/internal/parser/parser.go +++ b/vendor/github.com/uptrace/bun/internal/parser/parser.go @@ -2,6 +2,8 @@ package parser import ( "bytes" + "fmt" + "io" "strconv" "github.com/uptrace/bun/internal" @@ -22,23 +24,43 @@ func NewString(s string) *Parser { return New(internal.Bytes(s)) } +func (p *Parser) Reset(b []byte) { + p.b = b + p.i = 0 +} + func (p *Parser) Valid() bool { return p.i < len(p.b) } -func (p *Parser) Bytes() []byte { +func (p *Parser) Remaining() []byte { return p.b[p.i:] } +func (p *Parser) ReadByte() (byte, error) { + if p.Valid() { + ch := p.b[p.i] + p.Advance() + return ch, nil + } + return 0, io.ErrUnexpectedEOF +} + func (p *Parser) Read() byte { if p.Valid() { - c := p.b[p.i] + ch := p.b[p.i] p.Advance() - return c + return ch } return 0 } +func (p *Parser) Unread() { + if p.i > 0 { + p.i-- + } +} + func (p *Parser) Peek() byte { if p.Valid() { return p.b[p.i] @@ -50,19 +72,25 @@ func (p *Parser) Advance() { p.i++ } -func (p *Parser) Skip(skip byte) bool { - if p.Peek() == skip { +func (p *Parser) Skip(skip byte) error { + ch := p.Peek() + if ch == skip { p.Advance() - return true + return nil } - return false + return fmt.Errorf("got %q, wanted %q", ch, skip) } -func (p *Parser) SkipBytes(skip []byte) bool { - if len(skip) > len(p.b[p.i:]) { - return false +func (p *Parser) SkipPrefix(skip []byte) error { + if !bytes.HasPrefix(p.b[p.i:], skip) { + return fmt.Errorf("got %q, wanted prefix %q", p.b, skip) } - if !bytes.Equal(p.b[p.i:p.i+len(skip)], skip) { + p.i += len(skip) + return nil +} + +func (p *Parser) CutPrefix(skip []byte) bool { + if !bytes.HasPrefix(p.b[p.i:], skip) { return false } p.i += len(skip) diff --git a/vendor/github.com/uptrace/bun/internal/unsafe.go b/vendor/github.com/uptrace/bun/internal/unsafe.go index 4bc79701f..1a0331297 100644 --- a/vendor/github.com/uptrace/bun/internal/unsafe.go +++ b/vendor/github.com/uptrace/bun/internal/unsafe.go @@ -1,3 +1,4 @@ +//go:build !appengine // +build !appengine package internal @@ -6,15 +7,16 @@ import "unsafe" // String converts byte slice to string. func String(b []byte) string { - return *(*string)(unsafe.Pointer(&b)) + if len(b) == 0 { + return "" + } + return unsafe.String(&b[0], len(b)) } // Bytes converts string to byte slice. func Bytes(s string) []byte { - return *(*[]byte)(unsafe.Pointer( - &struct { - string - Cap int - }{s, len(s)}, - )) + if s == "" { + return []byte{} + } + return unsafe.Slice(unsafe.StringData(s), len(s)) } diff --git a/vendor/github.com/uptrace/bun/migrate/migrations.go b/vendor/github.com/uptrace/bun/migrate/migrations.go index 289735270..1a7ea5668 100644 --- a/vendor/github.com/uptrace/bun/migrate/migrations.go +++ b/vendor/github.com/uptrace/bun/migrate/migrations.go @@ -96,10 +96,6 @@ func (m *Migrations) Discover(fsys fs.FS) error { } migration := m.getOrCreateMigration(name) - if err != nil { - return err - } - migration.Comment = comment migrationFunc := NewSQLMigrationFunc(fsys, path) diff --git a/vendor/github.com/uptrace/bun/migrate/migrator.go b/vendor/github.com/uptrace/bun/migrate/migrator.go index 33c5bd16f..e6d70e39f 100644 --- a/vendor/github.com/uptrace/bun/migrate/migrator.go +++ b/vendor/github.com/uptrace/bun/migrate/migrator.go @@ -362,7 +362,10 @@ func (m *Migrator) MarkUnapplied(ctx context.Context, migration *Migration) erro } func (m *Migrator) TruncateTable(ctx context.Context) error { - _, err := m.db.NewTruncateTable().TableExpr(m.table).Exec(ctx) + _, err := m.db.NewTruncateTable(). + Model((*Migration)(nil)). + ModelTableExpr(m.table). + Exec(ctx) return err } diff --git a/vendor/github.com/uptrace/bun/model_map.go b/vendor/github.com/uptrace/bun/model_map.go index 814d636e6..d7342576f 100644 --- a/vendor/github.com/uptrace/bun/model_map.go +++ b/vendor/github.com/uptrace/bun/model_map.go @@ -1,6 +1,7 @@ package bun import ( + "bytes" "context" "database/sql" "reflect" @@ -82,6 +83,8 @@ func (m *mapModel) Scan(src interface{}) error { return m.scanRaw(src) case reflect.Slice: if scanType.Elem().Kind() == reflect.Uint8 { + // Reference types such as []byte are only valid until the next call to Scan. + src := bytes.Clone(src.([]byte)) return m.scanRaw(src) } } diff --git a/vendor/github.com/uptrace/bun/model_table_has_many.go b/vendor/github.com/uptrace/bun/model_table_has_many.go index 3d8a5da6f..544cdf5d6 100644 --- a/vendor/github.com/uptrace/bun/model_table_has_many.go +++ b/vendor/github.com/uptrace/bun/model_table_has_many.go @@ -24,7 +24,7 @@ var _ TableModel = (*hasManyModel)(nil) func newHasManyModel(j *relationJoin) *hasManyModel { baseTable := j.BaseModel.Table() joinModel := j.JoinModel.(*sliceTableModel) - baseValues := baseValues(joinModel, j.Relation.BaseFields) + baseValues := baseValues(joinModel, j.Relation.BasePKs) if len(baseValues) == 0 { return nil } @@ -92,9 +92,9 @@ func (m *hasManyModel) Scan(src interface{}) error { return err } - for _, f := range m.rel.JoinFields { + for _, f := range m.rel.JoinPKs { if f.Name == field.Name { - m.structKey = append(m.structKey, field.Value(m.strct).Interface()) + m.structKey = append(m.structKey, indirectFieldValue(field.Value(m.strct))) break } } @@ -103,6 +103,7 @@ func (m *hasManyModel) Scan(src interface{}) error { } func (m *hasManyModel) parkStruct() error { + baseValues, ok := m.baseValues[internal.NewMapKey(m.structKey)] if !ok { return fmt.Errorf( @@ -143,7 +144,19 @@ func baseValues(model TableModel, fields []*schema.Field) map[internal.MapKey][] func modelKey(key []interface{}, strct reflect.Value, fields []*schema.Field) []interface{} { for _, f := range fields { - key = append(key, f.Value(strct).Interface()) + key = append(key, indirectFieldValue(f.Value(strct))) } return key } + +// indirectFieldValue return the field value dereferencing the pointer if necessary. +// The value is then used as a map key. +func indirectFieldValue(field reflect.Value) interface{} { + if field.Kind() != reflect.Ptr { + return field.Interface() + } + if field.IsNil() { + return nil + } + return field.Elem().Interface() +} diff --git a/vendor/github.com/uptrace/bun/model_table_m2m.go b/vendor/github.com/uptrace/bun/model_table_m2m.go index 88d8a1268..14d385e62 100644 --- a/vendor/github.com/uptrace/bun/model_table_m2m.go +++ b/vendor/github.com/uptrace/bun/model_table_m2m.go @@ -24,7 +24,7 @@ var _ TableModel = (*m2mModel)(nil) func newM2MModel(j *relationJoin) *m2mModel { baseTable := j.BaseModel.Table() joinModel := j.JoinModel.(*sliceTableModel) - baseValues := baseValues(joinModel, baseTable.PKs) + baseValues := baseValues(joinModel, j.Relation.BasePKs) if len(baseValues) == 0 { return nil } @@ -83,27 +83,21 @@ func (m *m2mModel) Scan(src interface{}) error { column := m.columns[m.scanIndex] m.scanIndex++ - field, ok := m.table.FieldMap[column] - if !ok { + // Base pks must come first. + if m.scanIndex <= len(m.rel.M2MBasePKs) { return m.scanM2MColumn(column, src) } - if err := field.ScanValue(m.strct, src); err != nil { - return err - } - - for _, fk := range m.rel.M2MBaseFields { - if fk.Name == field.Name { - m.structKey = append(m.structKey, field.Value(m.strct).Interface()) - break - } + if field, ok := m.table.FieldMap[column]; ok { + return field.ScanValue(m.strct, src) } - return nil + _, err := m.scanColumn(column, src) + return err } func (m *m2mModel) scanM2MColumn(column string, src interface{}) error { - for _, field := range m.rel.M2MBaseFields { + for _, field := range m.rel.M2MBasePKs { if field.Name == column { dest := reflect.New(field.IndirectType).Elem() if err := field.Scan(dest, src); err != nil { diff --git a/vendor/github.com/uptrace/bun/model_table_struct.go b/vendor/github.com/uptrace/bun/model_table_struct.go index a5c9a7bc3..a8860908e 100644 --- a/vendor/github.com/uptrace/bun/model_table_struct.go +++ b/vendor/github.com/uptrace/bun/model_table_struct.go @@ -242,7 +242,7 @@ func (m *structTableModel) ScanRows(ctx context.Context, rows *sql.Rows) (int, e n++ // And discard the rest. This is especially important for SQLite3, which can return - // a row like it was inserted sucessfully and then return an actual error for the next row. + // a row like it was inserted successfully and then return an actual error for the next row. // See issues/100. for rows.Next() { n++ diff --git a/vendor/github.com/uptrace/bun/package.json b/vendor/github.com/uptrace/bun/package.json index 331e4be8b..6a8d7082e 100644 --- a/vendor/github.com/uptrace/bun/package.json +++ b/vendor/github.com/uptrace/bun/package.json @@ -1,6 +1,6 @@ { "name": "gobun", - "version": "1.2.1", + "version": "1.2.5", "main": "index.js", "repository": "git@github.com:uptrace/bun.git", "author": "Vladimir Mihailenco <vladimir.webdev@gmail.com>", diff --git a/vendor/github.com/uptrace/bun/query_base.go b/vendor/github.com/uptrace/bun/query_base.go index 2321a7537..8a26a4c8a 100644 --- a/vendor/github.com/uptrace/bun/query_base.go +++ b/vendor/github.com/uptrace/bun/query_base.go @@ -8,6 +8,7 @@ import ( "fmt" "time" + "github.com/uptrace/bun/dialect" "github.com/uptrace/bun/dialect/feature" "github.com/uptrace/bun/internal" "github.com/uptrace/bun/schema" @@ -418,7 +419,11 @@ func (q *baseQuery) _appendTables( } else { b = fmter.AppendQuery(b, string(q.table.SQLNameForSelects)) if withAlias && q.table.SQLAlias != q.table.SQLNameForSelects { - b = append(b, " AS "...) + if q.db.dialect.Name() == dialect.Oracle { + b = append(b, ' ') + } else { + b = append(b, " AS "...) + } b = append(b, q.table.SQLAlias...) } } diff --git a/vendor/github.com/uptrace/bun/query_select.go b/vendor/github.com/uptrace/bun/query_select.go index c0e145110..5bb329143 100644 --- a/vendor/github.com/uptrace/bun/query_select.go +++ b/vendor/github.com/uptrace/bun/query_select.go @@ -538,6 +538,11 @@ func (q *SelectQuery) appendQuery( if count && !cteCount { b = append(b, "count(*)"...) } else { + // MSSQL: allows Limit() without Order() as per https://stackoverflow.com/a/36156953 + if q.limit > 0 && len(q.order) == 0 && fmter.Dialect().Name() == dialect.MSSQL { + b = append(b, "0 AS _temp_sort, "...) + } + b, err = q.appendColumns(fmter, b) if err != nil { return nil, err @@ -564,8 +569,8 @@ func (q *SelectQuery) appendQuery( return nil, err } - for _, j := range q.joins { - b, err = j.AppendQuery(fmter, b) + for _, join := range q.joins { + b, err = join.AppendQuery(fmter, b) if err != nil { return nil, err } @@ -793,6 +798,12 @@ func (q *SelectQuery) appendOrder(fmter schema.Formatter, b []byte) (_ []byte, e return b, nil } + + // MSSQL: allows Limit() without Order() as per https://stackoverflow.com/a/36156953 + if q.limit > 0 && fmter.Dialect().Name() == dialect.MSSQL { + return append(b, " ORDER BY _temp_sort"...), nil + } + return b, nil } @@ -856,52 +867,57 @@ func (q *SelectQuery) Exec(ctx context.Context, dest ...interface{}) (res sql.Re } func (q *SelectQuery) Scan(ctx context.Context, dest ...interface{}) error { + _, err := q.scanResult(ctx, dest...) + return err +} + +func (q *SelectQuery) scanResult(ctx context.Context, dest ...interface{}) (sql.Result, error) { if q.err != nil { - return q.err + return nil, q.err } model, err := q.getModel(dest) if err != nil { - return err + return nil, err } if q.table != nil { if err := q.beforeSelectHook(ctx); err != nil { - return err + return nil, err } } if err := q.beforeAppendModel(ctx, q); err != nil { - return err + return nil, err } queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes()) if err != nil { - return err + return nil, err } query := internal.String(queryBytes) res, err := q.scan(ctx, q, query, model, true) if err != nil { - return err + return nil, err } if n, _ := res.RowsAffected(); n > 0 { if tableModel, ok := model.(TableModel); ok { if err := q.selectJoins(ctx, tableModel.getJoins()); err != nil { - return err + return nil, err } } } if q.table != nil { if err := q.afterSelectHook(ctx); err != nil { - return err + return nil, err } } - return nil + return res, nil } func (q *SelectQuery) beforeSelectHook(ctx context.Context) error { @@ -946,6 +962,16 @@ func (q *SelectQuery) Count(ctx context.Context) (int, error) { } func (q *SelectQuery) ScanAndCount(ctx context.Context, dest ...interface{}) (int, error) { + if q.offset == 0 && q.limit == 0 { + // If there is no limit and offset, we can use a single query to get the count and scan + if res, err := q.scanResult(ctx, dest...); err != nil { + return 0, err + } else if n, err := res.RowsAffected(); err != nil { + return 0, err + } else { + return int(n), nil + } + } if _, ok := q.conn.(*DB); ok { return q.scanAndCountConc(ctx, dest...) } diff --git a/vendor/github.com/uptrace/bun/query_table_create.go b/vendor/github.com/uptrace/bun/query_table_create.go index 3d98da07b..aeb79cd37 100644 --- a/vendor/github.com/uptrace/bun/query_table_create.go +++ b/vendor/github.com/uptrace/bun/query_table_create.go @@ -9,6 +9,7 @@ import ( "strconv" "strings" + "github.com/uptrace/bun/dialect" "github.com/uptrace/bun/dialect/feature" "github.com/uptrace/bun/dialect/sqltype" "github.com/uptrace/bun/internal" @@ -165,7 +166,7 @@ func (q *CreateTableQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []by b = append(b, field.SQLName...) b = append(b, " "...) b = q.appendSQLType(b, field) - if field.NotNull { + if field.NotNull && q.db.dialect.Name() != dialect.Oracle { b = append(b, " NOT NULL"...) } @@ -246,7 +247,11 @@ func (q *CreateTableQuery) appendSQLType(b []byte, field *schema.Field) []byte { return append(b, field.CreateTableSQLType...) } - b = append(b, sqltype.VarChar...) + if q.db.dialect.Name() == dialect.Oracle { + b = append(b, "VARCHAR2"...) + } else { + b = append(b, sqltype.VarChar...) + } b = append(b, "("...) b = strconv.AppendInt(b, int64(q.varchar), 10) b = append(b, ")"...) @@ -297,9 +302,9 @@ func (q *CreateTableQuery) appendFKConstraintsRel(fmter schema.Formatter, b []by b, err = q.appendFK(fmter, b, schema.QueryWithArgs{ Query: "(?) REFERENCES ? (?) ? ?", Args: []interface{}{ - Safe(appendColumns(nil, "", rel.BaseFields)), + Safe(appendColumns(nil, "", rel.BasePKs)), rel.JoinTable.SQLName, - Safe(appendColumns(nil, "", rel.JoinFields)), + Safe(appendColumns(nil, "", rel.JoinPKs)), Safe(rel.OnUpdate), Safe(rel.OnDelete), }, diff --git a/vendor/github.com/uptrace/bun/query_table_truncate.go b/vendor/github.com/uptrace/bun/query_table_truncate.go index a704b7b10..9ac5599d9 100644 --- a/vendor/github.com/uptrace/bun/query_table_truncate.go +++ b/vendor/github.com/uptrace/bun/query_table_truncate.go @@ -57,6 +57,11 @@ func (q *TruncateTableQuery) TableExpr(query string, args ...interface{}) *Trunc return q } +func (q *TruncateTableQuery) ModelTableExpr(query string, args ...interface{}) *TruncateTableQuery { + q.modelTableName = schema.SafeQuery(query, args) + return q +} + //------------------------------------------------------------------------------ func (q *TruncateTableQuery) ContinueIdentity() *TruncateTableQuery { diff --git a/vendor/github.com/uptrace/bun/relation_join.go b/vendor/github.com/uptrace/bun/relation_join.go index ba542666d..487f776ed 100644 --- a/vendor/github.com/uptrace/bun/relation_join.go +++ b/vendor/github.com/uptrace/bun/relation_join.go @@ -70,11 +70,11 @@ func (j *relationJoin) manyQuery(q *SelectQuery) *SelectQuery { } func (j *relationJoin) manyQueryCompositeIn(where []byte, q *SelectQuery) *SelectQuery { - if len(j.Relation.JoinFields) > 1 { + if len(j.Relation.JoinPKs) > 1 { where = append(where, '(') } - where = appendColumns(where, j.JoinModel.Table().SQLAlias, j.Relation.JoinFields) - if len(j.Relation.JoinFields) > 1 { + where = appendColumns(where, j.JoinModel.Table().SQLAlias, j.Relation.JoinPKs) + if len(j.Relation.JoinPKs) > 1 { where = append(where, ')') } where = append(where, " IN ("...) @@ -83,7 +83,7 @@ func (j *relationJoin) manyQueryCompositeIn(where []byte, q *SelectQuery) *Selec where, j.JoinModel.rootValue(), j.JoinModel.parentIndex(), - j.Relation.BaseFields, + j.Relation.BasePKs, ) where = append(where, ")"...) q = q.Where(internal.String(where)) @@ -104,8 +104,8 @@ func (j *relationJoin) manyQueryMulti(where []byte, q *SelectQuery) *SelectQuery where, j.JoinModel.rootValue(), j.JoinModel.parentIndex(), - j.Relation.BaseFields, - j.Relation.JoinFields, + j.Relation.BasePKs, + j.Relation.JoinPKs, j.JoinModel.Table().SQLAlias, ) @@ -175,10 +175,10 @@ func (j *relationJoin) m2mQuery(q *SelectQuery) *SelectQuery { q = q.Model(m2mModel) index := j.JoinModel.parentIndex() - baseTable := j.BaseModel.Table() if j.Relation.M2MTable != nil { - fields := append(j.Relation.M2MBaseFields, j.Relation.M2MJoinFields...) + // We only need base pks to park joined models to the base model. + fields := j.Relation.M2MBasePKs b := make([]byte, 0, len(fields)) b = appendColumns(b, j.Relation.M2MTable.SQLAlias, fields) @@ -193,7 +193,7 @@ func (j *relationJoin) m2mQuery(q *SelectQuery) *SelectQuery { join = append(join, " AS "...) join = append(join, j.Relation.M2MTable.SQLAlias...) join = append(join, " ON ("...) - for i, col := range j.Relation.M2MBaseFields { + for i, col := range j.Relation.M2MBasePKs { if i > 0 { join = append(join, ", "...) } @@ -202,13 +202,13 @@ func (j *relationJoin) m2mQuery(q *SelectQuery) *SelectQuery { join = append(join, col.SQLName...) } join = append(join, ") IN ("...) - join = appendChildValues(fmter, join, j.BaseModel.rootValue(), index, baseTable.PKs) + join = appendChildValues(fmter, join, j.BaseModel.rootValue(), index, j.Relation.BasePKs) join = append(join, ")"...) q = q.Join(internal.String(join)) joinTable := j.JoinModel.Table() - for i, m2mJoinField := range j.Relation.M2MJoinFields { - joinField := j.Relation.JoinFields[i] + for i, m2mJoinField := range j.Relation.M2MJoinPKs { + joinField := j.Relation.JoinPKs[i] q = q.Where("?.? = ?.?", joinTable.SQLAlias, joinField.SQLName, j.Relation.M2MTable.SQLAlias, m2mJoinField.SQLName) @@ -310,13 +310,13 @@ func (j *relationJoin) appendHasOneJoin( b = append(b, " ON "...) b = append(b, '(') - for i, baseField := range j.Relation.BaseFields { + for i, baseField := range j.Relation.BasePKs { if i > 0 { b = append(b, " AND "...) } b = j.appendAlias(fmter, b) b = append(b, '.') - b = append(b, j.Relation.JoinFields[i].SQLName...) + b = append(b, j.Relation.JoinPKs[i].SQLName...) b = append(b, " = "...) b = j.appendBaseAlias(fmter, b) b = append(b, '.') @@ -367,13 +367,13 @@ func appendChildValues( } // appendMultiValues is an alternative to appendChildValues that doesn't use the sql keyword ID -// but instead use a old style ((k1=v1) AND (k2=v2)) OR (...) of conditions. +// but instead uses old style ((k1=v1) AND (k2=v2)) OR (...) conditions. func appendMultiValues( fmter schema.Formatter, b []byte, v reflect.Value, index []int, baseFields, joinFields []*schema.Field, joinTable schema.Safe, ) []byte { // This is based on a mix of appendChildValues and query_base.appendColumns - // These should never missmatch in length but nice to know if it does + // These should never mismatch in length but nice to know if it does if len(joinFields) != len(baseFields) { panic("not reached") } diff --git a/vendor/github.com/uptrace/bun/schema/append_value.go b/vendor/github.com/uptrace/bun/schema/append_value.go index 9f0782e0f..a67b41e38 100644 --- a/vendor/github.com/uptrace/bun/schema/append_value.go +++ b/vendor/github.com/uptrace/bun/schema/append_value.go @@ -7,9 +7,9 @@ import ( "reflect" "strconv" "strings" - "sync" "time" + "github.com/puzpuzpuz/xsync/v3" "github.com/uptrace/bun/dialect" "github.com/uptrace/bun/dialect/sqltype" "github.com/uptrace/bun/extra/bunjson" @@ -51,7 +51,7 @@ var appenders = []AppenderFunc{ reflect.UnsafePointer: nil, } -var appenderMap sync.Map +var appenderCache = xsync.NewMapOf[reflect.Type, AppenderFunc]() func FieldAppender(dialect Dialect, field *Field) AppenderFunc { if field.Tag.HasOption("msgpack") { @@ -67,7 +67,7 @@ func FieldAppender(dialect Dialect, field *Field) AppenderFunc { } if fieldType.Kind() != reflect.Ptr { - if reflect.PtrTo(fieldType).Implements(driverValuerType) { + if reflect.PointerTo(fieldType).Implements(driverValuerType) { return addrAppender(appendDriverValue) } } @@ -79,14 +79,14 @@ func FieldAppender(dialect Dialect, field *Field) AppenderFunc { } func Appender(dialect Dialect, typ reflect.Type) AppenderFunc { - if v, ok := appenderMap.Load(typ); ok { - return v.(AppenderFunc) + if v, ok := appenderCache.Load(typ); ok { + return v } fn := appender(dialect, typ) - if v, ok := appenderMap.LoadOrStore(typ, fn); ok { - return v.(AppenderFunc) + if v, ok := appenderCache.LoadOrStore(typ, fn); ok { + return v } return fn } @@ -99,10 +99,10 @@ func appender(dialect Dialect, typ reflect.Type) AppenderFunc { return appendTimeValue case timePtrType: return PtrAppender(appendTimeValue) - case ipType: - return appendIPValue case ipNetType: return appendIPNetValue + case ipType, netipPrefixType, netipAddrType: + return appendStringer case jsonRawMessageType: return appendJSONRawMessageValue } @@ -123,7 +123,7 @@ func appender(dialect Dialect, typ reflect.Type) AppenderFunc { } if kind != reflect.Ptr { - ptr := reflect.PtrTo(typ) + ptr := reflect.PointerTo(typ) if ptr.Implements(queryAppenderType) { return addrAppender(appendQueryAppenderValue) } @@ -247,16 +247,15 @@ func appendTimeValue(fmter Formatter, b []byte, v reflect.Value) []byte { return fmter.Dialect().AppendTime(b, tm) } -func appendIPValue(fmter Formatter, b []byte, v reflect.Value) []byte { - ip := v.Interface().(net.IP) - return fmter.Dialect().AppendString(b, ip.String()) -} - func appendIPNetValue(fmter Formatter, b []byte, v reflect.Value) []byte { ipnet := v.Interface().(net.IPNet) return fmter.Dialect().AppendString(b, ipnet.String()) } +func appendStringer(fmter Formatter, b []byte, v reflect.Value) []byte { + return fmter.Dialect().AppendString(b, v.Interface().(fmt.Stringer).String()) +} + func appendJSONRawMessageValue(fmter Formatter, b []byte, v reflect.Value) []byte { bytes := v.Bytes() if bytes == nil { diff --git a/vendor/github.com/uptrace/bun/schema/dialect.go b/vendor/github.com/uptrace/bun/schema/dialect.go index 8814313f7..330293444 100644 --- a/vendor/github.com/uptrace/bun/schema/dialect.go +++ b/vendor/github.com/uptrace/bun/schema/dialect.go @@ -118,7 +118,7 @@ func (BaseDialect) AppendJSON(b, jsonb []byte) []byte { case '\000': continue case '\\': - if p.SkipBytes([]byte("u0000")) { + if p.CutPrefix([]byte("u0000")) { b = append(b, `\\u0000`...) } else { b = append(b, '\\') diff --git a/vendor/github.com/uptrace/bun/schema/reflect.go b/vendor/github.com/uptrace/bun/schema/reflect.go index 89be8eeb6..75980b102 100644 --- a/vendor/github.com/uptrace/bun/schema/reflect.go +++ b/vendor/github.com/uptrace/bun/schema/reflect.go @@ -4,6 +4,7 @@ import ( "database/sql/driver" "encoding/json" "net" + "net/netip" "reflect" "time" ) @@ -14,6 +15,8 @@ var ( timeType = timePtrType.Elem() ipType = reflect.TypeOf((*net.IP)(nil)).Elem() ipNetType = reflect.TypeOf((*net.IPNet)(nil)).Elem() + netipPrefixType = reflect.TypeOf((*netip.Prefix)(nil)).Elem() + netipAddrType = reflect.TypeOf((*netip.Addr)(nil)).Elem() jsonRawMessageType = reflect.TypeOf((*json.RawMessage)(nil)).Elem() driverValuerType = reflect.TypeOf((*driver.Valuer)(nil)).Elem() diff --git a/vendor/github.com/uptrace/bun/schema/relation.go b/vendor/github.com/uptrace/bun/schema/relation.go index 9eb74f7e9..f653cd7a3 100644 --- a/vendor/github.com/uptrace/bun/schema/relation.go +++ b/vendor/github.com/uptrace/bun/schema/relation.go @@ -13,21 +13,25 @@ const ( ) type Relation struct { - Type int - Field *Field - JoinTable *Table - BaseFields []*Field - JoinFields []*Field - OnUpdate string - OnDelete string - Condition []string + // Base and Join can be explained with this query: + // + // SELECT * FROM base_table JOIN join_table + + Type int + Field *Field + JoinTable *Table + BasePKs []*Field + JoinPKs []*Field + OnUpdate string + OnDelete string + Condition []string PolymorphicField *Field PolymorphicValue string - M2MTable *Table - M2MBaseFields []*Field - M2MJoinFields []*Field + M2MTable *Table + M2MBasePKs []*Field + M2MJoinPKs []*Field } // References returns true if the table to which the Relation belongs needs to declare a foreign key constraint to create the relation. diff --git a/vendor/github.com/uptrace/bun/schema/scan.go b/vendor/github.com/uptrace/bun/schema/scan.go index 96b31caf3..4da160daf 100644 --- a/vendor/github.com/uptrace/bun/schema/scan.go +++ b/vendor/github.com/uptrace/bun/schema/scan.go @@ -8,9 +8,9 @@ import ( "reflect" "strconv" "strings" - "sync" "time" + "github.com/puzpuzpuz/xsync/v3" "github.com/vmihailenco/msgpack/v5" "github.com/uptrace/bun/dialect/sqltype" @@ -53,7 +53,7 @@ func init() { } } -var scannerMap sync.Map +var scannerCache = xsync.NewMapOf[reflect.Type, ScannerFunc]() func FieldScanner(dialect Dialect, field *Field) ScannerFunc { if field.Tag.HasOption("msgpack") { @@ -72,14 +72,14 @@ func FieldScanner(dialect Dialect, field *Field) ScannerFunc { } func Scanner(typ reflect.Type) ScannerFunc { - if v, ok := scannerMap.Load(typ); ok { - return v.(ScannerFunc) + if v, ok := scannerCache.Load(typ); ok { + return v } fn := scanner(typ) - if v, ok := scannerMap.LoadOrStore(typ, fn); ok { - return v.(ScannerFunc) + if v, ok := scannerCache.LoadOrStore(typ, fn); ok { + return v } return fn } @@ -111,7 +111,7 @@ func scanner(typ reflect.Type) ScannerFunc { } if kind != reflect.Ptr { - ptr := reflect.PtrTo(typ) + ptr := reflect.PointerTo(typ) if ptr.Implements(scannerType) { return addrScanner(scanScanner) } diff --git a/vendor/github.com/uptrace/bun/schema/table.go b/vendor/github.com/uptrace/bun/schema/table.go index 0a23156a2..c8e71e38f 100644 --- a/vendor/github.com/uptrace/bun/schema/table.go +++ b/vendor/github.com/uptrace/bun/schema/table.go @@ -74,16 +74,7 @@ type structField struct { Table *Table } -func newTable( - dialect Dialect, typ reflect.Type, seen map[reflect.Type]*Table, canAddr bool, -) *Table { - if table, ok := seen[typ]; ok { - return table - } - - table := new(Table) - seen[typ] = table - +func (table *Table) init(dialect Dialect, typ reflect.Type, canAddr bool) { table.dialect = dialect table.Type = typ table.ZeroValue = reflect.New(table.Type).Elem() @@ -97,7 +88,7 @@ func newTable( table.Fields = make([]*Field, 0, typ.NumField()) table.FieldMap = make(map[string]*Field, typ.NumField()) - table.processFields(typ, seen, canAddr) + table.processFields(typ, canAddr) hooks := []struct { typ reflect.Type @@ -109,28 +100,15 @@ func newTable( {afterScanRowHookType, afterScanRowHookFlag}, } - typ = reflect.PtrTo(table.Type) + typ = reflect.PointerTo(table.Type) for _, hook := range hooks { if typ.Implements(hook.typ) { table.flags = table.flags.Set(hook.flag) } } - - return table -} - -func (t *Table) init() { - for _, field := range t.relFields { - t.processRelation(field) - } - t.relFields = nil } -func (t *Table) processFields( - typ reflect.Type, - seen map[reflect.Type]*Table, - canAddr bool, -) { +func (t *Table) processFields(typ reflect.Type, canAddr bool) { type embeddedField struct { prefix string index []int @@ -172,7 +150,7 @@ func (t *Table) processFields( continue } - subtable := newTable(t.dialect, sfType, seen, canAddr) + subtable := t.dialect.Tables().InProgress(sfType) for _, subfield := range subtable.allFields { embedded = append(embedded, embeddedField{ @@ -206,7 +184,7 @@ func (t *Table) processFields( t.TypeName, sf.Name, fieldType.Kind())) } - subtable := newTable(t.dialect, fieldType, seen, canAddr) + subtable := t.dialect.Tables().InProgress(fieldType) for _, subfield := range subtable.allFields { embedded = append(embedded, embeddedField{ prefix: prefix, @@ -229,7 +207,7 @@ func (t *Table) processFields( } t.StructMap[field.Name] = &structField{ Index: field.Index, - Table: newTable(t.dialect, field.IndirectType, seen, canAddr), + Table: t.dialect.Tables().InProgress(field.IndirectType), } } } @@ -423,6 +401,10 @@ func (t *Table) newField(sf reflect.StructField, tag tagparser.Tag) *Field { sqlName = tag.Name } + if s, ok := tag.Option("column"); ok { + sqlName = s + } + for name := range tag.Options { if !isKnownFieldOption(name) { internal.Warn.Printf("%s.%s has unknown tag option: %q", t.TypeName, sf.Name, name) @@ -490,6 +472,13 @@ func (t *Table) newField(sf reflect.StructField, tag tagparser.Tag) *Field { //--------------------------------------------------------------------------------------- +func (t *Table) initRelations() { + for _, field := range t.relFields { + t.processRelation(field) + } + t.relFields = nil +} + func (t *Table) processRelation(field *Field) { if rel, ok := field.Tag.Option("rel"); ok { t.initRelation(field, rel) @@ -577,7 +566,7 @@ func (t *Table) belongsToRelation(field *Field) *Relation { joinColumn := joinColumns[i] if f := t.FieldMap[baseColumn]; f != nil { - rel.BaseFields = append(rel.BaseFields, f) + rel.BasePKs = append(rel.BasePKs, f) } else { panic(fmt.Errorf( "bun: %s belongs-to %s: %s must have column %s", @@ -586,7 +575,7 @@ func (t *Table) belongsToRelation(field *Field) *Relation { } if f := joinTable.FieldMap[joinColumn]; f != nil { - rel.JoinFields = append(rel.JoinFields, f) + rel.JoinPKs = append(rel.JoinPKs, f) } else { panic(fmt.Errorf( "bun: %s belongs-to %s: %s must have column %s", @@ -597,17 +586,17 @@ func (t *Table) belongsToRelation(field *Field) *Relation { return rel } - rel.JoinFields = joinTable.PKs + rel.JoinPKs = joinTable.PKs fkPrefix := internal.Underscore(field.GoName) + "_" for _, joinPK := range joinTable.PKs { fkName := fkPrefix + joinPK.Name if fk := t.FieldMap[fkName]; fk != nil { - rel.BaseFields = append(rel.BaseFields, fk) + rel.BasePKs = append(rel.BasePKs, fk) continue } if fk := t.FieldMap[joinPK.Name]; fk != nil { - rel.BaseFields = append(rel.BaseFields, fk) + rel.BasePKs = append(rel.BasePKs, fk) continue } @@ -640,7 +629,7 @@ func (t *Table) hasOneRelation(field *Field) *Relation { baseColumns, joinColumns := parseRelationJoin(join) for i, baseColumn := range baseColumns { if f := t.FieldMap[baseColumn]; f != nil { - rel.BaseFields = append(rel.BaseFields, f) + rel.BasePKs = append(rel.BasePKs, f) } else { panic(fmt.Errorf( "bun: %s has-one %s: %s must have column %s", @@ -650,7 +639,7 @@ func (t *Table) hasOneRelation(field *Field) *Relation { joinColumn := joinColumns[i] if f := joinTable.FieldMap[joinColumn]; f != nil { - rel.JoinFields = append(rel.JoinFields, f) + rel.JoinPKs = append(rel.JoinPKs, f) } else { panic(fmt.Errorf( "bun: %s has-one %s: %s must have column %s", @@ -661,17 +650,17 @@ func (t *Table) hasOneRelation(field *Field) *Relation { return rel } - rel.BaseFields = t.PKs + rel.BasePKs = t.PKs fkPrefix := internal.Underscore(t.ModelName) + "_" for _, pk := range t.PKs { fkName := fkPrefix + pk.Name if f := joinTable.FieldMap[fkName]; f != nil { - rel.JoinFields = append(rel.JoinFields, f) + rel.JoinPKs = append(rel.JoinPKs, f) continue } if f := joinTable.FieldMap[pk.Name]; f != nil { - rel.JoinFields = append(rel.JoinFields, f) + rel.JoinPKs = append(rel.JoinPKs, f) continue } @@ -720,7 +709,7 @@ func (t *Table) hasManyRelation(field *Field) *Relation { } if f := t.FieldMap[baseColumn]; f != nil { - rel.BaseFields = append(rel.BaseFields, f) + rel.BasePKs = append(rel.BasePKs, f) } else { panic(fmt.Errorf( "bun: %s has-many %s: %s must have column %s", @@ -729,7 +718,7 @@ func (t *Table) hasManyRelation(field *Field) *Relation { } if f := joinTable.FieldMap[joinColumn]; f != nil { - rel.JoinFields = append(rel.JoinFields, f) + rel.JoinPKs = append(rel.JoinPKs, f) } else { panic(fmt.Errorf( "bun: %s has-many %s: %s must have column %s", @@ -738,7 +727,7 @@ func (t *Table) hasManyRelation(field *Field) *Relation { } } } else { - rel.BaseFields = t.PKs + rel.BasePKs = t.PKs fkPrefix := internal.Underscore(t.ModelName) + "_" if isPolymorphic { polymorphicColumn = fkPrefix + "type" @@ -747,12 +736,12 @@ func (t *Table) hasManyRelation(field *Field) *Relation { for _, pk := range t.PKs { joinColumn := fkPrefix + pk.Name if fk := joinTable.FieldMap[joinColumn]; fk != nil { - rel.JoinFields = append(rel.JoinFields, fk) + rel.JoinPKs = append(rel.JoinPKs, fk) continue } if fk := joinTable.FieldMap[pk.Name]; fk != nil { - rel.JoinFields = append(rel.JoinFields, fk) + rel.JoinPKs = append(rel.JoinPKs, fk) continue } @@ -852,12 +841,12 @@ func (t *Table) m2mRelation(field *Field) *Relation { } leftRel := m2mTable.belongsToRelation(leftField) - rel.BaseFields = leftRel.JoinFields - rel.M2MBaseFields = leftRel.BaseFields + rel.BasePKs = leftRel.JoinPKs + rel.M2MBasePKs = leftRel.BasePKs rightRel := m2mTable.belongsToRelation(rightField) - rel.JoinFields = rightRel.JoinFields - rel.M2MJoinFields = rightRel.BaseFields + rel.JoinPKs = rightRel.JoinPKs + rel.M2MJoinPKs = rightRel.BasePKs return rel } @@ -918,6 +907,7 @@ func isKnownFieldOption(name string) bool { "array", "hstore", "composite", + "multirange", "json_use_number", "msgpack", "notnull", diff --git a/vendor/github.com/uptrace/bun/schema/tables.go b/vendor/github.com/uptrace/bun/schema/tables.go index 19aff8606..985093421 100644 --- a/vendor/github.com/uptrace/bun/schema/tables.go +++ b/vendor/github.com/uptrace/bun/schema/tables.go @@ -4,22 +4,24 @@ import ( "fmt" "reflect" "sync" + + "github.com/puzpuzpuz/xsync/v3" ) type Tables struct { dialect Dialect - tables sync.Map - mu sync.RWMutex - seen map[reflect.Type]*Table - inProgress map[reflect.Type]*tableInProgress + mu sync.Mutex + tables *xsync.MapOf[reflect.Type, *Table] + + inProgress map[reflect.Type]*Table } func NewTables(dialect Dialect) *Tables { return &Tables{ dialect: dialect, - seen: make(map[reflect.Type]*Table), - inProgress: make(map[reflect.Type]*tableInProgress), + tables: xsync.NewMapOf[reflect.Type, *Table](), + inProgress: make(map[reflect.Type]*Table), } } @@ -30,58 +32,26 @@ func (t *Tables) Register(models ...interface{}) { } func (t *Tables) Get(typ reflect.Type) *Table { - return t.table(typ, false) -} - -func (t *Tables) InProgress(typ reflect.Type) *Table { - return t.table(typ, true) -} - -func (t *Tables) table(typ reflect.Type, allowInProgress bool) *Table { typ = indirectType(typ) if typ.Kind() != reflect.Struct { panic(fmt.Errorf("got %s, wanted %s", typ.Kind(), reflect.Struct)) } if v, ok := t.tables.Load(typ); ok { - return v.(*Table) + return v } t.mu.Lock() + defer t.mu.Unlock() if v, ok := t.tables.Load(typ); ok { - t.mu.Unlock() - return v.(*Table) - } - - var table *Table - - inProgress := t.inProgress[typ] - if inProgress == nil { - table = newTable(t.dialect, typ, t.seen, false) - inProgress = newTableInProgress(table) - t.inProgress[typ] = inProgress - } else { - table = inProgress.table + return v } - t.mu.Unlock() - - if allowInProgress { - return table - } - - if !inProgress.init() { - return table - } - - t.mu.Lock() - delete(t.inProgress, typ) - t.tables.Store(typ, table) - t.mu.Unlock() + table := t.InProgress(typ) + table.initRelations() t.dialect.OnTable(table) - for _, field := range table.FieldMap { if field.UserSQLType == "" { field.UserSQLType = field.DiscoveredSQLType @@ -91,15 +61,27 @@ func (t *Tables) table(typ reflect.Type, allowInProgress bool) *Table { } } + t.tables.Store(typ, table) + return table +} + +func (t *Tables) InProgress(typ reflect.Type) *Table { + if table, ok := t.inProgress[typ]; ok { + return table + } + + table := new(Table) + t.inProgress[typ] = table + table.init(t.dialect, typ, false) + return table } func (t *Tables) ByModel(name string) *Table { var found *Table - t.tables.Range(func(key, value interface{}) bool { - t := value.(*Table) - if t.TypeName == name { - found = t + t.tables.Range(func(typ reflect.Type, table *Table) bool { + if table.TypeName == name { + found = table return false } return true @@ -109,34 +91,12 @@ func (t *Tables) ByModel(name string) *Table { func (t *Tables) ByName(name string) *Table { var found *Table - t.tables.Range(func(key, value interface{}) bool { - t := value.(*Table) - if t.Name == name { - found = t + t.tables.Range(func(typ reflect.Type, table *Table) bool { + if table.Name == name { + found = table return false } return true }) return found } - -type tableInProgress struct { - table *Table - - initOnce sync.Once -} - -func newTableInProgress(table *Table) *tableInProgress { - return &tableInProgress{ - table: table, - } -} - -func (inp *tableInProgress) init() bool { - var inited bool - inp.initOnce.Do(func() { - inp.table.init() - inited = true - }) - return inited -} diff --git a/vendor/github.com/uptrace/bun/schema/zerochecker.go b/vendor/github.com/uptrace/bun/schema/zerochecker.go index f24e51d30..7c1f088c1 100644 --- a/vendor/github.com/uptrace/bun/schema/zerochecker.go +++ b/vendor/github.com/uptrace/bun/schema/zerochecker.go @@ -60,7 +60,7 @@ func zeroChecker(typ reflect.Type) IsZeroerFunc { kind := typ.Kind() if kind != reflect.Ptr { - ptr := reflect.PtrTo(typ) + ptr := reflect.PointerTo(typ) if ptr.Implements(isZeroerType) { return addrChecker(isZeroInterface) } diff --git a/vendor/github.com/uptrace/bun/version.go b/vendor/github.com/uptrace/bun/version.go index be6c67f30..7f23c12c3 100644 --- a/vendor/github.com/uptrace/bun/version.go +++ b/vendor/github.com/uptrace/bun/version.go @@ -2,5 +2,5 @@ package bun // Version is the current release version. func Version() string { - return "1.2.1" + return "1.2.5" } |