diff options
Diffstat (limited to 'vendor/github.com/uptrace/bun/dialect/pgdialect/append.go')
-rw-r--r-- | vendor/github.com/uptrace/bun/dialect/pgdialect/append.go | 303 |
1 files changed, 303 insertions, 0 deletions
diff --git a/vendor/github.com/uptrace/bun/dialect/pgdialect/append.go b/vendor/github.com/uptrace/bun/dialect/pgdialect/append.go new file mode 100644 index 000000000..475621197 --- /dev/null +++ b/vendor/github.com/uptrace/bun/dialect/pgdialect/append.go @@ -0,0 +1,303 @@ +package pgdialect + +import ( + "database/sql/driver" + "fmt" + "reflect" + "strconv" + "time" + "unicode/utf8" + + "github.com/uptrace/bun/dialect" + "github.com/uptrace/bun/schema" +) + +var ( + driverValuerType = reflect.TypeOf((*driver.Valuer)(nil)).Elem() + + stringType = reflect.TypeOf((*string)(nil)).Elem() + sliceStringType = reflect.TypeOf([]string(nil)) + + intType = reflect.TypeOf((*int)(nil)).Elem() + sliceIntType = reflect.TypeOf([]int(nil)) + + int64Type = reflect.TypeOf((*int64)(nil)).Elem() + sliceInt64Type = reflect.TypeOf([]int64(nil)) + + float64Type = reflect.TypeOf((*float64)(nil)).Elem() + sliceFloat64Type = reflect.TypeOf([]float64(nil)) +) + +func customAppender(typ reflect.Type) schema.AppenderFunc { + switch typ.Kind() { + case reflect.Uint32: + return appendUint32ValueAsInt + case reflect.Uint, reflect.Uint64: + return appendUint64ValueAsInt + } + return nil +} + +func appendUint32ValueAsInt(fmter schema.Formatter, b []byte, v reflect.Value) []byte { + return strconv.AppendInt(b, int64(int32(v.Uint())), 10) +} + +func appendUint64ValueAsInt(fmter schema.Formatter, b []byte, v reflect.Value) []byte { + return strconv.AppendInt(b, int64(v.Uint()), 10) +} + +//------------------------------------------------------------------------------ + +func arrayAppend(fmter schema.Formatter, b []byte, v interface{}) []byte { + switch v := v.(type) { + case int64: + return strconv.AppendInt(b, v, 10) + case float64: + return dialect.AppendFloat64(b, v) + case bool: + return dialect.AppendBool(b, v) + case []byte: + return dialect.AppendBytes(b, v) + case string: + return arrayAppendString(b, v) + case time.Time: + return dialect.AppendTime(b, v) + default: + err := fmt.Errorf("pgdialect: can't append %T", v) + return dialect.AppendError(b, err) + } +} + +func arrayElemAppender(typ reflect.Type) schema.AppenderFunc { + if typ.Kind() == reflect.String { + return arrayAppendStringValue + } + if typ.Implements(driverValuerType) { + return arrayAppendDriverValue + } + return schema.Appender(typ, customAppender) +} + +func arrayAppendStringValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte { + return arrayAppendString(b, v.String()) +} + +func arrayAppendDriverValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte { + iface, err := v.Interface().(driver.Valuer).Value() + if err != nil { + return dialect.AppendError(b, err) + } + return arrayAppend(fmter, b, iface) +} + +//------------------------------------------------------------------------------ + +func arrayAppender(typ reflect.Type) schema.AppenderFunc { + kind := typ.Kind() + if kind == reflect.Ptr { + typ = typ.Elem() + kind = typ.Kind() + } + + switch kind { + case reflect.Slice, reflect.Array: + // ok: + default: + return nil + } + + elemType := typ.Elem() + + if kind == reflect.Slice { + switch elemType { + case stringType: + return appendStringSliceValue + case intType: + return appendIntSliceValue + case int64Type: + return appendInt64SliceValue + case float64Type: + return appendFloat64SliceValue + } + } + + appendElem := arrayElemAppender(elemType) + if appendElem == nil { + panic(fmt.Errorf("pgdialect: %s is not supported", typ)) + } + + return func(fmter schema.Formatter, b []byte, v reflect.Value) []byte { + kind := v.Kind() + switch kind { + case reflect.Ptr, reflect.Slice: + if v.IsNil() { + return dialect.AppendNull(b) + } + } + + if kind == reflect.Ptr { + v = v.Elem() + } + + b = append(b, '\'') + + b = append(b, '{') + for i := 0; i < v.Len(); i++ { + elem := v.Index(i) + b = appendElem(fmter, b, elem) + b = append(b, ',') + } + if v.Len() > 0 { + b[len(b)-1] = '}' // Replace trailing comma. + } else { + b = append(b, '}') + } + + b = append(b, '\'') + + return b + } +} + +func appendStringSliceValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte { + ss := v.Convert(sliceStringType).Interface().([]string) + return appendStringSlice(b, ss) +} + +func appendStringSlice(b []byte, ss []string) []byte { + if ss == nil { + return dialect.AppendNull(b) + } + + b = append(b, '\'') + + b = append(b, '{') + for _, s := range ss { + b = arrayAppendString(b, s) + b = append(b, ',') + } + if len(ss) > 0 { + b[len(b)-1] = '}' // Replace trailing comma. + } else { + b = append(b, '}') + } + + b = append(b, '\'') + + return b +} + +func appendIntSliceValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte { + ints := v.Convert(sliceIntType).Interface().([]int) + return appendIntSlice(b, ints) +} + +func appendIntSlice(b []byte, ints []int) []byte { + if ints == nil { + return dialect.AppendNull(b) + } + + b = append(b, '\'') + + b = append(b, '{') + for _, n := range ints { + b = strconv.AppendInt(b, int64(n), 10) + b = append(b, ',') + } + if len(ints) > 0 { + b[len(b)-1] = '}' // Replace trailing comma. + } else { + b = append(b, '}') + } + + b = append(b, '\'') + + return b +} + +func appendInt64SliceValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte { + ints := v.Convert(sliceInt64Type).Interface().([]int64) + return appendInt64Slice(b, ints) +} + +func appendInt64Slice(b []byte, ints []int64) []byte { + if ints == nil { + return dialect.AppendNull(b) + } + + b = append(b, '\'') + + b = append(b, '{') + for _, n := range ints { + b = strconv.AppendInt(b, n, 10) + b = append(b, ',') + } + if len(ints) > 0 { + b[len(b)-1] = '}' // Replace trailing comma. + } else { + b = append(b, '}') + } + + b = append(b, '\'') + + return b +} + +func appendFloat64SliceValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte { + floats := v.Convert(sliceFloat64Type).Interface().([]float64) + return appendFloat64Slice(b, floats) +} + +func appendFloat64Slice(b []byte, floats []float64) []byte { + if floats == nil { + return dialect.AppendNull(b) + } + + b = append(b, '\'') + + b = append(b, '{') + for _, n := range floats { + b = dialect.AppendFloat64(b, n) + b = append(b, ',') + } + if len(floats) > 0 { + b[len(b)-1] = '}' // Replace trailing comma. + } else { + b = append(b, '}') + } + + b = append(b, '\'') + + return b +} + +//------------------------------------------------------------------------------ + +func arrayAppendString(b []byte, s string) []byte { + b = append(b, '"') + for _, r := range s { + switch r { + case 0: + // ignore + case '\'': + b = append(b, "'''"...) + case '"': + b = append(b, '\\', '"') + case '\\': + b = append(b, '\\', '\\') + default: + if r < utf8.RuneSelf { + b = append(b, byte(r)) + break + } + l := len(b) + if cap(b)-l < utf8.UTFMax { + b = append(b, make([]byte, utf8.UTFMax)...) + } + n := utf8.EncodeRune(b[l:l+utf8.UTFMax], r) + b = b[:l+n] + } + } + b = append(b, '"') + return b +} |