diff options
| author | 2022-04-24 12:26:22 +0200 | |
|---|---|---|
| committer | 2022-04-24 12:26:22 +0200 | |
| commit | 88979b35d462516e1765524d70a41c0d26eec911 (patch) | |
| tree | fd37cb19317217e226ee7717824f24031f53b031 /vendor/github.com/uptrace/bun/schema | |
| parent | Revert "[chore] Tidy up federating db locks a tiny bit (#472)" (#479) (diff) | |
| download | gotosocial-88979b35d462516e1765524d70a41c0d26eec911.tar.xz | |
[chore] Update bun and sqlite dependencies (#478)
* update bun + sqlite versions
* step bun to v1.1.3
Diffstat (limited to 'vendor/github.com/uptrace/bun/schema')
| -rw-r--r-- | vendor/github.com/uptrace/bun/schema/append.go | 52 | ||||
| -rw-r--r-- | vendor/github.com/uptrace/bun/schema/append_value.go | 38 | ||||
| -rw-r--r-- | vendor/github.com/uptrace/bun/schema/field.go | 36 | ||||
| -rw-r--r-- | vendor/github.com/uptrace/bun/schema/reflect.go | 4 | ||||
| -rw-r--r-- | vendor/github.com/uptrace/bun/schema/scan.go | 69 | ||||
| -rw-r--r-- | vendor/github.com/uptrace/bun/schema/sqlfmt.go | 2 | ||||
| -rw-r--r-- | vendor/github.com/uptrace/bun/schema/sqltype.go | 5 | ||||
| -rw-r--r-- | vendor/github.com/uptrace/bun/schema/table.go | 81 |
8 files changed, 201 insertions, 86 deletions
diff --git a/vendor/github.com/uptrace/bun/schema/append.go b/vendor/github.com/uptrace/bun/schema/append.go index d19f40d50..04538c036 100644 --- a/vendor/github.com/uptrace/bun/schema/append.go +++ b/vendor/github.com/uptrace/bun/schema/append.go @@ -1,6 +1,7 @@ package schema import ( + "fmt" "reflect" "strconv" "time" @@ -47,3 +48,54 @@ func Append(fmter Formatter, b []byte, v interface{}) []byte { return appender(fmter, b, vv) } } + +//------------------------------------------------------------------------------ + +func In(slice interface{}) QueryAppender { + v := reflect.ValueOf(slice) + if v.Kind() != reflect.Slice { + return &inValues{ + err: fmt.Errorf("bun: In(non-slice %T)", slice), + } + } + return &inValues{ + slice: v, + } +} + +type inValues struct { + slice reflect.Value + err error +} + +var _ QueryAppender = (*inValues)(nil) + +func (in *inValues) AppendQuery(fmter Formatter, b []byte) (_ []byte, err error) { + if in.err != nil { + return nil, in.err + } + return appendIn(fmter, b, in.slice), nil +} + +func appendIn(fmter Formatter, b []byte, slice reflect.Value) []byte { + sliceLen := slice.Len() + for i := 0; i < sliceLen; i++ { + if i > 0 { + b = append(b, ", "...) + } + + elem := slice.Index(i) + if elem.Kind() == reflect.Interface { + elem = elem.Elem() + } + + if elem.Kind() == reflect.Slice && elem.Type() != bytesType { + b = append(b, '(') + b = appendIn(fmter, b, elem) + b = append(b, ')') + } else { + b = fmter.AppendValue(b, elem) + } + } + return b +} diff --git a/vendor/github.com/uptrace/bun/schema/append_value.go b/vendor/github.com/uptrace/bun/schema/append_value.go index 5697e35e3..7e9c451db 100644 --- a/vendor/github.com/uptrace/bun/schema/append_value.go +++ b/vendor/github.com/uptrace/bun/schema/append_value.go @@ -58,12 +58,24 @@ func FieldAppender(dialect Dialect, field *Field) AppenderFunc { return appendMsgpack } + fieldType := field.StructField.Type + switch strings.ToUpper(field.UserSQLType) { case sqltype.JSON, sqltype.JSONB: + if fieldType.Implements(driverValuerType) { + return appendDriverValue + } + + if fieldType.Kind() != reflect.Ptr { + if reflect.PtrTo(fieldType).Implements(driverValuerType) { + return addrAppender(appendDriverValue) + } + } + return AppendJSONValue } - return Appender(dialect, field.StructField.Type) + return Appender(dialect, fieldType) } func Appender(dialect Dialect, typ reflect.Type) AppenderFunc { @@ -85,6 +97,8 @@ func appender(dialect Dialect, typ reflect.Type) AppenderFunc { return appendBytesValue case timeType: return appendTimeValue + case timePtrType: + return PtrAppender(appendTimeValue) case ipType: return appendIPValue case ipNetType: @@ -93,15 +107,21 @@ func appender(dialect Dialect, typ reflect.Type) AppenderFunc { return appendJSONRawMessageValue } + kind := typ.Kind() + if typ.Implements(queryAppenderType) { + if kind == reflect.Ptr { + return nilAwareAppender(appendQueryAppenderValue) + } return appendQueryAppenderValue } if typ.Implements(driverValuerType) { + if kind == reflect.Ptr { + return nilAwareAppender(appendDriverValue) + } return appendDriverValue } - kind := typ.Kind() - if kind != reflect.Ptr { ptr := reflect.PtrTo(typ) if ptr.Implements(queryAppenderType) { @@ -116,6 +136,9 @@ func appender(dialect Dialect, typ reflect.Type) AppenderFunc { case reflect.Interface: return ifaceAppenderFunc case reflect.Ptr: + if typ.Implements(jsonMarshalerType) { + return nilAwareAppender(AppendJSONValue) + } if fn := Appender(dialect, typ.Elem()); fn != nil { return PtrAppender(fn) } @@ -141,6 +164,15 @@ func ifaceAppenderFunc(fmter Formatter, b []byte, v reflect.Value) []byte { return appender(fmter, b, elem) } +func nilAwareAppender(fn AppenderFunc) AppenderFunc { + return func(fmter Formatter, b []byte, v reflect.Value) []byte { + if v.IsNil() { + return dialect.AppendNull(b) + } + return fn(fmter, b, v) + } +} + func PtrAppender(fn AppenderFunc) AppenderFunc { return func(fmter Formatter, b []byte, v reflect.Value) []byte { if v.IsNil() { diff --git a/vendor/github.com/uptrace/bun/schema/field.go b/vendor/github.com/uptrace/bun/schema/field.go index 59990b924..ac6359da4 100644 --- a/vendor/github.com/uptrace/bun/schema/field.go +++ b/vendor/github.com/uptrace/bun/schema/field.go @@ -10,6 +10,7 @@ import ( type Field struct { StructField reflect.StructField + IsPtr bool Tag tagparser.Tag IndirectType reflect.Type @@ -51,15 +52,36 @@ func (f *Field) Value(strct reflect.Value) reflect.Value { return fieldByIndexAlloc(strct, f.Index) } +func (f *Field) HasNilValue(v reflect.Value) bool { + if len(f.Index) == 1 { + return v.Field(f.Index[0]).IsNil() + } + + for _, index := range f.Index { + if v.Kind() == reflect.Ptr { + if v.IsNil() { + return true + } + v = v.Elem() + } + v = v.Field(index) + } + return v.IsNil() +} + func (f *Field) HasZeroValue(v reflect.Value) bool { - for _, idx := range f.Index { + if len(f.Index) == 1 { + return f.IsZero(v.Field(f.Index[0])) + } + + for _, index := range f.Index { if v.Kind() == reflect.Ptr { if v.IsNil() { return true } v = v.Elem() } - v = v.Field(idx) + v = v.Field(index) } return f.IsZero(v) } @@ -70,7 +92,7 @@ func (f *Field) AppendValue(fmter Formatter, b []byte, strct reflect.Value) []by return dialect.AppendNull(b) } - if f.NullZero && f.IsZero(fv) { + if (f.IsPtr && fv.IsNil()) || (f.NullZero && f.IsZero(fv)) { return dialect.AppendNull(b) } if f.Append == nil { @@ -98,14 +120,6 @@ func (f *Field) ScanValue(strct reflect.Value, src interface{}) error { return f.ScanWithCheck(fv, src) } -func (f *Field) markAsPK() { - f.IsPK = true - f.NotNull = true - if !f.Tag.HasOption("allowzero") { - f.NullZero = true - } -} - func indexEqual(ind1, ind2 []int) bool { if len(ind1) != len(ind2) { return false diff --git a/vendor/github.com/uptrace/bun/schema/reflect.go b/vendor/github.com/uptrace/bun/schema/reflect.go index 5b20b1964..f13826a6c 100644 --- a/vendor/github.com/uptrace/bun/schema/reflect.go +++ b/vendor/github.com/uptrace/bun/schema/reflect.go @@ -10,13 +10,15 @@ import ( var ( bytesType = reflect.TypeOf((*[]byte)(nil)).Elem() - timeType = reflect.TypeOf((*time.Time)(nil)).Elem() + timePtrType = reflect.TypeOf((*time.Time)(nil)) + timeType = timePtrType.Elem() ipType = reflect.TypeOf((*net.IP)(nil)).Elem() ipNetType = reflect.TypeOf((*net.IPNet)(nil)).Elem() jsonRawMessageType = reflect.TypeOf((*json.RawMessage)(nil)).Elem() driverValuerType = reflect.TypeOf((*driver.Valuer)(nil)).Elem() queryAppenderType = reflect.TypeOf((*QueryAppender)(nil)).Elem() + jsonMarshalerType = reflect.TypeOf((*json.Marshaler)(nil)).Elem() ) func indirectType(t reflect.Type) reflect.Type { diff --git a/vendor/github.com/uptrace/bun/schema/scan.go b/vendor/github.com/uptrace/bun/schema/scan.go index 30abcfc35..069b14e44 100644 --- a/vendor/github.com/uptrace/bun/schema/scan.go +++ b/vendor/github.com/uptrace/bun/schema/scan.go @@ -94,6 +94,8 @@ func scanner(typ reflect.Type) ScannerFunc { } switch typ { + case bytesType: + return scanBytes case timeType: return scanTime case ipType: @@ -134,12 +136,22 @@ func scanBool(dest reflect.Value, src interface{}) error { dest.SetBool(src != 0) return nil case []byte: - if len(src) == 1 { - dest.SetBool(src[0] != '0') - return nil + f, err := strconv.ParseBool(internal.String(src)) + if err != nil { + return err + } + dest.SetBool(f) + return nil + case string: + f, err := strconv.ParseBool(src) + if err != nil { + return err } + dest.SetBool(f) + return nil + default: + return scanError(dest.Type(), src) } - return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type()) } func scanInt64(dest reflect.Value, src interface{}) error { @@ -167,8 +179,9 @@ func scanInt64(dest reflect.Value, src interface{}) error { } dest.SetInt(n) return nil + default: + return scanError(dest.Type(), src) } - return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type()) } func scanUint64(dest reflect.Value, src interface{}) error { @@ -189,8 +202,16 @@ func scanUint64(dest reflect.Value, src interface{}) error { } dest.SetUint(n) return nil + case string: + n, err := strconv.ParseUint(src, 10, 64) + if err != nil { + return err + } + dest.SetUint(n) + return nil + default: + return scanError(dest.Type(), src) } - return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type()) } func scanFloat64(dest reflect.Value, src interface{}) error { @@ -208,8 +229,16 @@ func scanFloat64(dest reflect.Value, src interface{}) error { } dest.SetFloat(f) return nil + case string: + f, err := strconv.ParseFloat(src, 64) + if err != nil { + return err + } + dest.SetFloat(f) + return nil + default: + return scanError(dest.Type(), src) } - return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type()) } func scanString(dest reflect.Value, src interface{}) error { @@ -226,8 +255,18 @@ func scanString(dest reflect.Value, src interface{}) error { case time.Time: dest.SetString(src.Format(time.RFC3339Nano)) return nil + case int64: + dest.SetString(strconv.FormatInt(src, 10)) + return nil + case uint64: + dest.SetString(strconv.FormatUint(src, 10)) + return nil + case float64: + dest.SetString(strconv.FormatFloat(src, 'G', -1, 64)) + return nil + default: + return scanError(dest.Type(), src) } - return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type()) } func scanBytes(dest reflect.Value, src interface{}) error { @@ -244,8 +283,9 @@ func scanBytes(dest reflect.Value, src interface{}) error { dest.SetBytes(clone) return nil + default: + return scanError(dest.Type(), src) } - return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type()) } func scanTime(dest reflect.Value, src interface{}) error { @@ -274,8 +314,9 @@ func scanTime(dest reflect.Value, src interface{}) error { destTime := dest.Addr().Interface().(*time.Time) *destTime = srcTime return nil + default: + return scanError(dest.Type(), src) } - return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type()) } func scanScanner(dest reflect.Value, src interface{}) error { @@ -438,7 +479,7 @@ func scanJSONIntoInterface(dest reflect.Value, src interface{}) error { if fn := Scanner(dest.Type()); fn != nil { return fn(dest, src) } - return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type()) + return scanError(dest.Type(), src) } func scanInterface(dest reflect.Value, src interface{}) error { @@ -454,7 +495,7 @@ func scanInterface(dest reflect.Value, src interface{}) error { if fn := Scanner(dest.Type()); fn != nil { return fn(dest, src) } - return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type()) + return scanError(dest.Type(), src) } func nilable(kind reflect.Kind) bool { @@ -464,3 +505,7 @@ func nilable(kind reflect.Kind) bool { } return false } + +func scanError(dest reflect.Type, src interface{}) error { + return fmt.Errorf("bun: can't scan %#v (%T) into %s", src, src, dest.String()) +} diff --git a/vendor/github.com/uptrace/bun/schema/sqlfmt.go b/vendor/github.com/uptrace/bun/schema/sqlfmt.go index 93a801c86..a4ed24af6 100644 --- a/vendor/github.com/uptrace/bun/schema/sqlfmt.go +++ b/vendor/github.com/uptrace/bun/schema/sqlfmt.go @@ -49,7 +49,7 @@ func SafeQuery(query string, args []interface{}) QueryWithArgs { if args == nil { args = make([]interface{}, 0) } else if len(query) > 0 && strings.IndexByte(query, '?') == -1 { - internal.Warn.Printf("query %q has args %v, but no placeholders", query, args) + internal.Warn.Printf("query %q has %v args, but no placeholders", query, args) } return QueryWithArgs{ Query: query, diff --git a/vendor/github.com/uptrace/bun/schema/sqltype.go b/vendor/github.com/uptrace/bun/schema/sqltype.go index 90551d6aa..233ba641b 100644 --- a/vendor/github.com/uptrace/bun/schema/sqltype.go +++ b/vendor/github.com/uptrace/bun/schema/sqltype.go @@ -4,7 +4,6 @@ import ( "bytes" "database/sql" "encoding/json" - "fmt" "reflect" "time" @@ -60,6 +59,8 @@ func DiscoverSQLType(typ reflect.Type) string { return sqltype.BigInt case nullStringType: return sqltype.VarChar + case jsonRawMessageType: + return sqltype.JSON } switch typ.Kind() { @@ -135,6 +136,6 @@ func (tm *NullTime) Scan(src interface{}) error { tm.Time = newtm return nil default: - return fmt.Errorf("bun: can't scan %#v into NullTime", src) + return scanError(bunNullTimeType, src) } } diff --git a/vendor/github.com/uptrace/bun/schema/table.go b/vendor/github.com/uptrace/bun/schema/table.go index 88b8d8e25..1a8393fc7 100644 --- a/vendor/github.com/uptrace/bun/schema/table.go +++ b/vendor/github.com/uptrace/bun/schema/table.go @@ -204,30 +204,6 @@ func (t *Table) initFields() { t.Fields = make([]*Field, 0, t.Type.NumField()) t.FieldMap = make(map[string]*Field, t.Type.NumField()) t.addFields(t.Type, "", nil) - - if len(t.PKs) == 0 { - for _, name := range []string{"id", "uuid", "pk_" + t.ModelName} { - if field, ok := t.FieldMap[name]; ok { - field.markAsPK() - t.PKs = []*Field{field} - t.DataFields = removeField(t.DataFields, field) - break - } - } - } - - if len(t.PKs) == 1 { - pk := t.PKs[0] - if pk.SQLDefault != "" { - return - } - - switch pk.IndirectType.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, - reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - pk.AutoIncrement = true - } - } } func (t *Table) addFields(typ reflect.Type, prefix string, index []int) { @@ -250,26 +226,27 @@ func (t *Table) addFields(typ reflect.Type, prefix string, index []int) { continue } + // If field is an embedded struct, add each field of the embedded struct. fieldType := indirectType(f.Type) - if fieldType.Kind() != reflect.Struct { + if fieldType.Kind() == reflect.Struct { + t.addFields(fieldType, "", withIndex(index, f.Index)) + + tag := tagparser.Parse(f.Tag.Get("bun")) + if tag.HasOption("inherit") || tag.HasOption("extend") { + embeddedTable := t.dialect.Tables().Ref(fieldType) + t.TypeName = embeddedTable.TypeName + t.SQLName = embeddedTable.SQLName + t.SQLNameForSelects = embeddedTable.SQLNameForSelects + t.Alias = embeddedTable.Alias + t.SQLAlias = embeddedTable.SQLAlias + t.ModelName = embeddedTable.ModelName + } continue } - t.addFields(fieldType, "", withIndex(index, f.Index)) - - tag := tagparser.Parse(f.Tag.Get("bun")) - if _, inherit := tag.Options["inherit"]; inherit { - embeddedTable := t.dialect.Tables().Ref(fieldType) - t.TypeName = embeddedTable.TypeName - t.SQLName = embeddedTable.SQLName - t.SQLNameForSelects = embeddedTable.SQLNameForSelects - t.Alias = embeddedTable.Alias - t.SQLAlias = embeddedTable.SQLAlias - t.ModelName = embeddedTable.ModelName - } - - continue } + // If field is not a struct, add it. + // This will also add any embedded non-struct type as a field. if field := t.newField(f, prefix, index); field != nil { t.addField(field) } @@ -355,6 +332,7 @@ func (t *Table) newField(f reflect.StructField, prefix string, index []int) *Fie field := &Field{ StructField: f, + IsPtr: f.Type.Kind() == reflect.Ptr, Tag: tag, IndirectType: indirectType(f.Type), @@ -367,9 +345,13 @@ func (t *Table) newField(f reflect.StructField, prefix string, index []int) *Fie field.NotNull = tag.HasOption("notnull") field.NullZero = tag.HasOption("nullzero") - field.AutoIncrement = tag.HasOption("autoincrement") if tag.HasOption("pk") { - field.markAsPK() + field.IsPK = true + field.NotNull = true + } + if tag.HasOption("autoincrement") { + field.AutoIncrement = true + field.NullZero = true } if v, ok := tag.Options["unique"]; ok { @@ -415,22 +397,10 @@ func (t *Table) newField(f reflect.StructField, prefix string, index []int) *Fie } if _, ok := tag.Options["soft_delete"]; ok { - field.NullZero = true t.SoftDeleteField = field t.UpdateSoftDeleteField = softDeleteFieldUpdater(field) } - // Check this in the end to undo NullZero. - if tag.HasOption("allowzero") { - if tag.HasOption("nullzero") { - internal.Warn.Printf( - "%s.%s: nullzero and allowzero options are mutually exclusive", - t.TypeName, f.Name, - ) - } - field.NullZero = false - } - return field } @@ -651,7 +621,7 @@ func (t *Table) hasManyRelation(field *Field) *Relation { rel.BaseFields = append(rel.BaseFields, f) } else { panic(fmt.Errorf( - "bun: %s has-one %s: %s must have column %s", + "bun: %s has-many %s: %s must have column %s", t.TypeName, field.GoName, t.TypeName, baseColumn, )) } @@ -660,7 +630,7 @@ func (t *Table) hasManyRelation(field *Field) *Relation { rel.JoinFields = append(rel.JoinFields, f) } else { panic(fmt.Errorf( - "bun: %s has-one %s: %s must have column %s", + "bun: %s has-many %s: %s must have column %s", t.TypeName, field.GoName, t.TypeName, baseColumn, )) } @@ -879,7 +849,6 @@ func isKnownFieldOption(name string) bool { "msgpack", "notnull", "nullzero", - "allowzero", "default", "unique", "soft_delete", |
