summaryrefslogtreecommitdiff
path: root/vendor/github.com/uptrace/bun/schema/table.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/uptrace/bun/schema/table.go')
-rw-r--r--vendor/github.com/uptrace/bun/schema/table.go104
1 files changed, 44 insertions, 60 deletions
diff --git a/vendor/github.com/uptrace/bun/schema/table.go b/vendor/github.com/uptrace/bun/schema/table.go
index eca18b781..7498a2bc8 100644
--- a/vendor/github.com/uptrace/bun/schema/table.go
+++ b/vendor/github.com/uptrace/bun/schema/table.go
@@ -60,10 +60,9 @@ type Table struct {
Unique map[string][]*Field
SoftDeleteField *Field
- UpdateSoftDeleteField func(fv reflect.Value) error
+ UpdateSoftDeleteField func(fv reflect.Value, tm time.Time) error
- allFields []*Field // read only
- skippedFields []*Field
+ allFields []*Field // read only
flags internal.Flag
}
@@ -104,9 +103,7 @@ func (t *Table) init1() {
}
func (t *Table) init2() {
- t.initInlines()
t.initRelations()
- t.skippedFields = nil
}
func (t *Table) setName(name string) {
@@ -207,15 +204,20 @@ func (t *Table) initFields() {
func (t *Table) addFields(typ reflect.Type, baseIndex []int) {
for i := 0; i < typ.NumField(); i++ {
f := typ.Field(i)
+ unexported := f.PkgPath != ""
- // Make a copy so slice is not shared between fields.
+ if unexported && !f.Anonymous { // unexported
+ continue
+ }
+ if f.Tag.Get("bun") == "-" {
+ continue
+ }
+
+ // Make a copy so the slice is not shared between fields.
index := make([]int, len(baseIndex))
copy(index, baseIndex)
if f.Anonymous {
- if f.Tag.Get("bun") == "-" {
- continue
- }
if f.Name == "BaseModel" && f.Type == baseModelType {
if len(index) == 0 {
t.processBaseModelField(f)
@@ -243,8 +245,7 @@ func (t *Table) addFields(typ reflect.Type, baseIndex []int) {
continue
}
- field := t.newField(f, index)
- if field != nil {
+ if field := t.newField(f, index); field != nil {
t.addField(field)
}
}
@@ -284,11 +285,10 @@ func (t *Table) processBaseModelField(f reflect.StructField) {
func (t *Table) newField(f reflect.StructField, index []int) *Field {
tag := tagparser.Parse(f.Tag.Get("bun"))
- if f.PkgPath != "" {
- return nil
- }
-
sqlName := internal.Underscore(f.Name)
+ if tag.Name != "" {
+ sqlName = tag.Name
+ }
if tag.Name != sqlName && isKnownFieldOption(tag.Name) {
internal.Warn.Printf(
@@ -303,11 +303,6 @@ func (t *Table) newField(f reflect.StructField, index []int) *Field {
}
}
- skip := tag.Name == "-"
- if !skip && tag.Name != "" {
- sqlName = tag.Name
- }
-
index = append(index, f.Index...)
if field := t.fieldWithLock(sqlName); field != nil {
if indexEqual(field.Index, index) {
@@ -371,9 +366,11 @@ func (t *Table) newField(f reflect.StructField, index []int) *Field {
}
t.allFields = append(t.allFields, field)
- if skip {
- t.skippedFields = append(t.skippedFields, field)
+ if tag.HasOption("scanonly") {
t.FieldMap[field.Name] = field
+ if field.IndirectType.Kind() == reflect.Struct {
+ t.inlineFields(field, nil)
+ }
return nil
}
@@ -386,14 +383,6 @@ func (t *Table) newField(f reflect.StructField, index []int) *Field {
return field
}
-func (t *Table) initInlines() {
- for _, f := range t.skippedFields {
- if f.IndirectType.Kind() == reflect.Struct {
- t.inlineFields(f, nil)
- }
- }
-}
-
//---------------------------------------------------------------------------------------
func (t *Table) initRelations() {
@@ -745,17 +734,15 @@ func (t *Table) m2mRelation(field *Field) *Relation {
return rel
}
-func (t *Table) inlineFields(field *Field, path map[reflect.Type]struct{}) {
- if path == nil {
- path = map[reflect.Type]struct{}{
- t.Type: {},
- }
+func (t *Table) inlineFields(field *Field, seen map[reflect.Type]struct{}) {
+ if seen == nil {
+ seen = map[reflect.Type]struct{}{t.Type: {}}
}
- if _, ok := path[field.IndirectType]; ok {
+ if _, ok := seen[field.IndirectType]; ok {
return
}
- path[field.IndirectType] = struct{}{}
+ seen[field.IndirectType] = struct{}{}
joinTable := t.dialect.Tables().Ref(field.IndirectType)
for _, f := range joinTable.allFields {
@@ -775,18 +762,15 @@ func (t *Table) inlineFields(field *Field, path map[reflect.Type]struct{}) {
continue
}
- if _, ok := path[f.IndirectType]; !ok {
- t.inlineFields(f, path)
+ if _, ok := seen[f.IndirectType]; !ok {
+ t.inlineFields(f, seen)
}
}
}
//------------------------------------------------------------------------------
-func (t *Table) Dialect() Dialect { return t.dialect }
-
-//------------------------------------------------------------------------------
-
+func (t *Table) Dialect() Dialect { return t.dialect }
func (t *Table) HasBeforeScanHook() bool { return t.flags.Has(beforeScanHookFlag) }
func (t *Table) HasAfterScanHook() bool { return t.flags.Has(afterScanHookFlag) }
@@ -845,6 +829,7 @@ func isKnownFieldOption(name string) bool {
"default",
"unique",
"soft_delete",
+ "scanonly",
"pk",
"autoincrement",
@@ -883,35 +868,35 @@ func parseRelationJoin(join string) ([]string, []string) {
//------------------------------------------------------------------------------
-func softDeleteFieldUpdater(field *Field) func(fv reflect.Value) error {
+func softDeleteFieldUpdater(field *Field) func(fv reflect.Value, tm time.Time) error {
typ := field.StructField.Type
switch typ {
case timeType:
- return func(fv reflect.Value) error {
+ return func(fv reflect.Value, tm time.Time) error {
ptr := fv.Addr().Interface().(*time.Time)
- *ptr = time.Now()
+ *ptr = tm
return nil
}
case nullTimeType:
- return func(fv reflect.Value) error {
+ return func(fv reflect.Value, tm time.Time) error {
ptr := fv.Addr().Interface().(*sql.NullTime)
- *ptr = sql.NullTime{Time: time.Now()}
+ *ptr = sql.NullTime{Time: tm}
return nil
}
case nullIntType:
- return func(fv reflect.Value) error {
+ return func(fv reflect.Value, tm time.Time) error {
ptr := fv.Addr().Interface().(*sql.NullInt64)
- *ptr = sql.NullInt64{Int64: time.Now().UnixNano()}
+ *ptr = sql.NullInt64{Int64: tm.UnixNano()}
return nil
}
}
switch field.IndirectType.Kind() {
case reflect.Int64:
- return func(fv reflect.Value) error {
+ return func(fv reflect.Value, tm time.Time) error {
ptr := fv.Addr().Interface().(*int64)
- *ptr = time.Now().UnixNano()
+ *ptr = tm.UnixNano()
return nil
}
case reflect.Ptr:
@@ -922,17 +907,16 @@ func softDeleteFieldUpdater(field *Field) func(fv reflect.Value) error {
switch typ { //nolint:gocritic
case timeType:
- return func(fv reflect.Value) error {
- now := time.Now()
- fv.Set(reflect.ValueOf(&now))
+ return func(fv reflect.Value, tm time.Time) error {
+ fv.Set(reflect.ValueOf(&tm))
return nil
}
}
switch typ.Kind() { //nolint:gocritic
case reflect.Int64:
- return func(fv reflect.Value) error {
- utime := time.Now().UnixNano()
+ return func(fv reflect.Value, tm time.Time) error {
+ utime := tm.UnixNano()
fv.Set(reflect.ValueOf(&utime))
return nil
}
@@ -941,8 +925,8 @@ func softDeleteFieldUpdater(field *Field) func(fv reflect.Value) error {
return softDeleteFieldUpdaterFallback(field)
}
-func softDeleteFieldUpdaterFallback(field *Field) func(fv reflect.Value) error {
- return func(fv reflect.Value) error {
- return field.ScanWithCheck(fv, time.Now())
+func softDeleteFieldUpdaterFallback(field *Field) func(fv reflect.Value, tm time.Time) error {
+ return func(fv reflect.Value, tm time.Time) error {
+ return field.ScanWithCheck(fv, tm)
}
}