diff options
Diffstat (limited to 'vendor/github.com/uptrace/bun/schema/scan.go')
-rw-r--r-- | vendor/github.com/uptrace/bun/schema/scan.go | 392 |
1 files changed, 392 insertions, 0 deletions
diff --git a/vendor/github.com/uptrace/bun/schema/scan.go b/vendor/github.com/uptrace/bun/schema/scan.go new file mode 100644 index 000000000..0e66a860f --- /dev/null +++ b/vendor/github.com/uptrace/bun/schema/scan.go @@ -0,0 +1,392 @@ +package schema + +import ( + "bytes" + "database/sql" + "fmt" + "net" + "reflect" + "strconv" + "time" + + "github.com/vmihailenco/msgpack/v5" + + "github.com/uptrace/bun/extra/bunjson" + "github.com/uptrace/bun/internal" +) + +var scannerType = reflect.TypeOf((*sql.Scanner)(nil)).Elem() + +type ScannerFunc func(dest reflect.Value, src interface{}) error + +var scanners = []ScannerFunc{ + reflect.Bool: scanBool, + reflect.Int: scanInt64, + reflect.Int8: scanInt64, + reflect.Int16: scanInt64, + reflect.Int32: scanInt64, + reflect.Int64: scanInt64, + reflect.Uint: scanUint64, + reflect.Uint8: scanUint64, + reflect.Uint16: scanUint64, + reflect.Uint32: scanUint64, + reflect.Uint64: scanUint64, + reflect.Uintptr: scanUint64, + reflect.Float32: scanFloat64, + reflect.Float64: scanFloat64, + reflect.Complex64: nil, + reflect.Complex128: nil, + reflect.Array: nil, + reflect.Chan: nil, + reflect.Func: nil, + reflect.Map: scanJSON, + reflect.Ptr: nil, + reflect.Slice: scanJSON, + reflect.String: scanString, + reflect.Struct: scanJSON, + reflect.UnsafePointer: nil, +} + +func FieldScanner(dialect Dialect, field *Field) ScannerFunc { + if field.Tag.HasOption("msgpack") { + return scanMsgpack + } + if field.Tag.HasOption("json_use_number") { + return scanJSONUseNumber + } + return dialect.Scanner(field.StructField.Type) +} + +func Scanner(typ reflect.Type) ScannerFunc { + kind := typ.Kind() + + if kind == reflect.Ptr { + if fn := Scanner(typ.Elem()); fn != nil { + return ptrScanner(fn) + } + } + + if typ.Implements(scannerType) { + return scanScanner + } + + if kind != reflect.Ptr { + ptr := reflect.PtrTo(typ) + if ptr.Implements(scannerType) { + return addrScanner(scanScanner) + } + } + + switch typ { + case timeType: + return scanTime + case ipType: + return scanIP + case ipNetType: + return scanIPNet + case jsonRawMessageType: + return scanJSONRawMessage + } + + return scanners[kind] +} + +func scanBool(dest reflect.Value, src interface{}) error { + switch src := src.(type) { + case nil: + dest.SetBool(false) + return nil + case bool: + dest.SetBool(src) + return nil + case int64: + dest.SetBool(src != 0) + return nil + case []byte: + if len(src) == 1 { + dest.SetBool(src[0] != '0') + return nil + } + } + return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type()) +} + +func scanInt64(dest reflect.Value, src interface{}) error { + switch src := src.(type) { + case nil: + dest.SetInt(0) + return nil + case int64: + dest.SetInt(src) + return nil + case uint64: + dest.SetInt(int64(src)) + return nil + case []byte: + n, err := strconv.ParseInt(internal.String(src), 10, 64) + if err != nil { + return err + } + dest.SetInt(n) + return nil + case string: + n, err := strconv.ParseInt(src, 10, 64) + if err != nil { + return err + } + dest.SetInt(n) + return nil + } + return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type()) +} + +func scanUint64(dest reflect.Value, src interface{}) error { + switch src := src.(type) { + case nil: + dest.SetUint(0) + return nil + case uint64: + dest.SetUint(src) + return nil + case int64: + dest.SetUint(uint64(src)) + return nil + case []byte: + n, err := strconv.ParseUint(internal.String(src), 10, 64) + if err != nil { + return err + } + dest.SetUint(n) + return nil + } + return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type()) +} + +func scanFloat64(dest reflect.Value, src interface{}) error { + switch src := src.(type) { + case nil: + dest.SetFloat(0) + return nil + case float64: + dest.SetFloat(src) + return nil + case []byte: + f, err := strconv.ParseFloat(internal.String(src), 64) + if err != nil { + return err + } + dest.SetFloat(f) + return nil + } + return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type()) +} + +func scanString(dest reflect.Value, src interface{}) error { + switch src := src.(type) { + case nil: + dest.SetString("") + return nil + case string: + dest.SetString(src) + return nil + case []byte: + dest.SetString(string(src)) + return nil + } + return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type()) +} + +func scanTime(dest reflect.Value, src interface{}) error { + switch src := src.(type) { + case nil: + destTime := dest.Addr().Interface().(*time.Time) + *destTime = time.Time{} + return nil + case time.Time: + destTime := dest.Addr().Interface().(*time.Time) + *destTime = src + return nil + case string: + srcTime, err := internal.ParseTime(src) + if err != nil { + return err + } + destTime := dest.Addr().Interface().(*time.Time) + *destTime = srcTime + return nil + case []byte: + srcTime, err := internal.ParseTime(internal.String(src)) + if err != nil { + return err + } + destTime := dest.Addr().Interface().(*time.Time) + *destTime = srcTime + return nil + } + return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type()) +} + +func scanScanner(dest reflect.Value, src interface{}) error { + return dest.Interface().(sql.Scanner).Scan(src) +} + +func scanMsgpack(dest reflect.Value, src interface{}) error { + if src == nil { + return scanNull(dest) + } + + b, err := toBytes(src) + if err != nil { + return err + } + + dec := msgpack.GetDecoder() + defer msgpack.PutDecoder(dec) + + dec.Reset(bytes.NewReader(b)) + return dec.DecodeValue(dest) +} + +func scanJSON(dest reflect.Value, src interface{}) error { + if src == nil { + return scanNull(dest) + } + + b, err := toBytes(src) + if err != nil { + return err + } + + return bunjson.Unmarshal(b, dest.Addr().Interface()) +} + +func scanJSONUseNumber(dest reflect.Value, src interface{}) error { + if src == nil { + return scanNull(dest) + } + + b, err := toBytes(src) + if err != nil { + return err + } + + dec := bunjson.NewDecoder(bytes.NewReader(b)) + dec.UseNumber() + return dec.Decode(dest.Addr().Interface()) +} + +func scanIP(dest reflect.Value, src interface{}) error { + if src == nil { + return scanNull(dest) + } + + b, err := toBytes(src) + if err != nil { + return err + } + + ip := net.ParseIP(internal.String(b)) + if ip == nil { + return fmt.Errorf("bun: invalid ip: %q", b) + } + + ptr := dest.Addr().Interface().(*net.IP) + *ptr = ip + + return nil +} + +func scanIPNet(dest reflect.Value, src interface{}) error { + if src == nil { + return scanNull(dest) + } + + b, err := toBytes(src) + if err != nil { + return err + } + + _, ipnet, err := net.ParseCIDR(internal.String(b)) + if err != nil { + return err + } + + ptr := dest.Addr().Interface().(*net.IPNet) + *ptr = *ipnet + + return nil +} + +func scanJSONRawMessage(dest reflect.Value, src interface{}) error { + if src == nil { + dest.SetBytes(nil) + return nil + } + + b, err := toBytes(src) + if err != nil { + return err + } + + dest.SetBytes(b) + return nil +} + +func addrScanner(fn ScannerFunc) ScannerFunc { + return func(dest reflect.Value, src interface{}) error { + if !dest.CanAddr() { + return fmt.Errorf("bun: Scan(nonaddressable %T)", dest.Interface()) + } + return fn(dest.Addr(), src) + } +} + +func toBytes(src interface{}) ([]byte, error) { + switch src := src.(type) { + case string: + return internal.Bytes(src), nil + case []byte: + return src, nil + default: + return nil, fmt.Errorf("bun: got %T, wanted []byte or string", src) + } +} + +func ptrScanner(fn ScannerFunc) ScannerFunc { + return func(dest reflect.Value, src interface{}) error { + if src == nil { + if !dest.CanAddr() { + if dest.IsNil() { + return nil + } + return fn(dest.Elem(), src) + } + + if !dest.IsNil() { + dest.Set(reflect.New(dest.Type().Elem())) + } + return nil + } + + if dest.IsNil() { + dest.Set(reflect.New(dest.Type().Elem())) + } + return fn(dest.Elem(), src) + } +} + +func scanNull(dest reflect.Value) error { + if nilable(dest.Kind()) && dest.IsNil() { + return nil + } + dest.Set(reflect.New(dest.Type()).Elem()) + return nil +} + +func nilable(kind reflect.Kind) bool { + switch kind { + case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice: + return true + } + return false +} |