diff options
Diffstat (limited to 'vendor/github.com/uptrace/bun/schema/table.go')
-rw-r--r-- | vendor/github.com/uptrace/bun/schema/table.go | 86 |
1 files changed, 38 insertions, 48 deletions
diff --git a/vendor/github.com/uptrace/bun/schema/table.go b/vendor/github.com/uptrace/bun/schema/table.go index 0a23156a2..c8e71e38f 100644 --- a/vendor/github.com/uptrace/bun/schema/table.go +++ b/vendor/github.com/uptrace/bun/schema/table.go @@ -74,16 +74,7 @@ type structField struct { Table *Table } -func newTable( - dialect Dialect, typ reflect.Type, seen map[reflect.Type]*Table, canAddr bool, -) *Table { - if table, ok := seen[typ]; ok { - return table - } - - table := new(Table) - seen[typ] = table - +func (table *Table) init(dialect Dialect, typ reflect.Type, canAddr bool) { table.dialect = dialect table.Type = typ table.ZeroValue = reflect.New(table.Type).Elem() @@ -97,7 +88,7 @@ func newTable( table.Fields = make([]*Field, 0, typ.NumField()) table.FieldMap = make(map[string]*Field, typ.NumField()) - table.processFields(typ, seen, canAddr) + table.processFields(typ, canAddr) hooks := []struct { typ reflect.Type @@ -109,28 +100,15 @@ func newTable( {afterScanRowHookType, afterScanRowHookFlag}, } - typ = reflect.PtrTo(table.Type) + typ = reflect.PointerTo(table.Type) for _, hook := range hooks { if typ.Implements(hook.typ) { table.flags = table.flags.Set(hook.flag) } } - - return table -} - -func (t *Table) init() { - for _, field := range t.relFields { - t.processRelation(field) - } - t.relFields = nil } -func (t *Table) processFields( - typ reflect.Type, - seen map[reflect.Type]*Table, - canAddr bool, -) { +func (t *Table) processFields(typ reflect.Type, canAddr bool) { type embeddedField struct { prefix string index []int @@ -172,7 +150,7 @@ func (t *Table) processFields( continue } - subtable := newTable(t.dialect, sfType, seen, canAddr) + subtable := t.dialect.Tables().InProgress(sfType) for _, subfield := range subtable.allFields { embedded = append(embedded, embeddedField{ @@ -206,7 +184,7 @@ func (t *Table) processFields( t.TypeName, sf.Name, fieldType.Kind())) } - subtable := newTable(t.dialect, fieldType, seen, canAddr) + subtable := t.dialect.Tables().InProgress(fieldType) for _, subfield := range subtable.allFields { embedded = append(embedded, embeddedField{ prefix: prefix, @@ -229,7 +207,7 @@ func (t *Table) processFields( } t.StructMap[field.Name] = &structField{ Index: field.Index, - Table: newTable(t.dialect, field.IndirectType, seen, canAddr), + Table: t.dialect.Tables().InProgress(field.IndirectType), } } } @@ -423,6 +401,10 @@ func (t *Table) newField(sf reflect.StructField, tag tagparser.Tag) *Field { sqlName = tag.Name } + if s, ok := tag.Option("column"); ok { + sqlName = s + } + for name := range tag.Options { if !isKnownFieldOption(name) { internal.Warn.Printf("%s.%s has unknown tag option: %q", t.TypeName, sf.Name, name) @@ -490,6 +472,13 @@ func (t *Table) newField(sf reflect.StructField, tag tagparser.Tag) *Field { //--------------------------------------------------------------------------------------- +func (t *Table) initRelations() { + for _, field := range t.relFields { + t.processRelation(field) + } + t.relFields = nil +} + func (t *Table) processRelation(field *Field) { if rel, ok := field.Tag.Option("rel"); ok { t.initRelation(field, rel) @@ -577,7 +566,7 @@ func (t *Table) belongsToRelation(field *Field) *Relation { joinColumn := joinColumns[i] if f := t.FieldMap[baseColumn]; f != nil { - rel.BaseFields = append(rel.BaseFields, f) + rel.BasePKs = append(rel.BasePKs, f) } else { panic(fmt.Errorf( "bun: %s belongs-to %s: %s must have column %s", @@ -586,7 +575,7 @@ func (t *Table) belongsToRelation(field *Field) *Relation { } if f := joinTable.FieldMap[joinColumn]; f != nil { - rel.JoinFields = append(rel.JoinFields, f) + rel.JoinPKs = append(rel.JoinPKs, f) } else { panic(fmt.Errorf( "bun: %s belongs-to %s: %s must have column %s", @@ -597,17 +586,17 @@ func (t *Table) belongsToRelation(field *Field) *Relation { return rel } - rel.JoinFields = joinTable.PKs + rel.JoinPKs = joinTable.PKs fkPrefix := internal.Underscore(field.GoName) + "_" for _, joinPK := range joinTable.PKs { fkName := fkPrefix + joinPK.Name if fk := t.FieldMap[fkName]; fk != nil { - rel.BaseFields = append(rel.BaseFields, fk) + rel.BasePKs = append(rel.BasePKs, fk) continue } if fk := t.FieldMap[joinPK.Name]; fk != nil { - rel.BaseFields = append(rel.BaseFields, fk) + rel.BasePKs = append(rel.BasePKs, fk) continue } @@ -640,7 +629,7 @@ func (t *Table) hasOneRelation(field *Field) *Relation { baseColumns, joinColumns := parseRelationJoin(join) for i, baseColumn := range baseColumns { if f := t.FieldMap[baseColumn]; f != nil { - rel.BaseFields = append(rel.BaseFields, f) + rel.BasePKs = append(rel.BasePKs, f) } else { panic(fmt.Errorf( "bun: %s has-one %s: %s must have column %s", @@ -650,7 +639,7 @@ func (t *Table) hasOneRelation(field *Field) *Relation { joinColumn := joinColumns[i] if f := joinTable.FieldMap[joinColumn]; f != nil { - rel.JoinFields = append(rel.JoinFields, f) + rel.JoinPKs = append(rel.JoinPKs, f) } else { panic(fmt.Errorf( "bun: %s has-one %s: %s must have column %s", @@ -661,17 +650,17 @@ func (t *Table) hasOneRelation(field *Field) *Relation { return rel } - rel.BaseFields = t.PKs + rel.BasePKs = t.PKs fkPrefix := internal.Underscore(t.ModelName) + "_" for _, pk := range t.PKs { fkName := fkPrefix + pk.Name if f := joinTable.FieldMap[fkName]; f != nil { - rel.JoinFields = append(rel.JoinFields, f) + rel.JoinPKs = append(rel.JoinPKs, f) continue } if f := joinTable.FieldMap[pk.Name]; f != nil { - rel.JoinFields = append(rel.JoinFields, f) + rel.JoinPKs = append(rel.JoinPKs, f) continue } @@ -720,7 +709,7 @@ func (t *Table) hasManyRelation(field *Field) *Relation { } if f := t.FieldMap[baseColumn]; f != nil { - rel.BaseFields = append(rel.BaseFields, f) + rel.BasePKs = append(rel.BasePKs, f) } else { panic(fmt.Errorf( "bun: %s has-many %s: %s must have column %s", @@ -729,7 +718,7 @@ func (t *Table) hasManyRelation(field *Field) *Relation { } if f := joinTable.FieldMap[joinColumn]; f != nil { - rel.JoinFields = append(rel.JoinFields, f) + rel.JoinPKs = append(rel.JoinPKs, f) } else { panic(fmt.Errorf( "bun: %s has-many %s: %s must have column %s", @@ -738,7 +727,7 @@ func (t *Table) hasManyRelation(field *Field) *Relation { } } } else { - rel.BaseFields = t.PKs + rel.BasePKs = t.PKs fkPrefix := internal.Underscore(t.ModelName) + "_" if isPolymorphic { polymorphicColumn = fkPrefix + "type" @@ -747,12 +736,12 @@ func (t *Table) hasManyRelation(field *Field) *Relation { for _, pk := range t.PKs { joinColumn := fkPrefix + pk.Name if fk := joinTable.FieldMap[joinColumn]; fk != nil { - rel.JoinFields = append(rel.JoinFields, fk) + rel.JoinPKs = append(rel.JoinPKs, fk) continue } if fk := joinTable.FieldMap[pk.Name]; fk != nil { - rel.JoinFields = append(rel.JoinFields, fk) + rel.JoinPKs = append(rel.JoinPKs, fk) continue } @@ -852,12 +841,12 @@ func (t *Table) m2mRelation(field *Field) *Relation { } leftRel := m2mTable.belongsToRelation(leftField) - rel.BaseFields = leftRel.JoinFields - rel.M2MBaseFields = leftRel.BaseFields + rel.BasePKs = leftRel.JoinPKs + rel.M2MBasePKs = leftRel.BasePKs rightRel := m2mTable.belongsToRelation(rightField) - rel.JoinFields = rightRel.JoinFields - rel.M2MJoinFields = rightRel.BaseFields + rel.JoinPKs = rightRel.JoinPKs + rel.M2MJoinPKs = rightRel.BasePKs return rel } @@ -918,6 +907,7 @@ func isKnownFieldOption(name string) bool { "array", "hstore", "composite", + "multirange", "json_use_number", "msgpack", "notnull", |