diff options
Diffstat (limited to 'vendor/github.com/uptrace/bun/schema')
-rw-r--r-- | vendor/github.com/uptrace/bun/schema/append_value.go | 22 | ||||
-rw-r--r-- | vendor/github.com/uptrace/bun/schema/formatter.go | 8 | ||||
-rw-r--r-- | vendor/github.com/uptrace/bun/schema/reflect.go (renamed from vendor/github.com/uptrace/bun/schema/util.go) | 19 | ||||
-rw-r--r-- | vendor/github.com/uptrace/bun/schema/scan.go | 121 | ||||
-rw-r--r-- | vendor/github.com/uptrace/bun/schema/sqlfmt.go | 2 | ||||
-rw-r--r-- | vendor/github.com/uptrace/bun/schema/sqltype.go | 49 | ||||
-rw-r--r-- | vendor/github.com/uptrace/bun/schema/table.go | 104 | ||||
-rw-r--r-- | vendor/github.com/uptrace/bun/schema/tables.go | 1 |
8 files changed, 191 insertions, 135 deletions
diff --git a/vendor/github.com/uptrace/bun/schema/append_value.go b/vendor/github.com/uptrace/bun/schema/append_value.go index 0c4677069..948ff86af 100644 --- a/vendor/github.com/uptrace/bun/schema/append_value.go +++ b/vendor/github.com/uptrace/bun/schema/append_value.go @@ -2,7 +2,6 @@ package schema import ( "database/sql/driver" - "encoding/json" "fmt" "net" "reflect" @@ -14,16 +13,6 @@ import ( "github.com/uptrace/bun/internal" ) -var ( - timeType = reflect.TypeOf((*time.Time)(nil)).Elem() - ipType = reflect.TypeOf((*net.IP)(nil)).Elem() - ipNetType = reflect.TypeOf((*net.IPNet)(nil)).Elem() - jsonRawMessageType = reflect.TypeOf((*json.RawMessage)(nil)).Elem() - - driverValuerType = reflect.TypeOf((*driver.Valuer)(nil)).Elem() - queryAppenderType = reflect.TypeOf((*QueryAppender)(nil)).Elem() -) - type ( AppenderFunc func(fmter Formatter, b []byte, v reflect.Value) []byte CustomAppender func(typ reflect.Type) AppenderFunc @@ -60,6 +49,8 @@ var appenders = []AppenderFunc{ func Appender(typ reflect.Type, custom CustomAppender) AppenderFunc { switch typ { + case bytesType: + return appendBytesValue case timeType: return appendTimeValue case ipType: @@ -93,7 +84,9 @@ func Appender(typ reflect.Type, custom CustomAppender) AppenderFunc { case reflect.Interface: return ifaceAppenderFunc(typ, custom) case reflect.Ptr: - return ptrAppenderFunc(typ, custom) + if fn := Appender(typ.Elem(), custom); fn != nil { + return PtrAppender(fn) + } case reflect.Slice: if typ.Elem().Kind() == reflect.Uint8 { return appendBytesValue @@ -123,13 +116,12 @@ func ifaceAppenderFunc(typ reflect.Type, custom func(reflect.Type) AppenderFunc) } } -func ptrAppenderFunc(typ reflect.Type, custom func(reflect.Type) AppenderFunc) AppenderFunc { - appender := Appender(typ.Elem(), custom) +func PtrAppender(fn AppenderFunc) AppenderFunc { return func(fmter Formatter, b []byte, v reflect.Value) []byte { if v.IsNil() { return dialect.AppendNull(b) } - return appender(fmter, b, v.Elem()) + return fn(fmter, b, v.Elem()) } } diff --git a/vendor/github.com/uptrace/bun/schema/formatter.go b/vendor/github.com/uptrace/bun/schema/formatter.go index 7b26fbaca..45a246307 100644 --- a/vendor/github.com/uptrace/bun/schema/formatter.go +++ b/vendor/github.com/uptrace/bun/schema/formatter.go @@ -89,10 +89,10 @@ func (f Formatter) AppendQuery(dst []byte, query string, args ...interface{}) [] func (f Formatter) append(dst []byte, p *parser.Parser, args []interface{}) []byte { var namedArgs NamedArgAppender if len(args) == 1 { - var ok bool - namedArgs, ok = args[0].(NamedArgAppender) - if !ok { - namedArgs, _ = newStructArgs(f, args[0]) + if v, ok := args[0].(NamedArgAppender); ok { + namedArgs = v + } else if v, ok := newStructArgs(f, args[0]); ok { + namedArgs = v } } diff --git a/vendor/github.com/uptrace/bun/schema/util.go b/vendor/github.com/uptrace/bun/schema/reflect.go index 6d474e4cc..5b20b1964 100644 --- a/vendor/github.com/uptrace/bun/schema/util.go +++ b/vendor/github.com/uptrace/bun/schema/reflect.go @@ -1,6 +1,23 @@ package schema -import "reflect" +import ( + "database/sql/driver" + "encoding/json" + "net" + "reflect" + "time" +) + +var ( + bytesType = reflect.TypeOf((*[]byte)(nil)).Elem() + timeType = reflect.TypeOf((*time.Time)(nil)).Elem() + ipType = reflect.TypeOf((*net.IP)(nil)).Elem() + ipNetType = reflect.TypeOf((*net.IPNet)(nil)).Elem() + jsonRawMessageType = reflect.TypeOf((*json.RawMessage)(nil)).Elem() + + driverValuerType = reflect.TypeOf((*driver.Valuer)(nil)).Elem() + queryAppenderType = reflect.TypeOf((*QueryAppender)(nil)).Elem() +) func indirectType(t reflect.Type) reflect.Type { if t.Kind() == reflect.Ptr { diff --git a/vendor/github.com/uptrace/bun/schema/scan.go b/vendor/github.com/uptrace/bun/schema/scan.go index 0e66a860f..85ba62a01 100644 --- a/vendor/github.com/uptrace/bun/schema/scan.go +++ b/vendor/github.com/uptrace/bun/schema/scan.go @@ -7,10 +7,12 @@ import ( "net" "reflect" "strconv" + "strings" "time" "github.com/vmihailenco/msgpack/v5" + "github.com/uptrace/bun/dialect/sqltype" "github.com/uptrace/bun/extra/bunjson" "github.com/uptrace/bun/internal" ) @@ -19,32 +21,35 @@ var scannerType = reflect.TypeOf((*sql.Scanner)(nil)).Elem() type ScannerFunc func(dest reflect.Value, src interface{}) error -var scanners = []ScannerFunc{ - reflect.Bool: scanBool, - reflect.Int: scanInt64, - reflect.Int8: scanInt64, - reflect.Int16: scanInt64, - reflect.Int32: scanInt64, - reflect.Int64: scanInt64, - reflect.Uint: scanUint64, - reflect.Uint8: scanUint64, - reflect.Uint16: scanUint64, - reflect.Uint32: scanUint64, - reflect.Uint64: scanUint64, - reflect.Uintptr: scanUint64, - reflect.Float32: scanFloat64, - reflect.Float64: scanFloat64, - reflect.Complex64: nil, - reflect.Complex128: nil, - reflect.Array: nil, - reflect.Chan: nil, - reflect.Func: nil, - reflect.Map: scanJSON, - reflect.Ptr: nil, - reflect.Slice: scanJSON, - reflect.String: scanString, - reflect.Struct: scanJSON, - reflect.UnsafePointer: nil, +var scanners []ScannerFunc + +func init() { + scanners = []ScannerFunc{ + reflect.Bool: scanBool, + reflect.Int: scanInt64, + reflect.Int8: scanInt64, + reflect.Int16: scanInt64, + reflect.Int32: scanInt64, + reflect.Int64: scanInt64, + reflect.Uint: scanUint64, + reflect.Uint8: scanUint64, + reflect.Uint16: scanUint64, + reflect.Uint32: scanUint64, + reflect.Uint64: scanUint64, + reflect.Uintptr: scanUint64, + reflect.Float32: scanFloat64, + reflect.Float64: scanFloat64, + reflect.Complex64: nil, + reflect.Complex128: nil, + reflect.Array: nil, + reflect.Interface: scanInterface, + reflect.Map: scanJSON, + reflect.Ptr: nil, + reflect.Slice: scanJSON, + reflect.String: scanString, + reflect.Struct: scanJSON, + reflect.UnsafePointer: nil, + } } func FieldScanner(dialect Dialect, field *Field) ScannerFunc { @@ -54,6 +59,12 @@ func FieldScanner(dialect Dialect, field *Field) ScannerFunc { if field.Tag.HasOption("json_use_number") { return scanJSONUseNumber } + if field.StructField.Type.Kind() == reflect.Interface { + switch strings.ToUpper(field.UserSQLType) { + case sqltype.JSON, sqltype.JSONB: + return scanJSONIntoInterface + } + } return dialect.Scanner(field.StructField.Type) } @@ -62,7 +73,7 @@ func Scanner(typ reflect.Type) ScannerFunc { if kind == reflect.Ptr { if fn := Scanner(typ.Elem()); fn != nil { - return ptrScanner(fn) + return PtrScanner(fn) } } @@ -84,6 +95,8 @@ func Scanner(typ reflect.Type) ScannerFunc { return scanIP case ipNetType: return scanIPNet + case bytesType: + return scanBytes case jsonRawMessageType: return scanJSONRawMessage } @@ -196,6 +209,21 @@ func scanString(dest reflect.Value, src interface{}) error { return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type()) } +func scanBytes(dest reflect.Value, src interface{}) error { + switch src := src.(type) { + case nil: + dest.SetBytes(nil) + return nil + case string: + dest.SetBytes([]byte(src)) + return nil + case []byte: + dest.SetBytes(src) + return nil + } + return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type()) +} + func scanTime(dest reflect.Value, src interface{}) error { switch src := src.(type) { case nil: @@ -352,7 +380,7 @@ func toBytes(src interface{}) ([]byte, error) { } } -func ptrScanner(fn ScannerFunc) ScannerFunc { +func PtrScanner(fn ScannerFunc) ScannerFunc { return func(dest reflect.Value, src interface{}) error { if src == nil { if !dest.CanAddr() { @@ -383,6 +411,43 @@ func scanNull(dest reflect.Value) error { return nil } +func scanJSONIntoInterface(dest reflect.Value, src interface{}) error { + if dest.IsNil() { + if src == nil { + return nil + } + + b, err := toBytes(src) + if err != nil { + return err + } + + return bunjson.Unmarshal(b, dest.Addr().Interface()) + } + + dest = dest.Elem() + if fn := Scanner(dest.Type()); fn != nil { + return fn(dest, src) + } + return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type()) +} + +func scanInterface(dest reflect.Value, src interface{}) error { + if dest.IsNil() { + if src == nil { + return nil + } + dest.Set(reflect.ValueOf(src)) + return nil + } + + dest = dest.Elem() + if fn := Scanner(dest.Type()); fn != nil { + return fn(dest, src) + } + return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type()) +} + func nilable(kind reflect.Kind) bool { switch kind { case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice: diff --git a/vendor/github.com/uptrace/bun/schema/sqlfmt.go b/vendor/github.com/uptrace/bun/schema/sqlfmt.go index 7b538cd0c..bbdb0a01f 100644 --- a/vendor/github.com/uptrace/bun/schema/sqlfmt.go +++ b/vendor/github.com/uptrace/bun/schema/sqlfmt.go @@ -40,7 +40,7 @@ type QueryWithArgs struct { var _ QueryAppender = QueryWithArgs{} func SafeQuery(query string, args []interface{}) QueryWithArgs { - if query != "" && args == nil { + if args == nil { args = make([]interface{}, 0) } return QueryWithArgs{Query: query, Args: args} diff --git a/vendor/github.com/uptrace/bun/schema/sqltype.go b/vendor/github.com/uptrace/bun/schema/sqltype.go index 560f695c2..23322a1e1 100644 --- a/vendor/github.com/uptrace/bun/schema/sqltype.go +++ b/vendor/github.com/uptrace/bun/schema/sqltype.go @@ -23,32 +23,29 @@ var ( ) var sqlTypes = []string{ - reflect.Bool: sqltype.Boolean, - reflect.Int: sqltype.BigInt, - reflect.Int8: sqltype.SmallInt, - reflect.Int16: sqltype.SmallInt, - reflect.Int32: sqltype.Integer, - reflect.Int64: sqltype.BigInt, - reflect.Uint: sqltype.BigInt, - reflect.Uint8: sqltype.SmallInt, - reflect.Uint16: sqltype.SmallInt, - reflect.Uint32: sqltype.Integer, - reflect.Uint64: sqltype.BigInt, - reflect.Uintptr: sqltype.BigInt, - reflect.Float32: sqltype.Real, - reflect.Float64: sqltype.DoublePrecision, - reflect.Complex64: "", - reflect.Complex128: "", - reflect.Array: "", - reflect.Chan: "", - reflect.Func: "", - reflect.Interface: "", - reflect.Map: sqltype.VarChar, - reflect.Ptr: "", - reflect.Slice: sqltype.VarChar, - reflect.String: sqltype.VarChar, - reflect.Struct: sqltype.VarChar, - reflect.UnsafePointer: "", + reflect.Bool: sqltype.Boolean, + reflect.Int: sqltype.BigInt, + reflect.Int8: sqltype.SmallInt, + reflect.Int16: sqltype.SmallInt, + reflect.Int32: sqltype.Integer, + reflect.Int64: sqltype.BigInt, + reflect.Uint: sqltype.BigInt, + reflect.Uint8: sqltype.SmallInt, + reflect.Uint16: sqltype.SmallInt, + reflect.Uint32: sqltype.Integer, + reflect.Uint64: sqltype.BigInt, + reflect.Uintptr: sqltype.BigInt, + reflect.Float32: sqltype.Real, + reflect.Float64: sqltype.DoublePrecision, + reflect.Complex64: "", + reflect.Complex128: "", + reflect.Array: "", + reflect.Interface: "", + reflect.Map: sqltype.VarChar, + reflect.Ptr: "", + reflect.Slice: sqltype.VarChar, + reflect.String: sqltype.VarChar, + reflect.Struct: sqltype.VarChar, } func DiscoverSQLType(typ reflect.Type) string { diff --git a/vendor/github.com/uptrace/bun/schema/table.go b/vendor/github.com/uptrace/bun/schema/table.go index eca18b781..7498a2bc8 100644 --- a/vendor/github.com/uptrace/bun/schema/table.go +++ b/vendor/github.com/uptrace/bun/schema/table.go @@ -60,10 +60,9 @@ type Table struct { Unique map[string][]*Field SoftDeleteField *Field - UpdateSoftDeleteField func(fv reflect.Value) error + UpdateSoftDeleteField func(fv reflect.Value, tm time.Time) error - allFields []*Field // read only - skippedFields []*Field + allFields []*Field // read only flags internal.Flag } @@ -104,9 +103,7 @@ func (t *Table) init1() { } func (t *Table) init2() { - t.initInlines() t.initRelations() - t.skippedFields = nil } func (t *Table) setName(name string) { @@ -207,15 +204,20 @@ func (t *Table) initFields() { func (t *Table) addFields(typ reflect.Type, baseIndex []int) { for i := 0; i < typ.NumField(); i++ { f := typ.Field(i) + unexported := f.PkgPath != "" - // Make a copy so slice is not shared between fields. + if unexported && !f.Anonymous { // unexported + continue + } + if f.Tag.Get("bun") == "-" { + continue + } + + // Make a copy so the slice is not shared between fields. index := make([]int, len(baseIndex)) copy(index, baseIndex) if f.Anonymous { - if f.Tag.Get("bun") == "-" { - continue - } if f.Name == "BaseModel" && f.Type == baseModelType { if len(index) == 0 { t.processBaseModelField(f) @@ -243,8 +245,7 @@ func (t *Table) addFields(typ reflect.Type, baseIndex []int) { continue } - field := t.newField(f, index) - if field != nil { + if field := t.newField(f, index); field != nil { t.addField(field) } } @@ -284,11 +285,10 @@ func (t *Table) processBaseModelField(f reflect.StructField) { func (t *Table) newField(f reflect.StructField, index []int) *Field { tag := tagparser.Parse(f.Tag.Get("bun")) - if f.PkgPath != "" { - return nil - } - sqlName := internal.Underscore(f.Name) + if tag.Name != "" { + sqlName = tag.Name + } if tag.Name != sqlName && isKnownFieldOption(tag.Name) { internal.Warn.Printf( @@ -303,11 +303,6 @@ func (t *Table) newField(f reflect.StructField, index []int) *Field { } } - skip := tag.Name == "-" - if !skip && tag.Name != "" { - sqlName = tag.Name - } - index = append(index, f.Index...) if field := t.fieldWithLock(sqlName); field != nil { if indexEqual(field.Index, index) { @@ -371,9 +366,11 @@ func (t *Table) newField(f reflect.StructField, index []int) *Field { } t.allFields = append(t.allFields, field) - if skip { - t.skippedFields = append(t.skippedFields, field) + if tag.HasOption("scanonly") { t.FieldMap[field.Name] = field + if field.IndirectType.Kind() == reflect.Struct { + t.inlineFields(field, nil) + } return nil } @@ -386,14 +383,6 @@ func (t *Table) newField(f reflect.StructField, index []int) *Field { return field } -func (t *Table) initInlines() { - for _, f := range t.skippedFields { - if f.IndirectType.Kind() == reflect.Struct { - t.inlineFields(f, nil) - } - } -} - //--------------------------------------------------------------------------------------- func (t *Table) initRelations() { @@ -745,17 +734,15 @@ func (t *Table) m2mRelation(field *Field) *Relation { return rel } -func (t *Table) inlineFields(field *Field, path map[reflect.Type]struct{}) { - if path == nil { - path = map[reflect.Type]struct{}{ - t.Type: {}, - } +func (t *Table) inlineFields(field *Field, seen map[reflect.Type]struct{}) { + if seen == nil { + seen = map[reflect.Type]struct{}{t.Type: {}} } - if _, ok := path[field.IndirectType]; ok { + if _, ok := seen[field.IndirectType]; ok { return } - path[field.IndirectType] = struct{}{} + seen[field.IndirectType] = struct{}{} joinTable := t.dialect.Tables().Ref(field.IndirectType) for _, f := range joinTable.allFields { @@ -775,18 +762,15 @@ func (t *Table) inlineFields(field *Field, path map[reflect.Type]struct{}) { continue } - if _, ok := path[f.IndirectType]; !ok { - t.inlineFields(f, path) + if _, ok := seen[f.IndirectType]; !ok { + t.inlineFields(f, seen) } } } //------------------------------------------------------------------------------ -func (t *Table) Dialect() Dialect { return t.dialect } - -//------------------------------------------------------------------------------ - +func (t *Table) Dialect() Dialect { return t.dialect } func (t *Table) HasBeforeScanHook() bool { return t.flags.Has(beforeScanHookFlag) } func (t *Table) HasAfterScanHook() bool { return t.flags.Has(afterScanHookFlag) } @@ -845,6 +829,7 @@ func isKnownFieldOption(name string) bool { "default", "unique", "soft_delete", + "scanonly", "pk", "autoincrement", @@ -883,35 +868,35 @@ func parseRelationJoin(join string) ([]string, []string) { //------------------------------------------------------------------------------ -func softDeleteFieldUpdater(field *Field) func(fv reflect.Value) error { +func softDeleteFieldUpdater(field *Field) func(fv reflect.Value, tm time.Time) error { typ := field.StructField.Type switch typ { case timeType: - return func(fv reflect.Value) error { + return func(fv reflect.Value, tm time.Time) error { ptr := fv.Addr().Interface().(*time.Time) - *ptr = time.Now() + *ptr = tm return nil } case nullTimeType: - return func(fv reflect.Value) error { + return func(fv reflect.Value, tm time.Time) error { ptr := fv.Addr().Interface().(*sql.NullTime) - *ptr = sql.NullTime{Time: time.Now()} + *ptr = sql.NullTime{Time: tm} return nil } case nullIntType: - return func(fv reflect.Value) error { + return func(fv reflect.Value, tm time.Time) error { ptr := fv.Addr().Interface().(*sql.NullInt64) - *ptr = sql.NullInt64{Int64: time.Now().UnixNano()} + *ptr = sql.NullInt64{Int64: tm.UnixNano()} return nil } } switch field.IndirectType.Kind() { case reflect.Int64: - return func(fv reflect.Value) error { + return func(fv reflect.Value, tm time.Time) error { ptr := fv.Addr().Interface().(*int64) - *ptr = time.Now().UnixNano() + *ptr = tm.UnixNano() return nil } case reflect.Ptr: @@ -922,17 +907,16 @@ func softDeleteFieldUpdater(field *Field) func(fv reflect.Value) error { switch typ { //nolint:gocritic case timeType: - return func(fv reflect.Value) error { - now := time.Now() - fv.Set(reflect.ValueOf(&now)) + return func(fv reflect.Value, tm time.Time) error { + fv.Set(reflect.ValueOf(&tm)) return nil } } switch typ.Kind() { //nolint:gocritic case reflect.Int64: - return func(fv reflect.Value) error { - utime := time.Now().UnixNano() + return func(fv reflect.Value, tm time.Time) error { + utime := tm.UnixNano() fv.Set(reflect.ValueOf(&utime)) return nil } @@ -941,8 +925,8 @@ func softDeleteFieldUpdater(field *Field) func(fv reflect.Value) error { return softDeleteFieldUpdaterFallback(field) } -func softDeleteFieldUpdaterFallback(field *Field) func(fv reflect.Value) error { - return func(fv reflect.Value) error { - return field.ScanWithCheck(fv, time.Now()) +func softDeleteFieldUpdaterFallback(field *Field) func(fv reflect.Value, tm time.Time) error { + return func(fv reflect.Value, tm time.Time) error { + return field.ScanWithCheck(fv, tm) } } diff --git a/vendor/github.com/uptrace/bun/schema/tables.go b/vendor/github.com/uptrace/bun/schema/tables.go index d82d08f59..4be856b34 100644 --- a/vendor/github.com/uptrace/bun/schema/tables.go +++ b/vendor/github.com/uptrace/bun/schema/tables.go @@ -67,6 +67,7 @@ func (t *Tables) Ref(typ reflect.Type) *Table { } func (t *Tables) table(typ reflect.Type, allowInProgress bool) *Table { + typ = indirectType(typ) if typ.Kind() != reflect.Struct { panic(fmt.Errorf("got %s, wanted %s", typ.Kind(), reflect.Struct)) } |