summaryrefslogtreecommitdiff
path: root/vendor/github.com/uptrace/bun/schema
diff options
context:
space:
mode:
authorLibravatar tobi <31960611+tsmethurst@users.noreply.github.com>2021-08-25 15:34:33 +0200
committerLibravatar GitHub <noreply@github.com>2021-08-25 15:34:33 +0200
commit2dc9fc1626507bb54417fc4a1920b847cafb27a2 (patch)
tree4ddeac479b923db38090aac8bd9209f3646851c1 /vendor/github.com/uptrace/bun/schema
parentManually approves followers (#146) (diff)
downloadgotosocial-2dc9fc1626507bb54417fc4a1920b847cafb27a2.tar.xz
Pg to bun (#148)
* start moving to bun * changing more stuff * more * and yet more * tests passing * seems stable now * more big changes * small fix * little fixes
Diffstat (limited to 'vendor/github.com/uptrace/bun/schema')
-rw-r--r--vendor/github.com/uptrace/bun/schema/append.go93
-rw-r--r--vendor/github.com/uptrace/bun/schema/append_value.go237
-rw-r--r--vendor/github.com/uptrace/bun/schema/dialect.go99
-rw-r--r--vendor/github.com/uptrace/bun/schema/field.go117
-rw-r--r--vendor/github.com/uptrace/bun/schema/formatter.go248
-rw-r--r--vendor/github.com/uptrace/bun/schema/hook.go20
-rw-r--r--vendor/github.com/uptrace/bun/schema/relation.go32
-rw-r--r--vendor/github.com/uptrace/bun/schema/scan.go392
-rw-r--r--vendor/github.com/uptrace/bun/schema/sqlfmt.go76
-rw-r--r--vendor/github.com/uptrace/bun/schema/sqltype.go129
-rw-r--r--vendor/github.com/uptrace/bun/schema/table.go948
-rw-r--r--vendor/github.com/uptrace/bun/schema/tables.go148
-rw-r--r--vendor/github.com/uptrace/bun/schema/util.go53
-rw-r--r--vendor/github.com/uptrace/bun/schema/zerochecker.go126
14 files changed, 2718 insertions, 0 deletions
diff --git a/vendor/github.com/uptrace/bun/schema/append.go b/vendor/github.com/uptrace/bun/schema/append.go
new file mode 100644
index 000000000..68f7071c8
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/schema/append.go
@@ -0,0 +1,93 @@
+package schema
+
+import (
+ "reflect"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/vmihailenco/msgpack/v5"
+
+ "github.com/uptrace/bun/dialect"
+ "github.com/uptrace/bun/dialect/sqltype"
+ "github.com/uptrace/bun/internal"
+)
+
+func FieldAppender(dialect Dialect, field *Field) AppenderFunc {
+ if field.Tag.HasOption("msgpack") {
+ return appendMsgpack
+ }
+
+ switch strings.ToUpper(field.UserSQLType) {
+ case sqltype.JSON, sqltype.JSONB:
+ return AppendJSONValue
+ }
+
+ return dialect.Appender(field.StructField.Type)
+}
+
+func Append(fmter Formatter, b []byte, v interface{}, custom CustomAppender) []byte {
+ switch v := v.(type) {
+ case nil:
+ return dialect.AppendNull(b)
+ case bool:
+ return dialect.AppendBool(b, v)
+ case int:
+ return strconv.AppendInt(b, int64(v), 10)
+ case int32:
+ return strconv.AppendInt(b, int64(v), 10)
+ case int64:
+ return strconv.AppendInt(b, v, 10)
+ case uint:
+ return strconv.AppendUint(b, uint64(v), 10)
+ case uint32:
+ return strconv.AppendUint(b, uint64(v), 10)
+ case uint64:
+ return strconv.AppendUint(b, v, 10)
+ case float32:
+ return dialect.AppendFloat32(b, v)
+ case float64:
+ return dialect.AppendFloat64(b, v)
+ case string:
+ return dialect.AppendString(b, v)
+ case time.Time:
+ return dialect.AppendTime(b, v)
+ case []byte:
+ return dialect.AppendBytes(b, v)
+ case QueryAppender:
+ return AppendQueryAppender(fmter, b, v)
+ default:
+ vv := reflect.ValueOf(v)
+ if vv.Kind() == reflect.Ptr && vv.IsNil() {
+ return dialect.AppendNull(b)
+ }
+ appender := Appender(vv.Type(), custom)
+ return appender(fmter, b, vv)
+ }
+}
+
+func appendMsgpack(fmter Formatter, b []byte, v reflect.Value) []byte {
+ hexEnc := internal.NewHexEncoder(b)
+
+ enc := msgpack.GetEncoder()
+ defer msgpack.PutEncoder(enc)
+
+ enc.Reset(hexEnc)
+ if err := enc.EncodeValue(v); err != nil {
+ return dialect.AppendError(b, err)
+ }
+
+ if err := hexEnc.Close(); err != nil {
+ return dialect.AppendError(b, err)
+ }
+
+ return hexEnc.Bytes()
+}
+
+func AppendQueryAppender(fmter Formatter, b []byte, app QueryAppender) []byte {
+ bb, err := app.AppendQuery(fmter, b)
+ if err != nil {
+ return dialect.AppendError(b, err)
+ }
+ return bb
+}
diff --git a/vendor/github.com/uptrace/bun/schema/append_value.go b/vendor/github.com/uptrace/bun/schema/append_value.go
new file mode 100644
index 000000000..0c4677069
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/schema/append_value.go
@@ -0,0 +1,237 @@
+package schema
+
+import (
+ "database/sql/driver"
+ "encoding/json"
+ "fmt"
+ "net"
+ "reflect"
+ "strconv"
+ "time"
+
+ "github.com/uptrace/bun/dialect"
+ "github.com/uptrace/bun/extra/bunjson"
+ "github.com/uptrace/bun/internal"
+)
+
+var (
+ timeType = reflect.TypeOf((*time.Time)(nil)).Elem()
+ ipType = reflect.TypeOf((*net.IP)(nil)).Elem()
+ ipNetType = reflect.TypeOf((*net.IPNet)(nil)).Elem()
+ jsonRawMessageType = reflect.TypeOf((*json.RawMessage)(nil)).Elem()
+
+ driverValuerType = reflect.TypeOf((*driver.Valuer)(nil)).Elem()
+ queryAppenderType = reflect.TypeOf((*QueryAppender)(nil)).Elem()
+)
+
+type (
+ AppenderFunc func(fmter Formatter, b []byte, v reflect.Value) []byte
+ CustomAppender func(typ reflect.Type) AppenderFunc
+)
+
+var appenders = []AppenderFunc{
+ reflect.Bool: AppendBoolValue,
+ reflect.Int: AppendIntValue,
+ reflect.Int8: AppendIntValue,
+ reflect.Int16: AppendIntValue,
+ reflect.Int32: AppendIntValue,
+ reflect.Int64: AppendIntValue,
+ reflect.Uint: AppendUintValue,
+ reflect.Uint8: AppendUintValue,
+ reflect.Uint16: AppendUintValue,
+ reflect.Uint32: AppendUintValue,
+ reflect.Uint64: AppendUintValue,
+ reflect.Uintptr: nil,
+ reflect.Float32: AppendFloat32Value,
+ reflect.Float64: AppendFloat64Value,
+ reflect.Complex64: nil,
+ reflect.Complex128: nil,
+ reflect.Array: AppendJSONValue,
+ reflect.Chan: nil,
+ reflect.Func: nil,
+ reflect.Interface: nil,
+ reflect.Map: AppendJSONValue,
+ reflect.Ptr: nil,
+ reflect.Slice: AppendJSONValue,
+ reflect.String: AppendStringValue,
+ reflect.Struct: AppendJSONValue,
+ reflect.UnsafePointer: nil,
+}
+
+func Appender(typ reflect.Type, custom CustomAppender) AppenderFunc {
+ switch typ {
+ case timeType:
+ return appendTimeValue
+ case ipType:
+ return appendIPValue
+ case ipNetType:
+ return appendIPNetValue
+ case jsonRawMessageType:
+ return appendJSONRawMessageValue
+ }
+
+ if typ.Implements(queryAppenderType) {
+ return appendQueryAppenderValue
+ }
+ if typ.Implements(driverValuerType) {
+ return driverValueAppender(custom)
+ }
+
+ kind := typ.Kind()
+
+ if kind != reflect.Ptr {
+ ptr := reflect.PtrTo(typ)
+ if ptr.Implements(queryAppenderType) {
+ return addrAppender(appendQueryAppenderValue, custom)
+ }
+ if ptr.Implements(driverValuerType) {
+ return addrAppender(driverValueAppender(custom), custom)
+ }
+ }
+
+ switch kind {
+ case reflect.Interface:
+ return ifaceAppenderFunc(typ, custom)
+ case reflect.Ptr:
+ return ptrAppenderFunc(typ, custom)
+ case reflect.Slice:
+ if typ.Elem().Kind() == reflect.Uint8 {
+ return appendBytesValue
+ }
+ case reflect.Array:
+ if typ.Elem().Kind() == reflect.Uint8 {
+ return appendArrayBytesValue
+ }
+ }
+
+ if custom != nil {
+ if fn := custom(typ); fn != nil {
+ return fn
+ }
+ }
+ return appenders[typ.Kind()]
+}
+
+func ifaceAppenderFunc(typ reflect.Type, custom func(reflect.Type) AppenderFunc) AppenderFunc {
+ return func(fmter Formatter, b []byte, v reflect.Value) []byte {
+ if v.IsNil() {
+ return dialect.AppendNull(b)
+ }
+ elem := v.Elem()
+ appender := Appender(elem.Type(), custom)
+ return appender(fmter, b, elem)
+ }
+}
+
+func ptrAppenderFunc(typ reflect.Type, custom func(reflect.Type) AppenderFunc) AppenderFunc {
+ appender := Appender(typ.Elem(), custom)
+ return func(fmter Formatter, b []byte, v reflect.Value) []byte {
+ if v.IsNil() {
+ return dialect.AppendNull(b)
+ }
+ return appender(fmter, b, v.Elem())
+ }
+}
+
+func AppendBoolValue(fmter Formatter, b []byte, v reflect.Value) []byte {
+ return dialect.AppendBool(b, v.Bool())
+}
+
+func AppendIntValue(fmter Formatter, b []byte, v reflect.Value) []byte {
+ return strconv.AppendInt(b, v.Int(), 10)
+}
+
+func AppendUintValue(fmter Formatter, b []byte, v reflect.Value) []byte {
+ return strconv.AppendUint(b, v.Uint(), 10)
+}
+
+func AppendFloat32Value(fmter Formatter, b []byte, v reflect.Value) []byte {
+ return dialect.AppendFloat32(b, float32(v.Float()))
+}
+
+func AppendFloat64Value(fmter Formatter, b []byte, v reflect.Value) []byte {
+ return dialect.AppendFloat64(b, float64(v.Float()))
+}
+
+func appendBytesValue(fmter Formatter, b []byte, v reflect.Value) []byte {
+ return dialect.AppendBytes(b, v.Bytes())
+}
+
+func appendArrayBytesValue(fmter Formatter, b []byte, v reflect.Value) []byte {
+ if v.CanAddr() {
+ return dialect.AppendBytes(b, v.Slice(0, v.Len()).Bytes())
+ }
+
+ tmp := make([]byte, v.Len())
+ reflect.Copy(reflect.ValueOf(tmp), v)
+ b = dialect.AppendBytes(b, tmp)
+ return b
+}
+
+func AppendStringValue(fmter Formatter, b []byte, v reflect.Value) []byte {
+ return dialect.AppendString(b, v.String())
+}
+
+func AppendJSONValue(fmter Formatter, b []byte, v reflect.Value) []byte {
+ bb, err := bunjson.Marshal(v.Interface())
+ if err != nil {
+ return dialect.AppendError(b, err)
+ }
+
+ if len(bb) > 0 && bb[len(bb)-1] == '\n' {
+ bb = bb[:len(bb)-1]
+ }
+
+ return dialect.AppendJSON(b, bb)
+}
+
+func appendTimeValue(fmter Formatter, b []byte, v reflect.Value) []byte {
+ tm := v.Interface().(time.Time)
+ return dialect.AppendTime(b, tm)
+}
+
+func appendIPValue(fmter Formatter, b []byte, v reflect.Value) []byte {
+ ip := v.Interface().(net.IP)
+ return dialect.AppendString(b, ip.String())
+}
+
+func appendIPNetValue(fmter Formatter, b []byte, v reflect.Value) []byte {
+ ipnet := v.Interface().(net.IPNet)
+ return dialect.AppendString(b, ipnet.String())
+}
+
+func appendJSONRawMessageValue(fmter Formatter, b []byte, v reflect.Value) []byte {
+ bytes := v.Bytes()
+ if bytes == nil {
+ return dialect.AppendNull(b)
+ }
+ return dialect.AppendString(b, internal.String(bytes))
+}
+
+func appendQueryAppenderValue(fmter Formatter, b []byte, v reflect.Value) []byte {
+ return AppendQueryAppender(fmter, b, v.Interface().(QueryAppender))
+}
+
+func driverValueAppender(custom CustomAppender) AppenderFunc {
+ return func(fmter Formatter, b []byte, v reflect.Value) []byte {
+ return appendDriverValue(fmter, b, v.Interface().(driver.Valuer), custom)
+ }
+}
+
+func appendDriverValue(fmter Formatter, b []byte, v driver.Valuer, custom CustomAppender) []byte {
+ value, err := v.Value()
+ if err != nil {
+ return dialect.AppendError(b, err)
+ }
+ return Append(fmter, b, value, custom)
+}
+
+func addrAppender(fn AppenderFunc, custom CustomAppender) AppenderFunc {
+ return func(fmter Formatter, b []byte, v reflect.Value) []byte {
+ if !v.CanAddr() {
+ err := fmt.Errorf("bun: Append(nonaddressable %T)", v.Interface())
+ return dialect.AppendError(b, err)
+ }
+ return fn(fmter, b, v.Addr())
+ }
+}
diff --git a/vendor/github.com/uptrace/bun/schema/dialect.go b/vendor/github.com/uptrace/bun/schema/dialect.go
new file mode 100644
index 000000000..c50de715a
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/schema/dialect.go
@@ -0,0 +1,99 @@
+package schema
+
+import (
+ "database/sql"
+ "reflect"
+ "sync"
+
+ "github.com/uptrace/bun/dialect"
+ "github.com/uptrace/bun/dialect/feature"
+)
+
+type Dialect interface {
+ Init(db *sql.DB)
+
+ Name() dialect.Name
+ Features() feature.Feature
+
+ Tables() *Tables
+ OnTable(table *Table)
+
+ IdentQuote() byte
+ Append(fmter Formatter, b []byte, v interface{}) []byte
+ Appender(typ reflect.Type) AppenderFunc
+ FieldAppender(field *Field) AppenderFunc
+ Scanner(typ reflect.Type) ScannerFunc
+}
+
+//------------------------------------------------------------------------------
+
+type nopDialect struct {
+ tables *Tables
+ features feature.Feature
+
+ appenderMap sync.Map
+ scannerMap sync.Map
+}
+
+func newNopDialect() *nopDialect {
+ d := new(nopDialect)
+ d.tables = NewTables(d)
+ d.features = feature.Returning
+ return d
+}
+
+func (d *nopDialect) Init(*sql.DB) {}
+
+func (d *nopDialect) Name() dialect.Name {
+ return dialect.Invalid
+}
+
+func (d *nopDialect) Features() feature.Feature {
+ return d.features
+}
+
+func (d *nopDialect) Tables() *Tables {
+ return d.tables
+}
+
+func (d *nopDialect) OnField(field *Field) {}
+
+func (d *nopDialect) OnTable(table *Table) {}
+
+func (d *nopDialect) IdentQuote() byte {
+ return '"'
+}
+
+func (d *nopDialect) Append(fmter Formatter, b []byte, v interface{}) []byte {
+ return Append(fmter, b, v, nil)
+}
+
+func (d *nopDialect) Appender(typ reflect.Type) AppenderFunc {
+ if v, ok := d.appenderMap.Load(typ); ok {
+ return v.(AppenderFunc)
+ }
+
+ fn := Appender(typ, nil)
+
+ if v, ok := d.appenderMap.LoadOrStore(typ, fn); ok {
+ return v.(AppenderFunc)
+ }
+ return fn
+}
+
+func (d *nopDialect) FieldAppender(field *Field) AppenderFunc {
+ return FieldAppender(d, field)
+}
+
+func (d *nopDialect) Scanner(typ reflect.Type) ScannerFunc {
+ if v, ok := d.scannerMap.Load(typ); ok {
+ return v.(ScannerFunc)
+ }
+
+ fn := Scanner(typ)
+
+ if v, ok := d.scannerMap.LoadOrStore(typ, fn); ok {
+ return v.(ScannerFunc)
+ }
+ return fn
+}
diff --git a/vendor/github.com/uptrace/bun/schema/field.go b/vendor/github.com/uptrace/bun/schema/field.go
new file mode 100644
index 000000000..1e069b82f
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/schema/field.go
@@ -0,0 +1,117 @@
+package schema
+
+import (
+ "fmt"
+ "reflect"
+
+ "github.com/uptrace/bun/dialect"
+ "github.com/uptrace/bun/internal/tagparser"
+)
+
+type Field struct {
+ StructField reflect.StructField
+
+ Tag tagparser.Tag
+ IndirectType reflect.Type
+ Index []int
+
+ Name string // SQL name, .e.g. id
+ SQLName Safe // escaped SQL name, e.g. "id"
+ GoName string // struct field name, e.g. Id
+
+ DiscoveredSQLType string
+ UserSQLType string
+ CreateTableSQLType string
+ SQLDefault string
+
+ OnDelete string
+ OnUpdate string
+
+ IsPK bool
+ NotNull bool
+ NullZero bool
+ AutoIncrement bool
+
+ Append AppenderFunc
+ Scan ScannerFunc
+ IsZero IsZeroerFunc
+}
+
+func (f *Field) String() string {
+ return f.Name
+}
+
+func (f *Field) Clone() *Field {
+ cp := *f
+ cp.Index = cp.Index[:len(f.Index):len(f.Index)]
+ return &cp
+}
+
+func (f *Field) Value(strct reflect.Value) reflect.Value {
+ return fieldByIndexAlloc(strct, f.Index)
+}
+
+func (f *Field) HasZeroValue(v reflect.Value) bool {
+ for _, idx := range f.Index {
+ if v.Kind() == reflect.Ptr {
+ if v.IsNil() {
+ return true
+ }
+ v = v.Elem()
+ }
+ v = v.Field(idx)
+ }
+ return f.IsZero(v)
+}
+
+func (f *Field) AppendValue(fmter Formatter, b []byte, strct reflect.Value) []byte {
+ fv, ok := fieldByIndex(strct, f.Index)
+ if !ok {
+ return dialect.AppendNull(b)
+ }
+
+ if f.NullZero && f.IsZero(fv) {
+ return dialect.AppendNull(b)
+ }
+ if f.Append == nil {
+ panic(fmt.Errorf("bun: AppendValue(unsupported %s)", fv.Type()))
+ }
+ return f.Append(fmter, b, fv)
+}
+
+func (f *Field) ScanWithCheck(fv reflect.Value, src interface{}) error {
+ if f.Scan == nil {
+ return fmt.Errorf("bun: Scan(unsupported %s)", f.IndirectType)
+ }
+ return f.Scan(fv, src)
+}
+
+func (f *Field) ScanValue(strct reflect.Value, src interface{}) error {
+ if src == nil {
+ if fv, ok := fieldByIndex(strct, f.Index); ok {
+ return f.ScanWithCheck(fv, src)
+ }
+ return nil
+ }
+
+ fv := fieldByIndexAlloc(strct, f.Index)
+ return f.ScanWithCheck(fv, src)
+}
+
+func (f *Field) markAsPK() {
+ f.IsPK = true
+ f.NotNull = true
+ f.NullZero = true
+}
+
+func indexEqual(ind1, ind2 []int) bool {
+ if len(ind1) != len(ind2) {
+ return false
+ }
+ for i, ind := range ind1 {
+ if ind != ind2[i] {
+ return false
+ }
+ }
+ return true
+}
diff --git a/vendor/github.com/uptrace/bun/schema/formatter.go b/vendor/github.com/uptrace/bun/schema/formatter.go
new file mode 100644
index 000000000..7b26fbaca
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/schema/formatter.go
@@ -0,0 +1,248 @@
+package schema
+
+import (
+ "reflect"
+ "strconv"
+ "strings"
+
+ "github.com/uptrace/bun/dialect"
+ "github.com/uptrace/bun/dialect/feature"
+ "github.com/uptrace/bun/internal"
+ "github.com/uptrace/bun/internal/parser"
+)
+
+var nopFormatter = Formatter{
+ dialect: newNopDialect(),
+}
+
+type Formatter struct {
+ dialect Dialect
+ args *namedArgList
+}
+
+func NewFormatter(dialect Dialect) Formatter {
+ return Formatter{
+ dialect: dialect,
+ }
+}
+
+func NewNopFormatter() Formatter {
+ return nopFormatter
+}
+
+func (f Formatter) IsNop() bool {
+ return f.dialect.Name() == dialect.Invalid
+}
+
+func (f Formatter) Dialect() Dialect {
+ return f.dialect
+}
+
+func (f Formatter) IdentQuote() byte {
+ return f.dialect.IdentQuote()
+}
+
+func (f Formatter) AppendIdent(b []byte, ident string) []byte {
+ return dialect.AppendIdent(b, ident, f.IdentQuote())
+}
+
+func (f Formatter) AppendValue(b []byte, v reflect.Value) []byte {
+ if v.Kind() == reflect.Ptr && v.IsNil() {
+ return dialect.AppendNull(b)
+ }
+ appender := f.dialect.Appender(v.Type())
+ return appender(f, b, v)
+}
+
+func (f Formatter) HasFeature(feature feature.Feature) bool {
+ return f.dialect.Features().Has(feature)
+}
+
+func (f Formatter) WithArg(arg NamedArgAppender) Formatter {
+ return Formatter{
+ dialect: f.dialect,
+ args: f.args.WithArg(arg),
+ }
+}
+
+func (f Formatter) WithNamedArg(name string, value interface{}) Formatter {
+ return Formatter{
+ dialect: f.dialect,
+ args: f.args.WithArg(&namedArg{name: name, value: value}),
+ }
+}
+
+func (f Formatter) FormatQuery(query string, args ...interface{}) string {
+ if f.IsNop() || (args == nil && f.args == nil) || strings.IndexByte(query, '?') == -1 {
+ return query
+ }
+ return internal.String(f.AppendQuery(nil, query, args...))
+}
+
+func (f Formatter) AppendQuery(dst []byte, query string, args ...interface{}) []byte {
+ if f.IsNop() || (args == nil && f.args == nil) || strings.IndexByte(query, '?') == -1 {
+ return append(dst, query...)
+ }
+ return f.append(dst, parser.NewString(query), args)
+}
+
+func (f Formatter) append(dst []byte, p *parser.Parser, args []interface{}) []byte {
+ var namedArgs NamedArgAppender
+ if len(args) == 1 {
+ var ok bool
+ namedArgs, ok = args[0].(NamedArgAppender)
+ if !ok {
+ namedArgs, _ = newStructArgs(f, args[0])
+ }
+ }
+
+ var argIndex int
+ for p.Valid() {
+ b, ok := p.ReadSep('?')
+ if !ok {
+ dst = append(dst, b...)
+ continue
+ }
+ if len(b) > 0 && b[len(b)-1] == '\\' {
+ dst = append(dst, b[:len(b)-1]...)
+ dst = append(dst, '?')
+ continue
+ }
+ dst = append(dst, b...)
+
+ name, numeric := p.ReadIdentifier()
+ if name != "" {
+ if numeric {
+ idx, err := strconv.Atoi(name)
+ if err != nil {
+ goto restore_arg
+ }
+
+ if idx >= len(args) {
+ goto restore_arg
+ }
+
+ dst = f.appendArg(dst, args[idx])
+ continue
+ }
+
+ if namedArgs != nil {
+ dst, ok = namedArgs.AppendNamedArg(f, dst, name)
+ if ok {
+ continue
+ }
+ }
+
+ dst, ok = f.args.AppendNamedArg(f, dst, name)
+ if ok {
+ continue
+ }
+
+ restore_arg:
+ dst = append(dst, '?')
+ dst = append(dst, name...)
+ continue
+ }
+
+ if argIndex >= len(args) {
+ dst = append(dst, '?')
+ continue
+ }
+
+ arg := args[argIndex]
+ argIndex++
+
+ dst = f.appendArg(dst, arg)
+ }
+
+ return dst
+}
+
+func (f Formatter) appendArg(b []byte, arg interface{}) []byte {
+ switch arg := arg.(type) {
+ case QueryAppender:
+ bb, err := arg.AppendQuery(f, b)
+ if err != nil {
+ return dialect.AppendError(b, err)
+ }
+ return bb
+ default:
+ return f.dialect.Append(f, b, arg)
+ }
+}
+
+//------------------------------------------------------------------------------
+
+type NamedArgAppender interface {
+ AppendNamedArg(fmter Formatter, b []byte, name string) ([]byte, bool)
+}
+
+//------------------------------------------------------------------------------
+
+type namedArgList struct {
+ arg NamedArgAppender
+ next *namedArgList
+}
+
+func (l *namedArgList) WithArg(arg NamedArgAppender) *namedArgList {
+ return &namedArgList{
+ arg: arg,
+ next: l,
+ }
+}
+
+func (l *namedArgList) AppendNamedArg(fmter Formatter, b []byte, name string) ([]byte, bool) {
+ for l != nil && l.arg != nil {
+ if b, ok := l.arg.AppendNamedArg(fmter, b, name); ok {
+ return b, true
+ }
+ l = l.next
+ }
+ return b, false
+}
+
+//------------------------------------------------------------------------------
+
+type namedArg struct {
+ name string
+ value interface{}
+}
+
+var _ NamedArgAppender = (*namedArg)(nil)
+
+func (a *namedArg) AppendNamedArg(fmter Formatter, b []byte, name string) ([]byte, bool) {
+ if a.name == name {
+ return fmter.appendArg(b, a.value), true
+ }
+ return b, false
+}
+
+//------------------------------------------------------------------------------
+
+var _ NamedArgAppender = (*structArgs)(nil)
+
+type structArgs struct {
+ table *Table
+ strct reflect.Value
+}
+
+func newStructArgs(fmter Formatter, strct interface{}) (*structArgs, bool) {
+ v := reflect.ValueOf(strct)
+ if !v.IsValid() {
+ return nil, false
+ }
+
+ v = reflect.Indirect(v)
+ if v.Kind() != reflect.Struct {
+ return nil, false
+ }
+
+ return &structArgs{
+ table: fmter.Dialect().Tables().Get(v.Type()),
+ strct: v,
+ }, true
+}
+
+func (m *structArgs) AppendNamedArg(fmter Formatter, b []byte, name string) ([]byte, bool) {
+ return m.table.AppendNamedArg(fmter, b, name, m.strct)
+}
diff --git a/vendor/github.com/uptrace/bun/schema/hook.go b/vendor/github.com/uptrace/bun/schema/hook.go
new file mode 100644
index 000000000..5391981d5
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/schema/hook.go
@@ -0,0 +1,20 @@
+package schema
+
+import (
+ "context"
+ "reflect"
+)
+
+type BeforeScanHook interface {
+ BeforeScan(context.Context) error
+}
+
+var beforeScanHookType = reflect.TypeOf((*BeforeScanHook)(nil)).Elem()
+
+//------------------------------------------------------------------------------
+
+type AfterScanHook interface {
+ AfterScan(context.Context) error
+}
+
+var afterScanHookType = reflect.TypeOf((*AfterScanHook)(nil)).Elem()
diff --git a/vendor/github.com/uptrace/bun/schema/relation.go b/vendor/github.com/uptrace/bun/schema/relation.go
new file mode 100644
index 000000000..8d1baeb3f
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/schema/relation.go
@@ -0,0 +1,32 @@
+package schema
+
+import (
+ "fmt"
+)
+
+const (
+ InvalidRelation = iota
+ HasOneRelation
+ BelongsToRelation
+ HasManyRelation
+ ManyToManyRelation
+)
+
+type Relation struct {
+ Type int
+ Field *Field
+ JoinTable *Table
+ BaseFields []*Field
+ JoinFields []*Field
+
+ PolymorphicField *Field
+ PolymorphicValue string
+
+ M2MTable *Table
+ M2MBaseFields []*Field
+ M2MJoinFields []*Field
+}
+
+func (r *Relation) String() string {
+ return fmt.Sprintf("relation=%s", r.Field.GoName)
+}
diff --git a/vendor/github.com/uptrace/bun/schema/scan.go b/vendor/github.com/uptrace/bun/schema/scan.go
new file mode 100644
index 000000000..0e66a860f
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/schema/scan.go
@@ -0,0 +1,392 @@
+package schema
+
+import (
+ "bytes"
+ "database/sql"
+ "fmt"
+ "net"
+ "reflect"
+ "strconv"
+ "time"
+
+ "github.com/vmihailenco/msgpack/v5"
+
+ "github.com/uptrace/bun/extra/bunjson"
+ "github.com/uptrace/bun/internal"
+)
+
+var scannerType = reflect.TypeOf((*sql.Scanner)(nil)).Elem()
+
+type ScannerFunc func(dest reflect.Value, src interface{}) error
+
+var scanners = []ScannerFunc{
+ reflect.Bool: scanBool,
+ reflect.Int: scanInt64,
+ reflect.Int8: scanInt64,
+ reflect.Int16: scanInt64,
+ reflect.Int32: scanInt64,
+ reflect.Int64: scanInt64,
+ reflect.Uint: scanUint64,
+ reflect.Uint8: scanUint64,
+ reflect.Uint16: scanUint64,
+ reflect.Uint32: scanUint64,
+ reflect.Uint64: scanUint64,
+ reflect.Uintptr: scanUint64,
+ reflect.Float32: scanFloat64,
+ reflect.Float64: scanFloat64,
+ reflect.Complex64: nil,
+ reflect.Complex128: nil,
+ reflect.Array: nil,
+ reflect.Chan: nil,
+ reflect.Func: nil,
+ reflect.Map: scanJSON,
+ reflect.Ptr: nil,
+ reflect.Slice: scanJSON,
+ reflect.String: scanString,
+ reflect.Struct: scanJSON,
+ reflect.UnsafePointer: nil,
+}
+
+func FieldScanner(dialect Dialect, field *Field) ScannerFunc {
+ if field.Tag.HasOption("msgpack") {
+ return scanMsgpack
+ }
+ if field.Tag.HasOption("json_use_number") {
+ return scanJSONUseNumber
+ }
+ return dialect.Scanner(field.StructField.Type)
+}
+
+func Scanner(typ reflect.Type) ScannerFunc {
+ kind := typ.Kind()
+
+ if kind == reflect.Ptr {
+ if fn := Scanner(typ.Elem()); fn != nil {
+ return ptrScanner(fn)
+ }
+ }
+
+ if typ.Implements(scannerType) {
+ return scanScanner
+ }
+
+ if kind != reflect.Ptr {
+ ptr := reflect.PtrTo(typ)
+ if ptr.Implements(scannerType) {
+ return addrScanner(scanScanner)
+ }
+ }
+
+ switch typ {
+ case timeType:
+ return scanTime
+ case ipType:
+ return scanIP
+ case ipNetType:
+ return scanIPNet
+ case jsonRawMessageType:
+ return scanJSONRawMessage
+ }
+
+ return scanners[kind]
+}
+
+func scanBool(dest reflect.Value, src interface{}) error {
+ switch src := src.(type) {
+ case nil:
+ dest.SetBool(false)
+ return nil
+ case bool:
+ dest.SetBool(src)
+ return nil
+ case int64:
+ dest.SetBool(src != 0)
+ return nil
+ case []byte:
+ if len(src) == 1 {
+ dest.SetBool(src[0] != '0')
+ return nil
+ }
+ }
+ return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type())
+}
+
+func scanInt64(dest reflect.Value, src interface{}) error {
+ switch src := src.(type) {
+ case nil:
+ dest.SetInt(0)
+ return nil
+ case int64:
+ dest.SetInt(src)
+ return nil
+ case uint64:
+ dest.SetInt(int64(src))
+ return nil
+ case []byte:
+ n, err := strconv.ParseInt(internal.String(src), 10, 64)
+ if err != nil {
+ return err
+ }
+ dest.SetInt(n)
+ return nil
+ case string:
+ n, err := strconv.ParseInt(src, 10, 64)
+ if err != nil {
+ return err
+ }
+ dest.SetInt(n)
+ return nil
+ }
+ return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type())
+}
+
+func scanUint64(dest reflect.Value, src interface{}) error {
+ switch src := src.(type) {
+ case nil:
+ dest.SetUint(0)
+ return nil
+ case uint64:
+ dest.SetUint(src)
+ return nil
+ case int64:
+ dest.SetUint(uint64(src))
+ return nil
+ case []byte:
+ n, err := strconv.ParseUint(internal.String(src), 10, 64)
+ if err != nil {
+ return err
+ }
+ dest.SetUint(n)
+ return nil
+ }
+ return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type())
+}
+
+func scanFloat64(dest reflect.Value, src interface{}) error {
+ switch src := src.(type) {
+ case nil:
+ dest.SetFloat(0)
+ return nil
+ case float64:
+ dest.SetFloat(src)
+ return nil
+ case []byte:
+ f, err := strconv.ParseFloat(internal.String(src), 64)
+ if err != nil {
+ return err
+ }
+ dest.SetFloat(f)
+ return nil
+ }
+ return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type())
+}
+
+func scanString(dest reflect.Value, src interface{}) error {
+ switch src := src.(type) {
+ case nil:
+ dest.SetString("")
+ return nil
+ case string:
+ dest.SetString(src)
+ return nil
+ case []byte:
+ dest.SetString(string(src))
+ return nil
+ }
+ return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type())
+}
+
+func scanTime(dest reflect.Value, src interface{}) error {
+ switch src := src.(type) {
+ case nil:
+ destTime := dest.Addr().Interface().(*time.Time)
+ *destTime = time.Time{}
+ return nil
+ case time.Time:
+ destTime := dest.Addr().Interface().(*time.Time)
+ *destTime = src
+ return nil
+ case string:
+ srcTime, err := internal.ParseTime(src)
+ if err != nil {
+ return err
+ }
+ destTime := dest.Addr().Interface().(*time.Time)
+ *destTime = srcTime
+ return nil
+ case []byte:
+ srcTime, err := internal.ParseTime(internal.String(src))
+ if err != nil {
+ return err
+ }
+ destTime := dest.Addr().Interface().(*time.Time)
+ *destTime = srcTime
+ return nil
+ }
+ return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type())
+}
+
+func scanScanner(dest reflect.Value, src interface{}) error {
+ return dest.Interface().(sql.Scanner).Scan(src)
+}
+
+func scanMsgpack(dest reflect.Value, src interface{}) error {
+ if src == nil {
+ return scanNull(dest)
+ }
+
+ b, err := toBytes(src)
+ if err != nil {
+ return err
+ }
+
+ dec := msgpack.GetDecoder()
+ defer msgpack.PutDecoder(dec)
+
+ dec.Reset(bytes.NewReader(b))
+ return dec.DecodeValue(dest)
+}
+
+func scanJSON(dest reflect.Value, src interface{}) error {
+ if src == nil {
+ return scanNull(dest)
+ }
+
+ b, err := toBytes(src)
+ if err != nil {
+ return err
+ }
+
+ return bunjson.Unmarshal(b, dest.Addr().Interface())
+}
+
+func scanJSONUseNumber(dest reflect.Value, src interface{}) error {
+ if src == nil {
+ return scanNull(dest)
+ }
+
+ b, err := toBytes(src)
+ if err != nil {
+ return err
+ }
+
+ dec := bunjson.NewDecoder(bytes.NewReader(b))
+ dec.UseNumber()
+ return dec.Decode(dest.Addr().Interface())
+}
+
+func scanIP(dest reflect.Value, src interface{}) error {
+ if src == nil {
+ return scanNull(dest)
+ }
+
+ b, err := toBytes(src)
+ if err != nil {
+ return err
+ }
+
+ ip := net.ParseIP(internal.String(b))
+ if ip == nil {
+ return fmt.Errorf("bun: invalid ip: %q", b)
+ }
+
+ ptr := dest.Addr().Interface().(*net.IP)
+ *ptr = ip
+
+ return nil
+}
+
+func scanIPNet(dest reflect.Value, src interface{}) error {
+ if src == nil {
+ return scanNull(dest)
+ }
+
+ b, err := toBytes(src)
+ if err != nil {
+ return err
+ }
+
+ _, ipnet, err := net.ParseCIDR(internal.String(b))
+ if err != nil {
+ return err
+ }
+
+ ptr := dest.Addr().Interface().(*net.IPNet)
+ *ptr = *ipnet
+
+ return nil
+}
+
+func scanJSONRawMessage(dest reflect.Value, src interface{}) error {
+ if src == nil {
+ dest.SetBytes(nil)
+ return nil
+ }
+
+ b, err := toBytes(src)
+ if err != nil {
+ return err
+ }
+
+ dest.SetBytes(b)
+ return nil
+}
+
+func addrScanner(fn ScannerFunc) ScannerFunc {
+ return func(dest reflect.Value, src interface{}) error {
+ if !dest.CanAddr() {
+ return fmt.Errorf("bun: Scan(nonaddressable %T)", dest.Interface())
+ }
+ return fn(dest.Addr(), src)
+ }
+}
+
+func toBytes(src interface{}) ([]byte, error) {
+ switch src := src.(type) {
+ case string:
+ return internal.Bytes(src), nil
+ case []byte:
+ return src, nil
+ default:
+ return nil, fmt.Errorf("bun: got %T, wanted []byte or string", src)
+ }
+}
+
+func ptrScanner(fn ScannerFunc) ScannerFunc {
+ return func(dest reflect.Value, src interface{}) error {
+ if src == nil {
+ if !dest.CanAddr() {
+ if dest.IsNil() {
+ return nil
+ }
+ return fn(dest.Elem(), src)
+ }
+
+ if !dest.IsNil() {
+ dest.Set(reflect.New(dest.Type().Elem()))
+ }
+ return nil
+ }
+
+ if dest.IsNil() {
+ dest.Set(reflect.New(dest.Type().Elem()))
+ }
+ return fn(dest.Elem(), src)
+ }
+}
+
+func scanNull(dest reflect.Value) error {
+ if nilable(dest.Kind()) && dest.IsNil() {
+ return nil
+ }
+ dest.Set(reflect.New(dest.Type()).Elem())
+ return nil
+}
+
+func nilable(kind reflect.Kind) bool {
+ switch kind {
+ case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice:
+ return true
+ }
+ return false
+}
diff --git a/vendor/github.com/uptrace/bun/schema/sqlfmt.go b/vendor/github.com/uptrace/bun/schema/sqlfmt.go
new file mode 100644
index 000000000..7b538cd0c
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/schema/sqlfmt.go
@@ -0,0 +1,76 @@
+package schema
+
+type QueryAppender interface {
+ AppendQuery(fmter Formatter, b []byte) ([]byte, error)
+}
+
+type ColumnsAppender interface {
+ AppendColumns(fmter Formatter, b []byte) ([]byte, error)
+}
+
+//------------------------------------------------------------------------------
+
+// Safe represents a safe SQL query.
+type Safe string
+
+var _ QueryAppender = (*Safe)(nil)
+
+func (s Safe) AppendQuery(fmter Formatter, b []byte) ([]byte, error) {
+ return append(b, s...), nil
+}
+
+//------------------------------------------------------------------------------
+
+// Ident represents a SQL identifier, for example, table or column name.
+type Ident string
+
+var _ QueryAppender = (*Ident)(nil)
+
+func (s Ident) AppendQuery(fmter Formatter, b []byte) ([]byte, error) {
+ return fmter.AppendIdent(b, string(s)), nil
+}
+
+//------------------------------------------------------------------------------
+
+type QueryWithArgs struct {
+ Query string
+ Args []interface{}
+}
+
+var _ QueryAppender = QueryWithArgs{}
+
+func SafeQuery(query string, args []interface{}) QueryWithArgs {
+ if query != "" && args == nil {
+ args = make([]interface{}, 0)
+ }
+ return QueryWithArgs{Query: query, Args: args}
+}
+
+func UnsafeIdent(ident string) QueryWithArgs {
+ return QueryWithArgs{Query: ident}
+}
+
+func (q QueryWithArgs) IsZero() bool {
+ return q.Query == "" && q.Args == nil
+}
+
+func (q QueryWithArgs) AppendQuery(fmter Formatter, b []byte) ([]byte, error) {
+ if q.Args == nil {
+ return fmter.AppendIdent(b, q.Query), nil
+ }
+ return fmter.AppendQuery(b, q.Query, q.Args...), nil
+}
+
+//------------------------------------------------------------------------------
+
+type QueryWithSep struct {
+ QueryWithArgs
+ Sep string
+}
+
+func SafeQueryWithSep(query string, args []interface{}, sep string) QueryWithSep {
+ return QueryWithSep{
+ QueryWithArgs: SafeQuery(query, args),
+ Sep: sep,
+ }
+}
diff --git a/vendor/github.com/uptrace/bun/schema/sqltype.go b/vendor/github.com/uptrace/bun/schema/sqltype.go
new file mode 100644
index 000000000..560f695c2
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/schema/sqltype.go
@@ -0,0 +1,129 @@
+package schema
+
+import (
+ "bytes"
+ "database/sql"
+ "encoding/json"
+ "fmt"
+ "reflect"
+ "time"
+
+ "github.com/uptrace/bun/dialect"
+ "github.com/uptrace/bun/dialect/sqltype"
+ "github.com/uptrace/bun/internal"
+)
+
+var (
+ bunNullTimeType = reflect.TypeOf((*NullTime)(nil)).Elem()
+ nullTimeType = reflect.TypeOf((*sql.NullTime)(nil)).Elem()
+ nullBoolType = reflect.TypeOf((*sql.NullBool)(nil)).Elem()
+ nullFloatType = reflect.TypeOf((*sql.NullFloat64)(nil)).Elem()
+ nullIntType = reflect.TypeOf((*sql.NullInt64)(nil)).Elem()
+ nullStringType = reflect.TypeOf((*sql.NullString)(nil)).Elem()
+)
+
+var sqlTypes = []string{
+ reflect.Bool: sqltype.Boolean,
+ reflect.Int: sqltype.BigInt,
+ reflect.Int8: sqltype.SmallInt,
+ reflect.Int16: sqltype.SmallInt,
+ reflect.Int32: sqltype.Integer,
+ reflect.Int64: sqltype.BigInt,
+ reflect.Uint: sqltype.BigInt,
+ reflect.Uint8: sqltype.SmallInt,
+ reflect.Uint16: sqltype.SmallInt,
+ reflect.Uint32: sqltype.Integer,
+ reflect.Uint64: sqltype.BigInt,
+ reflect.Uintptr: sqltype.BigInt,
+ reflect.Float32: sqltype.Real,
+ reflect.Float64: sqltype.DoublePrecision,
+ reflect.Complex64: "",
+ reflect.Complex128: "",
+ reflect.Array: "",
+ reflect.Chan: "",
+ reflect.Func: "",
+ reflect.Interface: "",
+ reflect.Map: sqltype.VarChar,
+ reflect.Ptr: "",
+ reflect.Slice: sqltype.VarChar,
+ reflect.String: sqltype.VarChar,
+ reflect.Struct: sqltype.VarChar,
+ reflect.UnsafePointer: "",
+}
+
+func DiscoverSQLType(typ reflect.Type) string {
+ switch typ {
+ case timeType, nullTimeType, bunNullTimeType:
+ return sqltype.Timestamp
+ case nullBoolType:
+ return sqltype.Boolean
+ case nullFloatType:
+ return sqltype.DoublePrecision
+ case nullIntType:
+ return sqltype.BigInt
+ case nullStringType:
+ return sqltype.VarChar
+ }
+ return sqlTypes[typ.Kind()]
+}
+
+//------------------------------------------------------------------------------
+
+var jsonNull = []byte("null")
+
+// NullTime is a time.Time wrapper that marshals zero time as JSON null and SQL NULL.
+type NullTime struct {
+ time.Time
+}
+
+var (
+ _ json.Marshaler = (*NullTime)(nil)
+ _ json.Unmarshaler = (*NullTime)(nil)
+ _ sql.Scanner = (*NullTime)(nil)
+ _ QueryAppender = (*NullTime)(nil)
+)
+
+func (tm NullTime) MarshalJSON() ([]byte, error) {
+ if tm.IsZero() {
+ return jsonNull, nil
+ }
+ return tm.Time.MarshalJSON()
+}
+
+func (tm *NullTime) UnmarshalJSON(b []byte) error {
+ if bytes.Equal(b, jsonNull) {
+ tm.Time = time.Time{}
+ return nil
+ }
+ return tm.Time.UnmarshalJSON(b)
+}
+
+func (tm NullTime) AppendQuery(fmter Formatter, b []byte) ([]byte, error) {
+ if tm.IsZero() {
+ return dialect.AppendNull(b), nil
+ }
+ return dialect.AppendTime(b, tm.Time), nil
+}
+
+func (tm *NullTime) Scan(src interface{}) error {
+ if src == nil {
+ tm.Time = time.Time{}
+ return nil
+ }
+
+ switch src := src.(type) {
+ case []byte:
+ newtm, err := internal.ParseTime(internal.String(src))
+ if err != nil {
+ return err
+ }
+
+ tm.Time = newtm
+ return nil
+ case time.Time:
+ tm.Time = src
+ return nil
+ default:
+ return fmt.Errorf("bun: can't scan %#v into NullTime", src)
+ }
+}
diff --git a/vendor/github.com/uptrace/bun/schema/table.go b/vendor/github.com/uptrace/bun/schema/table.go
new file mode 100644
index 000000000..eca18b781
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/schema/table.go
@@ -0,0 +1,948 @@
+package schema
+
+import (
+ "database/sql"
+ "fmt"
+ "reflect"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/jinzhu/inflection"
+
+ "github.com/uptrace/bun/internal"
+ "github.com/uptrace/bun/internal/tagparser"
+)
+
+const (
+ beforeScanHookFlag internal.Flag = 1 << iota
+ afterScanHookFlag
+)
+
+var (
+ baseModelType = reflect.TypeOf((*BaseModel)(nil)).Elem()
+ tableNameInflector = inflection.Plural
+)
+
+type BaseModel struct{}
+
+// SetTableNameInflector overrides the default func that pluralizes
+// model name to get table name, e.g. my_article becomes my_articles.
+func SetTableNameInflector(fn func(string) string) {
+ tableNameInflector = fn
+}
+
+// Table represents a SQL table created from Go struct.
+type Table struct {
+ dialect Dialect
+
+ Type reflect.Type
+ ZeroValue reflect.Value // reflect.Struct
+ ZeroIface interface{} // struct pointer
+
+ TypeName string
+ ModelName string
+
+ Name string
+ SQLName Safe
+ SQLNameForSelects Safe
+ Alias string
+ SQLAlias Safe
+
+ Fields []*Field // PKs + DataFields
+ PKs []*Field
+ DataFields []*Field
+
+ fieldsMapMu sync.RWMutex
+ FieldMap map[string]*Field
+
+ Relations map[string]*Relation
+ Unique map[string][]*Field
+
+ SoftDeleteField *Field
+ UpdateSoftDeleteField func(fv reflect.Value) error
+
+ allFields []*Field // read only
+ skippedFields []*Field
+
+ flags internal.Flag
+}
+
+func newTable(dialect Dialect, typ reflect.Type) *Table {
+ t := new(Table)
+ t.dialect = dialect
+ t.Type = typ
+ t.ZeroValue = reflect.New(t.Type).Elem()
+ t.ZeroIface = reflect.New(t.Type).Interface()
+ t.TypeName = internal.ToExported(t.Type.Name())
+ t.ModelName = internal.Underscore(t.Type.Name())
+ tableName := tableNameInflector(t.ModelName)
+ t.setName(tableName)
+ t.Alias = t.ModelName
+ t.SQLAlias = t.quoteIdent(t.ModelName)
+
+ hooks := []struct {
+ typ reflect.Type
+ flag internal.Flag
+ }{
+ {beforeScanHookType, beforeScanHookFlag},
+ {afterScanHookType, afterScanHookFlag},
+ }
+
+ typ = reflect.PtrTo(t.Type)
+ for _, hook := range hooks {
+ if typ.Implements(hook.typ) {
+ t.flags = t.flags.Set(hook.flag)
+ }
+ }
+
+ return t
+}
+
+func (t *Table) init1() {
+ t.initFields()
+}
+
+func (t *Table) init2() {
+ t.initInlines()
+ t.initRelations()
+ t.skippedFields = nil
+}
+
+func (t *Table) setName(name string) {
+ t.Name = name
+ t.SQLName = t.quoteIdent(name)
+ t.SQLNameForSelects = t.quoteIdent(name)
+ if t.SQLAlias == "" {
+ t.Alias = name
+ t.SQLAlias = t.quoteIdent(name)
+ }
+}
+
+func (t *Table) String() string {
+ return "model=" + t.TypeName
+}
+
+func (t *Table) CheckPKs() error {
+ if len(t.PKs) == 0 {
+ return fmt.Errorf("bun: %s does not have primary keys", t)
+ }
+ return nil
+}
+
+func (t *Table) addField(field *Field) {
+ t.Fields = append(t.Fields, field)
+ if field.IsPK {
+ t.PKs = append(t.PKs, field)
+ } else {
+ t.DataFields = append(t.DataFields, field)
+ }
+ t.FieldMap[field.Name] = field
+}
+
+func (t *Table) removeField(field *Field) {
+ t.Fields = removeField(t.Fields, field)
+ if field.IsPK {
+ t.PKs = removeField(t.PKs, field)
+ } else {
+ t.DataFields = removeField(t.DataFields, field)
+ }
+ delete(t.FieldMap, field.Name)
+}
+
+func (t *Table) fieldWithLock(name string) *Field {
+ t.fieldsMapMu.RLock()
+ field := t.FieldMap[name]
+ t.fieldsMapMu.RUnlock()
+ return field
+}
+
+func (t *Table) HasField(name string) bool {
+ _, ok := t.FieldMap[name]
+ return ok
+}
+
+func (t *Table) Field(name string) (*Field, error) {
+ field, ok := t.FieldMap[name]
+ if !ok {
+ return nil, fmt.Errorf("bun: %s does not have column=%s", t, name)
+ }
+ return field, nil
+}
+
+func (t *Table) fieldByGoName(name string) *Field {
+ for _, f := range t.allFields {
+ if f.GoName == name {
+ return f
+ }
+ }
+ return nil
+}
+
+func (t *Table) initFields() {
+ t.Fields = make([]*Field, 0, t.Type.NumField())
+ t.FieldMap = make(map[string]*Field, t.Type.NumField())
+ t.addFields(t.Type, nil)
+
+ if len(t.PKs) > 0 {
+ return
+ }
+ for _, name := range []string{"id", "uuid", "pk_" + t.ModelName} {
+ if field, ok := t.FieldMap[name]; ok {
+ field.markAsPK()
+ t.PKs = []*Field{field}
+ t.DataFields = removeField(t.DataFields, field)
+ break
+ }
+ }
+ if len(t.PKs) == 1 {
+ switch t.PKs[0].IndirectType.Kind() {
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
+ reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
+ t.PKs[0].AutoIncrement = true
+ }
+ }
+}
+
+func (t *Table) addFields(typ reflect.Type, baseIndex []int) {
+ for i := 0; i < typ.NumField(); i++ {
+ f := typ.Field(i)
+
+ // Make a copy so slice is not shared between fields.
+ index := make([]int, len(baseIndex))
+ copy(index, baseIndex)
+
+ if f.Anonymous {
+ if f.Tag.Get("bun") == "-" {
+ continue
+ }
+ if f.Name == "BaseModel" && f.Type == baseModelType {
+ if len(index) == 0 {
+ t.processBaseModelField(f)
+ }
+ continue
+ }
+
+ fieldType := indirectType(f.Type)
+ if fieldType.Kind() != reflect.Struct {
+ continue
+ }
+ t.addFields(fieldType, append(index, f.Index...))
+
+ tag := tagparser.Parse(f.Tag.Get("bun"))
+ if _, inherit := tag.Options["inherit"]; inherit {
+ embeddedTable := t.dialect.Tables().Ref(fieldType)
+ t.TypeName = embeddedTable.TypeName
+ t.SQLName = embeddedTable.SQLName
+ t.SQLNameForSelects = embeddedTable.SQLNameForSelects
+ t.Alias = embeddedTable.Alias
+ t.SQLAlias = embeddedTable.SQLAlias
+ t.ModelName = embeddedTable.ModelName
+ }
+
+ continue
+ }
+
+ field := t.newField(f, index)
+ if field != nil {
+ t.addField(field)
+ }
+ }
+}
+
+func (t *Table) processBaseModelField(f reflect.StructField) {
+ tag := tagparser.Parse(f.Tag.Get("bun"))
+
+ if isKnownTableOption(tag.Name) {
+ internal.Warn.Printf(
+ "%s.%s tag name %q is also an option name; is it a mistake?",
+ t.TypeName, f.Name, tag.Name,
+ )
+ }
+
+ for name := range tag.Options {
+ if !isKnownTableOption(name) {
+ internal.Warn.Printf("%s.%s has unknown tag option: %q", t.TypeName, f.Name, name)
+ }
+ }
+
+ if tag.Name != "" {
+ t.setName(tag.Name)
+ }
+
+ if s, ok := tag.Options["select"]; ok {
+ t.SQLNameForSelects = t.quoteTableName(s)
+ }
+
+ if s, ok := tag.Options["alias"]; ok {
+ t.Alias = s
+ t.SQLAlias = t.quoteIdent(s)
+ }
+}
+
+//nolint
+func (t *Table) newField(f reflect.StructField, index []int) *Field {
+ tag := tagparser.Parse(f.Tag.Get("bun"))
+
+ if f.PkgPath != "" {
+ return nil
+ }
+
+ sqlName := internal.Underscore(f.Name)
+
+ if tag.Name != sqlName && isKnownFieldOption(tag.Name) {
+ internal.Warn.Printf(
+ "%s.%s tag name %q is also an option name; is it a mistake?",
+ t.TypeName, f.Name, tag.Name,
+ )
+ }
+
+ for name := range tag.Options {
+ if !isKnownFieldOption(name) {
+ internal.Warn.Printf("%s.%s has unknown tag option: %q", t.TypeName, f.Name, name)
+ }
+ }
+
+ skip := tag.Name == "-"
+ if !skip && tag.Name != "" {
+ sqlName = tag.Name
+ }
+
+ index = append(index, f.Index...)
+ if field := t.fieldWithLock(sqlName); field != nil {
+ if indexEqual(field.Index, index) {
+ return field
+ }
+ t.removeField(field)
+ }
+
+ field := &Field{
+ StructField: f,
+
+ Tag: tag,
+ IndirectType: indirectType(f.Type),
+ Index: index,
+
+ Name: sqlName,
+ GoName: f.Name,
+ SQLName: t.quoteIdent(sqlName),
+ }
+
+ field.NotNull = tag.HasOption("notnull")
+ field.NullZero = tag.HasOption("nullzero")
+ field.AutoIncrement = tag.HasOption("autoincrement")
+ if tag.HasOption("pk") {
+ field.markAsPK()
+ }
+ if tag.HasOption("allowzero") {
+ if tag.HasOption("nullzero") {
+ internal.Warn.Printf(
+ "%s.%s: nullzero and allowzero options are mutually exclusive",
+ t.TypeName, f.Name,
+ )
+ }
+ field.NullZero = false
+ }
+
+ if v, ok := tag.Options["unique"]; ok {
+ // Split the value by comma, this will allow multiple names to be specified.
+ // We can use this to create multiple named unique constraints where a single column
+ // might be included in multiple constraints.
+ for _, uniqueName := range strings.Split(v, ",") {
+ if t.Unique == nil {
+ t.Unique = make(map[string][]*Field)
+ }
+ t.Unique[uniqueName] = append(t.Unique[uniqueName], field)
+ }
+ }
+ if s, ok := tag.Options["default"]; ok {
+ field.SQLDefault = s
+ }
+ if s, ok := field.Tag.Options["type"]; ok {
+ field.UserSQLType = s
+ }
+ field.DiscoveredSQLType = DiscoverSQLType(field.IndirectType)
+ field.Append = t.dialect.FieldAppender(field)
+ field.Scan = FieldScanner(t.dialect, field)
+ field.IsZero = FieldZeroChecker(field)
+
+ if v, ok := tag.Options["alt"]; ok {
+ t.FieldMap[v] = field
+ }
+
+ t.allFields = append(t.allFields, field)
+ if skip {
+ t.skippedFields = append(t.skippedFields, field)
+ t.FieldMap[field.Name] = field
+ return nil
+ }
+
+ if _, ok := tag.Options["soft_delete"]; ok {
+ field.NullZero = true
+ t.SoftDeleteField = field
+ t.UpdateSoftDeleteField = softDeleteFieldUpdater(field)
+ }
+
+ return field
+}
+
+func (t *Table) initInlines() {
+ for _, f := range t.skippedFields {
+ if f.IndirectType.Kind() == reflect.Struct {
+ t.inlineFields(f, nil)
+ }
+ }
+}
+
+//---------------------------------------------------------------------------------------
+
+func (t *Table) initRelations() {
+ for i := 0; i < len(t.Fields); {
+ f := t.Fields[i]
+ if t.tryRelation(f) {
+ t.Fields = removeField(t.Fields, f)
+ t.DataFields = removeField(t.DataFields, f)
+ } else {
+ i++
+ }
+
+ if f.IndirectType.Kind() == reflect.Struct {
+ t.inlineFields(f, nil)
+ }
+ }
+}
+
+func (t *Table) tryRelation(field *Field) bool {
+ if rel, ok := field.Tag.Options["rel"]; ok {
+ t.initRelation(field, rel)
+ return true
+ }
+ if field.Tag.HasOption("m2m") {
+ t.addRelation(t.m2mRelation(field))
+ return true
+ }
+
+ if field.Tag.HasOption("join") {
+ internal.Warn.Printf(
+ `%s.%s option "join" requires a relation type`,
+ t.TypeName, field.GoName,
+ )
+ }
+
+ return false
+}
+
+func (t *Table) initRelation(field *Field, rel string) {
+ switch rel {
+ case "belongs-to":
+ t.addRelation(t.belongsToRelation(field))
+ case "has-one":
+ t.addRelation(t.hasOneRelation(field))
+ case "has-many":
+ t.addRelation(t.hasManyRelation(field))
+ default:
+ panic(fmt.Errorf("bun: unknown relation=%s on field=%s", rel, field.GoName))
+ }
+}
+
+func (t *Table) addRelation(rel *Relation) {
+ if t.Relations == nil {
+ t.Relations = make(map[string]*Relation)
+ }
+ _, ok := t.Relations[rel.Field.GoName]
+ if ok {
+ panic(fmt.Errorf("%s already has %s", t, rel))
+ }
+ t.Relations[rel.Field.GoName] = rel
+}
+
+func (t *Table) belongsToRelation(field *Field) *Relation {
+ joinTable := t.dialect.Tables().Ref(field.IndirectType)
+ if err := joinTable.CheckPKs(); err != nil {
+ panic(err)
+ }
+
+ rel := &Relation{
+ Type: HasOneRelation,
+ Field: field,
+ JoinTable: joinTable,
+ }
+
+ if join, ok := field.Tag.Options["join"]; ok {
+ baseColumns, joinColumns := parseRelationJoin(join)
+ for i, baseColumn := range baseColumns {
+ joinColumn := joinColumns[i]
+
+ if f := t.fieldWithLock(baseColumn); f != nil {
+ rel.BaseFields = append(rel.BaseFields, f)
+ } else {
+ panic(fmt.Errorf(
+ "bun: %s belongs-to %s: %s must have column %s",
+ t.TypeName, field.GoName, t.TypeName, baseColumn,
+ ))
+ }
+
+ if f := joinTable.fieldWithLock(joinColumn); f != nil {
+ rel.JoinFields = append(rel.JoinFields, f)
+ } else {
+ panic(fmt.Errorf(
+ "bun: %s belongs-to %s: %s must have column %s",
+ t.TypeName, field.GoName, t.TypeName, baseColumn,
+ ))
+ }
+ }
+ return rel
+ }
+
+ rel.JoinFields = joinTable.PKs
+ fkPrefix := internal.Underscore(field.GoName) + "_"
+ for _, joinPK := range joinTable.PKs {
+ fkName := fkPrefix + joinPK.Name
+ if fk := t.fieldWithLock(fkName); fk != nil {
+ rel.BaseFields = append(rel.BaseFields, fk)
+ continue
+ }
+
+ if fk := t.fieldWithLock(joinPK.Name); fk != nil {
+ rel.BaseFields = append(rel.BaseFields, fk)
+ continue
+ }
+
+ panic(fmt.Errorf(
+ "bun: %s belongs-to %s: %s must have column %s "+
+ "(to override, use join:base_column=join_column tag on %s field)",
+ t.TypeName, field.GoName, t.TypeName, fkName, field.GoName,
+ ))
+ }
+ return rel
+}
+
+func (t *Table) hasOneRelation(field *Field) *Relation {
+ if err := t.CheckPKs(); err != nil {
+ panic(err)
+ }
+
+ joinTable := t.dialect.Tables().Ref(field.IndirectType)
+ rel := &Relation{
+ Type: BelongsToRelation,
+ Field: field,
+ JoinTable: joinTable,
+ }
+
+ if join, ok := field.Tag.Options["join"]; ok {
+ baseColumns, joinColumns := parseRelationJoin(join)
+ for i, baseColumn := range baseColumns {
+ if f := t.fieldWithLock(baseColumn); f != nil {
+ rel.BaseFields = append(rel.BaseFields, f)
+ } else {
+ panic(fmt.Errorf(
+ "bun: %s has-one %s: %s must have column %s",
+ field.GoName, t.TypeName, joinTable.TypeName, baseColumn,
+ ))
+ }
+
+ joinColumn := joinColumns[i]
+ if f := joinTable.fieldWithLock(joinColumn); f != nil {
+ rel.JoinFields = append(rel.JoinFields, f)
+ } else {
+ panic(fmt.Errorf(
+ "bun: %s has-one %s: %s must have column %s",
+ field.GoName, t.TypeName, joinTable.TypeName, baseColumn,
+ ))
+ }
+ }
+ return rel
+ }
+
+ rel.BaseFields = t.PKs
+ fkPrefix := internal.Underscore(t.ModelName) + "_"
+ for _, pk := range t.PKs {
+ fkName := fkPrefix + pk.Name
+ if f := joinTable.fieldWithLock(fkName); f != nil {
+ rel.JoinFields = append(rel.JoinFields, f)
+ continue
+ }
+
+ if f := joinTable.fieldWithLock(pk.Name); f != nil {
+ rel.JoinFields = append(rel.JoinFields, f)
+ continue
+ }
+
+ panic(fmt.Errorf(
+ "bun: %s has-one %s: %s must have column %s "+
+ "(to override, use join:base_column=join_column tag on %s field)",
+ field.GoName, t.TypeName, joinTable.TypeName, fkName, field.GoName,
+ ))
+ }
+ return rel
+}
+
+func (t *Table) hasManyRelation(field *Field) *Relation {
+ if err := t.CheckPKs(); err != nil {
+ panic(err)
+ }
+ if field.IndirectType.Kind() != reflect.Slice {
+ panic(fmt.Errorf(
+ "bun: %s.%s has-many relation requires slice, got %q",
+ t.TypeName, field.GoName, field.IndirectType.Kind(),
+ ))
+ }
+
+ joinTable := t.dialect.Tables().Ref(indirectType(field.IndirectType.Elem()))
+ polymorphicValue, isPolymorphic := field.Tag.Options["polymorphic"]
+ rel := &Relation{
+ Type: HasManyRelation,
+ Field: field,
+ JoinTable: joinTable,
+ }
+ var polymorphicColumn string
+
+ if join, ok := field.Tag.Options["join"]; ok {
+ baseColumns, joinColumns := parseRelationJoin(join)
+ for i, baseColumn := range baseColumns {
+ joinColumn := joinColumns[i]
+
+ if isPolymorphic && baseColumn == "type" {
+ polymorphicColumn = joinColumn
+ continue
+ }
+
+ if f := t.fieldWithLock(baseColumn); f != nil {
+ rel.BaseFields = append(rel.BaseFields, f)
+ } else {
+ panic(fmt.Errorf(
+ "bun: %s has-one %s: %s must have column %s",
+ t.TypeName, field.GoName, t.TypeName, baseColumn,
+ ))
+ }
+
+ if f := joinTable.fieldWithLock(joinColumn); f != nil {
+ rel.JoinFields = append(rel.JoinFields, f)
+ } else {
+ panic(fmt.Errorf(
+ "bun: %s has-one %s: %s must have column %s",
+ t.TypeName, field.GoName, t.TypeName, baseColumn,
+ ))
+ }
+ }
+ } else {
+ rel.BaseFields = t.PKs
+ fkPrefix := internal.Underscore(t.ModelName) + "_"
+ if isPolymorphic {
+ polymorphicColumn = fkPrefix + "type"
+ }
+
+ for _, pk := range t.PKs {
+ joinColumn := fkPrefix + pk.Name
+ if fk := joinTable.fieldWithLock(joinColumn); fk != nil {
+ rel.JoinFields = append(rel.JoinFields, fk)
+ continue
+ }
+
+ if fk := joinTable.fieldWithLock(pk.Name); fk != nil {
+ rel.JoinFields = append(rel.JoinFields, fk)
+ continue
+ }
+
+ panic(fmt.Errorf(
+ "bun: %s has-many %s: %s must have column %s "+
+ "(to override, use join:base_column=join_column tag on the field %s)",
+ t.TypeName, field.GoName, joinTable.TypeName, joinColumn, field.GoName,
+ ))
+ }
+ }
+
+ if isPolymorphic {
+ rel.PolymorphicField = joinTable.fieldWithLock(polymorphicColumn)
+ if rel.PolymorphicField == nil {
+ panic(fmt.Errorf(
+ "bun: %s has-many %s: %s must have polymorphic column %s",
+ t.TypeName, field.GoName, joinTable.TypeName, polymorphicColumn,
+ ))
+ }
+
+ if polymorphicValue == "" {
+ polymorphicValue = t.ModelName
+ }
+ rel.PolymorphicValue = polymorphicValue
+ }
+
+ return rel
+}
+
+func (t *Table) m2mRelation(field *Field) *Relation {
+ if field.IndirectType.Kind() != reflect.Slice {
+ panic(fmt.Errorf(
+ "bun: %s.%s m2m relation requires slice, got %q",
+ t.TypeName, field.GoName, field.IndirectType.Kind(),
+ ))
+ }
+ joinTable := t.dialect.Tables().Ref(indirectType(field.IndirectType.Elem()))
+
+ if err := t.CheckPKs(); err != nil {
+ panic(err)
+ }
+ if err := joinTable.CheckPKs(); err != nil {
+ panic(err)
+ }
+
+ m2mTableName, ok := field.Tag.Options["m2m"]
+ if !ok {
+ panic(fmt.Errorf("bun: %s must have m2m tag option", field.GoName))
+ }
+
+ m2mTable := t.dialect.Tables().ByName(m2mTableName)
+ if m2mTable == nil {
+ panic(fmt.Errorf(
+ "bun: can't find m2m %s table (use db.RegisterModel)",
+ m2mTableName,
+ ))
+ }
+
+ rel := &Relation{
+ Type: ManyToManyRelation,
+ Field: field,
+ JoinTable: joinTable,
+ M2MTable: m2mTable,
+ }
+ var leftColumn, rightColumn string
+
+ if join, ok := field.Tag.Options["join"]; ok {
+ left, right := parseRelationJoin(join)
+ leftColumn = left[0]
+ rightColumn = right[0]
+ } else {
+ leftColumn = t.TypeName
+ rightColumn = joinTable.TypeName
+ }
+
+ leftField := m2mTable.fieldByGoName(leftColumn)
+ if leftField == nil {
+ panic(fmt.Errorf(
+ "bun: %s many-to-many %s: %s must have field %s "+
+ "(to override, use tag join:LeftField=RightField on field %s.%s",
+ t.TypeName, field.GoName, m2mTable.TypeName, leftColumn, t.TypeName, field.GoName,
+ ))
+ }
+
+ rightField := m2mTable.fieldByGoName(rightColumn)
+ if rightField == nil {
+ panic(fmt.Errorf(
+ "bun: %s many-to-many %s: %s must have field %s "+
+ "(to override, use tag join:LeftField=RightField on field %s.%s",
+ t.TypeName, field.GoName, m2mTable.TypeName, rightColumn, t.TypeName, field.GoName,
+ ))
+ }
+
+ leftRel := m2mTable.belongsToRelation(leftField)
+ rel.BaseFields = leftRel.JoinFields
+ rel.M2MBaseFields = leftRel.BaseFields
+
+ rightRel := m2mTable.belongsToRelation(rightField)
+ rel.JoinFields = rightRel.JoinFields
+ rel.M2MJoinFields = rightRel.BaseFields
+
+ return rel
+}
+
+func (t *Table) inlineFields(field *Field, path map[reflect.Type]struct{}) {
+ if path == nil {
+ path = map[reflect.Type]struct{}{
+ t.Type: {},
+ }
+ }
+
+ if _, ok := path[field.IndirectType]; ok {
+ return
+ }
+ path[field.IndirectType] = struct{}{}
+
+ joinTable := t.dialect.Tables().Ref(field.IndirectType)
+ for _, f := range joinTable.allFields {
+ f = f.Clone()
+ f.GoName = field.GoName + "_" + f.GoName
+ f.Name = field.Name + "__" + f.Name
+ f.SQLName = t.quoteIdent(f.Name)
+ f.Index = appendNew(field.Index, f.Index...)
+
+ t.fieldsMapMu.Lock()
+ if _, ok := t.FieldMap[f.Name]; !ok {
+ t.FieldMap[f.Name] = f
+ }
+ t.fieldsMapMu.Unlock()
+
+ if f.IndirectType.Kind() != reflect.Struct {
+ continue
+ }
+
+ if _, ok := path[f.IndirectType]; !ok {
+ t.inlineFields(f, path)
+ }
+ }
+}
+
+//------------------------------------------------------------------------------
+
+func (t *Table) Dialect() Dialect { return t.dialect }
+
+//------------------------------------------------------------------------------
+
+func (t *Table) HasBeforeScanHook() bool { return t.flags.Has(beforeScanHookFlag) }
+func (t *Table) HasAfterScanHook() bool { return t.flags.Has(afterScanHookFlag) }
+
+//------------------------------------------------------------------------------
+
+func (t *Table) AppendNamedArg(
+ fmter Formatter, b []byte, name string, strct reflect.Value,
+) ([]byte, bool) {
+ if field, ok := t.FieldMap[name]; ok {
+ return fmter.appendArg(b, field.Value(strct).Interface()), true
+ }
+ return b, false
+}
+
+func (t *Table) quoteTableName(s string) Safe {
+ // Don't quote if table name contains placeholder (?) or parentheses.
+ if strings.IndexByte(s, '?') >= 0 ||
+ strings.IndexByte(s, '(') >= 0 ||
+ strings.IndexByte(s, ')') >= 0 {
+ return Safe(s)
+ }
+ return t.quoteIdent(s)
+}
+
+func (t *Table) quoteIdent(s string) Safe {
+ return Safe(NewFormatter(t.dialect).AppendIdent(nil, s))
+}
+
+func appendNew(dst []int, src ...int) []int {
+ cp := make([]int, len(dst)+len(src))
+ copy(cp, dst)
+ copy(cp[len(dst):], src)
+ return cp
+}
+
+func isKnownTableOption(name string) bool {
+ switch name {
+ case "alias", "select":
+ return true
+ }
+ return false
+}
+
+func isKnownFieldOption(name string) bool {
+ switch name {
+ case "alias",
+ "type",
+ "array",
+ "hstore",
+ "composite",
+ "json_use_number",
+ "msgpack",
+ "notnull",
+ "nullzero",
+ "allowzero",
+ "default",
+ "unique",
+ "soft_delete",
+
+ "pk",
+ "autoincrement",
+ "rel",
+ "join",
+ "m2m",
+ "polymorphic":
+ return true
+ }
+ return false
+}
+
+func removeField(fields []*Field, field *Field) []*Field {
+ for i, f := range fields {
+ if f == field {
+ return append(fields[:i], fields[i+1:]...)
+ }
+ }
+ return fields
+}
+
+func parseRelationJoin(join string) ([]string, []string) {
+ ss := strings.Split(join, ",")
+ baseColumns := make([]string, len(ss))
+ joinColumns := make([]string, len(ss))
+ for i, s := range ss {
+ ss := strings.Split(strings.TrimSpace(s), "=")
+ if len(ss) != 2 {
+ panic(fmt.Errorf("can't parse relation join: %q", join))
+ }
+ baseColumns[i] = ss[0]
+ joinColumns[i] = ss[1]
+ }
+ return baseColumns, joinColumns
+}
+
+//------------------------------------------------------------------------------
+
+func softDeleteFieldUpdater(field *Field) func(fv reflect.Value) error {
+ typ := field.StructField.Type
+
+ switch typ {
+ case timeType:
+ return func(fv reflect.Value) error {
+ ptr := fv.Addr().Interface().(*time.Time)
+ *ptr = time.Now()
+ return nil
+ }
+ case nullTimeType:
+ return func(fv reflect.Value) error {
+ ptr := fv.Addr().Interface().(*sql.NullTime)
+ *ptr = sql.NullTime{Time: time.Now()}
+ return nil
+ }
+ case nullIntType:
+ return func(fv reflect.Value) error {
+ ptr := fv.Addr().Interface().(*sql.NullInt64)
+ *ptr = sql.NullInt64{Int64: time.Now().UnixNano()}
+ return nil
+ }
+ }
+
+ switch field.IndirectType.Kind() {
+ case reflect.Int64:
+ return func(fv reflect.Value) error {
+ ptr := fv.Addr().Interface().(*int64)
+ *ptr = time.Now().UnixNano()
+ return nil
+ }
+ case reflect.Ptr:
+ typ = typ.Elem()
+ default:
+ return softDeleteFieldUpdaterFallback(field)
+ }
+
+ switch typ { //nolint:gocritic
+ case timeType:
+ return func(fv reflect.Value) error {
+ now := time.Now()
+ fv.Set(reflect.ValueOf(&now))
+ return nil
+ }
+ }
+
+ switch typ.Kind() { //nolint:gocritic
+ case reflect.Int64:
+ return func(fv reflect.Value) error {
+ utime := time.Now().UnixNano()
+ fv.Set(reflect.ValueOf(&utime))
+ return nil
+ }
+ }
+
+ return softDeleteFieldUpdaterFallback(field)
+}
+
+func softDeleteFieldUpdaterFallback(field *Field) func(fv reflect.Value) error {
+ return func(fv reflect.Value) error {
+ return field.ScanWithCheck(fv, time.Now())
+ }
+}
diff --git a/vendor/github.com/uptrace/bun/schema/tables.go b/vendor/github.com/uptrace/bun/schema/tables.go
new file mode 100644
index 000000000..d82d08f59
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/schema/tables.go
@@ -0,0 +1,148 @@
+package schema
+
+import (
+ "fmt"
+ "reflect"
+ "sync"
+)
+
+type tableInProgress struct {
+ table *Table
+
+ init1Once sync.Once
+ init2Once sync.Once
+}
+
+func newTableInProgress(table *Table) *tableInProgress {
+ return &tableInProgress{
+ table: table,
+ }
+}
+
+func (inp *tableInProgress) init1() bool {
+ var inited bool
+ inp.init1Once.Do(func() {
+ inp.table.init1()
+ inited = true
+ })
+ return inited
+}
+
+func (inp *tableInProgress) init2() bool {
+ var inited bool
+ inp.init2Once.Do(func() {
+ inp.table.init2()
+ inited = true
+ })
+ return inited
+}
+
+type Tables struct {
+ dialect Dialect
+ tables sync.Map
+
+ mu sync.RWMutex
+ inProgress map[reflect.Type]*tableInProgress
+}
+
+func NewTables(dialect Dialect) *Tables {
+ return &Tables{
+ dialect: dialect,
+ inProgress: make(map[reflect.Type]*tableInProgress),
+ }
+}
+
+func (t *Tables) Register(models ...interface{}) {
+ for _, model := range models {
+ _ = t.Get(reflect.TypeOf(model).Elem())
+ }
+}
+
+func (t *Tables) Get(typ reflect.Type) *Table {
+ return t.table(typ, false)
+}
+
+func (t *Tables) Ref(typ reflect.Type) *Table {
+ return t.table(typ, true)
+}
+
+func (t *Tables) table(typ reflect.Type, allowInProgress bool) *Table {
+ if typ.Kind() != reflect.Struct {
+ panic(fmt.Errorf("got %s, wanted %s", typ.Kind(), reflect.Struct))
+ }
+
+ if v, ok := t.tables.Load(typ); ok {
+ return v.(*Table)
+ }
+
+ t.mu.Lock()
+
+ if v, ok := t.tables.Load(typ); ok {
+ t.mu.Unlock()
+ return v.(*Table)
+ }
+
+ var table *Table
+
+ inProgress := t.inProgress[typ]
+ if inProgress == nil {
+ table = newTable(t.dialect, typ)
+ inProgress = newTableInProgress(table)
+ t.inProgress[typ] = inProgress
+ } else {
+ table = inProgress.table
+ }
+
+ t.mu.Unlock()
+
+ inProgress.init1()
+ if allowInProgress {
+ return table
+ }
+
+ if inProgress.init2() {
+ t.mu.Lock()
+ delete(t.inProgress, typ)
+ t.tables.Store(typ, table)
+ t.mu.Unlock()
+ }
+
+ t.dialect.OnTable(table)
+
+ for _, field := range table.FieldMap {
+ if field.UserSQLType == "" {
+ field.UserSQLType = field.DiscoveredSQLType
+ }
+ if field.CreateTableSQLType == "" {
+ field.CreateTableSQLType = field.UserSQLType
+ }
+ }
+
+ return table
+}
+
+func (t *Tables) ByModel(name string) *Table {
+ var found *Table
+ t.tables.Range(func(key, value interface{}) bool {
+ t := value.(*Table)
+ if t.TypeName == name {
+ found = t
+ return false
+ }
+ return true
+ })
+ return found
+}
+
+func (t *Tables) ByName(name string) *Table {
+ var found *Table
+ t.tables.Range(func(key, value interface{}) bool {
+ t := value.(*Table)
+ if t.Name == name {
+ found = t
+ return false
+ }
+ return true
+ })
+ return found
+}
diff --git a/vendor/github.com/uptrace/bun/schema/util.go b/vendor/github.com/uptrace/bun/schema/util.go
new file mode 100644
index 000000000..6d474e4cc
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/schema/util.go
@@ -0,0 +1,53 @@
+package schema
+
+import "reflect"
+
+func indirectType(t reflect.Type) reflect.Type {
+ if t.Kind() == reflect.Ptr {
+ t = t.Elem()
+ }
+ return t
+}
+
+func fieldByIndex(v reflect.Value, index []int) (_ reflect.Value, ok bool) {
+ if len(index) == 1 {
+ return v.Field(index[0]), true
+ }
+
+ for i, idx := range index {
+ if i > 0 {
+ if v.Kind() == reflect.Ptr {
+ if v.IsNil() {
+ return v, false
+ }
+ v = v.Elem()
+ }
+ }
+ v = v.Field(idx)
+ }
+ return v, true
+}
+
+func fieldByIndexAlloc(v reflect.Value, index []int) reflect.Value {
+ if len(index) == 1 {
+ return v.Field(index[0])
+ }
+
+ for i, idx := range index {
+ if i > 0 {
+ v = indirectNil(v)
+ }
+ v = v.Field(idx)
+ }
+ return v
+}
+
+func indirectNil(v reflect.Value) reflect.Value {
+ if v.Kind() == reflect.Ptr {
+ if v.IsNil() {
+ v.Set(reflect.New(v.Type().Elem()))
+ }
+ v = v.Elem()
+ }
+ return v
+}
diff --git a/vendor/github.com/uptrace/bun/schema/zerochecker.go b/vendor/github.com/uptrace/bun/schema/zerochecker.go
new file mode 100644
index 000000000..95efeee6b
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/schema/zerochecker.go
@@ -0,0 +1,126 @@
+package schema
+
+import (
+ "database/sql/driver"
+ "reflect"
+)
+
+var isZeroerType = reflect.TypeOf((*isZeroer)(nil)).Elem()
+
+type isZeroer interface {
+ IsZero() bool
+}
+
+type IsZeroerFunc func(reflect.Value) bool
+
+func FieldZeroChecker(field *Field) IsZeroerFunc {
+ return zeroChecker(field.IndirectType)
+}
+
+func zeroChecker(typ reflect.Type) IsZeroerFunc {
+ if typ.Implements(isZeroerType) {
+ return isZeroInterface
+ }
+
+ kind := typ.Kind()
+
+ if kind != reflect.Ptr {
+ ptr := reflect.PtrTo(typ)
+ if ptr.Implements(isZeroerType) {
+ return addrChecker(isZeroInterface)
+ }
+ }
+
+ switch kind {
+ case reflect.Array:
+ if typ.Elem().Kind() == reflect.Uint8 {
+ return isZeroBytes
+ }
+ return isZeroLen
+ case reflect.String:
+ return isZeroLen
+ case reflect.Bool:
+ return isZeroBool
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ return isZeroInt
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
+ return isZeroUint
+ case reflect.Float32, reflect.Float64:
+ return isZeroFloat
+ case reflect.Interface, reflect.Ptr, reflect.Slice, reflect.Map:
+ return isNil
+ }
+
+ if typ.Implements(driverValuerType) {
+ return isZeroDriverValue
+ }
+
+ return notZero
+}
+
+func addrChecker(fn IsZeroerFunc) IsZeroerFunc {
+ return func(v reflect.Value) bool {
+ if !v.CanAddr() {
+ return false
+ }
+ return fn(v.Addr())
+ }
+}
+
+func isZeroInterface(v reflect.Value) bool {
+ if v.Kind() == reflect.Ptr && v.IsNil() {
+ return true
+ }
+ return v.Interface().(isZeroer).IsZero()
+}
+
+func isZeroDriverValue(v reflect.Value) bool {
+ if v.Kind() == reflect.Ptr {
+ return v.IsNil()
+ }
+
+ valuer := v.Interface().(driver.Valuer)
+ value, err := valuer.Value()
+ if err != nil {
+ return false
+ }
+ return value == nil
+}
+
+func isZeroLen(v reflect.Value) bool {
+ return v.Len() == 0
+}
+
+func isNil(v reflect.Value) bool {
+ return v.IsNil()
+}
+
+func isZeroBool(v reflect.Value) bool {
+ return !v.Bool()
+}
+
+func isZeroInt(v reflect.Value) bool {
+ return v.Int() == 0
+}
+
+func isZeroUint(v reflect.Value) bool {
+ return v.Uint() == 0
+}
+
+func isZeroFloat(v reflect.Value) bool {
+ return v.Float() == 0
+}
+
+func isZeroBytes(v reflect.Value) bool {
+ b := v.Slice(0, v.Len()).Bytes()
+ for _, c := range b {
+ if c != 0 {
+ return false
+ }
+ }
+ return true
+}
+
+func notZero(v reflect.Value) bool {
+ return false
+}