diff options
Diffstat (limited to 'vendor/github.com/uptrace/bun/query_table_create.go')
-rw-r--r-- | vendor/github.com/uptrace/bun/query_table_create.go | 95 |
1 files changed, 60 insertions, 35 deletions
diff --git a/vendor/github.com/uptrace/bun/query_table_create.go b/vendor/github.com/uptrace/bun/query_table_create.go index 518dbfd1c..3d98da07b 100644 --- a/vendor/github.com/uptrace/bun/query_table_create.go +++ b/vendor/github.com/uptrace/bun/query_table_create.go @@ -1,6 +1,7 @@ package bun import ( + "bytes" "context" "database/sql" "fmt" @@ -19,6 +20,7 @@ type CreateTableQuery struct { temp bool ifNotExists bool + fksFromRel bool // Create foreign keys captured in table's relations. // varchar changes the default length for VARCHAR columns. // Because some dialects require that length is always specified for VARCHAR type, @@ -120,21 +122,9 @@ func (q *CreateTableQuery) TableSpace(tablespace string) *CreateTableQuery { return q } +// WithForeignKeys adds a FOREIGN KEY clause for each of the model's existing relations. func (q *CreateTableQuery) WithForeignKeys() *CreateTableQuery { - for _, relation := range q.tableModel.Table().Relations { - if relation.Type == schema.ManyToManyRelation || - relation.Type == schema.HasManyRelation { - continue - } - - q = q.ForeignKey("(?) REFERENCES ? (?) ? ?", - Safe(appendColumns(nil, "", relation.BaseFields)), - relation.JoinTable.SQLName, - Safe(appendColumns(nil, "", relation.JoinFields)), - Safe(relation.OnUpdate), - Safe(relation.OnDelete), - ) - } + q.fksFromRel = true return q } @@ -157,7 +147,7 @@ func (q *CreateTableQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []by b = append(b, "TEMP "...) } b = append(b, "TABLE "...) - if q.ifNotExists && fmter.Dialect().Features().Has(feature.TableNotExists) { + if q.ifNotExists && fmter.HasFeature(feature.TableNotExists) { b = append(b, "IF NOT EXISTS "...) } b, err = q.appendFirstTable(fmter, b) @@ -178,19 +168,12 @@ func (q *CreateTableQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []by if field.NotNull { b = append(b, " NOT NULL"...) } - if field.AutoIncrement { - switch { - case fmter.Dialect().Features().Has(feature.AutoIncrement): - b = append(b, " AUTO_INCREMENT"...) - case fmter.Dialect().Features().Has(feature.Identity): - b = append(b, " IDENTITY"...) - } - } - if field.Identity { - if fmter.Dialect().Features().Has(feature.GeneratedIdentity) { - b = append(b, " GENERATED BY DEFAULT AS IDENTITY"...) - } + + if (field.Identity && fmter.HasFeature(feature.GeneratedIdentity)) || + (field.AutoIncrement && (fmter.HasFeature(feature.AutoIncrement) || fmter.HasFeature(feature.Identity))) { + b = q.db.dialect.AppendSequence(b, q.table, field) } + if field.SQLDefault != "" { b = append(b, " DEFAULT "...) b = append(b, field.SQLDefault...) @@ -210,8 +193,20 @@ func (q *CreateTableQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []by } } - b = q.appendPKConstraint(b, q.table.PKs) + // In SQLite AUTOINCREMENT is only valid for INTEGER PRIMARY KEY columns, so it might be that + // a primary key constraint has already been created in dialect.AppendSequence() call above. + // See sqldialect.Dialect.AppendSequence() for more details. + if len(q.table.PKs) > 0 && !bytes.Contains(b, []byte("PRIMARY KEY")) { + b = q.appendPKConstraint(b, q.table.PKs) + } b = q.appendUniqueConstraints(fmter, b) + + if q.fksFromRel { + b, err = q.appendFKConstraintsRel(fmter, b) + if err != nil { + return nil, err + } + } b, err = q.appendFKConstraints(fmter, b) if err != nil { return nil, err @@ -295,13 +290,38 @@ func (q *CreateTableQuery) appendUniqueConstraint( return b } +// appendFKConstraintsRel appends a FOREIGN KEY clause for each of the model's existing relations. +func (q *CreateTableQuery) appendFKConstraintsRel(fmter schema.Formatter, b []byte) (_ []byte, err error) { + for _, rel := range q.tableModel.Table().Relations { + if rel.References() { + b, err = q.appendFK(fmter, b, schema.QueryWithArgs{ + Query: "(?) REFERENCES ? (?) ? ?", + Args: []interface{}{ + Safe(appendColumns(nil, "", rel.BaseFields)), + rel.JoinTable.SQLName, + Safe(appendColumns(nil, "", rel.JoinFields)), + Safe(rel.OnUpdate), + Safe(rel.OnDelete), + }, + }) + if err != nil { + return nil, err + } + } + } + return b, nil +} + +func (q *CreateTableQuery) appendFK(fmter schema.Formatter, b []byte, fk schema.QueryWithArgs) (_ []byte, err error) { + b = append(b, ", FOREIGN KEY "...) + return fk.AppendQuery(fmter, b) +} + func (q *CreateTableQuery) appendFKConstraints( fmter schema.Formatter, b []byte, ) (_ []byte, err error) { for _, fk := range q.fks { - b = append(b, ", FOREIGN KEY "...) - b, err = fk.AppendQuery(fmter, b) - if err != nil { + if b, err = q.appendFK(fmter, b, fk); err != nil { return nil, err } } @@ -309,10 +329,6 @@ func (q *CreateTableQuery) appendFKConstraints( } func (q *CreateTableQuery) appendPKConstraint(b []byte, pks []*schema.Field) []byte { - if len(pks) == 0 { - return b - } - b = append(b, ", PRIMARY KEY ("...) b = appendColumns(b, "", pks) b = append(b, ")"...) @@ -364,3 +380,12 @@ func (q *CreateTableQuery) afterCreateTableHook(ctx context.Context) error { } return nil } + +func (q *CreateTableQuery) String() string { + buf, err := q.AppendQuery(q.db.Formatter(), nil) + if err != nil { + panic(err) + } + + return string(buf) +} |