diff options
Diffstat (limited to 'vendor/github.com/uptrace/bun/schema/table.go')
-rw-r--r-- | vendor/github.com/uptrace/bun/schema/table.go | 512 |
1 files changed, 263 insertions, 249 deletions
diff --git a/vendor/github.com/uptrace/bun/schema/table.go b/vendor/github.com/uptrace/bun/schema/table.go index e6986f109..0a23156a2 100644 --- a/vendor/github.com/uptrace/bun/schema/table.go +++ b/vendor/github.com/uptrace/bun/schema/table.go @@ -5,7 +5,6 @@ import ( "fmt" "reflect" "strings" - "sync" "time" "github.com/jinzhu/inflection" @@ -52,12 +51,14 @@ type Table struct { Alias string SQLAlias Safe + allFields []*Field // all fields including scanonly Fields []*Field // PKs + DataFields PKs []*Field DataFields []*Field + relFields []*Field - fieldsMapMu sync.RWMutex - FieldMap map[string]*Field + FieldMap map[string]*Field + StructMap map[string]*structField Relations map[string]*Relation Unique map[string][]*Field @@ -65,23 +66,38 @@ type Table struct { SoftDeleteField *Field UpdateSoftDeleteField func(fv reflect.Value, tm time.Time) error - allFields []*Field // read only - flags internal.Flag } -func newTable(dialect Dialect, typ reflect.Type) *Table { - t := new(Table) - t.dialect = dialect - t.Type = typ - t.ZeroValue = reflect.New(t.Type).Elem() - t.ZeroIface = reflect.New(t.Type).Interface() - t.TypeName = internal.ToExported(t.Type.Name()) - t.ModelName = internal.Underscore(t.Type.Name()) - tableName := tableNameInflector(t.ModelName) - t.setName(tableName) - t.Alias = t.ModelName - t.SQLAlias = t.quoteIdent(t.ModelName) +type structField struct { + Index []int + Table *Table +} + +func newTable( + dialect Dialect, typ reflect.Type, seen map[reflect.Type]*Table, canAddr bool, +) *Table { + if table, ok := seen[typ]; ok { + return table + } + + table := new(Table) + seen[typ] = table + + table.dialect = dialect + table.Type = typ + table.ZeroValue = reflect.New(table.Type).Elem() + table.ZeroIface = reflect.New(table.Type).Interface() + table.TypeName = internal.ToExported(table.Type.Name()) + table.ModelName = internal.Underscore(table.Type.Name()) + tableName := tableNameInflector(table.ModelName) + table.setName(tableName) + table.Alias = table.ModelName + table.SQLAlias = table.quoteIdent(table.ModelName) + + table.Fields = make([]*Field, 0, typ.NumField()) + table.FieldMap = make(map[string]*Field, typ.NumField()) + table.processFields(typ, seen, canAddr) hooks := []struct { typ reflect.Type @@ -89,45 +105,168 @@ func newTable(dialect Dialect, typ reflect.Type) *Table { }{ {beforeAppendModelHookType, beforeAppendModelHookFlag}, - {beforeScanHookType, beforeScanHookFlag}, - {afterScanHookType, afterScanHookFlag}, - {beforeScanRowHookType, beforeScanRowHookFlag}, {afterScanRowHookType, afterScanRowHookFlag}, } - typ = reflect.PtrTo(t.Type) + typ = reflect.PtrTo(table.Type) for _, hook := range hooks { if typ.Implements(hook.typ) { - t.flags = t.flags.Set(hook.flag) + table.flags = table.flags.Set(hook.flag) } } - // Deprecated. - deprecatedHooks := []struct { - typ reflect.Type - flag internal.Flag - msg string - }{ - {beforeScanHookType, beforeScanHookFlag, "rename BeforeScan hook to BeforeScanRow"}, - {afterScanHookType, afterScanHookFlag, "rename AfterScan hook to AfterScanRow"}, + return table +} + +func (t *Table) init() { + for _, field := range t.relFields { + t.processRelation(field) } - for _, hook := range deprecatedHooks { - if typ.Implements(hook.typ) { - internal.Deprecated.Printf("%s: %s", t.TypeName, hook.msg) - t.flags = t.flags.Set(hook.flag) + t.relFields = nil +} + +func (t *Table) processFields( + typ reflect.Type, + seen map[reflect.Type]*Table, + canAddr bool, +) { + type embeddedField struct { + prefix string + index []int + unexported bool + subtable *Table + subfield *Field + } + + names := make(map[string]struct{}) + embedded := make([]embeddedField, 0, 10) + + for i, n := 0, typ.NumField(); i < n; i++ { + sf := typ.Field(i) + unexported := sf.PkgPath != "" + + tagstr := sf.Tag.Get("bun") + if tagstr == "-" { + names[sf.Name] = struct{}{} + continue + } + tag := tagparser.Parse(tagstr) + + if unexported && !sf.Anonymous { // unexported + continue + } + + if sf.Anonymous { + if sf.Name == "BaseModel" && sf.Type == baseModelType { + t.processBaseModelField(sf) + continue + } + + sfType := sf.Type + if sfType.Kind() == reflect.Ptr { + sfType = sfType.Elem() + } + + if sfType.Kind() != reflect.Struct { // ignore unexported non-struct types + continue + } + + subtable := newTable(t.dialect, sfType, seen, canAddr) + + for _, subfield := range subtable.allFields { + embedded = append(embedded, embeddedField{ + index: sf.Index, + unexported: unexported, + subtable: subtable, + subfield: subfield, + }) + } + + if tagstr != "" { + tag := tagparser.Parse(tagstr) + if tag.HasOption("inherit") || tag.HasOption("extend") { + t.Name = subtable.Name + t.TypeName = subtable.TypeName + t.SQLName = subtable.SQLName + t.SQLNameForSelects = subtable.SQLNameForSelects + t.Alias = subtable.Alias + t.SQLAlias = subtable.SQLAlias + t.ModelName = subtable.ModelName + } + } + + continue + } + + if prefix, ok := tag.Option("embed"); ok { + fieldType := indirectType(sf.Type) + if fieldType.Kind() != reflect.Struct { + panic(fmt.Errorf("bun: embed %s.%s: got %s, wanted reflect.Struct", + t.TypeName, sf.Name, fieldType.Kind())) + } + + subtable := newTable(t.dialect, fieldType, seen, canAddr) + for _, subfield := range subtable.allFields { + embedded = append(embedded, embeddedField{ + prefix: prefix, + index: sf.Index, + unexported: unexported, + subtable: subtable, + subfield: subfield, + }) + } + continue + } + + field := t.newField(sf, tag) + t.addField(field) + names[field.Name] = struct{}{} + + if field.IndirectType.Kind() == reflect.Struct { + if t.StructMap == nil { + t.StructMap = make(map[string]*structField) + } + t.StructMap[field.Name] = &structField{ + Index: field.Index, + Table: newTable(t.dialect, field.IndirectType, seen, canAddr), + } } } - return t -} + // Only unambiguous embedded fields must be serialized. + ambiguousNames := make(map[string]int) + ambiguousTags := make(map[string]int) -func (t *Table) init1() { - t.initFields() -} + // Embedded types can never override a field that was already present at + // the top-level. + for name := range names { + ambiguousNames[name]++ + ambiguousTags[name]++ + } + + for _, f := range embedded { + ambiguousNames[f.prefix+f.subfield.Name]++ + if !f.subfield.Tag.IsZero() { + ambiguousTags[f.prefix+f.subfield.Name]++ + } + } -func (t *Table) init2() { - t.initRelations() + for _, embfield := range embedded { + subfield := embfield.subfield.Clone() + + if ambiguousNames[subfield.Name] > 1 && + !(!subfield.Tag.IsZero() && ambiguousTags[subfield.Name] == 1) { + continue // ambiguous embedded field + } + + subfield.Index = makeIndex(embfield.index, subfield.Index) + if embfield.prefix != "" { + subfield.Name = embfield.prefix + subfield.Name + subfield.SQLName = t.quoteIdent(subfield.Name) + } + t.addField(subfield) + } } func (t *Table) setName(name string) { @@ -152,30 +291,67 @@ func (t *Table) CheckPKs() error { } func (t *Table) addField(field *Field) { + t.allFields = append(t.allFields, field) + + if field.Tag.HasOption("rel") || field.Tag.HasOption("m2m") { + t.relFields = append(t.relFields, field) + return + } + + if field.Tag.HasOption("join") { + internal.Warn.Printf( + `%s.%s "join" option must come together with "rel" option`, + t.TypeName, field.GoName, + ) + } + + t.FieldMap[field.Name] = field + if altName, ok := field.Tag.Option("alt"); ok { + t.FieldMap[altName] = field + } + + if field.Tag.HasOption("scanonly") { + return + } + + if _, ok := field.Tag.Options["soft_delete"]; ok { + t.SoftDeleteField = field + t.UpdateSoftDeleteField = softDeleteFieldUpdater(field) + } + t.Fields = append(t.Fields, field) if field.IsPK { t.PKs = append(t.PKs, field) } else { t.DataFields = append(t.DataFields, field) } - t.FieldMap[field.Name] = field } -func (t *Table) removeField(field *Field) { - t.Fields = removeField(t.Fields, field) - if field.IsPK { - t.PKs = removeField(t.PKs, field) - } else { - t.DataFields = removeField(t.DataFields, field) +func (t *Table) LookupField(name string) *Field { + if field, ok := t.FieldMap[name]; ok { + return field } - delete(t.FieldMap, field.Name) -} -func (t *Table) fieldWithLock(name string) *Field { - t.fieldsMapMu.RLock() - field := t.FieldMap[name] - t.fieldsMapMu.RUnlock() - return field + table := t + var index []int + for { + structName, columnName, ok := strings.Cut(name, "__") + if !ok { + field, ok := table.FieldMap[name] + if !ok { + return nil + } + return field.WithIndex(index) + } + name = columnName + + strct := table.StructMap[structName] + if strct == nil { + return nil + } + table = strct.Table + index = append(index, strct.Index...) + } } func (t *Table) HasField(name string) bool { @@ -200,59 +376,6 @@ func (t *Table) fieldByGoName(name string) *Field { return nil } -func (t *Table) initFields() { - t.Fields = make([]*Field, 0, t.Type.NumField()) - t.FieldMap = make(map[string]*Field, t.Type.NumField()) - t.addFields(t.Type, "", nil) -} - -func (t *Table) addFields(typ reflect.Type, prefix string, index []int) { - for i := 0; i < typ.NumField(); i++ { - f := typ.Field(i) - unexported := f.PkgPath != "" - - if unexported && !f.Anonymous { // unexported - continue - } - if f.Tag.Get("bun") == "-" { - continue - } - - if f.Anonymous { - if f.Name == "BaseModel" && f.Type == baseModelType { - if len(index) == 0 { - t.processBaseModelField(f) - } - continue - } - - // If field is an embedded struct, add each field of the embedded struct. - fieldType := indirectType(f.Type) - if fieldType.Kind() == reflect.Struct { - t.addFields(fieldType, "", withIndex(index, f.Index)) - - tag := tagparser.Parse(f.Tag.Get("bun")) - if tag.HasOption("inherit") || tag.HasOption("extend") { - embeddedTable := t.dialect.Tables().Ref(fieldType) - t.TypeName = embeddedTable.TypeName - t.SQLName = embeddedTable.SQLName - t.SQLNameForSelects = embeddedTable.SQLNameForSelects - t.Alias = embeddedTable.Alias - t.SQLAlias = embeddedTable.SQLAlias - t.ModelName = embeddedTable.ModelName - } - continue - } - } - - // If field is not a struct, add it. - // This will also add any embedded non-struct type as a field. - if field := t.newField(f, prefix, index); field != nil { - t.addField(field) - } - } -} - func (t *Table) processBaseModelField(f reflect.StructField) { tag := tagparser.Parse(f.Tag.Get("bun")) @@ -288,58 +411,34 @@ func (t *Table) processBaseModelField(f reflect.StructField) { } // nolint -func (t *Table) newField(f reflect.StructField, prefix string, index []int) *Field { - tag := tagparser.Parse(f.Tag.Get("bun")) - - if nextPrefix, ok := tag.Option("embed"); ok { - fieldType := indirectType(f.Type) - if fieldType.Kind() != reflect.Struct { - panic(fmt.Errorf("bun: embed %s.%s: got %s, wanted reflect.Struct", - t.TypeName, f.Name, fieldType.Kind())) - } - t.addFields(fieldType, prefix+nextPrefix, withIndex(index, f.Index)) - return nil - } - - sqlName := internal.Underscore(f.Name) +func (t *Table) newField(sf reflect.StructField, tag tagparser.Tag) *Field { + sqlName := internal.Underscore(sf.Name) if tag.Name != "" && tag.Name != sqlName { if isKnownFieldOption(tag.Name) { internal.Warn.Printf( "%s.%s tag name %q is also an option name, is it a mistake? Try column:%s.", - t.TypeName, f.Name, tag.Name, tag.Name, + t.TypeName, sf.Name, tag.Name, tag.Name, ) } sqlName = tag.Name } - if s, ok := tag.Option("column"); ok { - sqlName = s - } - sqlName = prefix + sqlName for name := range tag.Options { if !isKnownFieldOption(name) { - internal.Warn.Printf("%s.%s has unknown tag option: %q", t.TypeName, f.Name, name) - } - } - - index = withIndex(index, f.Index) - if field := t.fieldWithLock(sqlName); field != nil { - if indexEqual(field.Index, index) { - return field + internal.Warn.Printf("%s.%s has unknown tag option: %q", t.TypeName, sf.Name, name) } - t.removeField(field) } field := &Field{ - StructField: f, - IsPtr: f.Type.Kind() == reflect.Ptr, + StructField: sf, + IsPtr: sf.Type.Kind() == reflect.Ptr, Tag: tag, - IndirectType: indirectType(f.Type), - Index: index, + IndirectType: indirectType(sf.Type), + Index: sf.Index, Name: sqlName, - GoName: f.Name, + GoName: sf.Name, SQLName: t.quoteIdent(sqlName), } @@ -386,63 +485,21 @@ func (t *Table) newField(f reflect.StructField, prefix string, index []int) *Fie field.Scan = FieldScanner(t.dialect, field) field.IsZero = zeroChecker(field.StructField.Type) - if v, ok := tag.Option("alt"); ok { - t.FieldMap[v] = field - } - - t.allFields = append(t.allFields, field) - if tag.HasOption("scanonly") { - t.FieldMap[field.Name] = field - if field.IndirectType.Kind() == reflect.Struct { - t.inlineFields(field, nil) - } - return nil - } - - if _, ok := tag.Options["soft_delete"]; ok { - t.SoftDeleteField = field - t.UpdateSoftDeleteField = softDeleteFieldUpdater(field) - } - return field } //--------------------------------------------------------------------------------------- -func (t *Table) initRelations() { - for i := 0; i < len(t.Fields); { - f := t.Fields[i] - if t.tryRelation(f) { - t.Fields = removeField(t.Fields, f) - t.DataFields = removeField(t.DataFields, f) - } else { - i++ - } - - if f.IndirectType.Kind() == reflect.Struct { - t.inlineFields(f, nil) - } - } -} - -func (t *Table) tryRelation(field *Field) bool { +func (t *Table) processRelation(field *Field) { if rel, ok := field.Tag.Option("rel"); ok { t.initRelation(field, rel) - return true + return } if field.Tag.HasOption("m2m") { t.addRelation(t.m2mRelation(field)) - return true - } - - if field.Tag.HasOption("join") { - internal.Warn.Printf( - `%s.%s "join" option must come together with "rel" option`, - t.TypeName, field.GoName, - ) + return } - - return false + panic("not reached") } func (t *Table) initRelation(field *Field, rel string) { @@ -470,7 +527,7 @@ func (t *Table) addRelation(rel *Relation) { } func (t *Table) belongsToRelation(field *Field) *Relation { - joinTable := t.dialect.Tables().Ref(field.IndirectType) + joinTable := t.dialect.Tables().InProgress(field.IndirectType) if err := joinTable.CheckPKs(); err != nil { panic(err) } @@ -519,7 +576,7 @@ func (t *Table) belongsToRelation(field *Field) *Relation { for i, baseColumn := range baseColumns { joinColumn := joinColumns[i] - if f := t.fieldWithLock(baseColumn); f != nil { + if f := t.FieldMap[baseColumn]; f != nil { rel.BaseFields = append(rel.BaseFields, f) } else { panic(fmt.Errorf( @@ -528,7 +585,7 @@ func (t *Table) belongsToRelation(field *Field) *Relation { )) } - if f := joinTable.fieldWithLock(joinColumn); f != nil { + if f := joinTable.FieldMap[joinColumn]; f != nil { rel.JoinFields = append(rel.JoinFields, f) } else { panic(fmt.Errorf( @@ -544,12 +601,12 @@ func (t *Table) belongsToRelation(field *Field) *Relation { fkPrefix := internal.Underscore(field.GoName) + "_" for _, joinPK := range joinTable.PKs { fkName := fkPrefix + joinPK.Name - if fk := t.fieldWithLock(fkName); fk != nil { + if fk := t.FieldMap[fkName]; fk != nil { rel.BaseFields = append(rel.BaseFields, fk) continue } - if fk := t.fieldWithLock(joinPK.Name); fk != nil { + if fk := t.FieldMap[joinPK.Name]; fk != nil { rel.BaseFields = append(rel.BaseFields, fk) continue } @@ -568,7 +625,7 @@ func (t *Table) hasOneRelation(field *Field) *Relation { panic(err) } - joinTable := t.dialect.Tables().Ref(field.IndirectType) + joinTable := t.dialect.Tables().InProgress(field.IndirectType) rel := &Relation{ Type: HasOneRelation, Field: field, @@ -582,7 +639,7 @@ func (t *Table) hasOneRelation(field *Field) *Relation { if join, ok := field.Tag.Options["join"]; ok { baseColumns, joinColumns := parseRelationJoin(join) for i, baseColumn := range baseColumns { - if f := t.fieldWithLock(baseColumn); f != nil { + if f := t.FieldMap[baseColumn]; f != nil { rel.BaseFields = append(rel.BaseFields, f) } else { panic(fmt.Errorf( @@ -592,7 +649,7 @@ func (t *Table) hasOneRelation(field *Field) *Relation { } joinColumn := joinColumns[i] - if f := joinTable.fieldWithLock(joinColumn); f != nil { + if f := joinTable.FieldMap[joinColumn]; f != nil { rel.JoinFields = append(rel.JoinFields, f) } else { panic(fmt.Errorf( @@ -608,12 +665,12 @@ func (t *Table) hasOneRelation(field *Field) *Relation { fkPrefix := internal.Underscore(t.ModelName) + "_" for _, pk := range t.PKs { fkName := fkPrefix + pk.Name - if f := joinTable.fieldWithLock(fkName); f != nil { + if f := joinTable.FieldMap[fkName]; f != nil { rel.JoinFields = append(rel.JoinFields, f) continue } - if f := joinTable.fieldWithLock(pk.Name); f != nil { + if f := joinTable.FieldMap[pk.Name]; f != nil { rel.JoinFields = append(rel.JoinFields, f) continue } @@ -638,7 +695,7 @@ func (t *Table) hasManyRelation(field *Field) *Relation { )) } - joinTable := t.dialect.Tables().Ref(indirectType(field.IndirectType.Elem())) + joinTable := t.dialect.Tables().InProgress(indirectType(field.IndirectType.Elem())) polymorphicValue, isPolymorphic := field.Tag.Option("polymorphic") rel := &Relation{ Type: HasManyRelation, @@ -662,7 +719,7 @@ func (t *Table) hasManyRelation(field *Field) *Relation { continue } - if f := t.fieldWithLock(baseColumn); f != nil { + if f := t.FieldMap[baseColumn]; f != nil { rel.BaseFields = append(rel.BaseFields, f) } else { panic(fmt.Errorf( @@ -671,7 +728,7 @@ func (t *Table) hasManyRelation(field *Field) *Relation { )) } - if f := joinTable.fieldWithLock(joinColumn); f != nil { + if f := joinTable.FieldMap[joinColumn]; f != nil { rel.JoinFields = append(rel.JoinFields, f) } else { panic(fmt.Errorf( @@ -689,12 +746,12 @@ func (t *Table) hasManyRelation(field *Field) *Relation { for _, pk := range t.PKs { joinColumn := fkPrefix + pk.Name - if fk := joinTable.fieldWithLock(joinColumn); fk != nil { + if fk := joinTable.FieldMap[joinColumn]; fk != nil { rel.JoinFields = append(rel.JoinFields, fk) continue } - if fk := joinTable.fieldWithLock(pk.Name); fk != nil { + if fk := joinTable.FieldMap[pk.Name]; fk != nil { rel.JoinFields = append(rel.JoinFields, fk) continue } @@ -708,7 +765,7 @@ func (t *Table) hasManyRelation(field *Field) *Relation { } if isPolymorphic { - rel.PolymorphicField = joinTable.fieldWithLock(polymorphicColumn) + rel.PolymorphicField = joinTable.FieldMap[polymorphicColumn] if rel.PolymorphicField == nil { panic(fmt.Errorf( "bun: %s has-many %s: %s must have polymorphic column %s", @@ -732,7 +789,7 @@ func (t *Table) m2mRelation(field *Field) *Relation { t.TypeName, field.GoName, field.IndirectType.Kind(), )) } - joinTable := t.dialect.Tables().Ref(indirectType(field.IndirectType.Elem())) + joinTable := t.dialect.Tables().InProgress(indirectType(field.IndirectType.Elem())) if err := t.CheckPKs(); err != nil { panic(err) @@ -805,40 +862,6 @@ func (t *Table) m2mRelation(field *Field) *Relation { return rel } -func (t *Table) inlineFields(field *Field, seen map[reflect.Type]struct{}) { - if seen == nil { - seen = map[reflect.Type]struct{}{t.Type: {}} - } - - if _, ok := seen[field.IndirectType]; ok { - return - } - seen[field.IndirectType] = struct{}{} - - joinTable := t.dialect.Tables().Ref(field.IndirectType) - for _, f := range joinTable.allFields { - f = f.Clone() - f.GoName = field.GoName + "_" + f.GoName - f.Name = field.Name + "__" + f.Name - f.SQLName = t.quoteIdent(f.Name) - f.Index = withIndex(field.Index, f.Index) - - t.fieldsMapMu.Lock() - if _, ok := t.FieldMap[f.Name]; !ok { - t.FieldMap[f.Name] = f - } - t.fieldsMapMu.Unlock() - - if f.IndirectType.Kind() != reflect.Struct { - continue - } - - if _, ok := seen[f.IndirectType]; !ok { - t.inlineFields(f, seen) - } - } -} - //------------------------------------------------------------------------------ func (t *Table) Dialect() Dialect { return t.dialect } @@ -890,7 +913,7 @@ func isKnownTableOption(name string) bool { func isKnownFieldOption(name string) bool { switch name { case "column", - "alias", + "alt", "type", "array", "hstore", @@ -931,15 +954,6 @@ func isKnownFKRule(name string) bool { return false } -func removeField(fields []*Field, field *Field) []*Field { - for i, f := range fields { - if f == field { - return append(fields[:i], fields[i+1:]...) - } - } - return fields -} - func parseRelationJoin(join []string) ([]string, []string) { var ss []string if len(join) == 1 { @@ -1026,7 +1040,7 @@ func softDeleteFieldUpdaterFallback(field *Field) func(fv reflect.Value, tm time } } -func withIndex(a, b []int) []int { +func makeIndex(a, b []int) []int { dest := make([]int, 0, len(a)+len(b)) dest = append(dest, a...) dest = append(dest, b...) |