summaryrefslogtreecommitdiff
path: root/vendor/github.com/uptrace/bun/schema/scan.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/uptrace/bun/schema/scan.go')
-rw-r--r--vendor/github.com/uptrace/bun/schema/scan.go121
1 files changed, 93 insertions, 28 deletions
diff --git a/vendor/github.com/uptrace/bun/schema/scan.go b/vendor/github.com/uptrace/bun/schema/scan.go
index 0e66a860f..85ba62a01 100644
--- a/vendor/github.com/uptrace/bun/schema/scan.go
+++ b/vendor/github.com/uptrace/bun/schema/scan.go
@@ -7,10 +7,12 @@ import (
"net"
"reflect"
"strconv"
+ "strings"
"time"
"github.com/vmihailenco/msgpack/v5"
+ "github.com/uptrace/bun/dialect/sqltype"
"github.com/uptrace/bun/extra/bunjson"
"github.com/uptrace/bun/internal"
)
@@ -19,32 +21,35 @@ var scannerType = reflect.TypeOf((*sql.Scanner)(nil)).Elem()
type ScannerFunc func(dest reflect.Value, src interface{}) error
-var scanners = []ScannerFunc{
- reflect.Bool: scanBool,
- reflect.Int: scanInt64,
- reflect.Int8: scanInt64,
- reflect.Int16: scanInt64,
- reflect.Int32: scanInt64,
- reflect.Int64: scanInt64,
- reflect.Uint: scanUint64,
- reflect.Uint8: scanUint64,
- reflect.Uint16: scanUint64,
- reflect.Uint32: scanUint64,
- reflect.Uint64: scanUint64,
- reflect.Uintptr: scanUint64,
- reflect.Float32: scanFloat64,
- reflect.Float64: scanFloat64,
- reflect.Complex64: nil,
- reflect.Complex128: nil,
- reflect.Array: nil,
- reflect.Chan: nil,
- reflect.Func: nil,
- reflect.Map: scanJSON,
- reflect.Ptr: nil,
- reflect.Slice: scanJSON,
- reflect.String: scanString,
- reflect.Struct: scanJSON,
- reflect.UnsafePointer: nil,
+var scanners []ScannerFunc
+
+func init() {
+ scanners = []ScannerFunc{
+ reflect.Bool: scanBool,
+ reflect.Int: scanInt64,
+ reflect.Int8: scanInt64,
+ reflect.Int16: scanInt64,
+ reflect.Int32: scanInt64,
+ reflect.Int64: scanInt64,
+ reflect.Uint: scanUint64,
+ reflect.Uint8: scanUint64,
+ reflect.Uint16: scanUint64,
+ reflect.Uint32: scanUint64,
+ reflect.Uint64: scanUint64,
+ reflect.Uintptr: scanUint64,
+ reflect.Float32: scanFloat64,
+ reflect.Float64: scanFloat64,
+ reflect.Complex64: nil,
+ reflect.Complex128: nil,
+ reflect.Array: nil,
+ reflect.Interface: scanInterface,
+ reflect.Map: scanJSON,
+ reflect.Ptr: nil,
+ reflect.Slice: scanJSON,
+ reflect.String: scanString,
+ reflect.Struct: scanJSON,
+ reflect.UnsafePointer: nil,
+ }
}
func FieldScanner(dialect Dialect, field *Field) ScannerFunc {
@@ -54,6 +59,12 @@ func FieldScanner(dialect Dialect, field *Field) ScannerFunc {
if field.Tag.HasOption("json_use_number") {
return scanJSONUseNumber
}
+ if field.StructField.Type.Kind() == reflect.Interface {
+ switch strings.ToUpper(field.UserSQLType) {
+ case sqltype.JSON, sqltype.JSONB:
+ return scanJSONIntoInterface
+ }
+ }
return dialect.Scanner(field.StructField.Type)
}
@@ -62,7 +73,7 @@ func Scanner(typ reflect.Type) ScannerFunc {
if kind == reflect.Ptr {
if fn := Scanner(typ.Elem()); fn != nil {
- return ptrScanner(fn)
+ return PtrScanner(fn)
}
}
@@ -84,6 +95,8 @@ func Scanner(typ reflect.Type) ScannerFunc {
return scanIP
case ipNetType:
return scanIPNet
+ case bytesType:
+ return scanBytes
case jsonRawMessageType:
return scanJSONRawMessage
}
@@ -196,6 +209,21 @@ func scanString(dest reflect.Value, src interface{}) error {
return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type())
}
+func scanBytes(dest reflect.Value, src interface{}) error {
+ switch src := src.(type) {
+ case nil:
+ dest.SetBytes(nil)
+ return nil
+ case string:
+ dest.SetBytes([]byte(src))
+ return nil
+ case []byte:
+ dest.SetBytes(src)
+ return nil
+ }
+ return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type())
+}
+
func scanTime(dest reflect.Value, src interface{}) error {
switch src := src.(type) {
case nil:
@@ -352,7 +380,7 @@ func toBytes(src interface{}) ([]byte, error) {
}
}
-func ptrScanner(fn ScannerFunc) ScannerFunc {
+func PtrScanner(fn ScannerFunc) ScannerFunc {
return func(dest reflect.Value, src interface{}) error {
if src == nil {
if !dest.CanAddr() {
@@ -383,6 +411,43 @@ func scanNull(dest reflect.Value) error {
return nil
}
+func scanJSONIntoInterface(dest reflect.Value, src interface{}) error {
+ if dest.IsNil() {
+ if src == nil {
+ return nil
+ }
+
+ b, err := toBytes(src)
+ if err != nil {
+ return err
+ }
+
+ return bunjson.Unmarshal(b, dest.Addr().Interface())
+ }
+
+ dest = dest.Elem()
+ if fn := Scanner(dest.Type()); fn != nil {
+ return fn(dest, src)
+ }
+ return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type())
+}
+
+func scanInterface(dest reflect.Value, src interface{}) error {
+ if dest.IsNil() {
+ if src == nil {
+ return nil
+ }
+ dest.Set(reflect.ValueOf(src))
+ return nil
+ }
+
+ dest = dest.Elem()
+ if fn := Scanner(dest.Type()); fn != nil {
+ return fn(dest, src)
+ }
+ return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type())
+}
+
func nilable(kind reflect.Kind) bool {
switch kind {
case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice: