summaryrefslogtreecommitdiff
path: root/vendor/github.com/uptrace/bun/schema/scan.go
diff options
context:
space:
mode:
authorLibravatar tobi <31960611+tsmethurst@users.noreply.github.com>2021-08-25 15:34:33 +0200
committerLibravatar GitHub <noreply@github.com>2021-08-25 15:34:33 +0200
commit2dc9fc1626507bb54417fc4a1920b847cafb27a2 (patch)
tree4ddeac479b923db38090aac8bd9209f3646851c1 /vendor/github.com/uptrace/bun/schema/scan.go
parentManually approves followers (#146) (diff)
downloadgotosocial-2dc9fc1626507bb54417fc4a1920b847cafb27a2.tar.xz
Pg to bun (#148)
* start moving to bun * changing more stuff * more * and yet more * tests passing * seems stable now * more big changes * small fix * little fixes
Diffstat (limited to 'vendor/github.com/uptrace/bun/schema/scan.go')
-rw-r--r--vendor/github.com/uptrace/bun/schema/scan.go392
1 files changed, 392 insertions, 0 deletions
diff --git a/vendor/github.com/uptrace/bun/schema/scan.go b/vendor/github.com/uptrace/bun/schema/scan.go
new file mode 100644
index 000000000..0e66a860f
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/schema/scan.go
@@ -0,0 +1,392 @@
+package schema
+
+import (
+ "bytes"
+ "database/sql"
+ "fmt"
+ "net"
+ "reflect"
+ "strconv"
+ "time"
+
+ "github.com/vmihailenco/msgpack/v5"
+
+ "github.com/uptrace/bun/extra/bunjson"
+ "github.com/uptrace/bun/internal"
+)
+
+var scannerType = reflect.TypeOf((*sql.Scanner)(nil)).Elem()
+
+type ScannerFunc func(dest reflect.Value, src interface{}) error
+
+var scanners = []ScannerFunc{
+ reflect.Bool: scanBool,
+ reflect.Int: scanInt64,
+ reflect.Int8: scanInt64,
+ reflect.Int16: scanInt64,
+ reflect.Int32: scanInt64,
+ reflect.Int64: scanInt64,
+ reflect.Uint: scanUint64,
+ reflect.Uint8: scanUint64,
+ reflect.Uint16: scanUint64,
+ reflect.Uint32: scanUint64,
+ reflect.Uint64: scanUint64,
+ reflect.Uintptr: scanUint64,
+ reflect.Float32: scanFloat64,
+ reflect.Float64: scanFloat64,
+ reflect.Complex64: nil,
+ reflect.Complex128: nil,
+ reflect.Array: nil,
+ reflect.Chan: nil,
+ reflect.Func: nil,
+ reflect.Map: scanJSON,
+ reflect.Ptr: nil,
+ reflect.Slice: scanJSON,
+ reflect.String: scanString,
+ reflect.Struct: scanJSON,
+ reflect.UnsafePointer: nil,
+}
+
+func FieldScanner(dialect Dialect, field *Field) ScannerFunc {
+ if field.Tag.HasOption("msgpack") {
+ return scanMsgpack
+ }
+ if field.Tag.HasOption("json_use_number") {
+ return scanJSONUseNumber
+ }
+ return dialect.Scanner(field.StructField.Type)
+}
+
+func Scanner(typ reflect.Type) ScannerFunc {
+ kind := typ.Kind()
+
+ if kind == reflect.Ptr {
+ if fn := Scanner(typ.Elem()); fn != nil {
+ return ptrScanner(fn)
+ }
+ }
+
+ if typ.Implements(scannerType) {
+ return scanScanner
+ }
+
+ if kind != reflect.Ptr {
+ ptr := reflect.PtrTo(typ)
+ if ptr.Implements(scannerType) {
+ return addrScanner(scanScanner)
+ }
+ }
+
+ switch typ {
+ case timeType:
+ return scanTime
+ case ipType:
+ return scanIP
+ case ipNetType:
+ return scanIPNet
+ case jsonRawMessageType:
+ return scanJSONRawMessage
+ }
+
+ return scanners[kind]
+}
+
+func scanBool(dest reflect.Value, src interface{}) error {
+ switch src := src.(type) {
+ case nil:
+ dest.SetBool(false)
+ return nil
+ case bool:
+ dest.SetBool(src)
+ return nil
+ case int64:
+ dest.SetBool(src != 0)
+ return nil
+ case []byte:
+ if len(src) == 1 {
+ dest.SetBool(src[0] != '0')
+ return nil
+ }
+ }
+ return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type())
+}
+
+func scanInt64(dest reflect.Value, src interface{}) error {
+ switch src := src.(type) {
+ case nil:
+ dest.SetInt(0)
+ return nil
+ case int64:
+ dest.SetInt(src)
+ return nil
+ case uint64:
+ dest.SetInt(int64(src))
+ return nil
+ case []byte:
+ n, err := strconv.ParseInt(internal.String(src), 10, 64)
+ if err != nil {
+ return err
+ }
+ dest.SetInt(n)
+ return nil
+ case string:
+ n, err := strconv.ParseInt(src, 10, 64)
+ if err != nil {
+ return err
+ }
+ dest.SetInt(n)
+ return nil
+ }
+ return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type())
+}
+
+func scanUint64(dest reflect.Value, src interface{}) error {
+ switch src := src.(type) {
+ case nil:
+ dest.SetUint(0)
+ return nil
+ case uint64:
+ dest.SetUint(src)
+ return nil
+ case int64:
+ dest.SetUint(uint64(src))
+ return nil
+ case []byte:
+ n, err := strconv.ParseUint(internal.String(src), 10, 64)
+ if err != nil {
+ return err
+ }
+ dest.SetUint(n)
+ return nil
+ }
+ return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type())
+}
+
+func scanFloat64(dest reflect.Value, src interface{}) error {
+ switch src := src.(type) {
+ case nil:
+ dest.SetFloat(0)
+ return nil
+ case float64:
+ dest.SetFloat(src)
+ return nil
+ case []byte:
+ f, err := strconv.ParseFloat(internal.String(src), 64)
+ if err != nil {
+ return err
+ }
+ dest.SetFloat(f)
+ return nil
+ }
+ return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type())
+}
+
+func scanString(dest reflect.Value, src interface{}) error {
+ switch src := src.(type) {
+ case nil:
+ dest.SetString("")
+ return nil
+ case string:
+ dest.SetString(src)
+ return nil
+ case []byte:
+ dest.SetString(string(src))
+ return nil
+ }
+ return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type())
+}
+
+func scanTime(dest reflect.Value, src interface{}) error {
+ switch src := src.(type) {
+ case nil:
+ destTime := dest.Addr().Interface().(*time.Time)
+ *destTime = time.Time{}
+ return nil
+ case time.Time:
+ destTime := dest.Addr().Interface().(*time.Time)
+ *destTime = src
+ return nil
+ case string:
+ srcTime, err := internal.ParseTime(src)
+ if err != nil {
+ return err
+ }
+ destTime := dest.Addr().Interface().(*time.Time)
+ *destTime = srcTime
+ return nil
+ case []byte:
+ srcTime, err := internal.ParseTime(internal.String(src))
+ if err != nil {
+ return err
+ }
+ destTime := dest.Addr().Interface().(*time.Time)
+ *destTime = srcTime
+ return nil
+ }
+ return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type())
+}
+
+func scanScanner(dest reflect.Value, src interface{}) error {
+ return dest.Interface().(sql.Scanner).Scan(src)
+}
+
+func scanMsgpack(dest reflect.Value, src interface{}) error {
+ if src == nil {
+ return scanNull(dest)
+ }
+
+ b, err := toBytes(src)
+ if err != nil {
+ return err
+ }
+
+ dec := msgpack.GetDecoder()
+ defer msgpack.PutDecoder(dec)
+
+ dec.Reset(bytes.NewReader(b))
+ return dec.DecodeValue(dest)
+}
+
+func scanJSON(dest reflect.Value, src interface{}) error {
+ if src == nil {
+ return scanNull(dest)
+ }
+
+ b, err := toBytes(src)
+ if err != nil {
+ return err
+ }
+
+ return bunjson.Unmarshal(b, dest.Addr().Interface())
+}
+
+func scanJSONUseNumber(dest reflect.Value, src interface{}) error {
+ if src == nil {
+ return scanNull(dest)
+ }
+
+ b, err := toBytes(src)
+ if err != nil {
+ return err
+ }
+
+ dec := bunjson.NewDecoder(bytes.NewReader(b))
+ dec.UseNumber()
+ return dec.Decode(dest.Addr().Interface())
+}
+
+func scanIP(dest reflect.Value, src interface{}) error {
+ if src == nil {
+ return scanNull(dest)
+ }
+
+ b, err := toBytes(src)
+ if err != nil {
+ return err
+ }
+
+ ip := net.ParseIP(internal.String(b))
+ if ip == nil {
+ return fmt.Errorf("bun: invalid ip: %q", b)
+ }
+
+ ptr := dest.Addr().Interface().(*net.IP)
+ *ptr = ip
+
+ return nil
+}
+
+func scanIPNet(dest reflect.Value, src interface{}) error {
+ if src == nil {
+ return scanNull(dest)
+ }
+
+ b, err := toBytes(src)
+ if err != nil {
+ return err
+ }
+
+ _, ipnet, err := net.ParseCIDR(internal.String(b))
+ if err != nil {
+ return err
+ }
+
+ ptr := dest.Addr().Interface().(*net.IPNet)
+ *ptr = *ipnet
+
+ return nil
+}
+
+func scanJSONRawMessage(dest reflect.Value, src interface{}) error {
+ if src == nil {
+ dest.SetBytes(nil)
+ return nil
+ }
+
+ b, err := toBytes(src)
+ if err != nil {
+ return err
+ }
+
+ dest.SetBytes(b)
+ return nil
+}
+
+func addrScanner(fn ScannerFunc) ScannerFunc {
+ return func(dest reflect.Value, src interface{}) error {
+ if !dest.CanAddr() {
+ return fmt.Errorf("bun: Scan(nonaddressable %T)", dest.Interface())
+ }
+ return fn(dest.Addr(), src)
+ }
+}
+
+func toBytes(src interface{}) ([]byte, error) {
+ switch src := src.(type) {
+ case string:
+ return internal.Bytes(src), nil
+ case []byte:
+ return src, nil
+ default:
+ return nil, fmt.Errorf("bun: got %T, wanted []byte or string", src)
+ }
+}
+
+func ptrScanner(fn ScannerFunc) ScannerFunc {
+ return func(dest reflect.Value, src interface{}) error {
+ if src == nil {
+ if !dest.CanAddr() {
+ if dest.IsNil() {
+ return nil
+ }
+ return fn(dest.Elem(), src)
+ }
+
+ if !dest.IsNil() {
+ dest.Set(reflect.New(dest.Type().Elem()))
+ }
+ return nil
+ }
+
+ if dest.IsNil() {
+ dest.Set(reflect.New(dest.Type().Elem()))
+ }
+ return fn(dest.Elem(), src)
+ }
+}
+
+func scanNull(dest reflect.Value) error {
+ if nilable(dest.Kind()) && dest.IsNil() {
+ return nil
+ }
+ dest.Set(reflect.New(dest.Type()).Elem())
+ return nil
+}
+
+func nilable(kind reflect.Kind) bool {
+ switch kind {
+ case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice:
+ return true
+ }
+ return false
+}