diff options
Diffstat (limited to 'vendor/github.com/uptrace/bun/schema/table.go')
-rw-r--r-- | vendor/github.com/uptrace/bun/schema/table.go | 48 |
1 files changed, 27 insertions, 21 deletions
diff --git a/vendor/github.com/uptrace/bun/schema/table.go b/vendor/github.com/uptrace/bun/schema/table.go index 823d4c69f..88b8d8e25 100644 --- a/vendor/github.com/uptrace/bun/schema/table.go +++ b/vendor/github.com/uptrace/bun/schema/table.go @@ -203,7 +203,7 @@ func (t *Table) fieldByGoName(name string) *Field { func (t *Table) initFields() { t.Fields = make([]*Field, 0, t.Type.NumField()) t.FieldMap = make(map[string]*Field, t.Type.NumField()) - t.addFields(t.Type, nil) + t.addFields(t.Type, "", nil) if len(t.PKs) == 0 { for _, name := range []string{"id", "uuid", "pk_" + t.ModelName} { @@ -230,7 +230,7 @@ func (t *Table) initFields() { } } -func (t *Table) addFields(typ reflect.Type, baseIndex []int) { +func (t *Table) addFields(typ reflect.Type, prefix string, index []int) { for i := 0; i < typ.NumField(); i++ { f := typ.Field(i) unexported := f.PkgPath != "" @@ -242,10 +242,6 @@ func (t *Table) addFields(typ reflect.Type, baseIndex []int) { continue } - // Make a copy so the slice is not shared between fields. - index := make([]int, len(baseIndex)) - copy(index, baseIndex) - if f.Anonymous { if f.Name == "BaseModel" && f.Type == baseModelType { if len(index) == 0 { @@ -258,7 +254,7 @@ func (t *Table) addFields(typ reflect.Type, baseIndex []int) { if fieldType.Kind() != reflect.Struct { continue } - t.addFields(fieldType, append(index, f.Index...)) + t.addFields(fieldType, "", withIndex(index, f.Index)) tag := tagparser.Parse(f.Tag.Get("bun")) if _, inherit := tag.Options["inherit"]; inherit { @@ -274,7 +270,7 @@ func (t *Table) addFields(typ reflect.Type, baseIndex []int) { continue } - if field := t.newField(f, index); field != nil { + if field := t.newField(f, prefix, index); field != nil { t.addField(field) } } @@ -315,10 +311,20 @@ func (t *Table) processBaseModelField(f reflect.StructField) { } //nolint -func (t *Table) newField(f reflect.StructField, index []int) *Field { - sqlName := internal.Underscore(f.Name) +func (t *Table) newField(f reflect.StructField, prefix string, index []int) *Field { tag := tagparser.Parse(f.Tag.Get("bun")) + if prefix, ok := tag.Option("embed"); ok { + fieldType := indirectType(f.Type) + if fieldType.Kind() != reflect.Struct { + panic(fmt.Errorf("bun: embed %s.%s: got %s, wanted reflect.Struct", + t.TypeName, f.Name, fieldType.Kind())) + } + t.addFields(fieldType, prefix, withIndex(index, f.Index)) + return nil + } + + sqlName := internal.Underscore(f.Name) if tag.Name != "" && tag.Name != sqlName { if isKnownFieldOption(tag.Name) { internal.Warn.Printf( @@ -328,10 +334,10 @@ func (t *Table) newField(f reflect.StructField, index []int) *Field { } sqlName = tag.Name } - if s, ok := tag.Option("column"); ok { sqlName = s } + sqlName = prefix + sqlName for name := range tag.Options { if !isKnownFieldOption(name) { @@ -339,7 +345,7 @@ func (t *Table) newField(f reflect.StructField, index []int) *Field { } } - index = append(index, f.Index...) + index = withIndex(index, f.Index) if field := t.fieldWithLock(sqlName); field != nil { if indexEqual(field.Index, index) { return field @@ -795,7 +801,7 @@ func (t *Table) inlineFields(field *Field, seen map[reflect.Type]struct{}) { f.GoName = field.GoName + "_" + f.GoName f.Name = field.Name + "__" + f.Name f.SQLName = t.quoteIdent(f.Name) - f.Index = appendNew(field.Index, f.Index...) + f.Index = withIndex(field.Index, f.Index) t.fieldsMapMu.Lock() if _, ok := t.FieldMap[f.Name]; !ok { @@ -834,7 +840,7 @@ func (t *Table) AppendNamedArg( fmter Formatter, b []byte, name string, strct reflect.Value, ) ([]byte, bool) { if field, ok := t.FieldMap[name]; ok { - return fmter.appendArg(b, field.Value(strct).Interface()), true + return field.AppendValue(fmter, b, strct), true } return b, false } @@ -853,13 +859,6 @@ func (t *Table) quoteIdent(s string) Safe { return Safe(NewFormatter(t.dialect).AppendIdent(nil, s)) } -func appendNew(dst []int, src ...int) []int { - cp := make([]int, len(dst)+len(src)) - copy(cp, dst) - copy(cp[len(dst):], src) - return cp -} - func isKnownTableOption(name string) bool { switch name { case "table", "alias", "select": @@ -991,3 +990,10 @@ func softDeleteFieldUpdaterFallback(field *Field) func(fv reflect.Value, tm time return field.ScanWithCheck(fv, tm) } } + +func withIndex(a, b []int) []int { + dest := make([]int, 0, len(a)+len(b)) + dest = append(dest, a...) + dest = append(dest, b...) + return dest +} |