diff options
| author | 2021-08-25 15:34:33 +0200 | |
|---|---|---|
| committer | 2021-08-25 15:34:33 +0200 | |
| commit | 2dc9fc1626507bb54417fc4a1920b847cafb27a2 (patch) | |
| tree | 4ddeac479b923db38090aac8bd9209f3646851c1 /vendor/github.com/uptrace/bun/schema | |
| parent | Manually approves followers (#146) (diff) | |
| download | gotosocial-2dc9fc1626507bb54417fc4a1920b847cafb27a2.tar.xz | |
Pg to bun (#148)
* start moving to bun
* changing more stuff
* more
* and yet more
* tests passing
* seems stable now
* more big changes
* small fix
* little fixes
Diffstat (limited to 'vendor/github.com/uptrace/bun/schema')
| -rw-r--r-- | vendor/github.com/uptrace/bun/schema/append.go | 93 | ||||
| -rw-r--r-- | vendor/github.com/uptrace/bun/schema/append_value.go | 237 | ||||
| -rw-r--r-- | vendor/github.com/uptrace/bun/schema/dialect.go | 99 | ||||
| -rw-r--r-- | vendor/github.com/uptrace/bun/schema/field.go | 117 | ||||
| -rw-r--r-- | vendor/github.com/uptrace/bun/schema/formatter.go | 248 | ||||
| -rw-r--r-- | vendor/github.com/uptrace/bun/schema/hook.go | 20 | ||||
| -rw-r--r-- | vendor/github.com/uptrace/bun/schema/relation.go | 32 | ||||
| -rw-r--r-- | vendor/github.com/uptrace/bun/schema/scan.go | 392 | ||||
| -rw-r--r-- | vendor/github.com/uptrace/bun/schema/sqlfmt.go | 76 | ||||
| -rw-r--r-- | vendor/github.com/uptrace/bun/schema/sqltype.go | 129 | ||||
| -rw-r--r-- | vendor/github.com/uptrace/bun/schema/table.go | 948 | ||||
| -rw-r--r-- | vendor/github.com/uptrace/bun/schema/tables.go | 148 | ||||
| -rw-r--r-- | vendor/github.com/uptrace/bun/schema/util.go | 53 | ||||
| -rw-r--r-- | vendor/github.com/uptrace/bun/schema/zerochecker.go | 126 |
14 files changed, 2718 insertions, 0 deletions
diff --git a/vendor/github.com/uptrace/bun/schema/append.go b/vendor/github.com/uptrace/bun/schema/append.go new file mode 100644 index 000000000..68f7071c8 --- /dev/null +++ b/vendor/github.com/uptrace/bun/schema/append.go @@ -0,0 +1,93 @@ +package schema + +import ( + "reflect" + "strconv" + "strings" + "time" + + "github.com/vmihailenco/msgpack/v5" + + "github.com/uptrace/bun/dialect" + "github.com/uptrace/bun/dialect/sqltype" + "github.com/uptrace/bun/internal" +) + +func FieldAppender(dialect Dialect, field *Field) AppenderFunc { + if field.Tag.HasOption("msgpack") { + return appendMsgpack + } + + switch strings.ToUpper(field.UserSQLType) { + case sqltype.JSON, sqltype.JSONB: + return AppendJSONValue + } + + return dialect.Appender(field.StructField.Type) +} + +func Append(fmter Formatter, b []byte, v interface{}, custom CustomAppender) []byte { + switch v := v.(type) { + case nil: + return dialect.AppendNull(b) + case bool: + return dialect.AppendBool(b, v) + case int: + return strconv.AppendInt(b, int64(v), 10) + case int32: + return strconv.AppendInt(b, int64(v), 10) + case int64: + return strconv.AppendInt(b, v, 10) + case uint: + return strconv.AppendUint(b, uint64(v), 10) + case uint32: + return strconv.AppendUint(b, uint64(v), 10) + case uint64: + return strconv.AppendUint(b, v, 10) + case float32: + return dialect.AppendFloat32(b, v) + case float64: + return dialect.AppendFloat64(b, v) + case string: + return dialect.AppendString(b, v) + case time.Time: + return dialect.AppendTime(b, v) + case []byte: + return dialect.AppendBytes(b, v) + case QueryAppender: + return AppendQueryAppender(fmter, b, v) + default: + vv := reflect.ValueOf(v) + if vv.Kind() == reflect.Ptr && vv.IsNil() { + return dialect.AppendNull(b) + } + appender := Appender(vv.Type(), custom) + return appender(fmter, b, vv) + } +} + +func appendMsgpack(fmter Formatter, b []byte, v reflect.Value) []byte { + hexEnc := internal.NewHexEncoder(b) + + enc := msgpack.GetEncoder() + defer msgpack.PutEncoder(enc) + + enc.Reset(hexEnc) + if err := enc.EncodeValue(v); err != nil { + return dialect.AppendError(b, err) + } + + if err := hexEnc.Close(); err != nil { + return dialect.AppendError(b, err) + } + + return hexEnc.Bytes() +} + +func AppendQueryAppender(fmter Formatter, b []byte, app QueryAppender) []byte { + bb, err := app.AppendQuery(fmter, b) + if err != nil { + return dialect.AppendError(b, err) + } + return bb +} diff --git a/vendor/github.com/uptrace/bun/schema/append_value.go b/vendor/github.com/uptrace/bun/schema/append_value.go new file mode 100644 index 000000000..0c4677069 --- /dev/null +++ b/vendor/github.com/uptrace/bun/schema/append_value.go @@ -0,0 +1,237 @@ +package schema + +import ( + "database/sql/driver" + "encoding/json" + "fmt" + "net" + "reflect" + "strconv" + "time" + + "github.com/uptrace/bun/dialect" + "github.com/uptrace/bun/extra/bunjson" + "github.com/uptrace/bun/internal" +) + +var ( + timeType = reflect.TypeOf((*time.Time)(nil)).Elem() + ipType = reflect.TypeOf((*net.IP)(nil)).Elem() + ipNetType = reflect.TypeOf((*net.IPNet)(nil)).Elem() + jsonRawMessageType = reflect.TypeOf((*json.RawMessage)(nil)).Elem() + + driverValuerType = reflect.TypeOf((*driver.Valuer)(nil)).Elem() + queryAppenderType = reflect.TypeOf((*QueryAppender)(nil)).Elem() +) + +type ( + AppenderFunc func(fmter Formatter, b []byte, v reflect.Value) []byte + CustomAppender func(typ reflect.Type) AppenderFunc +) + +var appenders = []AppenderFunc{ + reflect.Bool: AppendBoolValue, + reflect.Int: AppendIntValue, + reflect.Int8: AppendIntValue, + reflect.Int16: AppendIntValue, + reflect.Int32: AppendIntValue, + reflect.Int64: AppendIntValue, + reflect.Uint: AppendUintValue, + reflect.Uint8: AppendUintValue, + reflect.Uint16: AppendUintValue, + reflect.Uint32: AppendUintValue, + reflect.Uint64: AppendUintValue, + reflect.Uintptr: nil, + reflect.Float32: AppendFloat32Value, + reflect.Float64: AppendFloat64Value, + reflect.Complex64: nil, + reflect.Complex128: nil, + reflect.Array: AppendJSONValue, + reflect.Chan: nil, + reflect.Func: nil, + reflect.Interface: nil, + reflect.Map: AppendJSONValue, + reflect.Ptr: nil, + reflect.Slice: AppendJSONValue, + reflect.String: AppendStringValue, + reflect.Struct: AppendJSONValue, + reflect.UnsafePointer: nil, +} + +func Appender(typ reflect.Type, custom CustomAppender) AppenderFunc { + switch typ { + case timeType: + return appendTimeValue + case ipType: + return appendIPValue + case ipNetType: + return appendIPNetValue + case jsonRawMessageType: + return appendJSONRawMessageValue + } + + if typ.Implements(queryAppenderType) { + return appendQueryAppenderValue + } + if typ.Implements(driverValuerType) { + return driverValueAppender(custom) + } + + kind := typ.Kind() + + if kind != reflect.Ptr { + ptr := reflect.PtrTo(typ) + if ptr.Implements(queryAppenderType) { + return addrAppender(appendQueryAppenderValue, custom) + } + if ptr.Implements(driverValuerType) { + return addrAppender(driverValueAppender(custom), custom) + } + } + + switch kind { + case reflect.Interface: + return ifaceAppenderFunc(typ, custom) + case reflect.Ptr: + return ptrAppenderFunc(typ, custom) + case reflect.Slice: + if typ.Elem().Kind() == reflect.Uint8 { + return appendBytesValue + } + case reflect.Array: + if typ.Elem().Kind() == reflect.Uint8 { + return appendArrayBytesValue + } + } + + if custom != nil { + if fn := custom(typ); fn != nil { + return fn + } + } + return appenders[typ.Kind()] +} + +func ifaceAppenderFunc(typ reflect.Type, custom func(reflect.Type) AppenderFunc) AppenderFunc { + return func(fmter Formatter, b []byte, v reflect.Value) []byte { + if v.IsNil() { + return dialect.AppendNull(b) + } + elem := v.Elem() + appender := Appender(elem.Type(), custom) + return appender(fmter, b, elem) + } +} + +func ptrAppenderFunc(typ reflect.Type, custom func(reflect.Type) AppenderFunc) AppenderFunc { + appender := Appender(typ.Elem(), custom) + return func(fmter Formatter, b []byte, v reflect.Value) []byte { + if v.IsNil() { + return dialect.AppendNull(b) + } + return appender(fmter, b, v.Elem()) + } +} + +func AppendBoolValue(fmter Formatter, b []byte, v reflect.Value) []byte { + return dialect.AppendBool(b, v.Bool()) +} + +func AppendIntValue(fmter Formatter, b []byte, v reflect.Value) []byte { + return strconv.AppendInt(b, v.Int(), 10) +} + +func AppendUintValue(fmter Formatter, b []byte, v reflect.Value) []byte { + return strconv.AppendUint(b, v.Uint(), 10) +} + +func AppendFloat32Value(fmter Formatter, b []byte, v reflect.Value) []byte { + return dialect.AppendFloat32(b, float32(v.Float())) +} + +func AppendFloat64Value(fmter Formatter, b []byte, v reflect.Value) []byte { + return dialect.AppendFloat64(b, float64(v.Float())) +} + +func appendBytesValue(fmter Formatter, b []byte, v reflect.Value) []byte { + return dialect.AppendBytes(b, v.Bytes()) +} + +func appendArrayBytesValue(fmter Formatter, b []byte, v reflect.Value) []byte { + if v.CanAddr() { + return dialect.AppendBytes(b, v.Slice(0, v.Len()).Bytes()) + } + + tmp := make([]byte, v.Len()) + reflect.Copy(reflect.ValueOf(tmp), v) + b = dialect.AppendBytes(b, tmp) + return b +} + +func AppendStringValue(fmter Formatter, b []byte, v reflect.Value) []byte { + return dialect.AppendString(b, v.String()) +} + +func AppendJSONValue(fmter Formatter, b []byte, v reflect.Value) []byte { + bb, err := bunjson.Marshal(v.Interface()) + if err != nil { + return dialect.AppendError(b, err) + } + + if len(bb) > 0 && bb[len(bb)-1] == '\n' { + bb = bb[:len(bb)-1] + } + + return dialect.AppendJSON(b, bb) +} + +func appendTimeValue(fmter Formatter, b []byte, v reflect.Value) []byte { + tm := v.Interface().(time.Time) + return dialect.AppendTime(b, tm) +} + +func appendIPValue(fmter Formatter, b []byte, v reflect.Value) []byte { + ip := v.Interface().(net.IP) + return dialect.AppendString(b, ip.String()) +} + +func appendIPNetValue(fmter Formatter, b []byte, v reflect.Value) []byte { + ipnet := v.Interface().(net.IPNet) + return dialect.AppendString(b, ipnet.String()) +} + +func appendJSONRawMessageValue(fmter Formatter, b []byte, v reflect.Value) []byte { + bytes := v.Bytes() + if bytes == nil { + return dialect.AppendNull(b) + } + return dialect.AppendString(b, internal.String(bytes)) +} + +func appendQueryAppenderValue(fmter Formatter, b []byte, v reflect.Value) []byte { + return AppendQueryAppender(fmter, b, v.Interface().(QueryAppender)) +} + +func driverValueAppender(custom CustomAppender) AppenderFunc { + return func(fmter Formatter, b []byte, v reflect.Value) []byte { + return appendDriverValue(fmter, b, v.Interface().(driver.Valuer), custom) + } +} + +func appendDriverValue(fmter Formatter, b []byte, v driver.Valuer, custom CustomAppender) []byte { + value, err := v.Value() + if err != nil { + return dialect.AppendError(b, err) + } + return Append(fmter, b, value, custom) +} + +func addrAppender(fn AppenderFunc, custom CustomAppender) AppenderFunc { + return func(fmter Formatter, b []byte, v reflect.Value) []byte { + if !v.CanAddr() { + err := fmt.Errorf("bun: Append(nonaddressable %T)", v.Interface()) + return dialect.AppendError(b, err) + } + return fn(fmter, b, v.Addr()) + } +} diff --git a/vendor/github.com/uptrace/bun/schema/dialect.go b/vendor/github.com/uptrace/bun/schema/dialect.go new file mode 100644 index 000000000..c50de715a --- /dev/null +++ b/vendor/github.com/uptrace/bun/schema/dialect.go @@ -0,0 +1,99 @@ +package schema + +import ( + "database/sql" + "reflect" + "sync" + + "github.com/uptrace/bun/dialect" + "github.com/uptrace/bun/dialect/feature" +) + +type Dialect interface { + Init(db *sql.DB) + + Name() dialect.Name + Features() feature.Feature + + Tables() *Tables + OnTable(table *Table) + + IdentQuote() byte + Append(fmter Formatter, b []byte, v interface{}) []byte + Appender(typ reflect.Type) AppenderFunc + FieldAppender(field *Field) AppenderFunc + Scanner(typ reflect.Type) ScannerFunc +} + +//------------------------------------------------------------------------------ + +type nopDialect struct { + tables *Tables + features feature.Feature + + appenderMap sync.Map + scannerMap sync.Map +} + +func newNopDialect() *nopDialect { + d := new(nopDialect) + d.tables = NewTables(d) + d.features = feature.Returning + return d +} + +func (d *nopDialect) Init(*sql.DB) {} + +func (d *nopDialect) Name() dialect.Name { + return dialect.Invalid +} + +func (d *nopDialect) Features() feature.Feature { + return d.features +} + +func (d *nopDialect) Tables() *Tables { + return d.tables +} + +func (d *nopDialect) OnField(field *Field) {} + +func (d *nopDialect) OnTable(table *Table) {} + +func (d *nopDialect) IdentQuote() byte { + return '"' +} + +func (d *nopDialect) Append(fmter Formatter, b []byte, v interface{}) []byte { + return Append(fmter, b, v, nil) +} + +func (d *nopDialect) Appender(typ reflect.Type) AppenderFunc { + if v, ok := d.appenderMap.Load(typ); ok { + return v.(AppenderFunc) + } + + fn := Appender(typ, nil) + + if v, ok := d.appenderMap.LoadOrStore(typ, fn); ok { + return v.(AppenderFunc) + } + return fn +} + +func (d *nopDialect) FieldAppender(field *Field) AppenderFunc { + return FieldAppender(d, field) +} + +func (d *nopDialect) Scanner(typ reflect.Type) ScannerFunc { + if v, ok := d.scannerMap.Load(typ); ok { + return v.(ScannerFunc) + } + + fn := Scanner(typ) + + if v, ok := d.scannerMap.LoadOrStore(typ, fn); ok { + return v.(ScannerFunc) + } + return fn +} diff --git a/vendor/github.com/uptrace/bun/schema/field.go b/vendor/github.com/uptrace/bun/schema/field.go new file mode 100644 index 000000000..1e069b82f --- /dev/null +++ b/vendor/github.com/uptrace/bun/schema/field.go @@ -0,0 +1,117 @@ +package schema + +import ( + "fmt" + "reflect" + + "github.com/uptrace/bun/dialect" + "github.com/uptrace/bun/internal/tagparser" +) + +type Field struct { + StructField reflect.StructField + + Tag tagparser.Tag + IndirectType reflect.Type + Index []int + + Name string // SQL name, .e.g. id + SQLName Safe // escaped SQL name, e.g. "id" + GoName string // struct field name, e.g. Id + + DiscoveredSQLType string + UserSQLType string + CreateTableSQLType string + SQLDefault string + + OnDelete string + OnUpdate string + + IsPK bool + NotNull bool + NullZero bool + AutoIncrement bool + + Append AppenderFunc + Scan ScannerFunc + IsZero IsZeroerFunc +} + +func (f *Field) String() string { + return f.Name +} + +func (f *Field) Clone() *Field { + cp := *f + cp.Index = cp.Index[:len(f.Index):len(f.Index)] + return &cp +} + +func (f *Field) Value(strct reflect.Value) reflect.Value { + return fieldByIndexAlloc(strct, f.Index) +} + +func (f *Field) HasZeroValue(v reflect.Value) bool { + for _, idx := range f.Index { + if v.Kind() == reflect.Ptr { + if v.IsNil() { + return true + } + v = v.Elem() + } + v = v.Field(idx) + } + return f.IsZero(v) +} + +func (f *Field) AppendValue(fmter Formatter, b []byte, strct reflect.Value) []byte { + fv, ok := fieldByIndex(strct, f.Index) + if !ok { + return dialect.AppendNull(b) + } + + if f.NullZero && f.IsZero(fv) { + return dialect.AppendNull(b) + } + if f.Append == nil { + panic(fmt.Errorf("bun: AppendValue(unsupported %s)", fv.Type())) + } + return f.Append(fmter, b, fv) +} + +func (f *Field) ScanWithCheck(fv reflect.Value, src interface{}) error { + if f.Scan == nil { + return fmt.Errorf("bun: Scan(unsupported %s)", f.IndirectType) + } + return f.Scan(fv, src) +} + +func (f *Field) ScanValue(strct reflect.Value, src interface{}) error { + if src == nil { + if fv, ok := fieldByIndex(strct, f.Index); ok { + return f.ScanWithCheck(fv, src) + } + return nil + } + + fv := fieldByIndexAlloc(strct, f.Index) + return f.ScanWithCheck(fv, src) +} + +func (f *Field) markAsPK() { + f.IsPK = true + f.NotNull = true + f.NullZero = true +} + +func indexEqual(ind1, ind2 []int) bool { + if len(ind1) != len(ind2) { + return false + } + for i, ind := range ind1 { + if ind != ind2[i] { + return false + } + } + return true +} diff --git a/vendor/github.com/uptrace/bun/schema/formatter.go b/vendor/github.com/uptrace/bun/schema/formatter.go new file mode 100644 index 000000000..7b26fbaca --- /dev/null +++ b/vendor/github.com/uptrace/bun/schema/formatter.go @@ -0,0 +1,248 @@ +package schema + +import ( + "reflect" + "strconv" + "strings" + + "github.com/uptrace/bun/dialect" + "github.com/uptrace/bun/dialect/feature" + "github.com/uptrace/bun/internal" + "github.com/uptrace/bun/internal/parser" +) + +var nopFormatter = Formatter{ + dialect: newNopDialect(), +} + +type Formatter struct { + dialect Dialect + args *namedArgList +} + +func NewFormatter(dialect Dialect) Formatter { + return Formatter{ + dialect: dialect, + } +} + +func NewNopFormatter() Formatter { + return nopFormatter +} + +func (f Formatter) IsNop() bool { + return f.dialect.Name() == dialect.Invalid +} + +func (f Formatter) Dialect() Dialect { + return f.dialect +} + +func (f Formatter) IdentQuote() byte { + return f.dialect.IdentQuote() +} + +func (f Formatter) AppendIdent(b []byte, ident string) []byte { + return dialect.AppendIdent(b, ident, f.IdentQuote()) +} + +func (f Formatter) AppendValue(b []byte, v reflect.Value) []byte { + if v.Kind() == reflect.Ptr && v.IsNil() { + return dialect.AppendNull(b) + } + appender := f.dialect.Appender(v.Type()) + return appender(f, b, v) +} + +func (f Formatter) HasFeature(feature feature.Feature) bool { + return f.dialect.Features().Has(feature) +} + +func (f Formatter) WithArg(arg NamedArgAppender) Formatter { + return Formatter{ + dialect: f.dialect, + args: f.args.WithArg(arg), + } +} + +func (f Formatter) WithNamedArg(name string, value interface{}) Formatter { + return Formatter{ + dialect: f.dialect, + args: f.args.WithArg(&namedArg{name: name, value: value}), + } +} + +func (f Formatter) FormatQuery(query string, args ...interface{}) string { + if f.IsNop() || (args == nil && f.args == nil) || strings.IndexByte(query, '?') == -1 { + return query + } + return internal.String(f.AppendQuery(nil, query, args...)) +} + +func (f Formatter) AppendQuery(dst []byte, query string, args ...interface{}) []byte { + if f.IsNop() || (args == nil && f.args == nil) || strings.IndexByte(query, '?') == -1 { + return append(dst, query...) + } + return f.append(dst, parser.NewString(query), args) +} + +func (f Formatter) append(dst []byte, p *parser.Parser, args []interface{}) []byte { + var namedArgs NamedArgAppender + if len(args) == 1 { + var ok bool + namedArgs, ok = args[0].(NamedArgAppender) + if !ok { + namedArgs, _ = newStructArgs(f, args[0]) + } + } + + var argIndex int + for p.Valid() { + b, ok := p.ReadSep('?') + if !ok { + dst = append(dst, b...) + continue + } + if len(b) > 0 && b[len(b)-1] == '\\' { + dst = append(dst, b[:len(b)-1]...) + dst = append(dst, '?') + continue + } + dst = append(dst, b...) + + name, numeric := p.ReadIdentifier() + if name != "" { + if numeric { + idx, err := strconv.Atoi(name) + if err != nil { + goto restore_arg + } + + if idx >= len(args) { + goto restore_arg + } + + dst = f.appendArg(dst, args[idx]) + continue + } + + if namedArgs != nil { + dst, ok = namedArgs.AppendNamedArg(f, dst, name) + if ok { + continue + } + } + + dst, ok = f.args.AppendNamedArg(f, dst, name) + if ok { + continue + } + + restore_arg: + dst = append(dst, '?') + dst = append(dst, name...) + continue + } + + if argIndex >= len(args) { + dst = append(dst, '?') + continue + } + + arg := args[argIndex] + argIndex++ + + dst = f.appendArg(dst, arg) + } + + return dst +} + +func (f Formatter) appendArg(b []byte, arg interface{}) []byte { + switch arg := arg.(type) { + case QueryAppender: + bb, err := arg.AppendQuery(f, b) + if err != nil { + return dialect.AppendError(b, err) + } + return bb + default: + return f.dialect.Append(f, b, arg) + } +} + +//------------------------------------------------------------------------------ + +type NamedArgAppender interface { + AppendNamedArg(fmter Formatter, b []byte, name string) ([]byte, bool) +} + +//------------------------------------------------------------------------------ + +type namedArgList struct { + arg NamedArgAppender + next *namedArgList +} + +func (l *namedArgList) WithArg(arg NamedArgAppender) *namedArgList { + return &namedArgList{ + arg: arg, + next: l, + } +} + +func (l *namedArgList) AppendNamedArg(fmter Formatter, b []byte, name string) ([]byte, bool) { + for l != nil && l.arg != nil { + if b, ok := l.arg.AppendNamedArg(fmter, b, name); ok { + return b, true + } + l = l.next + } + return b, false +} + +//------------------------------------------------------------------------------ + +type namedArg struct { + name string + value interface{} +} + +var _ NamedArgAppender = (*namedArg)(nil) + +func (a *namedArg) AppendNamedArg(fmter Formatter, b []byte, name string) ([]byte, bool) { + if a.name == name { + return fmter.appendArg(b, a.value), true + } + return b, false +} + +//------------------------------------------------------------------------------ + +var _ NamedArgAppender = (*structArgs)(nil) + +type structArgs struct { + table *Table + strct reflect.Value +} + +func newStructArgs(fmter Formatter, strct interface{}) (*structArgs, bool) { + v := reflect.ValueOf(strct) + if !v.IsValid() { + return nil, false + } + + v = reflect.Indirect(v) + if v.Kind() != reflect.Struct { + return nil, false + } + + return &structArgs{ + table: fmter.Dialect().Tables().Get(v.Type()), + strct: v, + }, true +} + +func (m *structArgs) AppendNamedArg(fmter Formatter, b []byte, name string) ([]byte, bool) { + return m.table.AppendNamedArg(fmter, b, name, m.strct) +} diff --git a/vendor/github.com/uptrace/bun/schema/hook.go b/vendor/github.com/uptrace/bun/schema/hook.go new file mode 100644 index 000000000..5391981d5 --- /dev/null +++ b/vendor/github.com/uptrace/bun/schema/hook.go @@ -0,0 +1,20 @@ +package schema + +import ( + "context" + "reflect" +) + +type BeforeScanHook interface { + BeforeScan(context.Context) error +} + +var beforeScanHookType = reflect.TypeOf((*BeforeScanHook)(nil)).Elem() + +//------------------------------------------------------------------------------ + +type AfterScanHook interface { + AfterScan(context.Context) error +} + +var afterScanHookType = reflect.TypeOf((*AfterScanHook)(nil)).Elem() diff --git a/vendor/github.com/uptrace/bun/schema/relation.go b/vendor/github.com/uptrace/bun/schema/relation.go new file mode 100644 index 000000000..8d1baeb3f --- /dev/null +++ b/vendor/github.com/uptrace/bun/schema/relation.go @@ -0,0 +1,32 @@ +package schema + +import ( + "fmt" +) + +const ( + InvalidRelation = iota + HasOneRelation + BelongsToRelation + HasManyRelation + ManyToManyRelation +) + +type Relation struct { + Type int + Field *Field + JoinTable *Table + BaseFields []*Field + JoinFields []*Field + + PolymorphicField *Field + PolymorphicValue string + + M2MTable *Table + M2MBaseFields []*Field + M2MJoinFields []*Field +} + +func (r *Relation) String() string { + return fmt.Sprintf("relation=%s", r.Field.GoName) +} diff --git a/vendor/github.com/uptrace/bun/schema/scan.go b/vendor/github.com/uptrace/bun/schema/scan.go new file mode 100644 index 000000000..0e66a860f --- /dev/null +++ b/vendor/github.com/uptrace/bun/schema/scan.go @@ -0,0 +1,392 @@ +package schema + +import ( + "bytes" + "database/sql" + "fmt" + "net" + "reflect" + "strconv" + "time" + + "github.com/vmihailenco/msgpack/v5" + + "github.com/uptrace/bun/extra/bunjson" + "github.com/uptrace/bun/internal" +) + +var scannerType = reflect.TypeOf((*sql.Scanner)(nil)).Elem() + +type ScannerFunc func(dest reflect.Value, src interface{}) error + +var scanners = []ScannerFunc{ + reflect.Bool: scanBool, + reflect.Int: scanInt64, + reflect.Int8: scanInt64, + reflect.Int16: scanInt64, + reflect.Int32: scanInt64, + reflect.Int64: scanInt64, + reflect.Uint: scanUint64, + reflect.Uint8: scanUint64, + reflect.Uint16: scanUint64, + reflect.Uint32: scanUint64, + reflect.Uint64: scanUint64, + reflect.Uintptr: scanUint64, + reflect.Float32: scanFloat64, + reflect.Float64: scanFloat64, + reflect.Complex64: nil, + reflect.Complex128: nil, + reflect.Array: nil, + reflect.Chan: nil, + reflect.Func: nil, + reflect.Map: scanJSON, + reflect.Ptr: nil, + reflect.Slice: scanJSON, + reflect.String: scanString, + reflect.Struct: scanJSON, + reflect.UnsafePointer: nil, +} + +func FieldScanner(dialect Dialect, field *Field) ScannerFunc { + if field.Tag.HasOption("msgpack") { + return scanMsgpack + } + if field.Tag.HasOption("json_use_number") { + return scanJSONUseNumber + } + return dialect.Scanner(field.StructField.Type) +} + +func Scanner(typ reflect.Type) ScannerFunc { + kind := typ.Kind() + + if kind == reflect.Ptr { + if fn := Scanner(typ.Elem()); fn != nil { + return ptrScanner(fn) + } + } + + if typ.Implements(scannerType) { + return scanScanner + } + + if kind != reflect.Ptr { + ptr := reflect.PtrTo(typ) + if ptr.Implements(scannerType) { + return addrScanner(scanScanner) + } + } + + switch typ { + case timeType: + return scanTime + case ipType: + return scanIP + case ipNetType: + return scanIPNet + case jsonRawMessageType: + return scanJSONRawMessage + } + + return scanners[kind] +} + +func scanBool(dest reflect.Value, src interface{}) error { + switch src := src.(type) { + case nil: + dest.SetBool(false) + return nil + case bool: + dest.SetBool(src) + return nil + case int64: + dest.SetBool(src != 0) + return nil + case []byte: + if len(src) == 1 { + dest.SetBool(src[0] != '0') + return nil + } + } + return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type()) +} + +func scanInt64(dest reflect.Value, src interface{}) error { + switch src := src.(type) { + case nil: + dest.SetInt(0) + return nil + case int64: + dest.SetInt(src) + return nil + case uint64: + dest.SetInt(int64(src)) + return nil + case []byte: + n, err := strconv.ParseInt(internal.String(src), 10, 64) + if err != nil { + return err + } + dest.SetInt(n) + return nil + case string: + n, err := strconv.ParseInt(src, 10, 64) + if err != nil { + return err + } + dest.SetInt(n) + return nil + } + return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type()) +} + +func scanUint64(dest reflect.Value, src interface{}) error { + switch src := src.(type) { + case nil: + dest.SetUint(0) + return nil + case uint64: + dest.SetUint(src) + return nil + case int64: + dest.SetUint(uint64(src)) + return nil + case []byte: + n, err := strconv.ParseUint(internal.String(src), 10, 64) + if err != nil { + return err + } + dest.SetUint(n) + return nil + } + return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type()) +} + +func scanFloat64(dest reflect.Value, src interface{}) error { + switch src := src.(type) { + case nil: + dest.SetFloat(0) + return nil + case float64: + dest.SetFloat(src) + return nil + case []byte: + f, err := strconv.ParseFloat(internal.String(src), 64) + if err != nil { + return err + } + dest.SetFloat(f) + return nil + } + return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type()) +} + +func scanString(dest reflect.Value, src interface{}) error { + switch src := src.(type) { + case nil: + dest.SetString("") + return nil + case string: + dest.SetString(src) + return nil + case []byte: + dest.SetString(string(src)) + return nil + } + return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type()) +} + +func scanTime(dest reflect.Value, src interface{}) error { + switch src := src.(type) { + case nil: + destTime := dest.Addr().Interface().(*time.Time) + *destTime = time.Time{} + return nil + case time.Time: + destTime := dest.Addr().Interface().(*time.Time) + *destTime = src + return nil + case string: + srcTime, err := internal.ParseTime(src) + if err != nil { + return err + } + destTime := dest.Addr().Interface().(*time.Time) + *destTime = srcTime + return nil + case []byte: + srcTime, err := internal.ParseTime(internal.String(src)) + if err != nil { + return err + } + destTime := dest.Addr().Interface().(*time.Time) + *destTime = srcTime + return nil + } + return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type()) +} + +func scanScanner(dest reflect.Value, src interface{}) error { + return dest.Interface().(sql.Scanner).Scan(src) +} + +func scanMsgpack(dest reflect.Value, src interface{}) error { + if src == nil { + return scanNull(dest) + } + + b, err := toBytes(src) + if err != nil { + return err + } + + dec := msgpack.GetDecoder() + defer msgpack.PutDecoder(dec) + + dec.Reset(bytes.NewReader(b)) + return dec.DecodeValue(dest) +} + +func scanJSON(dest reflect.Value, src interface{}) error { + if src == nil { + return scanNull(dest) + } + + b, err := toBytes(src) + if err != nil { + return err + } + + return bunjson.Unmarshal(b, dest.Addr().Interface()) +} + +func scanJSONUseNumber(dest reflect.Value, src interface{}) error { + if src == nil { + return scanNull(dest) + } + + b, err := toBytes(src) + if err != nil { + return err + } + + dec := bunjson.NewDecoder(bytes.NewReader(b)) + dec.UseNumber() + return dec.Decode(dest.Addr().Interface()) +} + +func scanIP(dest reflect.Value, src interface{}) error { + if src == nil { + return scanNull(dest) + } + + b, err := toBytes(src) + if err != nil { + return err + } + + ip := net.ParseIP(internal.String(b)) + if ip == nil { + return fmt.Errorf("bun: invalid ip: %q", b) + } + + ptr := dest.Addr().Interface().(*net.IP) + *ptr = ip + + return nil +} + +func scanIPNet(dest reflect.Value, src interface{}) error { + if src == nil { + return scanNull(dest) + } + + b, err := toBytes(src) + if err != nil { + return err + } + + _, ipnet, err := net.ParseCIDR(internal.String(b)) + if err != nil { + return err + } + + ptr := dest.Addr().Interface().(*net.IPNet) + *ptr = *ipnet + + return nil +} + +func scanJSONRawMessage(dest reflect.Value, src interface{}) error { + if src == nil { + dest.SetBytes(nil) + return nil + } + + b, err := toBytes(src) + if err != nil { + return err + } + + dest.SetBytes(b) + return nil +} + +func addrScanner(fn ScannerFunc) ScannerFunc { + return func(dest reflect.Value, src interface{}) error { + if !dest.CanAddr() { + return fmt.Errorf("bun: Scan(nonaddressable %T)", dest.Interface()) + } + return fn(dest.Addr(), src) + } +} + +func toBytes(src interface{}) ([]byte, error) { + switch src := src.(type) { + case string: + return internal.Bytes(src), nil + case []byte: + return src, nil + default: + return nil, fmt.Errorf("bun: got %T, wanted []byte or string", src) + } +} + +func ptrScanner(fn ScannerFunc) ScannerFunc { + return func(dest reflect.Value, src interface{}) error { + if src == nil { + if !dest.CanAddr() { + if dest.IsNil() { + return nil + } + return fn(dest.Elem(), src) + } + + if !dest.IsNil() { + dest.Set(reflect.New(dest.Type().Elem())) + } + return nil + } + + if dest.IsNil() { + dest.Set(reflect.New(dest.Type().Elem())) + } + return fn(dest.Elem(), src) + } +} + +func scanNull(dest reflect.Value) error { + if nilable(dest.Kind()) && dest.IsNil() { + return nil + } + dest.Set(reflect.New(dest.Type()).Elem()) + return nil +} + +func nilable(kind reflect.Kind) bool { + switch kind { + case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice: + return true + } + return false +} diff --git a/vendor/github.com/uptrace/bun/schema/sqlfmt.go b/vendor/github.com/uptrace/bun/schema/sqlfmt.go new file mode 100644 index 000000000..7b538cd0c --- /dev/null +++ b/vendor/github.com/uptrace/bun/schema/sqlfmt.go @@ -0,0 +1,76 @@ +package schema + +type QueryAppender interface { + AppendQuery(fmter Formatter, b []byte) ([]byte, error) +} + +type ColumnsAppender interface { + AppendColumns(fmter Formatter, b []byte) ([]byte, error) +} + +//------------------------------------------------------------------------------ + +// Safe represents a safe SQL query. +type Safe string + +var _ QueryAppender = (*Safe)(nil) + +func (s Safe) AppendQuery(fmter Formatter, b []byte) ([]byte, error) { + return append(b, s...), nil +} + +//------------------------------------------------------------------------------ + +// Ident represents a SQL identifier, for example, table or column name. +type Ident string + +var _ QueryAppender = (*Ident)(nil) + +func (s Ident) AppendQuery(fmter Formatter, b []byte) ([]byte, error) { + return fmter.AppendIdent(b, string(s)), nil +} + +//------------------------------------------------------------------------------ + +type QueryWithArgs struct { + Query string + Args []interface{} +} + +var _ QueryAppender = QueryWithArgs{} + +func SafeQuery(query string, args []interface{}) QueryWithArgs { + if query != "" && args == nil { + args = make([]interface{}, 0) + } + return QueryWithArgs{Query: query, Args: args} +} + +func UnsafeIdent(ident string) QueryWithArgs { + return QueryWithArgs{Query: ident} +} + +func (q QueryWithArgs) IsZero() bool { + return q.Query == "" && q.Args == nil +} + +func (q QueryWithArgs) AppendQuery(fmter Formatter, b []byte) ([]byte, error) { + if q.Args == nil { + return fmter.AppendIdent(b, q.Query), nil + } + return fmter.AppendQuery(b, q.Query, q.Args...), nil +} + +//------------------------------------------------------------------------------ + +type QueryWithSep struct { + QueryWithArgs + Sep string +} + +func SafeQueryWithSep(query string, args []interface{}, sep string) QueryWithSep { + return QueryWithSep{ + QueryWithArgs: SafeQuery(query, args), + Sep: sep, + } +} diff --git a/vendor/github.com/uptrace/bun/schema/sqltype.go b/vendor/github.com/uptrace/bun/schema/sqltype.go new file mode 100644 index 000000000..560f695c2 --- /dev/null +++ b/vendor/github.com/uptrace/bun/schema/sqltype.go @@ -0,0 +1,129 @@ +package schema + +import ( + "bytes" + "database/sql" + "encoding/json" + "fmt" + "reflect" + "time" + + "github.com/uptrace/bun/dialect" + "github.com/uptrace/bun/dialect/sqltype" + "github.com/uptrace/bun/internal" +) + +var ( + bunNullTimeType = reflect.TypeOf((*NullTime)(nil)).Elem() + nullTimeType = reflect.TypeOf((*sql.NullTime)(nil)).Elem() + nullBoolType = reflect.TypeOf((*sql.NullBool)(nil)).Elem() + nullFloatType = reflect.TypeOf((*sql.NullFloat64)(nil)).Elem() + nullIntType = reflect.TypeOf((*sql.NullInt64)(nil)).Elem() + nullStringType = reflect.TypeOf((*sql.NullString)(nil)).Elem() +) + +var sqlTypes = []string{ + reflect.Bool: sqltype.Boolean, + reflect.Int: sqltype.BigInt, + reflect.Int8: sqltype.SmallInt, + reflect.Int16: sqltype.SmallInt, + reflect.Int32: sqltype.Integer, + reflect.Int64: sqltype.BigInt, + reflect.Uint: sqltype.BigInt, + reflect.Uint8: sqltype.SmallInt, + reflect.Uint16: sqltype.SmallInt, + reflect.Uint32: sqltype.Integer, + reflect.Uint64: sqltype.BigInt, + reflect.Uintptr: sqltype.BigInt, + reflect.Float32: sqltype.Real, + reflect.Float64: sqltype.DoublePrecision, + reflect.Complex64: "", + reflect.Complex128: "", + reflect.Array: "", + reflect.Chan: "", + reflect.Func: "", + reflect.Interface: "", + reflect.Map: sqltype.VarChar, + reflect.Ptr: "", + reflect.Slice: sqltype.VarChar, + reflect.String: sqltype.VarChar, + reflect.Struct: sqltype.VarChar, + reflect.UnsafePointer: "", +} + +func DiscoverSQLType(typ reflect.Type) string { + switch typ { + case timeType, nullTimeType, bunNullTimeType: + return sqltype.Timestamp + case nullBoolType: + return sqltype.Boolean + case nullFloatType: + return sqltype.DoublePrecision + case nullIntType: + return sqltype.BigInt + case nullStringType: + return sqltype.VarChar + } + return sqlTypes[typ.Kind()] +} + +//------------------------------------------------------------------------------ + +var jsonNull = []byte("null") + +// NullTime is a time.Time wrapper that marshals zero time as JSON null and SQL NULL. +type NullTime struct { + time.Time +} + +var ( + _ json.Marshaler = (*NullTime)(nil) + _ json.Unmarshaler = (*NullTime)(nil) + _ sql.Scanner = (*NullTime)(nil) + _ QueryAppender = (*NullTime)(nil) +) + +func (tm NullTime) MarshalJSON() ([]byte, error) { + if tm.IsZero() { + return jsonNull, nil + } + return tm.Time.MarshalJSON() +} + +func (tm *NullTime) UnmarshalJSON(b []byte) error { + if bytes.Equal(b, jsonNull) { + tm.Time = time.Time{} + return nil + } + return tm.Time.UnmarshalJSON(b) +} + +func (tm NullTime) AppendQuery(fmter Formatter, b []byte) ([]byte, error) { + if tm.IsZero() { + return dialect.AppendNull(b), nil + } + return dialect.AppendTime(b, tm.Time), nil +} + +func (tm *NullTime) Scan(src interface{}) error { + if src == nil { + tm.Time = time.Time{} + return nil + } + + switch src := src.(type) { + case []byte: + newtm, err := internal.ParseTime(internal.String(src)) + if err != nil { + return err + } + + tm.Time = newtm + return nil + case time.Time: + tm.Time = src + return nil + default: + return fmt.Errorf("bun: can't scan %#v into NullTime", src) + } +} diff --git a/vendor/github.com/uptrace/bun/schema/table.go b/vendor/github.com/uptrace/bun/schema/table.go new file mode 100644 index 000000000..eca18b781 --- /dev/null +++ b/vendor/github.com/uptrace/bun/schema/table.go @@ -0,0 +1,948 @@ +package schema + +import ( + "database/sql" + "fmt" + "reflect" + "strings" + "sync" + "time" + + "github.com/jinzhu/inflection" + + "github.com/uptrace/bun/internal" + "github.com/uptrace/bun/internal/tagparser" +) + +const ( + beforeScanHookFlag internal.Flag = 1 << iota + afterScanHookFlag +) + +var ( + baseModelType = reflect.TypeOf((*BaseModel)(nil)).Elem() + tableNameInflector = inflection.Plural +) + +type BaseModel struct{} + +// SetTableNameInflector overrides the default func that pluralizes +// model name to get table name, e.g. my_article becomes my_articles. +func SetTableNameInflector(fn func(string) string) { + tableNameInflector = fn +} + +// Table represents a SQL table created from Go struct. +type Table struct { + dialect Dialect + + Type reflect.Type + ZeroValue reflect.Value // reflect.Struct + ZeroIface interface{} // struct pointer + + TypeName string + ModelName string + + Name string + SQLName Safe + SQLNameForSelects Safe + Alias string + SQLAlias Safe + + Fields []*Field // PKs + DataFields + PKs []*Field + DataFields []*Field + + fieldsMapMu sync.RWMutex + FieldMap map[string]*Field + + Relations map[string]*Relation + Unique map[string][]*Field + + SoftDeleteField *Field + UpdateSoftDeleteField func(fv reflect.Value) error + + allFields []*Field // read only + skippedFields []*Field + + flags internal.Flag +} + +func newTable(dialect Dialect, typ reflect.Type) *Table { + t := new(Table) + t.dialect = dialect + t.Type = typ + t.ZeroValue = reflect.New(t.Type).Elem() + t.ZeroIface = reflect.New(t.Type).Interface() + t.TypeName = internal.ToExported(t.Type.Name()) + t.ModelName = internal.Underscore(t.Type.Name()) + tableName := tableNameInflector(t.ModelName) + t.setName(tableName) + t.Alias = t.ModelName + t.SQLAlias = t.quoteIdent(t.ModelName) + + hooks := []struct { + typ reflect.Type + flag internal.Flag + }{ + {beforeScanHookType, beforeScanHookFlag}, + {afterScanHookType, afterScanHookFlag}, + } + + typ = reflect.PtrTo(t.Type) + for _, hook := range hooks { + if typ.Implements(hook.typ) { + t.flags = t.flags.Set(hook.flag) + } + } + + return t +} + +func (t *Table) init1() { + t.initFields() +} + +func (t *Table) init2() { + t.initInlines() + t.initRelations() + t.skippedFields = nil +} + +func (t *Table) setName(name string) { + t.Name = name + t.SQLName = t.quoteIdent(name) + t.SQLNameForSelects = t.quoteIdent(name) + if t.SQLAlias == "" { + t.Alias = name + t.SQLAlias = t.quoteIdent(name) + } +} + +func (t *Table) String() string { + return "model=" + t.TypeName +} + +func (t *Table) CheckPKs() error { + if len(t.PKs) == 0 { + return fmt.Errorf("bun: %s does not have primary keys", t) + } + return nil +} + +func (t *Table) addField(field *Field) { + t.Fields = append(t.Fields, field) + if field.IsPK { + t.PKs = append(t.PKs, field) + } else { + t.DataFields = append(t.DataFields, field) + } + t.FieldMap[field.Name] = field +} + +func (t *Table) removeField(field *Field) { + t.Fields = removeField(t.Fields, field) + if field.IsPK { + t.PKs = removeField(t.PKs, field) + } else { + t.DataFields = removeField(t.DataFields, field) + } + delete(t.FieldMap, field.Name) +} + +func (t *Table) fieldWithLock(name string) *Field { + t.fieldsMapMu.RLock() + field := t.FieldMap[name] + t.fieldsMapMu.RUnlock() + return field +} + +func (t *Table) HasField(name string) bool { + _, ok := t.FieldMap[name] + return ok +} + +func (t *Table) Field(name string) (*Field, error) { + field, ok := t.FieldMap[name] + if !ok { + return nil, fmt.Errorf("bun: %s does not have column=%s", t, name) + } + return field, nil +} + +func (t *Table) fieldByGoName(name string) *Field { + for _, f := range t.allFields { + if f.GoName == name { + return f + } + } + return nil +} + +func (t *Table) initFields() { + t.Fields = make([]*Field, 0, t.Type.NumField()) + t.FieldMap = make(map[string]*Field, t.Type.NumField()) + t.addFields(t.Type, nil) + + if len(t.PKs) > 0 { + return + } + for _, name := range []string{"id", "uuid", "pk_" + t.ModelName} { + if field, ok := t.FieldMap[name]; ok { + field.markAsPK() + t.PKs = []*Field{field} + t.DataFields = removeField(t.DataFields, field) + break + } + } + if len(t.PKs) == 1 { + switch t.PKs[0].IndirectType.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + t.PKs[0].AutoIncrement = true + } + } +} + +func (t *Table) addFields(typ reflect.Type, baseIndex []int) { + for i := 0; i < typ.NumField(); i++ { + f := typ.Field(i) + + // Make a copy so slice is not shared between fields. + index := make([]int, len(baseIndex)) + copy(index, baseIndex) + + if f.Anonymous { + if f.Tag.Get("bun") == "-" { + continue + } + if f.Name == "BaseModel" && f.Type == baseModelType { + if len(index) == 0 { + t.processBaseModelField(f) + } + continue + } + + fieldType := indirectType(f.Type) + if fieldType.Kind() != reflect.Struct { + continue + } + t.addFields(fieldType, append(index, f.Index...)) + + tag := tagparser.Parse(f.Tag.Get("bun")) + if _, inherit := tag.Options["inherit"]; inherit { + embeddedTable := t.dialect.Tables().Ref(fieldType) + t.TypeName = embeddedTable.TypeName + t.SQLName = embeddedTable.SQLName + t.SQLNameForSelects = embeddedTable.SQLNameForSelects + t.Alias = embeddedTable.Alias + t.SQLAlias = embeddedTable.SQLAlias + t.ModelName = embeddedTable.ModelName + } + + continue + } + + field := t.newField(f, index) + if field != nil { + t.addField(field) + } + } +} + +func (t *Table) processBaseModelField(f reflect.StructField) { + tag := tagparser.Parse(f.Tag.Get("bun")) + + if isKnownTableOption(tag.Name) { + internal.Warn.Printf( + "%s.%s tag name %q is also an option name; is it a mistake?", + t.TypeName, f.Name, tag.Name, + ) + } + + for name := range tag.Options { + if !isKnownTableOption(name) { + internal.Warn.Printf("%s.%s has unknown tag option: %q", t.TypeName, f.Name, name) + } + } + + if tag.Name != "" { + t.setName(tag.Name) + } + + if s, ok := tag.Options["select"]; ok { + t.SQLNameForSelects = t.quoteTableName(s) + } + + if s, ok := tag.Options["alias"]; ok { + t.Alias = s + t.SQLAlias = t.quoteIdent(s) + } +} + +//nolint +func (t *Table) newField(f reflect.StructField, index []int) *Field { + tag := tagparser.Parse(f.Tag.Get("bun")) + + if f.PkgPath != "" { + return nil + } + + sqlName := internal.Underscore(f.Name) + + if tag.Name != sqlName && isKnownFieldOption(tag.Name) { + internal.Warn.Printf( + "%s.%s tag name %q is also an option name; is it a mistake?", + t.TypeName, f.Name, tag.Name, + ) + } + + for name := range tag.Options { + if !isKnownFieldOption(name) { + internal.Warn.Printf("%s.%s has unknown tag option: %q", t.TypeName, f.Name, name) + } + } + + skip := tag.Name == "-" + if !skip && tag.Name != "" { + sqlName = tag.Name + } + + index = append(index, f.Index...) + if field := t.fieldWithLock(sqlName); field != nil { + if indexEqual(field.Index, index) { + return field + } + t.removeField(field) + } + + field := &Field{ + StructField: f, + + Tag: tag, + IndirectType: indirectType(f.Type), + Index: index, + + Name: sqlName, + GoName: f.Name, + SQLName: t.quoteIdent(sqlName), + } + + field.NotNull = tag.HasOption("notnull") + field.NullZero = tag.HasOption("nullzero") + field.AutoIncrement = tag.HasOption("autoincrement") + if tag.HasOption("pk") { + field.markAsPK() + } + if tag.HasOption("allowzero") { + if tag.HasOption("nullzero") { + internal.Warn.Printf( + "%s.%s: nullzero and allowzero options are mutually exclusive", + t.TypeName, f.Name, + ) + } + field.NullZero = false + } + + if v, ok := tag.Options["unique"]; ok { + // Split the value by comma, this will allow multiple names to be specified. + // We can use this to create multiple named unique constraints where a single column + // might be included in multiple constraints. + for _, uniqueName := range strings.Split(v, ",") { + if t.Unique == nil { + t.Unique = make(map[string][]*Field) + } + t.Unique[uniqueName] = append(t.Unique[uniqueName], field) + } + } + if s, ok := tag.Options["default"]; ok { + field.SQLDefault = s + } + if s, ok := field.Tag.Options["type"]; ok { + field.UserSQLType = s + } + field.DiscoveredSQLType = DiscoverSQLType(field.IndirectType) + field.Append = t.dialect.FieldAppender(field) + field.Scan = FieldScanner(t.dialect, field) + field.IsZero = FieldZeroChecker(field) + + if v, ok := tag.Options["alt"]; ok { + t.FieldMap[v] = field + } + + t.allFields = append(t.allFields, field) + if skip { + t.skippedFields = append(t.skippedFields, field) + t.FieldMap[field.Name] = field + return nil + } + + if _, ok := tag.Options["soft_delete"]; ok { + field.NullZero = true + t.SoftDeleteField = field + t.UpdateSoftDeleteField = softDeleteFieldUpdater(field) + } + + return field +} + +func (t *Table) initInlines() { + for _, f := range t.skippedFields { + if f.IndirectType.Kind() == reflect.Struct { + t.inlineFields(f, nil) + } + } +} + +//--------------------------------------------------------------------------------------- + +func (t *Table) initRelations() { + for i := 0; i < len(t.Fields); { + f := t.Fields[i] + if t.tryRelation(f) { + t.Fields = removeField(t.Fields, f) + t.DataFields = removeField(t.DataFields, f) + } else { + i++ + } + + if f.IndirectType.Kind() == reflect.Struct { + t.inlineFields(f, nil) + } + } +} + +func (t *Table) tryRelation(field *Field) bool { + if rel, ok := field.Tag.Options["rel"]; ok { + t.initRelation(field, rel) + return true + } + if field.Tag.HasOption("m2m") { + t.addRelation(t.m2mRelation(field)) + return true + } + + if field.Tag.HasOption("join") { + internal.Warn.Printf( + `%s.%s option "join" requires a relation type`, + t.TypeName, field.GoName, + ) + } + + return false +} + +func (t *Table) initRelation(field *Field, rel string) { + switch rel { + case "belongs-to": + t.addRelation(t.belongsToRelation(field)) + case "has-one": + t.addRelation(t.hasOneRelation(field)) + case "has-many": + t.addRelation(t.hasManyRelation(field)) + default: + panic(fmt.Errorf("bun: unknown relation=%s on field=%s", rel, field.GoName)) + } +} + +func (t *Table) addRelation(rel *Relation) { + if t.Relations == nil { + t.Relations = make(map[string]*Relation) + } + _, ok := t.Relations[rel.Field.GoName] + if ok { + panic(fmt.Errorf("%s already has %s", t, rel)) + } + t.Relations[rel.Field.GoName] = rel +} + +func (t *Table) belongsToRelation(field *Field) *Relation { + joinTable := t.dialect.Tables().Ref(field.IndirectType) + if err := joinTable.CheckPKs(); err != nil { + panic(err) + } + + rel := &Relation{ + Type: HasOneRelation, + Field: field, + JoinTable: joinTable, + } + + if join, ok := field.Tag.Options["join"]; ok { + baseColumns, joinColumns := parseRelationJoin(join) + for i, baseColumn := range baseColumns { + joinColumn := joinColumns[i] + + if f := t.fieldWithLock(baseColumn); f != nil { + rel.BaseFields = append(rel.BaseFields, f) + } else { + panic(fmt.Errorf( + "bun: %s belongs-to %s: %s must have column %s", + t.TypeName, field.GoName, t.TypeName, baseColumn, + )) + } + + if f := joinTable.fieldWithLock(joinColumn); f != nil { + rel.JoinFields = append(rel.JoinFields, f) + } else { + panic(fmt.Errorf( + "bun: %s belongs-to %s: %s must have column %s", + t.TypeName, field.GoName, t.TypeName, baseColumn, + )) + } + } + return rel + } + + rel.JoinFields = joinTable.PKs + fkPrefix := internal.Underscore(field.GoName) + "_" + for _, joinPK := range joinTable.PKs { + fkName := fkPrefix + joinPK.Name + if fk := t.fieldWithLock(fkName); fk != nil { + rel.BaseFields = append(rel.BaseFields, fk) + continue + } + + if fk := t.fieldWithLock(joinPK.Name); fk != nil { + rel.BaseFields = append(rel.BaseFields, fk) + continue + } + + panic(fmt.Errorf( + "bun: %s belongs-to %s: %s must have column %s "+ + "(to override, use join:base_column=join_column tag on %s field)", + t.TypeName, field.GoName, t.TypeName, fkName, field.GoName, + )) + } + return rel +} + +func (t *Table) hasOneRelation(field *Field) *Relation { + if err := t.CheckPKs(); err != nil { + panic(err) + } + + joinTable := t.dialect.Tables().Ref(field.IndirectType) + rel := &Relation{ + Type: BelongsToRelation, + Field: field, + JoinTable: joinTable, + } + + if join, ok := field.Tag.Options["join"]; ok { + baseColumns, joinColumns := parseRelationJoin(join) + for i, baseColumn := range baseColumns { + if f := t.fieldWithLock(baseColumn); f != nil { + rel.BaseFields = append(rel.BaseFields, f) + } else { + panic(fmt.Errorf( + "bun: %s has-one %s: %s must have column %s", + field.GoName, t.TypeName, joinTable.TypeName, baseColumn, + )) + } + + joinColumn := joinColumns[i] + if f := joinTable.fieldWithLock(joinColumn); f != nil { + rel.JoinFields = append(rel.JoinFields, f) + } else { + panic(fmt.Errorf( + "bun: %s has-one %s: %s must have column %s", + field.GoName, t.TypeName, joinTable.TypeName, baseColumn, + )) + } + } + return rel + } + + rel.BaseFields = t.PKs + fkPrefix := internal.Underscore(t.ModelName) + "_" + for _, pk := range t.PKs { + fkName := fkPrefix + pk.Name + if f := joinTable.fieldWithLock(fkName); f != nil { + rel.JoinFields = append(rel.JoinFields, f) + continue + } + + if f := joinTable.fieldWithLock(pk.Name); f != nil { + rel.JoinFields = append(rel.JoinFields, f) + continue + } + + panic(fmt.Errorf( + "bun: %s has-one %s: %s must have column %s "+ + "(to override, use join:base_column=join_column tag on %s field)", + field.GoName, t.TypeName, joinTable.TypeName, fkName, field.GoName, + )) + } + return rel +} + +func (t *Table) hasManyRelation(field *Field) *Relation { + if err := t.CheckPKs(); err != nil { + panic(err) + } + if field.IndirectType.Kind() != reflect.Slice { + panic(fmt.Errorf( + "bun: %s.%s has-many relation requires slice, got %q", + t.TypeName, field.GoName, field.IndirectType.Kind(), + )) + } + + joinTable := t.dialect.Tables().Ref(indirectType(field.IndirectType.Elem())) + polymorphicValue, isPolymorphic := field.Tag.Options["polymorphic"] + rel := &Relation{ + Type: HasManyRelation, + Field: field, + JoinTable: joinTable, + } + var polymorphicColumn string + + if join, ok := field.Tag.Options["join"]; ok { + baseColumns, joinColumns := parseRelationJoin(join) + for i, baseColumn := range baseColumns { + joinColumn := joinColumns[i] + + if isPolymorphic && baseColumn == "type" { + polymorphicColumn = joinColumn + continue + } + + if f := t.fieldWithLock(baseColumn); f != nil { + rel.BaseFields = append(rel.BaseFields, f) + } else { + panic(fmt.Errorf( + "bun: %s has-one %s: %s must have column %s", + t.TypeName, field.GoName, t.TypeName, baseColumn, + )) + } + + if f := joinTable.fieldWithLock(joinColumn); f != nil { + rel.JoinFields = append(rel.JoinFields, f) + } else { + panic(fmt.Errorf( + "bun: %s has-one %s: %s must have column %s", + t.TypeName, field.GoName, t.TypeName, baseColumn, + )) + } + } + } else { + rel.BaseFields = t.PKs + fkPrefix := internal.Underscore(t.ModelName) + "_" + if isPolymorphic { + polymorphicColumn = fkPrefix + "type" + } + + for _, pk := range t.PKs { + joinColumn := fkPrefix + pk.Name + if fk := joinTable.fieldWithLock(joinColumn); fk != nil { + rel.JoinFields = append(rel.JoinFields, fk) + continue + } + + if fk := joinTable.fieldWithLock(pk.Name); fk != nil { + rel.JoinFields = append(rel.JoinFields, fk) + continue + } + + panic(fmt.Errorf( + "bun: %s has-many %s: %s must have column %s "+ + "(to override, use join:base_column=join_column tag on the field %s)", + t.TypeName, field.GoName, joinTable.TypeName, joinColumn, field.GoName, + )) + } + } + + if isPolymorphic { + rel.PolymorphicField = joinTable.fieldWithLock(polymorphicColumn) + if rel.PolymorphicField == nil { + panic(fmt.Errorf( + "bun: %s has-many %s: %s must have polymorphic column %s", + t.TypeName, field.GoName, joinTable.TypeName, polymorphicColumn, + )) + } + + if polymorphicValue == "" { + polymorphicValue = t.ModelName + } + rel.PolymorphicValue = polymorphicValue + } + + return rel +} + +func (t *Table) m2mRelation(field *Field) *Relation { + if field.IndirectType.Kind() != reflect.Slice { + panic(fmt.Errorf( + "bun: %s.%s m2m relation requires slice, got %q", + t.TypeName, field.GoName, field.IndirectType.Kind(), + )) + } + joinTable := t.dialect.Tables().Ref(indirectType(field.IndirectType.Elem())) + + if err := t.CheckPKs(); err != nil { + panic(err) + } + if err := joinTable.CheckPKs(); err != nil { + panic(err) + } + + m2mTableName, ok := field.Tag.Options["m2m"] + if !ok { + panic(fmt.Errorf("bun: %s must have m2m tag option", field.GoName)) + } + + m2mTable := t.dialect.Tables().ByName(m2mTableName) + if m2mTable == nil { + panic(fmt.Errorf( + "bun: can't find m2m %s table (use db.RegisterModel)", + m2mTableName, + )) + } + + rel := &Relation{ + Type: ManyToManyRelation, + Field: field, + JoinTable: joinTable, + M2MTable: m2mTable, + } + var leftColumn, rightColumn string + + if join, ok := field.Tag.Options["join"]; ok { + left, right := parseRelationJoin(join) + leftColumn = left[0] + rightColumn = right[0] + } else { + leftColumn = t.TypeName + rightColumn = joinTable.TypeName + } + + leftField := m2mTable.fieldByGoName(leftColumn) + if leftField == nil { + panic(fmt.Errorf( + "bun: %s many-to-many %s: %s must have field %s "+ + "(to override, use tag join:LeftField=RightField on field %s.%s", + t.TypeName, field.GoName, m2mTable.TypeName, leftColumn, t.TypeName, field.GoName, + )) + } + + rightField := m2mTable.fieldByGoName(rightColumn) + if rightField == nil { + panic(fmt.Errorf( + "bun: %s many-to-many %s: %s must have field %s "+ + "(to override, use tag join:LeftField=RightField on field %s.%s", + t.TypeName, field.GoName, m2mTable.TypeName, rightColumn, t.TypeName, field.GoName, + )) + } + + leftRel := m2mTable.belongsToRelation(leftField) + rel.BaseFields = leftRel.JoinFields + rel.M2MBaseFields = leftRel.BaseFields + + rightRel := m2mTable.belongsToRelation(rightField) + rel.JoinFields = rightRel.JoinFields + rel.M2MJoinFields = rightRel.BaseFields + + return rel +} + +func (t *Table) inlineFields(field *Field, path map[reflect.Type]struct{}) { + if path == nil { + path = map[reflect.Type]struct{}{ + t.Type: {}, + } + } + + if _, ok := path[field.IndirectType]; ok { + return + } + path[field.IndirectType] = struct{}{} + + joinTable := t.dialect.Tables().Ref(field.IndirectType) + for _, f := range joinTable.allFields { + f = f.Clone() + f.GoName = field.GoName + "_" + f.GoName + f.Name = field.Name + "__" + f.Name + f.SQLName = t.quoteIdent(f.Name) + f.Index = appendNew(field.Index, f.Index...) + + t.fieldsMapMu.Lock() + if _, ok := t.FieldMap[f.Name]; !ok { + t.FieldMap[f.Name] = f + } + t.fieldsMapMu.Unlock() + + if f.IndirectType.Kind() != reflect.Struct { + continue + } + + if _, ok := path[f.IndirectType]; !ok { + t.inlineFields(f, path) + } + } +} + +//------------------------------------------------------------------------------ + +func (t *Table) Dialect() Dialect { return t.dialect } + +//------------------------------------------------------------------------------ + +func (t *Table) HasBeforeScanHook() bool { return t.flags.Has(beforeScanHookFlag) } +func (t *Table) HasAfterScanHook() bool { return t.flags.Has(afterScanHookFlag) } + +//------------------------------------------------------------------------------ + +func (t *Table) AppendNamedArg( + fmter Formatter, b []byte, name string, strct reflect.Value, +) ([]byte, bool) { + if field, ok := t.FieldMap[name]; ok { + return fmter.appendArg(b, field.Value(strct).Interface()), true + } + return b, false +} + +func (t *Table) quoteTableName(s string) Safe { + // Don't quote if table name contains placeholder (?) or parentheses. + if strings.IndexByte(s, '?') >= 0 || + strings.IndexByte(s, '(') >= 0 || + strings.IndexByte(s, ')') >= 0 { + return Safe(s) + } + return t.quoteIdent(s) +} + +func (t *Table) quoteIdent(s string) Safe { + return Safe(NewFormatter(t.dialect).AppendIdent(nil, s)) +} + +func appendNew(dst []int, src ...int) []int { + cp := make([]int, len(dst)+len(src)) + copy(cp, dst) + copy(cp[len(dst):], src) + return cp +} + +func isKnownTableOption(name string) bool { + switch name { + case "alias", "select": + return true + } + return false +} + +func isKnownFieldOption(name string) bool { + switch name { + case "alias", + "type", + "array", + "hstore", + "composite", + "json_use_number", + "msgpack", + "notnull", + "nullzero", + "allowzero", + "default", + "unique", + "soft_delete", + + "pk", + "autoincrement", + "rel", + "join", + "m2m", + "polymorphic": + return true + } + return false +} + +func removeField(fields []*Field, field *Field) []*Field { + for i, f := range fields { + if f == field { + return append(fields[:i], fields[i+1:]...) + } + } + return fields +} + +func parseRelationJoin(join string) ([]string, []string) { + ss := strings.Split(join, ",") + baseColumns := make([]string, len(ss)) + joinColumns := make([]string, len(ss)) + for i, s := range ss { + ss := strings.Split(strings.TrimSpace(s), "=") + if len(ss) != 2 { + panic(fmt.Errorf("can't parse relation join: %q", join)) + } + baseColumns[i] = ss[0] + joinColumns[i] = ss[1] + } + return baseColumns, joinColumns +} + +//------------------------------------------------------------------------------ + +func softDeleteFieldUpdater(field *Field) func(fv reflect.Value) error { + typ := field.StructField.Type + + switch typ { + case timeType: + return func(fv reflect.Value) error { + ptr := fv.Addr().Interface().(*time.Time) + *ptr = time.Now() + return nil + } + case nullTimeType: + return func(fv reflect.Value) error { + ptr := fv.Addr().Interface().(*sql.NullTime) + *ptr = sql.NullTime{Time: time.Now()} + return nil + } + case nullIntType: + return func(fv reflect.Value) error { + ptr := fv.Addr().Interface().(*sql.NullInt64) + *ptr = sql.NullInt64{Int64: time.Now().UnixNano()} + return nil + } + } + + switch field.IndirectType.Kind() { + case reflect.Int64: + return func(fv reflect.Value) error { + ptr := fv.Addr().Interface().(*int64) + *ptr = time.Now().UnixNano() + return nil + } + case reflect.Ptr: + typ = typ.Elem() + default: + return softDeleteFieldUpdaterFallback(field) + } + + switch typ { //nolint:gocritic + case timeType: + return func(fv reflect.Value) error { + now := time.Now() + fv.Set(reflect.ValueOf(&now)) + return nil + } + } + + switch typ.Kind() { //nolint:gocritic + case reflect.Int64: + return func(fv reflect.Value) error { + utime := time.Now().UnixNano() + fv.Set(reflect.ValueOf(&utime)) + return nil + } + } + + return softDeleteFieldUpdaterFallback(field) +} + +func softDeleteFieldUpdaterFallback(field *Field) func(fv reflect.Value) error { + return func(fv reflect.Value) error { + return field.ScanWithCheck(fv, time.Now()) + } +} diff --git a/vendor/github.com/uptrace/bun/schema/tables.go b/vendor/github.com/uptrace/bun/schema/tables.go new file mode 100644 index 000000000..d82d08f59 --- /dev/null +++ b/vendor/github.com/uptrace/bun/schema/tables.go @@ -0,0 +1,148 @@ +package schema + +import ( + "fmt" + "reflect" + "sync" +) + +type tableInProgress struct { + table *Table + + init1Once sync.Once + init2Once sync.Once +} + +func newTableInProgress(table *Table) *tableInProgress { + return &tableInProgress{ + table: table, + } +} + +func (inp *tableInProgress) init1() bool { + var inited bool + inp.init1Once.Do(func() { + inp.table.init1() + inited = true + }) + return inited +} + +func (inp *tableInProgress) init2() bool { + var inited bool + inp.init2Once.Do(func() { + inp.table.init2() + inited = true + }) + return inited +} + +type Tables struct { + dialect Dialect + tables sync.Map + + mu sync.RWMutex + inProgress map[reflect.Type]*tableInProgress +} + +func NewTables(dialect Dialect) *Tables { + return &Tables{ + dialect: dialect, + inProgress: make(map[reflect.Type]*tableInProgress), + } +} + +func (t *Tables) Register(models ...interface{}) { + for _, model := range models { + _ = t.Get(reflect.TypeOf(model).Elem()) + } +} + +func (t *Tables) Get(typ reflect.Type) *Table { + return t.table(typ, false) +} + +func (t *Tables) Ref(typ reflect.Type) *Table { + return t.table(typ, true) +} + +func (t *Tables) table(typ reflect.Type, allowInProgress bool) *Table { + if typ.Kind() != reflect.Struct { + panic(fmt.Errorf("got %s, wanted %s", typ.Kind(), reflect.Struct)) + } + + if v, ok := t.tables.Load(typ); ok { + return v.(*Table) + } + + t.mu.Lock() + + if v, ok := t.tables.Load(typ); ok { + t.mu.Unlock() + return v.(*Table) + } + + var table *Table + + inProgress := t.inProgress[typ] + if inProgress == nil { + table = newTable(t.dialect, typ) + inProgress = newTableInProgress(table) + t.inProgress[typ] = inProgress + } else { + table = inProgress.table + } + + t.mu.Unlock() + + inProgress.init1() + if allowInProgress { + return table + } + + if inProgress.init2() { + t.mu.Lock() + delete(t.inProgress, typ) + t.tables.Store(typ, table) + t.mu.Unlock() + } + + t.dialect.OnTable(table) + + for _, field := range table.FieldMap { + if field.UserSQLType == "" { + field.UserSQLType = field.DiscoveredSQLType + } + if field.CreateTableSQLType == "" { + field.CreateTableSQLType = field.UserSQLType + } + } + + return table +} + +func (t *Tables) ByModel(name string) *Table { + var found *Table + t.tables.Range(func(key, value interface{}) bool { + t := value.(*Table) + if t.TypeName == name { + found = t + return false + } + return true + }) + return found +} + +func (t *Tables) ByName(name string) *Table { + var found *Table + t.tables.Range(func(key, value interface{}) bool { + t := value.(*Table) + if t.Name == name { + found = t + return false + } + return true + }) + return found +} diff --git a/vendor/github.com/uptrace/bun/schema/util.go b/vendor/github.com/uptrace/bun/schema/util.go new file mode 100644 index 000000000..6d474e4cc --- /dev/null +++ b/vendor/github.com/uptrace/bun/schema/util.go @@ -0,0 +1,53 @@ +package schema + +import "reflect" + +func indirectType(t reflect.Type) reflect.Type { + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + return t +} + +func fieldByIndex(v reflect.Value, index []int) (_ reflect.Value, ok bool) { + if len(index) == 1 { + return v.Field(index[0]), true + } + + for i, idx := range index { + if i > 0 { + if v.Kind() == reflect.Ptr { + if v.IsNil() { + return v, false + } + v = v.Elem() + } + } + v = v.Field(idx) + } + return v, true +} + +func fieldByIndexAlloc(v reflect.Value, index []int) reflect.Value { + if len(index) == 1 { + return v.Field(index[0]) + } + + for i, idx := range index { + if i > 0 { + v = indirectNil(v) + } + v = v.Field(idx) + } + return v +} + +func indirectNil(v reflect.Value) reflect.Value { + if v.Kind() == reflect.Ptr { + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } + v = v.Elem() + } + return v +} diff --git a/vendor/github.com/uptrace/bun/schema/zerochecker.go b/vendor/github.com/uptrace/bun/schema/zerochecker.go new file mode 100644 index 000000000..95efeee6b --- /dev/null +++ b/vendor/github.com/uptrace/bun/schema/zerochecker.go @@ -0,0 +1,126 @@ +package schema + +import ( + "database/sql/driver" + "reflect" +) + +var isZeroerType = reflect.TypeOf((*isZeroer)(nil)).Elem() + +type isZeroer interface { + IsZero() bool +} + +type IsZeroerFunc func(reflect.Value) bool + +func FieldZeroChecker(field *Field) IsZeroerFunc { + return zeroChecker(field.IndirectType) +} + +func zeroChecker(typ reflect.Type) IsZeroerFunc { + if typ.Implements(isZeroerType) { + return isZeroInterface + } + + kind := typ.Kind() + + if kind != reflect.Ptr { + ptr := reflect.PtrTo(typ) + if ptr.Implements(isZeroerType) { + return addrChecker(isZeroInterface) + } + } + + switch kind { + case reflect.Array: + if typ.Elem().Kind() == reflect.Uint8 { + return isZeroBytes + } + return isZeroLen + case reflect.String: + return isZeroLen + case reflect.Bool: + return isZeroBool + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return isZeroInt + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return isZeroUint + case reflect.Float32, reflect.Float64: + return isZeroFloat + case reflect.Interface, reflect.Ptr, reflect.Slice, reflect.Map: + return isNil + } + + if typ.Implements(driverValuerType) { + return isZeroDriverValue + } + + return notZero +} + +func addrChecker(fn IsZeroerFunc) IsZeroerFunc { + return func(v reflect.Value) bool { + if !v.CanAddr() { + return false + } + return fn(v.Addr()) + } +} + +func isZeroInterface(v reflect.Value) bool { + if v.Kind() == reflect.Ptr && v.IsNil() { + return true + } + return v.Interface().(isZeroer).IsZero() +} + +func isZeroDriverValue(v reflect.Value) bool { + if v.Kind() == reflect.Ptr { + return v.IsNil() + } + + valuer := v.Interface().(driver.Valuer) + value, err := valuer.Value() + if err != nil { + return false + } + return value == nil +} + +func isZeroLen(v reflect.Value) bool { + return v.Len() == 0 +} + +func isNil(v reflect.Value) bool { + return v.IsNil() +} + +func isZeroBool(v reflect.Value) bool { + return !v.Bool() +} + +func isZeroInt(v reflect.Value) bool { + return v.Int() == 0 +} + +func isZeroUint(v reflect.Value) bool { + return v.Uint() == 0 +} + +func isZeroFloat(v reflect.Value) bool { + return v.Float() == 0 +} + +func isZeroBytes(v reflect.Value) bool { + b := v.Slice(0, v.Len()).Bytes() + for _, c := range b { + if c != 0 { + return false + } + } + return true +} + +func notZero(v reflect.Value) bool { + return false +} |
