diff options
Diffstat (limited to 'vendor/github.com/uptrace/bun/schema/tables.go')
-rw-r--r-- | vendor/github.com/uptrace/bun/schema/tables.go | 104 |
1 files changed, 32 insertions, 72 deletions
diff --git a/vendor/github.com/uptrace/bun/schema/tables.go b/vendor/github.com/uptrace/bun/schema/tables.go index 19aff8606..985093421 100644 --- a/vendor/github.com/uptrace/bun/schema/tables.go +++ b/vendor/github.com/uptrace/bun/schema/tables.go @@ -4,22 +4,24 @@ import ( "fmt" "reflect" "sync" + + "github.com/puzpuzpuz/xsync/v3" ) type Tables struct { dialect Dialect - tables sync.Map - mu sync.RWMutex - seen map[reflect.Type]*Table - inProgress map[reflect.Type]*tableInProgress + mu sync.Mutex + tables *xsync.MapOf[reflect.Type, *Table] + + inProgress map[reflect.Type]*Table } func NewTables(dialect Dialect) *Tables { return &Tables{ dialect: dialect, - seen: make(map[reflect.Type]*Table), - inProgress: make(map[reflect.Type]*tableInProgress), + tables: xsync.NewMapOf[reflect.Type, *Table](), + inProgress: make(map[reflect.Type]*Table), } } @@ -30,58 +32,26 @@ func (t *Tables) Register(models ...interface{}) { } func (t *Tables) Get(typ reflect.Type) *Table { - return t.table(typ, false) -} - -func (t *Tables) InProgress(typ reflect.Type) *Table { - return t.table(typ, true) -} - -func (t *Tables) table(typ reflect.Type, allowInProgress bool) *Table { typ = indirectType(typ) if typ.Kind() != reflect.Struct { panic(fmt.Errorf("got %s, wanted %s", typ.Kind(), reflect.Struct)) } if v, ok := t.tables.Load(typ); ok { - return v.(*Table) + return v } t.mu.Lock() + defer t.mu.Unlock() if v, ok := t.tables.Load(typ); ok { - t.mu.Unlock() - return v.(*Table) - } - - var table *Table - - inProgress := t.inProgress[typ] - if inProgress == nil { - table = newTable(t.dialect, typ, t.seen, false) - inProgress = newTableInProgress(table) - t.inProgress[typ] = inProgress - } else { - table = inProgress.table + return v } - t.mu.Unlock() - - if allowInProgress { - return table - } - - if !inProgress.init() { - return table - } - - t.mu.Lock() - delete(t.inProgress, typ) - t.tables.Store(typ, table) - t.mu.Unlock() + table := t.InProgress(typ) + table.initRelations() t.dialect.OnTable(table) - for _, field := range table.FieldMap { if field.UserSQLType == "" { field.UserSQLType = field.DiscoveredSQLType @@ -91,15 +61,27 @@ func (t *Tables) table(typ reflect.Type, allowInProgress bool) *Table { } } + t.tables.Store(typ, table) + return table +} + +func (t *Tables) InProgress(typ reflect.Type) *Table { + if table, ok := t.inProgress[typ]; ok { + return table + } + + table := new(Table) + t.inProgress[typ] = table + table.init(t.dialect, typ, false) + return table } func (t *Tables) ByModel(name string) *Table { var found *Table - t.tables.Range(func(key, value interface{}) bool { - t := value.(*Table) - if t.TypeName == name { - found = t + t.tables.Range(func(typ reflect.Type, table *Table) bool { + if table.TypeName == name { + found = table return false } return true @@ -109,34 +91,12 @@ func (t *Tables) ByModel(name string) *Table { func (t *Tables) ByName(name string) *Table { var found *Table - t.tables.Range(func(key, value interface{}) bool { - t := value.(*Table) - if t.Name == name { - found = t + t.tables.Range(func(typ reflect.Type, table *Table) bool { + if table.Name == name { + found = table return false } return true }) return found } - -type tableInProgress struct { - table *Table - - initOnce sync.Once -} - -func newTableInProgress(table *Table) *tableInProgress { - return &tableInProgress{ - table: table, - } -} - -func (inp *tableInProgress) init() bool { - var inited bool - inp.initOnce.Do(func() { - inp.table.init() - inited = true - }) - return inited -} |