summaryrefslogtreecommitdiff
path: root/vendor/github.com/uptrace/bun/schema
diff options
context:
space:
mode:
authorLibravatar tobi <31960611+tsmethurst@users.noreply.github.com>2022-04-24 12:26:22 +0200
committerLibravatar GitHub <noreply@github.com>2022-04-24 12:26:22 +0200
commit88979b35d462516e1765524d70a41c0d26eec911 (patch)
treefd37cb19317217e226ee7717824f24031f53b031 /vendor/github.com/uptrace/bun/schema
parentRevert "[chore] Tidy up federating db locks a tiny bit (#472)" (#479) (diff)
downloadgotosocial-88979b35d462516e1765524d70a41c0d26eec911.tar.xz
[chore] Update bun and sqlite dependencies (#478)
* update bun + sqlite versions * step bun to v1.1.3
Diffstat (limited to 'vendor/github.com/uptrace/bun/schema')
-rw-r--r--vendor/github.com/uptrace/bun/schema/append.go52
-rw-r--r--vendor/github.com/uptrace/bun/schema/append_value.go38
-rw-r--r--vendor/github.com/uptrace/bun/schema/field.go36
-rw-r--r--vendor/github.com/uptrace/bun/schema/reflect.go4
-rw-r--r--vendor/github.com/uptrace/bun/schema/scan.go69
-rw-r--r--vendor/github.com/uptrace/bun/schema/sqlfmt.go2
-rw-r--r--vendor/github.com/uptrace/bun/schema/sqltype.go5
-rw-r--r--vendor/github.com/uptrace/bun/schema/table.go81
8 files changed, 201 insertions, 86 deletions
diff --git a/vendor/github.com/uptrace/bun/schema/append.go b/vendor/github.com/uptrace/bun/schema/append.go
index d19f40d50..04538c036 100644
--- a/vendor/github.com/uptrace/bun/schema/append.go
+++ b/vendor/github.com/uptrace/bun/schema/append.go
@@ -1,6 +1,7 @@
package schema
import (
+ "fmt"
"reflect"
"strconv"
"time"
@@ -47,3 +48,54 @@ func Append(fmter Formatter, b []byte, v interface{}) []byte {
return appender(fmter, b, vv)
}
}
+
+//------------------------------------------------------------------------------
+
+func In(slice interface{}) QueryAppender {
+ v := reflect.ValueOf(slice)
+ if v.Kind() != reflect.Slice {
+ return &inValues{
+ err: fmt.Errorf("bun: In(non-slice %T)", slice),
+ }
+ }
+ return &inValues{
+ slice: v,
+ }
+}
+
+type inValues struct {
+ slice reflect.Value
+ err error
+}
+
+var _ QueryAppender = (*inValues)(nil)
+
+func (in *inValues) AppendQuery(fmter Formatter, b []byte) (_ []byte, err error) {
+ if in.err != nil {
+ return nil, in.err
+ }
+ return appendIn(fmter, b, in.slice), nil
+}
+
+func appendIn(fmter Formatter, b []byte, slice reflect.Value) []byte {
+ sliceLen := slice.Len()
+ for i := 0; i < sliceLen; i++ {
+ if i > 0 {
+ b = append(b, ", "...)
+ }
+
+ elem := slice.Index(i)
+ if elem.Kind() == reflect.Interface {
+ elem = elem.Elem()
+ }
+
+ if elem.Kind() == reflect.Slice && elem.Type() != bytesType {
+ b = append(b, '(')
+ b = appendIn(fmter, b, elem)
+ b = append(b, ')')
+ } else {
+ b = fmter.AppendValue(b, elem)
+ }
+ }
+ return b
+}
diff --git a/vendor/github.com/uptrace/bun/schema/append_value.go b/vendor/github.com/uptrace/bun/schema/append_value.go
index 5697e35e3..7e9c451db 100644
--- a/vendor/github.com/uptrace/bun/schema/append_value.go
+++ b/vendor/github.com/uptrace/bun/schema/append_value.go
@@ -58,12 +58,24 @@ func FieldAppender(dialect Dialect, field *Field) AppenderFunc {
return appendMsgpack
}
+ fieldType := field.StructField.Type
+
switch strings.ToUpper(field.UserSQLType) {
case sqltype.JSON, sqltype.JSONB:
+ if fieldType.Implements(driverValuerType) {
+ return appendDriverValue
+ }
+
+ if fieldType.Kind() != reflect.Ptr {
+ if reflect.PtrTo(fieldType).Implements(driverValuerType) {
+ return addrAppender(appendDriverValue)
+ }
+ }
+
return AppendJSONValue
}
- return Appender(dialect, field.StructField.Type)
+ return Appender(dialect, fieldType)
}
func Appender(dialect Dialect, typ reflect.Type) AppenderFunc {
@@ -85,6 +97,8 @@ func appender(dialect Dialect, typ reflect.Type) AppenderFunc {
return appendBytesValue
case timeType:
return appendTimeValue
+ case timePtrType:
+ return PtrAppender(appendTimeValue)
case ipType:
return appendIPValue
case ipNetType:
@@ -93,15 +107,21 @@ func appender(dialect Dialect, typ reflect.Type) AppenderFunc {
return appendJSONRawMessageValue
}
+ kind := typ.Kind()
+
if typ.Implements(queryAppenderType) {
+ if kind == reflect.Ptr {
+ return nilAwareAppender(appendQueryAppenderValue)
+ }
return appendQueryAppenderValue
}
if typ.Implements(driverValuerType) {
+ if kind == reflect.Ptr {
+ return nilAwareAppender(appendDriverValue)
+ }
return appendDriverValue
}
- kind := typ.Kind()
-
if kind != reflect.Ptr {
ptr := reflect.PtrTo(typ)
if ptr.Implements(queryAppenderType) {
@@ -116,6 +136,9 @@ func appender(dialect Dialect, typ reflect.Type) AppenderFunc {
case reflect.Interface:
return ifaceAppenderFunc
case reflect.Ptr:
+ if typ.Implements(jsonMarshalerType) {
+ return nilAwareAppender(AppendJSONValue)
+ }
if fn := Appender(dialect, typ.Elem()); fn != nil {
return PtrAppender(fn)
}
@@ -141,6 +164,15 @@ func ifaceAppenderFunc(fmter Formatter, b []byte, v reflect.Value) []byte {
return appender(fmter, b, elem)
}
+func nilAwareAppender(fn AppenderFunc) AppenderFunc {
+ return func(fmter Formatter, b []byte, v reflect.Value) []byte {
+ if v.IsNil() {
+ return dialect.AppendNull(b)
+ }
+ return fn(fmter, b, v)
+ }
+}
+
func PtrAppender(fn AppenderFunc) AppenderFunc {
return func(fmter Formatter, b []byte, v reflect.Value) []byte {
if v.IsNil() {
diff --git a/vendor/github.com/uptrace/bun/schema/field.go b/vendor/github.com/uptrace/bun/schema/field.go
index 59990b924..ac6359da4 100644
--- a/vendor/github.com/uptrace/bun/schema/field.go
+++ b/vendor/github.com/uptrace/bun/schema/field.go
@@ -10,6 +10,7 @@ import (
type Field struct {
StructField reflect.StructField
+ IsPtr bool
Tag tagparser.Tag
IndirectType reflect.Type
@@ -51,15 +52,36 @@ func (f *Field) Value(strct reflect.Value) reflect.Value {
return fieldByIndexAlloc(strct, f.Index)
}
+func (f *Field) HasNilValue(v reflect.Value) bool {
+ if len(f.Index) == 1 {
+ return v.Field(f.Index[0]).IsNil()
+ }
+
+ for _, index := range f.Index {
+ if v.Kind() == reflect.Ptr {
+ if v.IsNil() {
+ return true
+ }
+ v = v.Elem()
+ }
+ v = v.Field(index)
+ }
+ return v.IsNil()
+}
+
func (f *Field) HasZeroValue(v reflect.Value) bool {
- for _, idx := range f.Index {
+ if len(f.Index) == 1 {
+ return f.IsZero(v.Field(f.Index[0]))
+ }
+
+ for _, index := range f.Index {
if v.Kind() == reflect.Ptr {
if v.IsNil() {
return true
}
v = v.Elem()
}
- v = v.Field(idx)
+ v = v.Field(index)
}
return f.IsZero(v)
}
@@ -70,7 +92,7 @@ func (f *Field) AppendValue(fmter Formatter, b []byte, strct reflect.Value) []by
return dialect.AppendNull(b)
}
- if f.NullZero && f.IsZero(fv) {
+ if (f.IsPtr && fv.IsNil()) || (f.NullZero && f.IsZero(fv)) {
return dialect.AppendNull(b)
}
if f.Append == nil {
@@ -98,14 +120,6 @@ func (f *Field) ScanValue(strct reflect.Value, src interface{}) error {
return f.ScanWithCheck(fv, src)
}
-func (f *Field) markAsPK() {
- f.IsPK = true
- f.NotNull = true
- if !f.Tag.HasOption("allowzero") {
- f.NullZero = true
- }
-}
-
func indexEqual(ind1, ind2 []int) bool {
if len(ind1) != len(ind2) {
return false
diff --git a/vendor/github.com/uptrace/bun/schema/reflect.go b/vendor/github.com/uptrace/bun/schema/reflect.go
index 5b20b1964..f13826a6c 100644
--- a/vendor/github.com/uptrace/bun/schema/reflect.go
+++ b/vendor/github.com/uptrace/bun/schema/reflect.go
@@ -10,13 +10,15 @@ import (
var (
bytesType = reflect.TypeOf((*[]byte)(nil)).Elem()
- timeType = reflect.TypeOf((*time.Time)(nil)).Elem()
+ timePtrType = reflect.TypeOf((*time.Time)(nil))
+ timeType = timePtrType.Elem()
ipType = reflect.TypeOf((*net.IP)(nil)).Elem()
ipNetType = reflect.TypeOf((*net.IPNet)(nil)).Elem()
jsonRawMessageType = reflect.TypeOf((*json.RawMessage)(nil)).Elem()
driverValuerType = reflect.TypeOf((*driver.Valuer)(nil)).Elem()
queryAppenderType = reflect.TypeOf((*QueryAppender)(nil)).Elem()
+ jsonMarshalerType = reflect.TypeOf((*json.Marshaler)(nil)).Elem()
)
func indirectType(t reflect.Type) reflect.Type {
diff --git a/vendor/github.com/uptrace/bun/schema/scan.go b/vendor/github.com/uptrace/bun/schema/scan.go
index 30abcfc35..069b14e44 100644
--- a/vendor/github.com/uptrace/bun/schema/scan.go
+++ b/vendor/github.com/uptrace/bun/schema/scan.go
@@ -94,6 +94,8 @@ func scanner(typ reflect.Type) ScannerFunc {
}
switch typ {
+ case bytesType:
+ return scanBytes
case timeType:
return scanTime
case ipType:
@@ -134,12 +136,22 @@ func scanBool(dest reflect.Value, src interface{}) error {
dest.SetBool(src != 0)
return nil
case []byte:
- if len(src) == 1 {
- dest.SetBool(src[0] != '0')
- return nil
+ f, err := strconv.ParseBool(internal.String(src))
+ if err != nil {
+ return err
+ }
+ dest.SetBool(f)
+ return nil
+ case string:
+ f, err := strconv.ParseBool(src)
+ if err != nil {
+ return err
}
+ dest.SetBool(f)
+ return nil
+ default:
+ return scanError(dest.Type(), src)
}
- return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type())
}
func scanInt64(dest reflect.Value, src interface{}) error {
@@ -167,8 +179,9 @@ func scanInt64(dest reflect.Value, src interface{}) error {
}
dest.SetInt(n)
return nil
+ default:
+ return scanError(dest.Type(), src)
}
- return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type())
}
func scanUint64(dest reflect.Value, src interface{}) error {
@@ -189,8 +202,16 @@ func scanUint64(dest reflect.Value, src interface{}) error {
}
dest.SetUint(n)
return nil
+ case string:
+ n, err := strconv.ParseUint(src, 10, 64)
+ if err != nil {
+ return err
+ }
+ dest.SetUint(n)
+ return nil
+ default:
+ return scanError(dest.Type(), src)
}
- return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type())
}
func scanFloat64(dest reflect.Value, src interface{}) error {
@@ -208,8 +229,16 @@ func scanFloat64(dest reflect.Value, src interface{}) error {
}
dest.SetFloat(f)
return nil
+ case string:
+ f, err := strconv.ParseFloat(src, 64)
+ if err != nil {
+ return err
+ }
+ dest.SetFloat(f)
+ return nil
+ default:
+ return scanError(dest.Type(), src)
}
- return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type())
}
func scanString(dest reflect.Value, src interface{}) error {
@@ -226,8 +255,18 @@ func scanString(dest reflect.Value, src interface{}) error {
case time.Time:
dest.SetString(src.Format(time.RFC3339Nano))
return nil
+ case int64:
+ dest.SetString(strconv.FormatInt(src, 10))
+ return nil
+ case uint64:
+ dest.SetString(strconv.FormatUint(src, 10))
+ return nil
+ case float64:
+ dest.SetString(strconv.FormatFloat(src, 'G', -1, 64))
+ return nil
+ default:
+ return scanError(dest.Type(), src)
}
- return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type())
}
func scanBytes(dest reflect.Value, src interface{}) error {
@@ -244,8 +283,9 @@ func scanBytes(dest reflect.Value, src interface{}) error {
dest.SetBytes(clone)
return nil
+ default:
+ return scanError(dest.Type(), src)
}
- return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type())
}
func scanTime(dest reflect.Value, src interface{}) error {
@@ -274,8 +314,9 @@ func scanTime(dest reflect.Value, src interface{}) error {
destTime := dest.Addr().Interface().(*time.Time)
*destTime = srcTime
return nil
+ default:
+ return scanError(dest.Type(), src)
}
- return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type())
}
func scanScanner(dest reflect.Value, src interface{}) error {
@@ -438,7 +479,7 @@ func scanJSONIntoInterface(dest reflect.Value, src interface{}) error {
if fn := Scanner(dest.Type()); fn != nil {
return fn(dest, src)
}
- return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type())
+ return scanError(dest.Type(), src)
}
func scanInterface(dest reflect.Value, src interface{}) error {
@@ -454,7 +495,7 @@ func scanInterface(dest reflect.Value, src interface{}) error {
if fn := Scanner(dest.Type()); fn != nil {
return fn(dest, src)
}
- return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type())
+ return scanError(dest.Type(), src)
}
func nilable(kind reflect.Kind) bool {
@@ -464,3 +505,7 @@ func nilable(kind reflect.Kind) bool {
}
return false
}
+
+func scanError(dest reflect.Type, src interface{}) error {
+ return fmt.Errorf("bun: can't scan %#v (%T) into %s", src, src, dest.String())
+}
diff --git a/vendor/github.com/uptrace/bun/schema/sqlfmt.go b/vendor/github.com/uptrace/bun/schema/sqlfmt.go
index 93a801c86..a4ed24af6 100644
--- a/vendor/github.com/uptrace/bun/schema/sqlfmt.go
+++ b/vendor/github.com/uptrace/bun/schema/sqlfmt.go
@@ -49,7 +49,7 @@ func SafeQuery(query string, args []interface{}) QueryWithArgs {
if args == nil {
args = make([]interface{}, 0)
} else if len(query) > 0 && strings.IndexByte(query, '?') == -1 {
- internal.Warn.Printf("query %q has args %v, but no placeholders", query, args)
+ internal.Warn.Printf("query %q has %v args, but no placeholders", query, args)
}
return QueryWithArgs{
Query: query,
diff --git a/vendor/github.com/uptrace/bun/schema/sqltype.go b/vendor/github.com/uptrace/bun/schema/sqltype.go
index 90551d6aa..233ba641b 100644
--- a/vendor/github.com/uptrace/bun/schema/sqltype.go
+++ b/vendor/github.com/uptrace/bun/schema/sqltype.go
@@ -4,7 +4,6 @@ import (
"bytes"
"database/sql"
"encoding/json"
- "fmt"
"reflect"
"time"
@@ -60,6 +59,8 @@ func DiscoverSQLType(typ reflect.Type) string {
return sqltype.BigInt
case nullStringType:
return sqltype.VarChar
+ case jsonRawMessageType:
+ return sqltype.JSON
}
switch typ.Kind() {
@@ -135,6 +136,6 @@ func (tm *NullTime) Scan(src interface{}) error {
tm.Time = newtm
return nil
default:
- return fmt.Errorf("bun: can't scan %#v into NullTime", src)
+ return scanError(bunNullTimeType, src)
}
}
diff --git a/vendor/github.com/uptrace/bun/schema/table.go b/vendor/github.com/uptrace/bun/schema/table.go
index 88b8d8e25..1a8393fc7 100644
--- a/vendor/github.com/uptrace/bun/schema/table.go
+++ b/vendor/github.com/uptrace/bun/schema/table.go
@@ -204,30 +204,6 @@ func (t *Table) initFields() {
t.Fields = make([]*Field, 0, t.Type.NumField())
t.FieldMap = make(map[string]*Field, t.Type.NumField())
t.addFields(t.Type, "", nil)
-
- if len(t.PKs) == 0 {
- for _, name := range []string{"id", "uuid", "pk_" + t.ModelName} {
- if field, ok := t.FieldMap[name]; ok {
- field.markAsPK()
- t.PKs = []*Field{field}
- t.DataFields = removeField(t.DataFields, field)
- break
- }
- }
- }
-
- if len(t.PKs) == 1 {
- pk := t.PKs[0]
- if pk.SQLDefault != "" {
- return
- }
-
- switch pk.IndirectType.Kind() {
- case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
- reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
- pk.AutoIncrement = true
- }
- }
}
func (t *Table) addFields(typ reflect.Type, prefix string, index []int) {
@@ -250,26 +226,27 @@ func (t *Table) addFields(typ reflect.Type, prefix string, index []int) {
continue
}
+ // If field is an embedded struct, add each field of the embedded struct.
fieldType := indirectType(f.Type)
- if fieldType.Kind() != reflect.Struct {
+ if fieldType.Kind() == reflect.Struct {
+ t.addFields(fieldType, "", withIndex(index, f.Index))
+
+ tag := tagparser.Parse(f.Tag.Get("bun"))
+ if tag.HasOption("inherit") || tag.HasOption("extend") {
+ embeddedTable := t.dialect.Tables().Ref(fieldType)
+ t.TypeName = embeddedTable.TypeName
+ t.SQLName = embeddedTable.SQLName
+ t.SQLNameForSelects = embeddedTable.SQLNameForSelects
+ t.Alias = embeddedTable.Alias
+ t.SQLAlias = embeddedTable.SQLAlias
+ t.ModelName = embeddedTable.ModelName
+ }
continue
}
- t.addFields(fieldType, "", withIndex(index, f.Index))
-
- tag := tagparser.Parse(f.Tag.Get("bun"))
- if _, inherit := tag.Options["inherit"]; inherit {
- embeddedTable := t.dialect.Tables().Ref(fieldType)
- t.TypeName = embeddedTable.TypeName
- t.SQLName = embeddedTable.SQLName
- t.SQLNameForSelects = embeddedTable.SQLNameForSelects
- t.Alias = embeddedTable.Alias
- t.SQLAlias = embeddedTable.SQLAlias
- t.ModelName = embeddedTable.ModelName
- }
-
- continue
}
+ // If field is not a struct, add it.
+ // This will also add any embedded non-struct type as a field.
if field := t.newField(f, prefix, index); field != nil {
t.addField(field)
}
@@ -355,6 +332,7 @@ func (t *Table) newField(f reflect.StructField, prefix string, index []int) *Fie
field := &Field{
StructField: f,
+ IsPtr: f.Type.Kind() == reflect.Ptr,
Tag: tag,
IndirectType: indirectType(f.Type),
@@ -367,9 +345,13 @@ func (t *Table) newField(f reflect.StructField, prefix string, index []int) *Fie
field.NotNull = tag.HasOption("notnull")
field.NullZero = tag.HasOption("nullzero")
- field.AutoIncrement = tag.HasOption("autoincrement")
if tag.HasOption("pk") {
- field.markAsPK()
+ field.IsPK = true
+ field.NotNull = true
+ }
+ if tag.HasOption("autoincrement") {
+ field.AutoIncrement = true
+ field.NullZero = true
}
if v, ok := tag.Options["unique"]; ok {
@@ -415,22 +397,10 @@ func (t *Table) newField(f reflect.StructField, prefix string, index []int) *Fie
}
if _, ok := tag.Options["soft_delete"]; ok {
- field.NullZero = true
t.SoftDeleteField = field
t.UpdateSoftDeleteField = softDeleteFieldUpdater(field)
}
- // Check this in the end to undo NullZero.
- if tag.HasOption("allowzero") {
- if tag.HasOption("nullzero") {
- internal.Warn.Printf(
- "%s.%s: nullzero and allowzero options are mutually exclusive",
- t.TypeName, f.Name,
- )
- }
- field.NullZero = false
- }
-
return field
}
@@ -651,7 +621,7 @@ func (t *Table) hasManyRelation(field *Field) *Relation {
rel.BaseFields = append(rel.BaseFields, f)
} else {
panic(fmt.Errorf(
- "bun: %s has-one %s: %s must have column %s",
+ "bun: %s has-many %s: %s must have column %s",
t.TypeName, field.GoName, t.TypeName, baseColumn,
))
}
@@ -660,7 +630,7 @@ func (t *Table) hasManyRelation(field *Field) *Relation {
rel.JoinFields = append(rel.JoinFields, f)
} else {
panic(fmt.Errorf(
- "bun: %s has-one %s: %s must have column %s",
+ "bun: %s has-many %s: %s must have column %s",
t.TypeName, field.GoName, t.TypeName, baseColumn,
))
}
@@ -879,7 +849,6 @@ func isKnownFieldOption(name string) bool {
"msgpack",
"notnull",
"nullzero",
- "allowzero",
"default",
"unique",
"soft_delete",