diff options
Diffstat (limited to 'vendor/github.com/uptrace/bun/schema/table.go')
-rw-r--r-- | vendor/github.com/uptrace/bun/schema/table.go | 104 |
1 files changed, 44 insertions, 60 deletions
diff --git a/vendor/github.com/uptrace/bun/schema/table.go b/vendor/github.com/uptrace/bun/schema/table.go index eca18b781..7498a2bc8 100644 --- a/vendor/github.com/uptrace/bun/schema/table.go +++ b/vendor/github.com/uptrace/bun/schema/table.go @@ -60,10 +60,9 @@ type Table struct { Unique map[string][]*Field SoftDeleteField *Field - UpdateSoftDeleteField func(fv reflect.Value) error + UpdateSoftDeleteField func(fv reflect.Value, tm time.Time) error - allFields []*Field // read only - skippedFields []*Field + allFields []*Field // read only flags internal.Flag } @@ -104,9 +103,7 @@ func (t *Table) init1() { } func (t *Table) init2() { - t.initInlines() t.initRelations() - t.skippedFields = nil } func (t *Table) setName(name string) { @@ -207,15 +204,20 @@ func (t *Table) initFields() { func (t *Table) addFields(typ reflect.Type, baseIndex []int) { for i := 0; i < typ.NumField(); i++ { f := typ.Field(i) + unexported := f.PkgPath != "" - // Make a copy so slice is not shared between fields. + if unexported && !f.Anonymous { // unexported + continue + } + if f.Tag.Get("bun") == "-" { + continue + } + + // Make a copy so the slice is not shared between fields. index := make([]int, len(baseIndex)) copy(index, baseIndex) if f.Anonymous { - if f.Tag.Get("bun") == "-" { - continue - } if f.Name == "BaseModel" && f.Type == baseModelType { if len(index) == 0 { t.processBaseModelField(f) @@ -243,8 +245,7 @@ func (t *Table) addFields(typ reflect.Type, baseIndex []int) { continue } - field := t.newField(f, index) - if field != nil { + if field := t.newField(f, index); field != nil { t.addField(field) } } @@ -284,11 +285,10 @@ func (t *Table) processBaseModelField(f reflect.StructField) { func (t *Table) newField(f reflect.StructField, index []int) *Field { tag := tagparser.Parse(f.Tag.Get("bun")) - if f.PkgPath != "" { - return nil - } - sqlName := internal.Underscore(f.Name) + if tag.Name != "" { + sqlName = tag.Name + } if tag.Name != sqlName && isKnownFieldOption(tag.Name) { internal.Warn.Printf( @@ -303,11 +303,6 @@ func (t *Table) newField(f reflect.StructField, index []int) *Field { } } - skip := tag.Name == "-" - if !skip && tag.Name != "" { - sqlName = tag.Name - } - index = append(index, f.Index...) if field := t.fieldWithLock(sqlName); field != nil { if indexEqual(field.Index, index) { @@ -371,9 +366,11 @@ func (t *Table) newField(f reflect.StructField, index []int) *Field { } t.allFields = append(t.allFields, field) - if skip { - t.skippedFields = append(t.skippedFields, field) + if tag.HasOption("scanonly") { t.FieldMap[field.Name] = field + if field.IndirectType.Kind() == reflect.Struct { + t.inlineFields(field, nil) + } return nil } @@ -386,14 +383,6 @@ func (t *Table) newField(f reflect.StructField, index []int) *Field { return field } -func (t *Table) initInlines() { - for _, f := range t.skippedFields { - if f.IndirectType.Kind() == reflect.Struct { - t.inlineFields(f, nil) - } - } -} - //--------------------------------------------------------------------------------------- func (t *Table) initRelations() { @@ -745,17 +734,15 @@ func (t *Table) m2mRelation(field *Field) *Relation { return rel } -func (t *Table) inlineFields(field *Field, path map[reflect.Type]struct{}) { - if path == nil { - path = map[reflect.Type]struct{}{ - t.Type: {}, - } +func (t *Table) inlineFields(field *Field, seen map[reflect.Type]struct{}) { + if seen == nil { + seen = map[reflect.Type]struct{}{t.Type: {}} } - if _, ok := path[field.IndirectType]; ok { + if _, ok := seen[field.IndirectType]; ok { return } - path[field.IndirectType] = struct{}{} + seen[field.IndirectType] = struct{}{} joinTable := t.dialect.Tables().Ref(field.IndirectType) for _, f := range joinTable.allFields { @@ -775,18 +762,15 @@ func (t *Table) inlineFields(field *Field, path map[reflect.Type]struct{}) { continue } - if _, ok := path[f.IndirectType]; !ok { - t.inlineFields(f, path) + if _, ok := seen[f.IndirectType]; !ok { + t.inlineFields(f, seen) } } } //------------------------------------------------------------------------------ -func (t *Table) Dialect() Dialect { return t.dialect } - -//------------------------------------------------------------------------------ - +func (t *Table) Dialect() Dialect { return t.dialect } func (t *Table) HasBeforeScanHook() bool { return t.flags.Has(beforeScanHookFlag) } func (t *Table) HasAfterScanHook() bool { return t.flags.Has(afterScanHookFlag) } @@ -845,6 +829,7 @@ func isKnownFieldOption(name string) bool { "default", "unique", "soft_delete", + "scanonly", "pk", "autoincrement", @@ -883,35 +868,35 @@ func parseRelationJoin(join string) ([]string, []string) { //------------------------------------------------------------------------------ -func softDeleteFieldUpdater(field *Field) func(fv reflect.Value) error { +func softDeleteFieldUpdater(field *Field) func(fv reflect.Value, tm time.Time) error { typ := field.StructField.Type switch typ { case timeType: - return func(fv reflect.Value) error { + return func(fv reflect.Value, tm time.Time) error { ptr := fv.Addr().Interface().(*time.Time) - *ptr = time.Now() + *ptr = tm return nil } case nullTimeType: - return func(fv reflect.Value) error { + return func(fv reflect.Value, tm time.Time) error { ptr := fv.Addr().Interface().(*sql.NullTime) - *ptr = sql.NullTime{Time: time.Now()} + *ptr = sql.NullTime{Time: tm} return nil } case nullIntType: - return func(fv reflect.Value) error { + return func(fv reflect.Value, tm time.Time) error { ptr := fv.Addr().Interface().(*sql.NullInt64) - *ptr = sql.NullInt64{Int64: time.Now().UnixNano()} + *ptr = sql.NullInt64{Int64: tm.UnixNano()} return nil } } switch field.IndirectType.Kind() { case reflect.Int64: - return func(fv reflect.Value) error { + return func(fv reflect.Value, tm time.Time) error { ptr := fv.Addr().Interface().(*int64) - *ptr = time.Now().UnixNano() + *ptr = tm.UnixNano() return nil } case reflect.Ptr: @@ -922,17 +907,16 @@ func softDeleteFieldUpdater(field *Field) func(fv reflect.Value) error { switch typ { //nolint:gocritic case timeType: - return func(fv reflect.Value) error { - now := time.Now() - fv.Set(reflect.ValueOf(&now)) + return func(fv reflect.Value, tm time.Time) error { + fv.Set(reflect.ValueOf(&tm)) return nil } } switch typ.Kind() { //nolint:gocritic case reflect.Int64: - return func(fv reflect.Value) error { - utime := time.Now().UnixNano() + return func(fv reflect.Value, tm time.Time) error { + utime := tm.UnixNano() fv.Set(reflect.ValueOf(&utime)) return nil } @@ -941,8 +925,8 @@ func softDeleteFieldUpdater(field *Field) func(fv reflect.Value) error { return softDeleteFieldUpdaterFallback(field) } -func softDeleteFieldUpdaterFallback(field *Field) func(fv reflect.Value) error { - return func(fv reflect.Value) error { - return field.ScanWithCheck(fv, time.Now()) +func softDeleteFieldUpdaterFallback(field *Field) func(fv reflect.Value, tm time.Time) error { + return func(fv reflect.Value, tm time.Time) error { + return field.ScanWithCheck(fv, tm) } } |