summaryrefslogtreecommitdiff
path: root/vendor/github.com/uptrace/bun/schema/tables.go
blob: 58c45cbee4a70b79fe1d7cc102dc67d6f35e8d4d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
package schema

import (
	"fmt"
	"reflect"
	"sync"

	"github.com/puzpuzpuz/xsync/v3"
)

type Tables struct {
	dialect Dialect

	mu     sync.Mutex
	tables *xsync.MapOf[reflect.Type, *Table]

	inProgress map[reflect.Type]*Table
}

func NewTables(dialect Dialect) *Tables {
	return &Tables{
		dialect:    dialect,
		tables:     xsync.NewMapOf[reflect.Type, *Table](),
		inProgress: make(map[reflect.Type]*Table),
	}
}

func (t *Tables) Register(models ...interface{}) {
	for _, model := range models {
		_ = t.Get(reflect.TypeOf(model).Elem())
	}
}

func (t *Tables) Get(typ reflect.Type) *Table {
	typ = indirectType(typ)
	if typ.Kind() != reflect.Struct {
		panic(fmt.Errorf("got %s, wanted %s", typ.Kind(), reflect.Struct))
	}

	if v, ok := t.tables.Load(typ); ok {
		return v
	}

	t.mu.Lock()
	defer t.mu.Unlock()

	if v, ok := t.tables.Load(typ); ok {
		return v
	}

	table := t.InProgress(typ)
	table.initRelations()

	t.dialect.OnTable(table)
	for _, field := range table.FieldMap {
		if field.UserSQLType == "" {
			field.UserSQLType = field.DiscoveredSQLType
		}
		if field.CreateTableSQLType == "" {
			field.CreateTableSQLType = field.UserSQLType
		}
	}

	t.tables.Store(typ, table)
	return table
}

func (t *Tables) InProgress(typ reflect.Type) *Table {
	if table, ok := t.inProgress[typ]; ok {
		return table
	}

	table := new(Table)
	t.inProgress[typ] = table
	table.init(t.dialect, typ, false)

	return table
}

// ByModel gets the table by its Go name.
func (t *Tables) ByModel(name string) *Table {
	var found *Table
	t.tables.Range(func(typ reflect.Type, table *Table) bool {
		if table.TypeName == name {
			found = table
			return false
		}
		return true
	})
	return found
}

// ByName gets the table by its SQL name.
func (t *Tables) ByName(name string) *Table {
	var found *Table
	t.tables.Range(func(typ reflect.Type, table *Table) bool {
		if table.Name == name {
			found = table
			return false
		}
		return true
	})
	return found
}

// All returns all registered tables.
func (t *Tables) All() []*Table {
	var found []*Table
	t.tables.Range(func(typ reflect.Type, table *Table) bool {
		found = append(found, table)
		return true
	})
	return found
}