summaryrefslogtreecommitdiff
path: root/vendor/github.com/uptrace/bun/schema
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/uptrace/bun/schema')
-rw-r--r--vendor/github.com/uptrace/bun/schema/append_value.go22
-rw-r--r--vendor/github.com/uptrace/bun/schema/formatter.go8
-rw-r--r--vendor/github.com/uptrace/bun/schema/reflect.go (renamed from vendor/github.com/uptrace/bun/schema/util.go)19
-rw-r--r--vendor/github.com/uptrace/bun/schema/scan.go121
-rw-r--r--vendor/github.com/uptrace/bun/schema/sqlfmt.go2
-rw-r--r--vendor/github.com/uptrace/bun/schema/sqltype.go49
-rw-r--r--vendor/github.com/uptrace/bun/schema/table.go104
-rw-r--r--vendor/github.com/uptrace/bun/schema/tables.go1
8 files changed, 191 insertions, 135 deletions
diff --git a/vendor/github.com/uptrace/bun/schema/append_value.go b/vendor/github.com/uptrace/bun/schema/append_value.go
index 0c4677069..948ff86af 100644
--- a/vendor/github.com/uptrace/bun/schema/append_value.go
+++ b/vendor/github.com/uptrace/bun/schema/append_value.go
@@ -2,7 +2,6 @@ package schema
import (
"database/sql/driver"
- "encoding/json"
"fmt"
"net"
"reflect"
@@ -14,16 +13,6 @@ import (
"github.com/uptrace/bun/internal"
)
-var (
- timeType = reflect.TypeOf((*time.Time)(nil)).Elem()
- ipType = reflect.TypeOf((*net.IP)(nil)).Elem()
- ipNetType = reflect.TypeOf((*net.IPNet)(nil)).Elem()
- jsonRawMessageType = reflect.TypeOf((*json.RawMessage)(nil)).Elem()
-
- driverValuerType = reflect.TypeOf((*driver.Valuer)(nil)).Elem()
- queryAppenderType = reflect.TypeOf((*QueryAppender)(nil)).Elem()
-)
-
type (
AppenderFunc func(fmter Formatter, b []byte, v reflect.Value) []byte
CustomAppender func(typ reflect.Type) AppenderFunc
@@ -60,6 +49,8 @@ var appenders = []AppenderFunc{
func Appender(typ reflect.Type, custom CustomAppender) AppenderFunc {
switch typ {
+ case bytesType:
+ return appendBytesValue
case timeType:
return appendTimeValue
case ipType:
@@ -93,7 +84,9 @@ func Appender(typ reflect.Type, custom CustomAppender) AppenderFunc {
case reflect.Interface:
return ifaceAppenderFunc(typ, custom)
case reflect.Ptr:
- return ptrAppenderFunc(typ, custom)
+ if fn := Appender(typ.Elem(), custom); fn != nil {
+ return PtrAppender(fn)
+ }
case reflect.Slice:
if typ.Elem().Kind() == reflect.Uint8 {
return appendBytesValue
@@ -123,13 +116,12 @@ func ifaceAppenderFunc(typ reflect.Type, custom func(reflect.Type) AppenderFunc)
}
}
-func ptrAppenderFunc(typ reflect.Type, custom func(reflect.Type) AppenderFunc) AppenderFunc {
- appender := Appender(typ.Elem(), custom)
+func PtrAppender(fn AppenderFunc) AppenderFunc {
return func(fmter Formatter, b []byte, v reflect.Value) []byte {
if v.IsNil() {
return dialect.AppendNull(b)
}
- return appender(fmter, b, v.Elem())
+ return fn(fmter, b, v.Elem())
}
}
diff --git a/vendor/github.com/uptrace/bun/schema/formatter.go b/vendor/github.com/uptrace/bun/schema/formatter.go
index 7b26fbaca..45a246307 100644
--- a/vendor/github.com/uptrace/bun/schema/formatter.go
+++ b/vendor/github.com/uptrace/bun/schema/formatter.go
@@ -89,10 +89,10 @@ func (f Formatter) AppendQuery(dst []byte, query string, args ...interface{}) []
func (f Formatter) append(dst []byte, p *parser.Parser, args []interface{}) []byte {
var namedArgs NamedArgAppender
if len(args) == 1 {
- var ok bool
- namedArgs, ok = args[0].(NamedArgAppender)
- if !ok {
- namedArgs, _ = newStructArgs(f, args[0])
+ if v, ok := args[0].(NamedArgAppender); ok {
+ namedArgs = v
+ } else if v, ok := newStructArgs(f, args[0]); ok {
+ namedArgs = v
}
}
diff --git a/vendor/github.com/uptrace/bun/schema/util.go b/vendor/github.com/uptrace/bun/schema/reflect.go
index 6d474e4cc..5b20b1964 100644
--- a/vendor/github.com/uptrace/bun/schema/util.go
+++ b/vendor/github.com/uptrace/bun/schema/reflect.go
@@ -1,6 +1,23 @@
package schema
-import "reflect"
+import (
+ "database/sql/driver"
+ "encoding/json"
+ "net"
+ "reflect"
+ "time"
+)
+
+var (
+ bytesType = reflect.TypeOf((*[]byte)(nil)).Elem()
+ timeType = reflect.TypeOf((*time.Time)(nil)).Elem()
+ ipType = reflect.TypeOf((*net.IP)(nil)).Elem()
+ ipNetType = reflect.TypeOf((*net.IPNet)(nil)).Elem()
+ jsonRawMessageType = reflect.TypeOf((*json.RawMessage)(nil)).Elem()
+
+ driverValuerType = reflect.TypeOf((*driver.Valuer)(nil)).Elem()
+ queryAppenderType = reflect.TypeOf((*QueryAppender)(nil)).Elem()
+)
func indirectType(t reflect.Type) reflect.Type {
if t.Kind() == reflect.Ptr {
diff --git a/vendor/github.com/uptrace/bun/schema/scan.go b/vendor/github.com/uptrace/bun/schema/scan.go
index 0e66a860f..85ba62a01 100644
--- a/vendor/github.com/uptrace/bun/schema/scan.go
+++ b/vendor/github.com/uptrace/bun/schema/scan.go
@@ -7,10 +7,12 @@ import (
"net"
"reflect"
"strconv"
+ "strings"
"time"
"github.com/vmihailenco/msgpack/v5"
+ "github.com/uptrace/bun/dialect/sqltype"
"github.com/uptrace/bun/extra/bunjson"
"github.com/uptrace/bun/internal"
)
@@ -19,32 +21,35 @@ var scannerType = reflect.TypeOf((*sql.Scanner)(nil)).Elem()
type ScannerFunc func(dest reflect.Value, src interface{}) error
-var scanners = []ScannerFunc{
- reflect.Bool: scanBool,
- reflect.Int: scanInt64,
- reflect.Int8: scanInt64,
- reflect.Int16: scanInt64,
- reflect.Int32: scanInt64,
- reflect.Int64: scanInt64,
- reflect.Uint: scanUint64,
- reflect.Uint8: scanUint64,
- reflect.Uint16: scanUint64,
- reflect.Uint32: scanUint64,
- reflect.Uint64: scanUint64,
- reflect.Uintptr: scanUint64,
- reflect.Float32: scanFloat64,
- reflect.Float64: scanFloat64,
- reflect.Complex64: nil,
- reflect.Complex128: nil,
- reflect.Array: nil,
- reflect.Chan: nil,
- reflect.Func: nil,
- reflect.Map: scanJSON,
- reflect.Ptr: nil,
- reflect.Slice: scanJSON,
- reflect.String: scanString,
- reflect.Struct: scanJSON,
- reflect.UnsafePointer: nil,
+var scanners []ScannerFunc
+
+func init() {
+ scanners = []ScannerFunc{
+ reflect.Bool: scanBool,
+ reflect.Int: scanInt64,
+ reflect.Int8: scanInt64,
+ reflect.Int16: scanInt64,
+ reflect.Int32: scanInt64,
+ reflect.Int64: scanInt64,
+ reflect.Uint: scanUint64,
+ reflect.Uint8: scanUint64,
+ reflect.Uint16: scanUint64,
+ reflect.Uint32: scanUint64,
+ reflect.Uint64: scanUint64,
+ reflect.Uintptr: scanUint64,
+ reflect.Float32: scanFloat64,
+ reflect.Float64: scanFloat64,
+ reflect.Complex64: nil,
+ reflect.Complex128: nil,
+ reflect.Array: nil,
+ reflect.Interface: scanInterface,
+ reflect.Map: scanJSON,
+ reflect.Ptr: nil,
+ reflect.Slice: scanJSON,
+ reflect.String: scanString,
+ reflect.Struct: scanJSON,
+ reflect.UnsafePointer: nil,
+ }
}
func FieldScanner(dialect Dialect, field *Field) ScannerFunc {
@@ -54,6 +59,12 @@ func FieldScanner(dialect Dialect, field *Field) ScannerFunc {
if field.Tag.HasOption("json_use_number") {
return scanJSONUseNumber
}
+ if field.StructField.Type.Kind() == reflect.Interface {
+ switch strings.ToUpper(field.UserSQLType) {
+ case sqltype.JSON, sqltype.JSONB:
+ return scanJSONIntoInterface
+ }
+ }
return dialect.Scanner(field.StructField.Type)
}
@@ -62,7 +73,7 @@ func Scanner(typ reflect.Type) ScannerFunc {
if kind == reflect.Ptr {
if fn := Scanner(typ.Elem()); fn != nil {
- return ptrScanner(fn)
+ return PtrScanner(fn)
}
}
@@ -84,6 +95,8 @@ func Scanner(typ reflect.Type) ScannerFunc {
return scanIP
case ipNetType:
return scanIPNet
+ case bytesType:
+ return scanBytes
case jsonRawMessageType:
return scanJSONRawMessage
}
@@ -196,6 +209,21 @@ func scanString(dest reflect.Value, src interface{}) error {
return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type())
}
+func scanBytes(dest reflect.Value, src interface{}) error {
+ switch src := src.(type) {
+ case nil:
+ dest.SetBytes(nil)
+ return nil
+ case string:
+ dest.SetBytes([]byte(src))
+ return nil
+ case []byte:
+ dest.SetBytes(src)
+ return nil
+ }
+ return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type())
+}
+
func scanTime(dest reflect.Value, src interface{}) error {
switch src := src.(type) {
case nil:
@@ -352,7 +380,7 @@ func toBytes(src interface{}) ([]byte, error) {
}
}
-func ptrScanner(fn ScannerFunc) ScannerFunc {
+func PtrScanner(fn ScannerFunc) ScannerFunc {
return func(dest reflect.Value, src interface{}) error {
if src == nil {
if !dest.CanAddr() {
@@ -383,6 +411,43 @@ func scanNull(dest reflect.Value) error {
return nil
}
+func scanJSONIntoInterface(dest reflect.Value, src interface{}) error {
+ if dest.IsNil() {
+ if src == nil {
+ return nil
+ }
+
+ b, err := toBytes(src)
+ if err != nil {
+ return err
+ }
+
+ return bunjson.Unmarshal(b, dest.Addr().Interface())
+ }
+
+ dest = dest.Elem()
+ if fn := Scanner(dest.Type()); fn != nil {
+ return fn(dest, src)
+ }
+ return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type())
+}
+
+func scanInterface(dest reflect.Value, src interface{}) error {
+ if dest.IsNil() {
+ if src == nil {
+ return nil
+ }
+ dest.Set(reflect.ValueOf(src))
+ return nil
+ }
+
+ dest = dest.Elem()
+ if fn := Scanner(dest.Type()); fn != nil {
+ return fn(dest, src)
+ }
+ return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type())
+}
+
func nilable(kind reflect.Kind) bool {
switch kind {
case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice:
diff --git a/vendor/github.com/uptrace/bun/schema/sqlfmt.go b/vendor/github.com/uptrace/bun/schema/sqlfmt.go
index 7b538cd0c..bbdb0a01f 100644
--- a/vendor/github.com/uptrace/bun/schema/sqlfmt.go
+++ b/vendor/github.com/uptrace/bun/schema/sqlfmt.go
@@ -40,7 +40,7 @@ type QueryWithArgs struct {
var _ QueryAppender = QueryWithArgs{}
func SafeQuery(query string, args []interface{}) QueryWithArgs {
- if query != "" && args == nil {
+ if args == nil {
args = make([]interface{}, 0)
}
return QueryWithArgs{Query: query, Args: args}
diff --git a/vendor/github.com/uptrace/bun/schema/sqltype.go b/vendor/github.com/uptrace/bun/schema/sqltype.go
index 560f695c2..23322a1e1 100644
--- a/vendor/github.com/uptrace/bun/schema/sqltype.go
+++ b/vendor/github.com/uptrace/bun/schema/sqltype.go
@@ -23,32 +23,29 @@ var (
)
var sqlTypes = []string{
- reflect.Bool: sqltype.Boolean,
- reflect.Int: sqltype.BigInt,
- reflect.Int8: sqltype.SmallInt,
- reflect.Int16: sqltype.SmallInt,
- reflect.Int32: sqltype.Integer,
- reflect.Int64: sqltype.BigInt,
- reflect.Uint: sqltype.BigInt,
- reflect.Uint8: sqltype.SmallInt,
- reflect.Uint16: sqltype.SmallInt,
- reflect.Uint32: sqltype.Integer,
- reflect.Uint64: sqltype.BigInt,
- reflect.Uintptr: sqltype.BigInt,
- reflect.Float32: sqltype.Real,
- reflect.Float64: sqltype.DoublePrecision,
- reflect.Complex64: "",
- reflect.Complex128: "",
- reflect.Array: "",
- reflect.Chan: "",
- reflect.Func: "",
- reflect.Interface: "",
- reflect.Map: sqltype.VarChar,
- reflect.Ptr: "",
- reflect.Slice: sqltype.VarChar,
- reflect.String: sqltype.VarChar,
- reflect.Struct: sqltype.VarChar,
- reflect.UnsafePointer: "",
+ reflect.Bool: sqltype.Boolean,
+ reflect.Int: sqltype.BigInt,
+ reflect.Int8: sqltype.SmallInt,
+ reflect.Int16: sqltype.SmallInt,
+ reflect.Int32: sqltype.Integer,
+ reflect.Int64: sqltype.BigInt,
+ reflect.Uint: sqltype.BigInt,
+ reflect.Uint8: sqltype.SmallInt,
+ reflect.Uint16: sqltype.SmallInt,
+ reflect.Uint32: sqltype.Integer,
+ reflect.Uint64: sqltype.BigInt,
+ reflect.Uintptr: sqltype.BigInt,
+ reflect.Float32: sqltype.Real,
+ reflect.Float64: sqltype.DoublePrecision,
+ reflect.Complex64: "",
+ reflect.Complex128: "",
+ reflect.Array: "",
+ reflect.Interface: "",
+ reflect.Map: sqltype.VarChar,
+ reflect.Ptr: "",
+ reflect.Slice: sqltype.VarChar,
+ reflect.String: sqltype.VarChar,
+ reflect.Struct: sqltype.VarChar,
}
func DiscoverSQLType(typ reflect.Type) string {
diff --git a/vendor/github.com/uptrace/bun/schema/table.go b/vendor/github.com/uptrace/bun/schema/table.go
index eca18b781..7498a2bc8 100644
--- a/vendor/github.com/uptrace/bun/schema/table.go
+++ b/vendor/github.com/uptrace/bun/schema/table.go
@@ -60,10 +60,9 @@ type Table struct {
Unique map[string][]*Field
SoftDeleteField *Field
- UpdateSoftDeleteField func(fv reflect.Value) error
+ UpdateSoftDeleteField func(fv reflect.Value, tm time.Time) error
- allFields []*Field // read only
- skippedFields []*Field
+ allFields []*Field // read only
flags internal.Flag
}
@@ -104,9 +103,7 @@ func (t *Table) init1() {
}
func (t *Table) init2() {
- t.initInlines()
t.initRelations()
- t.skippedFields = nil
}
func (t *Table) setName(name string) {
@@ -207,15 +204,20 @@ func (t *Table) initFields() {
func (t *Table) addFields(typ reflect.Type, baseIndex []int) {
for i := 0; i < typ.NumField(); i++ {
f := typ.Field(i)
+ unexported := f.PkgPath != ""
- // Make a copy so slice is not shared between fields.
+ if unexported && !f.Anonymous { // unexported
+ continue
+ }
+ if f.Tag.Get("bun") == "-" {
+ continue
+ }
+
+ // Make a copy so the slice is not shared between fields.
index := make([]int, len(baseIndex))
copy(index, baseIndex)
if f.Anonymous {
- if f.Tag.Get("bun") == "-" {
- continue
- }
if f.Name == "BaseModel" && f.Type == baseModelType {
if len(index) == 0 {
t.processBaseModelField(f)
@@ -243,8 +245,7 @@ func (t *Table) addFields(typ reflect.Type, baseIndex []int) {
continue
}
- field := t.newField(f, index)
- if field != nil {
+ if field := t.newField(f, index); field != nil {
t.addField(field)
}
}
@@ -284,11 +285,10 @@ func (t *Table) processBaseModelField(f reflect.StructField) {
func (t *Table) newField(f reflect.StructField, index []int) *Field {
tag := tagparser.Parse(f.Tag.Get("bun"))
- if f.PkgPath != "" {
- return nil
- }
-
sqlName := internal.Underscore(f.Name)
+ if tag.Name != "" {
+ sqlName = tag.Name
+ }
if tag.Name != sqlName && isKnownFieldOption(tag.Name) {
internal.Warn.Printf(
@@ -303,11 +303,6 @@ func (t *Table) newField(f reflect.StructField, index []int) *Field {
}
}
- skip := tag.Name == "-"
- if !skip && tag.Name != "" {
- sqlName = tag.Name
- }
-
index = append(index, f.Index...)
if field := t.fieldWithLock(sqlName); field != nil {
if indexEqual(field.Index, index) {
@@ -371,9 +366,11 @@ func (t *Table) newField(f reflect.StructField, index []int) *Field {
}
t.allFields = append(t.allFields, field)
- if skip {
- t.skippedFields = append(t.skippedFields, field)
+ if tag.HasOption("scanonly") {
t.FieldMap[field.Name] = field
+ if field.IndirectType.Kind() == reflect.Struct {
+ t.inlineFields(field, nil)
+ }
return nil
}
@@ -386,14 +383,6 @@ func (t *Table) newField(f reflect.StructField, index []int) *Field {
return field
}
-func (t *Table) initInlines() {
- for _, f := range t.skippedFields {
- if f.IndirectType.Kind() == reflect.Struct {
- t.inlineFields(f, nil)
- }
- }
-}
-
//---------------------------------------------------------------------------------------
func (t *Table) initRelations() {
@@ -745,17 +734,15 @@ func (t *Table) m2mRelation(field *Field) *Relation {
return rel
}
-func (t *Table) inlineFields(field *Field, path map[reflect.Type]struct{}) {
- if path == nil {
- path = map[reflect.Type]struct{}{
- t.Type: {},
- }
+func (t *Table) inlineFields(field *Field, seen map[reflect.Type]struct{}) {
+ if seen == nil {
+ seen = map[reflect.Type]struct{}{t.Type: {}}
}
- if _, ok := path[field.IndirectType]; ok {
+ if _, ok := seen[field.IndirectType]; ok {
return
}
- path[field.IndirectType] = struct{}{}
+ seen[field.IndirectType] = struct{}{}
joinTable := t.dialect.Tables().Ref(field.IndirectType)
for _, f := range joinTable.allFields {
@@ -775,18 +762,15 @@ func (t *Table) inlineFields(field *Field, path map[reflect.Type]struct{}) {
continue
}
- if _, ok := path[f.IndirectType]; !ok {
- t.inlineFields(f, path)
+ if _, ok := seen[f.IndirectType]; !ok {
+ t.inlineFields(f, seen)
}
}
}
//------------------------------------------------------------------------------
-func (t *Table) Dialect() Dialect { return t.dialect }
-
-//------------------------------------------------------------------------------
-
+func (t *Table) Dialect() Dialect { return t.dialect }
func (t *Table) HasBeforeScanHook() bool { return t.flags.Has(beforeScanHookFlag) }
func (t *Table) HasAfterScanHook() bool { return t.flags.Has(afterScanHookFlag) }
@@ -845,6 +829,7 @@ func isKnownFieldOption(name string) bool {
"default",
"unique",
"soft_delete",
+ "scanonly",
"pk",
"autoincrement",
@@ -883,35 +868,35 @@ func parseRelationJoin(join string) ([]string, []string) {
//------------------------------------------------------------------------------
-func softDeleteFieldUpdater(field *Field) func(fv reflect.Value) error {
+func softDeleteFieldUpdater(field *Field) func(fv reflect.Value, tm time.Time) error {
typ := field.StructField.Type
switch typ {
case timeType:
- return func(fv reflect.Value) error {
+ return func(fv reflect.Value, tm time.Time) error {
ptr := fv.Addr().Interface().(*time.Time)
- *ptr = time.Now()
+ *ptr = tm
return nil
}
case nullTimeType:
- return func(fv reflect.Value) error {
+ return func(fv reflect.Value, tm time.Time) error {
ptr := fv.Addr().Interface().(*sql.NullTime)
- *ptr = sql.NullTime{Time: time.Now()}
+ *ptr = sql.NullTime{Time: tm}
return nil
}
case nullIntType:
- return func(fv reflect.Value) error {
+ return func(fv reflect.Value, tm time.Time) error {
ptr := fv.Addr().Interface().(*sql.NullInt64)
- *ptr = sql.NullInt64{Int64: time.Now().UnixNano()}
+ *ptr = sql.NullInt64{Int64: tm.UnixNano()}
return nil
}
}
switch field.IndirectType.Kind() {
case reflect.Int64:
- return func(fv reflect.Value) error {
+ return func(fv reflect.Value, tm time.Time) error {
ptr := fv.Addr().Interface().(*int64)
- *ptr = time.Now().UnixNano()
+ *ptr = tm.UnixNano()
return nil
}
case reflect.Ptr:
@@ -922,17 +907,16 @@ func softDeleteFieldUpdater(field *Field) func(fv reflect.Value) error {
switch typ { //nolint:gocritic
case timeType:
- return func(fv reflect.Value) error {
- now := time.Now()
- fv.Set(reflect.ValueOf(&now))
+ return func(fv reflect.Value, tm time.Time) error {
+ fv.Set(reflect.ValueOf(&tm))
return nil
}
}
switch typ.Kind() { //nolint:gocritic
case reflect.Int64:
- return func(fv reflect.Value) error {
- utime := time.Now().UnixNano()
+ return func(fv reflect.Value, tm time.Time) error {
+ utime := tm.UnixNano()
fv.Set(reflect.ValueOf(&utime))
return nil
}
@@ -941,8 +925,8 @@ func softDeleteFieldUpdater(field *Field) func(fv reflect.Value) error {
return softDeleteFieldUpdaterFallback(field)
}
-func softDeleteFieldUpdaterFallback(field *Field) func(fv reflect.Value) error {
- return func(fv reflect.Value) error {
- return field.ScanWithCheck(fv, time.Now())
+func softDeleteFieldUpdaterFallback(field *Field) func(fv reflect.Value, tm time.Time) error {
+ return func(fv reflect.Value, tm time.Time) error {
+ return field.ScanWithCheck(fv, tm)
}
}
diff --git a/vendor/github.com/uptrace/bun/schema/tables.go b/vendor/github.com/uptrace/bun/schema/tables.go
index d82d08f59..4be856b34 100644
--- a/vendor/github.com/uptrace/bun/schema/tables.go
+++ b/vendor/github.com/uptrace/bun/schema/tables.go
@@ -67,6 +67,7 @@ func (t *Tables) Ref(typ reflect.Type) *Table {
}
func (t *Tables) table(typ reflect.Type, allowInProgress bool) *Table {
+ typ = indirectType(typ)
if typ.Kind() != reflect.Struct {
panic(fmt.Errorf("got %s, wanted %s", typ.Kind(), reflect.Struct))
}