diff options
Diffstat (limited to 'vendor/github.com/uptrace/bun/query_table_create.go')
-rw-r--r-- | vendor/github.com/uptrace/bun/query_table_create.go | 42 |
1 files changed, 32 insertions, 10 deletions
diff --git a/vendor/github.com/uptrace/bun/query_table_create.go b/vendor/github.com/uptrace/bun/query_table_create.go index 002250bc1..518dbfd1c 100644 --- a/vendor/github.com/uptrace/bun/query_table_create.go +++ b/vendor/github.com/uptrace/bun/query_table_create.go @@ -3,8 +3,10 @@ package bun import ( "context" "database/sql" + "fmt" "sort" "strconv" + "strings" "github.com/uptrace/bun/dialect/feature" "github.com/uptrace/bun/dialect/sqltype" @@ -17,7 +19,12 @@ type CreateTableQuery struct { temp bool ifNotExists bool - varchar int + + // varchar changes the default length for VARCHAR columns. + // Because some dialects require that length is always specified for VARCHAR type, + // we will use the exact user-defined type if length is set explicitly, as in `bun:",type:varchar(5)"`, + // but assume the new default length when it's omitted, e.g. `bun:",type:varchar"`. + varchar int fks []schema.QueryWithArgs partitionBy schema.QueryWithArgs @@ -32,6 +39,7 @@ func NewCreateTableQuery(db *DB) *CreateTableQuery { db: db, conn: db.DB, }, + varchar: db.Dialect().DefaultVarcharLen(), } return q } @@ -46,6 +54,11 @@ func (q *CreateTableQuery) Model(model interface{}) *CreateTableQuery { return q } +func (q *CreateTableQuery) Err(err error) *CreateTableQuery { + q.setErr(err) + return q +} + // ------------------------------------------------------------------------------ func (q *CreateTableQuery) Table(tables ...string) *CreateTableQuery { @@ -82,7 +95,12 @@ func (q *CreateTableQuery) IfNotExists() *CreateTableQuery { return q } +// Varchar sets default length for VARCHAR columns. func (q *CreateTableQuery) Varchar(n int) *CreateTableQuery { + if n <= 0 { + q.setErr(fmt.Errorf("bun: illegal VARCHAR length: %d", n)) + return q + } q.varchar = n return q } @@ -120,7 +138,7 @@ func (q *CreateTableQuery) WithForeignKeys() *CreateTableQuery { return q } -//------------------------------------------------------------------------------ +// ------------------------------------------------------------------------------ func (q *CreateTableQuery) Operation() string { return "CREATE TABLE" @@ -221,19 +239,23 @@ func (q *CreateTableQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []by } func (q *CreateTableQuery) appendSQLType(b []byte, field *schema.Field) []byte { - if field.CreateTableSQLType != field.DiscoveredSQLType { + // Most of the time these two will match, but for the cases where DiscoveredSQLType is dialect-specific, + // e.g. pgdialect would change sqltype.SmallInt to pgTypeSmallSerial for columns that have `bun:",autoincrement"` + if !strings.EqualFold(field.CreateTableSQLType, field.DiscoveredSQLType) { return append(b, field.CreateTableSQLType...) } - if q.varchar > 0 && - field.CreateTableSQLType == sqltype.VarChar { - b = append(b, "varchar("...) - b = strconv.AppendInt(b, int64(q.varchar), 10) - b = append(b, ")"...) - return b + // For all common SQL types except VARCHAR, both UserDefinedSQLType and DiscoveredSQLType specify the correct type, + // and we needn't modify it. For VARCHAR columns, we will stop to check if a valid length has been set in .Varchar(int). + if !strings.EqualFold(field.CreateTableSQLType, sqltype.VarChar) || q.varchar <= 0 { + return append(b, field.CreateTableSQLType...) } - return append(b, field.CreateTableSQLType...) + b = append(b, sqltype.VarChar...) + b = append(b, "("...) + b = strconv.AppendInt(b, int64(q.varchar), 10) + b = append(b, ")"...) + return b } func (q *CreateTableQuery) appendUniqueConstraints(fmter schema.Formatter, b []byte) []byte { |