diff options
Diffstat (limited to 'vendor/github.com/uptrace/bun/schema/scan.go')
-rw-r--r-- | vendor/github.com/uptrace/bun/schema/scan.go | 519 |
1 files changed, 0 insertions, 519 deletions
diff --git a/vendor/github.com/uptrace/bun/schema/scan.go b/vendor/github.com/uptrace/bun/schema/scan.go deleted file mode 100644 index 9db46cd6f..000000000 --- a/vendor/github.com/uptrace/bun/schema/scan.go +++ /dev/null @@ -1,519 +0,0 @@ -package schema - -import ( - "bytes" - "database/sql" - "fmt" - "net" - "reflect" - "strconv" - "strings" - "time" - - "github.com/puzpuzpuz/xsync/v3" - "github.com/vmihailenco/msgpack/v5" - - "github.com/uptrace/bun/dialect/sqltype" - "github.com/uptrace/bun/extra/bunjson" - "github.com/uptrace/bun/internal" -) - -var scannerType = reflect.TypeFor[sql.Scanner]() - -type ScannerFunc func(dest reflect.Value, src interface{}) error - -var scanners []ScannerFunc - -func init() { - scanners = []ScannerFunc{ - reflect.Bool: scanBool, - reflect.Int: scanInt64, - reflect.Int8: scanInt64, - reflect.Int16: scanInt64, - reflect.Int32: scanInt64, - reflect.Int64: scanInt64, - reflect.Uint: scanUint64, - reflect.Uint8: scanUint64, - reflect.Uint16: scanUint64, - reflect.Uint32: scanUint64, - reflect.Uint64: scanUint64, - reflect.Uintptr: scanUint64, - reflect.Float32: scanFloat, - reflect.Float64: scanFloat, - reflect.Complex64: nil, - reflect.Complex128: nil, - reflect.Array: nil, - reflect.Interface: scanInterface, - reflect.Map: scanJSON, - reflect.Ptr: nil, - reflect.Slice: scanJSON, - reflect.String: scanString, - reflect.Struct: scanJSON, - reflect.UnsafePointer: nil, - } -} - -var scannerCache = xsync.NewMapOf[reflect.Type, ScannerFunc]() - -func FieldScanner(dialect Dialect, field *Field) ScannerFunc { - if field.Tag.HasOption("msgpack") { - return scanMsgpack - } - if field.Tag.HasOption("json_use_number") { - return scanJSONUseNumber - } - if field.StructField.Type.Kind() == reflect.Interface { - switch strings.ToUpper(field.UserSQLType) { - case sqltype.JSON, sqltype.JSONB: - return scanJSONIntoInterface - } - } - return Scanner(field.StructField.Type) -} - -func Scanner(typ reflect.Type) ScannerFunc { - if v, ok := scannerCache.Load(typ); ok { - return v - } - - fn := scanner(typ) - - if v, ok := scannerCache.LoadOrStore(typ, fn); ok { - return v - } - return fn -} - -func scanner(typ reflect.Type) ScannerFunc { - kind := typ.Kind() - - if kind == reflect.Ptr { - if fn := Scanner(typ.Elem()); fn != nil { - return PtrScanner(fn) - } - } - - switch typ { - case bytesType: - return scanBytes - case timeType: - return scanTime - case ipType: - return scanIP - case ipNetType: - return scanIPNet - case jsonRawMessageType: - return scanBytes - } - - if typ.Implements(scannerType) { - return scanScanner - } - - if kind != reflect.Ptr { - ptr := reflect.PointerTo(typ) - if ptr.Implements(scannerType) { - return addrScanner(scanScanner) - } - } - - if typ.Kind() == reflect.Slice && typ.Elem().Kind() == reflect.Uint8 { - return scanBytes - } - - return scanners[kind] -} - -func scanBool(dest reflect.Value, src interface{}) error { - switch src := src.(type) { - case nil: - dest.SetBool(false) - return nil - case bool: - dest.SetBool(src) - return nil - case int64: - dest.SetBool(src != 0) - return nil - case []byte: - f, err := strconv.ParseBool(internal.String(src)) - if err != nil { - return err - } - dest.SetBool(f) - return nil - case string: - f, err := strconv.ParseBool(src) - if err != nil { - return err - } - dest.SetBool(f) - return nil - default: - return scanError(dest.Type(), src) - } -} - -func scanInt64(dest reflect.Value, src interface{}) error { - switch src := src.(type) { - case nil: - dest.SetInt(0) - return nil - case int64: - dest.SetInt(src) - return nil - case uint64: - dest.SetInt(int64(src)) - return nil - case []byte: - n, err := strconv.ParseInt(internal.String(src), 10, 64) - if err != nil { - return err - } - dest.SetInt(n) - return nil - case string: - n, err := strconv.ParseInt(src, 10, 64) - if err != nil { - return err - } - dest.SetInt(n) - return nil - default: - return scanError(dest.Type(), src) - } -} - -func scanUint64(dest reflect.Value, src interface{}) error { - switch src := src.(type) { - case nil: - dest.SetUint(0) - return nil - case uint64: - dest.SetUint(src) - return nil - case int64: - dest.SetUint(uint64(src)) - return nil - case []byte: - n, err := strconv.ParseUint(internal.String(src), 10, 64) - if err != nil { - return err - } - dest.SetUint(n) - return nil - case string: - n, err := strconv.ParseUint(src, 10, 64) - if err != nil { - return err - } - dest.SetUint(n) - return nil - default: - return scanError(dest.Type(), src) - } -} - -func scanFloat(dest reflect.Value, src interface{}) error { - switch src := src.(type) { - case nil: - dest.SetFloat(0) - return nil - case float32: - dest.SetFloat(float64(src)) - return nil - case float64: - dest.SetFloat(src) - return nil - case []byte: - f, err := strconv.ParseFloat(internal.String(src), 64) - if err != nil { - return err - } - dest.SetFloat(f) - return nil - case string: - f, err := strconv.ParseFloat(src, 64) - if err != nil { - return err - } - dest.SetFloat(f) - return nil - default: - return scanError(dest.Type(), src) - } -} - -func scanString(dest reflect.Value, src interface{}) error { - switch src := src.(type) { - case nil: - dest.SetString("") - return nil - case string: - dest.SetString(src) - return nil - case []byte: - dest.SetString(string(src)) - return nil - case time.Time: - dest.SetString(src.Format(time.RFC3339Nano)) - return nil - case int64: - dest.SetString(strconv.FormatInt(src, 10)) - return nil - case uint64: - dest.SetString(strconv.FormatUint(src, 10)) - return nil - case float64: - dest.SetString(strconv.FormatFloat(src, 'G', -1, 64)) - return nil - default: - return scanError(dest.Type(), src) - } -} - -func scanBytes(dest reflect.Value, src interface{}) error { - switch src := src.(type) { - case nil: - dest.SetBytes(nil) - return nil - case string: - dest.SetBytes([]byte(src)) - return nil - case []byte: - clone := make([]byte, len(src)) - copy(clone, src) - - dest.SetBytes(clone) - return nil - default: - return scanError(dest.Type(), src) - } -} - -func scanTime(dest reflect.Value, src interface{}) error { - switch src := src.(type) { - case nil: - destTime := dest.Addr().Interface().(*time.Time) - *destTime = time.Time{} - return nil - case time.Time: - destTime := dest.Addr().Interface().(*time.Time) - *destTime = src - return nil - case string: - srcTime, err := internal.ParseTime(src) - if err != nil { - return err - } - destTime := dest.Addr().Interface().(*time.Time) - *destTime = srcTime - return nil - case []byte: - srcTime, err := internal.ParseTime(internal.String(src)) - if err != nil { - return err - } - destTime := dest.Addr().Interface().(*time.Time) - *destTime = srcTime - return nil - default: - return scanError(dest.Type(), src) - } -} - -func scanScanner(dest reflect.Value, src interface{}) error { - return dest.Interface().(sql.Scanner).Scan(src) -} - -func scanMsgpack(dest reflect.Value, src interface{}) error { - if src == nil { - return scanNull(dest) - } - - b, err := toBytes(src) - if err != nil { - return err - } - - dec := msgpack.GetDecoder() - defer msgpack.PutDecoder(dec) - - dec.Reset(bytes.NewReader(b)) - return dec.DecodeValue(dest) -} - -func scanJSON(dest reflect.Value, src interface{}) error { - if src == nil { - return scanNull(dest) - } - - b, err := toBytes(src) - if err != nil { - return err - } - - return bunjson.Unmarshal(b, dest.Addr().Interface()) -} - -func scanJSONUseNumber(dest reflect.Value, src interface{}) error { - if src == nil { - return scanNull(dest) - } - - b, err := toBytes(src) - if err != nil { - return err - } - - dec := bunjson.NewDecoder(bytes.NewReader(b)) - dec.UseNumber() - return dec.Decode(dest.Addr().Interface()) -} - -func scanIP(dest reflect.Value, src interface{}) error { - if src == nil { - return scanNull(dest) - } - - b, err := toBytes(src) - if err != nil { - return err - } - - ip := net.ParseIP(internal.String(b)) - if ip == nil { - return fmt.Errorf("bun: invalid ip: %q", b) - } - - ptr := dest.Addr().Interface().(*net.IP) - *ptr = ip - - return nil -} - -func scanIPNet(dest reflect.Value, src interface{}) error { - if src == nil { - return scanNull(dest) - } - - b, err := toBytes(src) - if err != nil { - return err - } - - _, ipnet, err := net.ParseCIDR(internal.String(b)) - if err != nil { - return err - } - - ptr := dest.Addr().Interface().(*net.IPNet) - *ptr = *ipnet - - return nil -} - -func addrScanner(fn ScannerFunc) ScannerFunc { - return func(dest reflect.Value, src interface{}) error { - if !dest.CanAddr() { - return fmt.Errorf("bun: Scan(nonaddressable %T)", dest.Interface()) - } - return fn(dest.Addr(), src) - } -} - -func toBytes(src interface{}) ([]byte, error) { - switch src := src.(type) { - case string: - return internal.Bytes(src), nil - case []byte: - return src, nil - default: - return nil, fmt.Errorf("bun: got %T, wanted []byte or string", src) - } -} - -func PtrScanner(fn ScannerFunc) ScannerFunc { - return func(dest reflect.Value, src interface{}) error { - if src == nil { - if !dest.CanAddr() { - if dest.IsNil() { - return nil - } - return fn(dest.Elem(), src) - } - - if !dest.IsNil() { - dest.Set(reflect.New(dest.Type().Elem())) - } - return nil - } - - if dest.IsNil() { - dest.Set(reflect.New(dest.Type().Elem())) - } - - if dest.Kind() == reflect.Map { - return fn(dest, src) - } - - return fn(dest.Elem(), src) - } -} - -func scanNull(dest reflect.Value) error { - if nilable(dest.Kind()) && dest.IsNil() { - return nil - } - dest.Set(reflect.New(dest.Type()).Elem()) - return nil -} - -func scanJSONIntoInterface(dest reflect.Value, src interface{}) error { - if dest.IsNil() { - if src == nil { - return nil - } - - b, err := toBytes(src) - if err != nil { - return err - } - - return bunjson.Unmarshal(b, dest.Addr().Interface()) - } - - dest = dest.Elem() - if fn := Scanner(dest.Type()); fn != nil { - return fn(dest, src) - } - return scanError(dest.Type(), src) -} - -func scanInterface(dest reflect.Value, src interface{}) error { - if dest.IsNil() { - if src == nil { - return nil - } - dest.Set(reflect.ValueOf(src)) - return nil - } - - dest = dest.Elem() - if fn := Scanner(dest.Type()); fn != nil { - return fn(dest, src) - } - return scanError(dest.Type(), src) -} - -func nilable(kind reflect.Kind) bool { - switch kind { - case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice: - return true - } - return false -} - -func scanError(dest reflect.Type, src interface{}) error { - return fmt.Errorf("bun: can't scan %#v (%T) into %s", src, src, dest.String()) -} |