summaryrefslogtreecommitdiff
path: root/vendor/github.com/uptrace/bun/query_table_create.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/uptrace/bun/query_table_create.go')
-rw-r--r--vendor/github.com/uptrace/bun/query_table_create.go275
1 files changed, 275 insertions, 0 deletions
diff --git a/vendor/github.com/uptrace/bun/query_table_create.go b/vendor/github.com/uptrace/bun/query_table_create.go
new file mode 100644
index 000000000..0a4b3567c
--- /dev/null
+++ b/vendor/github.com/uptrace/bun/query_table_create.go
@@ -0,0 +1,275 @@
+package bun
+
+import (
+ "context"
+ "database/sql"
+ "sort"
+ "strconv"
+
+ "github.com/uptrace/bun/dialect/feature"
+ "github.com/uptrace/bun/dialect/sqltype"
+ "github.com/uptrace/bun/internal"
+ "github.com/uptrace/bun/schema"
+)
+
+type CreateTableQuery struct {
+ baseQuery
+
+ temp bool
+ ifNotExists bool
+ varchar int
+
+ fks []schema.QueryWithArgs
+ partitionBy schema.QueryWithArgs
+ tablespace schema.QueryWithArgs
+}
+
+func NewCreateTableQuery(db *DB) *CreateTableQuery {
+ q := &CreateTableQuery{
+ baseQuery: baseQuery{
+ db: db,
+ conn: db.DB,
+ },
+ }
+ return q
+}
+
+func (q *CreateTableQuery) Conn(db IConn) *CreateTableQuery {
+ q.setConn(db)
+ return q
+}
+
+func (q *CreateTableQuery) Model(model interface{}) *CreateTableQuery {
+ q.setTableModel(model)
+ return q
+}
+
+//------------------------------------------------------------------------------
+
+func (q *CreateTableQuery) Table(tables ...string) *CreateTableQuery {
+ for _, table := range tables {
+ q.addTable(schema.UnsafeIdent(table))
+ }
+ return q
+}
+
+func (q *CreateTableQuery) TableExpr(query string, args ...interface{}) *CreateTableQuery {
+ q.addTable(schema.SafeQuery(query, args))
+ return q
+}
+
+func (q *CreateTableQuery) ModelTableExpr(query string, args ...interface{}) *CreateTableQuery {
+ q.modelTable = schema.SafeQuery(query, args)
+ return q
+}
+
+//------------------------------------------------------------------------------
+
+func (q *CreateTableQuery) Temp() *CreateTableQuery {
+ q.temp = true
+ return q
+}
+
+func (q *CreateTableQuery) IfNotExists() *CreateTableQuery {
+ q.ifNotExists = true
+ return q
+}
+
+func (q *CreateTableQuery) Varchar(n int) *CreateTableQuery {
+ q.varchar = n
+ return q
+}
+
+func (q *CreateTableQuery) ForeignKey(query string, args ...interface{}) *CreateTableQuery {
+ q.fks = append(q.fks, schema.SafeQuery(query, args))
+ return q
+}
+
+func (q *CreateTableQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) {
+ if q.err != nil {
+ return nil, q.err
+ }
+ if q.table == nil {
+ return nil, errNilModel
+ }
+
+ b = append(b, "CREATE "...)
+ if q.temp {
+ b = append(b, "TEMP "...)
+ }
+ b = append(b, "TABLE "...)
+ if q.ifNotExists {
+ b = append(b, "IF NOT EXISTS "...)
+ }
+ b, err = q.appendFirstTable(fmter, b)
+ if err != nil {
+ return nil, err
+ }
+
+ b = append(b, " ("...)
+
+ for i, field := range q.table.Fields {
+ if i > 0 {
+ b = append(b, ", "...)
+ }
+
+ b = append(b, field.SQLName...)
+ b = append(b, " "...)
+ b = q.appendSQLType(b, field)
+ if field.NotNull {
+ b = append(b, " NOT NULL"...)
+ }
+ if q.db.features.Has(feature.AutoIncrement) && field.AutoIncrement {
+ b = append(b, " AUTO_INCREMENT"...)
+ }
+ if field.SQLDefault != "" {
+ b = append(b, " DEFAULT "...)
+ b = append(b, field.SQLDefault...)
+ }
+ }
+
+ b = q.appendPKConstraint(b, q.table.PKs)
+ b = q.appendUniqueConstraints(fmter, b)
+ b, err = q.appenFKConstraints(fmter, b)
+ if err != nil {
+ return nil, err
+ }
+
+ b = append(b, ")"...)
+
+ if !q.partitionBy.IsZero() {
+ b = append(b, " PARTITION BY "...)
+ b, err = q.partitionBy.AppendQuery(fmter, b)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ if !q.tablespace.IsZero() {
+ b = append(b, " TABLESPACE "...)
+ b, err = q.tablespace.AppendQuery(fmter, b)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ return b, nil
+}
+
+func (q *CreateTableQuery) appendSQLType(b []byte, field *schema.Field) []byte {
+ if field.CreateTableSQLType != field.DiscoveredSQLType {
+ return append(b, field.CreateTableSQLType...)
+ }
+
+ if q.varchar > 0 &&
+ field.CreateTableSQLType == sqltype.VarChar {
+ b = append(b, "varchar("...)
+ b = strconv.AppendInt(b, int64(q.varchar), 10)
+ b = append(b, ")"...)
+ return b
+ }
+
+ return append(b, field.CreateTableSQLType...)
+}
+
+func (q *CreateTableQuery) appendUniqueConstraints(fmter schema.Formatter, b []byte) []byte {
+ unique := q.table.Unique
+
+ keys := make([]string, 0, len(unique))
+ for key := range unique {
+ keys = append(keys, key)
+ }
+ sort.Strings(keys)
+
+ for _, key := range keys {
+ b = q.appendUniqueConstraint(fmter, b, key, unique[key])
+ }
+
+ return b
+}
+
+func (q *CreateTableQuery) appendUniqueConstraint(
+ fmter schema.Formatter, b []byte, name string, fields []*schema.Field,
+) []byte {
+ if name != "" {
+ b = append(b, ", CONSTRAINT "...)
+ b = fmter.AppendIdent(b, name)
+ } else {
+ b = append(b, ","...)
+ }
+ b = append(b, " UNIQUE ("...)
+ b = appendColumns(b, "", fields)
+ b = append(b, ")"...)
+
+ return b
+}
+
+func (q *CreateTableQuery) appenFKConstraints(
+ fmter schema.Formatter, b []byte,
+) (_ []byte, err error) {
+ for _, fk := range q.fks {
+ b = append(b, ", FOREIGN KEY "...)
+ b, err = fk.AppendQuery(fmter, b)
+ if err != nil {
+ return nil, err
+ }
+ }
+ return b, nil
+}
+
+func (q *CreateTableQuery) appendPKConstraint(b []byte, pks []*schema.Field) []byte {
+ if len(pks) == 0 {
+ return b
+ }
+
+ b = append(b, ", PRIMARY KEY ("...)
+ b = appendColumns(b, "", pks)
+ b = append(b, ")"...)
+ return b
+}
+
+//------------------------------------------------------------------------------
+
+func (q *CreateTableQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) {
+ if err := q.beforeCreateTableHook(ctx); err != nil {
+ return nil, err
+ }
+
+ queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes())
+ if err != nil {
+ return nil, err
+ }
+
+ query := internal.String(queryBytes)
+
+ res, err := q.exec(ctx, q, query)
+ if err != nil {
+ return nil, err
+ }
+
+ if q.table != nil {
+ if err := q.afterCreateTableHook(ctx); err != nil {
+ return nil, err
+ }
+ }
+
+ return res, nil
+}
+
+func (q *CreateTableQuery) beforeCreateTableHook(ctx context.Context) error {
+ if hook, ok := q.table.ZeroIface.(BeforeCreateTableHook); ok {
+ if err := hook.BeforeCreateTable(ctx, q); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func (q *CreateTableQuery) afterCreateTableHook(ctx context.Context) error {
+ if hook, ok := q.table.ZeroIface.(AfterCreateTableHook); ok {
+ if err := hook.AfterCreateTable(ctx, q); err != nil {
+ return err
+ }
+ }
+ return nil
+}