diff options
Diffstat (limited to 'vendor/github.com/uptrace/bun/schema/scan.go')
-rw-r--r-- | vendor/github.com/uptrace/bun/schema/scan.go | 121 |
1 files changed, 93 insertions, 28 deletions
diff --git a/vendor/github.com/uptrace/bun/schema/scan.go b/vendor/github.com/uptrace/bun/schema/scan.go index 0e66a860f..85ba62a01 100644 --- a/vendor/github.com/uptrace/bun/schema/scan.go +++ b/vendor/github.com/uptrace/bun/schema/scan.go @@ -7,10 +7,12 @@ import ( "net" "reflect" "strconv" + "strings" "time" "github.com/vmihailenco/msgpack/v5" + "github.com/uptrace/bun/dialect/sqltype" "github.com/uptrace/bun/extra/bunjson" "github.com/uptrace/bun/internal" ) @@ -19,32 +21,35 @@ var scannerType = reflect.TypeOf((*sql.Scanner)(nil)).Elem() type ScannerFunc func(dest reflect.Value, src interface{}) error -var scanners = []ScannerFunc{ - reflect.Bool: scanBool, - reflect.Int: scanInt64, - reflect.Int8: scanInt64, - reflect.Int16: scanInt64, - reflect.Int32: scanInt64, - reflect.Int64: scanInt64, - reflect.Uint: scanUint64, - reflect.Uint8: scanUint64, - reflect.Uint16: scanUint64, - reflect.Uint32: scanUint64, - reflect.Uint64: scanUint64, - reflect.Uintptr: scanUint64, - reflect.Float32: scanFloat64, - reflect.Float64: scanFloat64, - reflect.Complex64: nil, - reflect.Complex128: nil, - reflect.Array: nil, - reflect.Chan: nil, - reflect.Func: nil, - reflect.Map: scanJSON, - reflect.Ptr: nil, - reflect.Slice: scanJSON, - reflect.String: scanString, - reflect.Struct: scanJSON, - reflect.UnsafePointer: nil, +var scanners []ScannerFunc + +func init() { + scanners = []ScannerFunc{ + reflect.Bool: scanBool, + reflect.Int: scanInt64, + reflect.Int8: scanInt64, + reflect.Int16: scanInt64, + reflect.Int32: scanInt64, + reflect.Int64: scanInt64, + reflect.Uint: scanUint64, + reflect.Uint8: scanUint64, + reflect.Uint16: scanUint64, + reflect.Uint32: scanUint64, + reflect.Uint64: scanUint64, + reflect.Uintptr: scanUint64, + reflect.Float32: scanFloat64, + reflect.Float64: scanFloat64, + reflect.Complex64: nil, + reflect.Complex128: nil, + reflect.Array: nil, + reflect.Interface: scanInterface, + reflect.Map: scanJSON, + reflect.Ptr: nil, + reflect.Slice: scanJSON, + reflect.String: scanString, + reflect.Struct: scanJSON, + reflect.UnsafePointer: nil, + } } func FieldScanner(dialect Dialect, field *Field) ScannerFunc { @@ -54,6 +59,12 @@ func FieldScanner(dialect Dialect, field *Field) ScannerFunc { if field.Tag.HasOption("json_use_number") { return scanJSONUseNumber } + if field.StructField.Type.Kind() == reflect.Interface { + switch strings.ToUpper(field.UserSQLType) { + case sqltype.JSON, sqltype.JSONB: + return scanJSONIntoInterface + } + } return dialect.Scanner(field.StructField.Type) } @@ -62,7 +73,7 @@ func Scanner(typ reflect.Type) ScannerFunc { if kind == reflect.Ptr { if fn := Scanner(typ.Elem()); fn != nil { - return ptrScanner(fn) + return PtrScanner(fn) } } @@ -84,6 +95,8 @@ func Scanner(typ reflect.Type) ScannerFunc { return scanIP case ipNetType: return scanIPNet + case bytesType: + return scanBytes case jsonRawMessageType: return scanJSONRawMessage } @@ -196,6 +209,21 @@ func scanString(dest reflect.Value, src interface{}) error { return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type()) } +func scanBytes(dest reflect.Value, src interface{}) error { + switch src := src.(type) { + case nil: + dest.SetBytes(nil) + return nil + case string: + dest.SetBytes([]byte(src)) + return nil + case []byte: + dest.SetBytes(src) + return nil + } + return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type()) +} + func scanTime(dest reflect.Value, src interface{}) error { switch src := src.(type) { case nil: @@ -352,7 +380,7 @@ func toBytes(src interface{}) ([]byte, error) { } } -func ptrScanner(fn ScannerFunc) ScannerFunc { +func PtrScanner(fn ScannerFunc) ScannerFunc { return func(dest reflect.Value, src interface{}) error { if src == nil { if !dest.CanAddr() { @@ -383,6 +411,43 @@ func scanNull(dest reflect.Value) error { return nil } +func scanJSONIntoInterface(dest reflect.Value, src interface{}) error { + if dest.IsNil() { + if src == nil { + return nil + } + + b, err := toBytes(src) + if err != nil { + return err + } + + return bunjson.Unmarshal(b, dest.Addr().Interface()) + } + + dest = dest.Elem() + if fn := Scanner(dest.Type()); fn != nil { + return fn(dest, src) + } + return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type()) +} + +func scanInterface(dest reflect.Value, src interface{}) error { + if dest.IsNil() { + if src == nil { + return nil + } + dest.Set(reflect.ValueOf(src)) + return nil + } + + dest = dest.Elem() + if fn := Scanner(dest.Type()); fn != nil { + return fn(dest, src) + } + return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type()) +} + func nilable(kind reflect.Kind) bool { switch kind { case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice: |