summaryrefslogtreecommitdiff
path: root/vendor/github.com/uptrace/bun/schema/tables.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/uptrace/bun/schema/tables.go')
-rw-r--r--vendor/github.com/uptrace/bun/schema/tables.go104
1 files changed, 32 insertions, 72 deletions
diff --git a/vendor/github.com/uptrace/bun/schema/tables.go b/vendor/github.com/uptrace/bun/schema/tables.go
index 19aff8606..985093421 100644
--- a/vendor/github.com/uptrace/bun/schema/tables.go
+++ b/vendor/github.com/uptrace/bun/schema/tables.go
@@ -4,22 +4,24 @@ import (
"fmt"
"reflect"
"sync"
+
+ "github.com/puzpuzpuz/xsync/v3"
)
type Tables struct {
dialect Dialect
- tables sync.Map
- mu sync.RWMutex
- seen map[reflect.Type]*Table
- inProgress map[reflect.Type]*tableInProgress
+ mu sync.Mutex
+ tables *xsync.MapOf[reflect.Type, *Table]
+
+ inProgress map[reflect.Type]*Table
}
func NewTables(dialect Dialect) *Tables {
return &Tables{
dialect: dialect,
- seen: make(map[reflect.Type]*Table),
- inProgress: make(map[reflect.Type]*tableInProgress),
+ tables: xsync.NewMapOf[reflect.Type, *Table](),
+ inProgress: make(map[reflect.Type]*Table),
}
}
@@ -30,58 +32,26 @@ func (t *Tables) Register(models ...interface{}) {
}
func (t *Tables) Get(typ reflect.Type) *Table {
- return t.table(typ, false)
-}
-
-func (t *Tables) InProgress(typ reflect.Type) *Table {
- return t.table(typ, true)
-}
-
-func (t *Tables) table(typ reflect.Type, allowInProgress bool) *Table {
typ = indirectType(typ)
if typ.Kind() != reflect.Struct {
panic(fmt.Errorf("got %s, wanted %s", typ.Kind(), reflect.Struct))
}
if v, ok := t.tables.Load(typ); ok {
- return v.(*Table)
+ return v
}
t.mu.Lock()
+ defer t.mu.Unlock()
if v, ok := t.tables.Load(typ); ok {
- t.mu.Unlock()
- return v.(*Table)
- }
-
- var table *Table
-
- inProgress := t.inProgress[typ]
- if inProgress == nil {
- table = newTable(t.dialect, typ, t.seen, false)
- inProgress = newTableInProgress(table)
- t.inProgress[typ] = inProgress
- } else {
- table = inProgress.table
+ return v
}
- t.mu.Unlock()
-
- if allowInProgress {
- return table
- }
-
- if !inProgress.init() {
- return table
- }
-
- t.mu.Lock()
- delete(t.inProgress, typ)
- t.tables.Store(typ, table)
- t.mu.Unlock()
+ table := t.InProgress(typ)
+ table.initRelations()
t.dialect.OnTable(table)
-
for _, field := range table.FieldMap {
if field.UserSQLType == "" {
field.UserSQLType = field.DiscoveredSQLType
@@ -91,15 +61,27 @@ func (t *Tables) table(typ reflect.Type, allowInProgress bool) *Table {
}
}
+ t.tables.Store(typ, table)
+ return table
+}
+
+func (t *Tables) InProgress(typ reflect.Type) *Table {
+ if table, ok := t.inProgress[typ]; ok {
+ return table
+ }
+
+ table := new(Table)
+ t.inProgress[typ] = table
+ table.init(t.dialect, typ, false)
+
return table
}
func (t *Tables) ByModel(name string) *Table {
var found *Table
- t.tables.Range(func(key, value interface{}) bool {
- t := value.(*Table)
- if t.TypeName == name {
- found = t
+ t.tables.Range(func(typ reflect.Type, table *Table) bool {
+ if table.TypeName == name {
+ found = table
return false
}
return true
@@ -109,34 +91,12 @@ func (t *Tables) ByModel(name string) *Table {
func (t *Tables) ByName(name string) *Table {
var found *Table
- t.tables.Range(func(key, value interface{}) bool {
- t := value.(*Table)
- if t.Name == name {
- found = t
+ t.tables.Range(func(typ reflect.Type, table *Table) bool {
+ if table.Name == name {
+ found = table
return false
}
return true
})
return found
}
-
-type tableInProgress struct {
- table *Table
-
- initOnce sync.Once
-}
-
-func newTableInProgress(table *Table) *tableInProgress {
- return &tableInProgress{
- table: table,
- }
-}
-
-func (inp *tableInProgress) init() bool {
- var inited bool
- inp.initOnce.Do(func() {
- inp.table.init()
- inited = true
- })
- return inited
-}