diff options
Diffstat (limited to 'vendor/github.com/uptrace/bun/schema/tables.go')
-rw-r--r-- | vendor/github.com/uptrace/bun/schema/tables.go | 148 |
1 files changed, 148 insertions, 0 deletions
diff --git a/vendor/github.com/uptrace/bun/schema/tables.go b/vendor/github.com/uptrace/bun/schema/tables.go new file mode 100644 index 000000000..d82d08f59 --- /dev/null +++ b/vendor/github.com/uptrace/bun/schema/tables.go @@ -0,0 +1,148 @@ +package schema + +import ( + "fmt" + "reflect" + "sync" +) + +type tableInProgress struct { + table *Table + + init1Once sync.Once + init2Once sync.Once +} + +func newTableInProgress(table *Table) *tableInProgress { + return &tableInProgress{ + table: table, + } +} + +func (inp *tableInProgress) init1() bool { + var inited bool + inp.init1Once.Do(func() { + inp.table.init1() + inited = true + }) + return inited +} + +func (inp *tableInProgress) init2() bool { + var inited bool + inp.init2Once.Do(func() { + inp.table.init2() + inited = true + }) + return inited +} + +type Tables struct { + dialect Dialect + tables sync.Map + + mu sync.RWMutex + inProgress map[reflect.Type]*tableInProgress +} + +func NewTables(dialect Dialect) *Tables { + return &Tables{ + dialect: dialect, + inProgress: make(map[reflect.Type]*tableInProgress), + } +} + +func (t *Tables) Register(models ...interface{}) { + for _, model := range models { + _ = t.Get(reflect.TypeOf(model).Elem()) + } +} + +func (t *Tables) Get(typ reflect.Type) *Table { + return t.table(typ, false) +} + +func (t *Tables) Ref(typ reflect.Type) *Table { + return t.table(typ, true) +} + +func (t *Tables) table(typ reflect.Type, allowInProgress bool) *Table { + if typ.Kind() != reflect.Struct { + panic(fmt.Errorf("got %s, wanted %s", typ.Kind(), reflect.Struct)) + } + + if v, ok := t.tables.Load(typ); ok { + return v.(*Table) + } + + t.mu.Lock() + + if v, ok := t.tables.Load(typ); ok { + t.mu.Unlock() + return v.(*Table) + } + + var table *Table + + inProgress := t.inProgress[typ] + if inProgress == nil { + table = newTable(t.dialect, typ) + inProgress = newTableInProgress(table) + t.inProgress[typ] = inProgress + } else { + table = inProgress.table + } + + t.mu.Unlock() + + inProgress.init1() + if allowInProgress { + return table + } + + if inProgress.init2() { + t.mu.Lock() + delete(t.inProgress, typ) + t.tables.Store(typ, table) + t.mu.Unlock() + } + + t.dialect.OnTable(table) + + for _, field := range table.FieldMap { + if field.UserSQLType == "" { + field.UserSQLType = field.DiscoveredSQLType + } + if field.CreateTableSQLType == "" { + field.CreateTableSQLType = field.UserSQLType + } + } + + return table +} + +func (t *Tables) ByModel(name string) *Table { + var found *Table + t.tables.Range(func(key, value interface{}) bool { + t := value.(*Table) + if t.TypeName == name { + found = t + return false + } + return true + }) + return found +} + +func (t *Tables) ByName(name string) *Table { + var found *Table + t.tables.Range(func(key, value interface{}) bool { + t := value.(*Table) + if t.Name == name { + found = t + return false + } + return true + }) + return found +} |