diff options
Diffstat (limited to 'vendor/github.com/uptrace')
80 files changed, 12141 insertions, 0 deletions
diff --git a/vendor/github.com/uptrace/bun/.gitignore b/vendor/github.com/uptrace/bun/.gitignore new file mode 100644 index 000000000..6f7763c71 --- /dev/null +++ b/vendor/github.com/uptrace/bun/.gitignore @@ -0,0 +1,3 @@ +*.s3db +*.prof +*.test diff --git a/vendor/github.com/uptrace/bun/.prettierrc.yaml b/vendor/github.com/uptrace/bun/.prettierrc.yaml new file mode 100644 index 000000000..decea5634 --- /dev/null +++ b/vendor/github.com/uptrace/bun/.prettierrc.yaml @@ -0,0 +1,6 @@ +trailingComma: all +tabWidth: 2 +semi: false +singleQuote: true +proseWrap: always +printWidth: 100 diff --git a/vendor/github.com/uptrace/bun/CHANGELOG.md b/vendor/github.com/uptrace/bun/CHANGELOG.md new file mode 100644 index 000000000..01bf6ba31 --- /dev/null +++ b/vendor/github.com/uptrace/bun/CHANGELOG.md @@ -0,0 +1,99 @@ +# Changelog + +## v0.4.1 - Aug 18 2021 + +- Fixed migrate package to properly rollback migrations. +- Added `allowzero` tag option that undoes `nullzero` option. + +## v0.4.0 - Aug 11 2021 + +- Changed `WhereGroup` function to accept `*SelectQuery`. +- Fixed query hooks for count queries. + +## v0.3.4 - Jul 19 2021 + +- Renamed `migrate.CreateGo` to `CreateGoMigration`. +- Added `migrate.WithPackageName` to customize the Go package name in generated migrations. +- Renamed `migrate.CreateSQL` to `CreateSQLMigrations` and changed `CreateSQLMigrations` to create + both up and down migration files. + +## v0.3.1 - Jul 12 2021 + +- Renamed `alias` field struct tag to `alt` so it is not confused with column alias. +- Reworked migrate package API. See + [migrate](https://github.com/uptrace/bun/tree/master/example/migrate) example for details. + +## v0.3.0 - Jul 09 2021 + +- Changed migrate package to return structured data instead of logging the progress. See + [migrate](https://github.com/uptrace/bun/tree/master/example/migrate) example for details. + +## v0.2.14 - Jul 01 2021 + +- Added [sqliteshim](https://pkg.go.dev/github.com/uptrace/bun/driver/sqliteshim) by + [Ivan Trubach](https://github.com/tie). +- Added support for MySQL 5.7 in addition to MySQL 8. + +## v0.2.12 - Jun 29 2021 + +- Fixed scanners for net.IP and net.IPNet. + +## v0.2.10 - Jun 29 2021 + +- Fixed pgdriver to format passed query args. + +## v0.2.9 - Jun 27 2021 + +- Added support for prepared statements in pgdriver. + +## v0.2.7 - Jun 26 2021 + +- Added `UpdateQuery.Bulk` helper to generate bulk-update queries. + + Before: + + ```go + models := []Model{ + {42, "hello"}, + {43, "world"}, + } + return db.NewUpdate(). + With("_data", db.NewValues(&models)). + Model(&models). + Table("_data"). + Set("model.str = _data.str"). + Where("model.id = _data.id") + ``` + + Now: + + ```go + db.NewUpdate(). + Model(&models). + Bulk() + ``` + +## v0.2.5 - Jun 25 2021 + +- Changed time.Time to always append zero time as `NULL`. +- Added `db.RunInTx` helper. + +## v0.2.4 - Jun 21 2021 + +- Added SSL support to pgdriver. + +## v0.2.3 - Jun 20 2021 + +- Replaced `ForceDelete(ctx)` with `ForceDelete().Exec(ctx)` for soft deletes. + +## v0.2.1 - Jun 17 2021 + +- Renamed `DBI` to `IConn`. `IConn` is a common interface for `*sql.DB`, `*sql.Conn`, and `*sql.Tx`. +- Added `IDB`. `IDB` is a common interface for `*bun.DB`, `bun.Conn`, and `bun.Tx`. + +## v0.2.0 - Jun 16 2021 + +- Changed [model hooks](https://bun.uptrace.dev/guide/hooks.html#model-hooks). See + [model-hooks](example/model-hooks) example. +- Renamed `has-one` to `belongs-to`. Renamed `belongs-to` to `has-one`. Previously Bun used + incorrect names for these relations. diff --git a/vendor/github.com/uptrace/bun/LICENSE b/vendor/github.com/uptrace/bun/LICENSE new file mode 100644 index 000000000..7ec81810c --- /dev/null +++ b/vendor/github.com/uptrace/bun/LICENSE @@ -0,0 +1,24 @@ +Copyright (c) 2021 Vladimir Mihailenco. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/github.com/uptrace/bun/Makefile b/vendor/github.com/uptrace/bun/Makefile new file mode 100644 index 000000000..54744c617 --- /dev/null +++ b/vendor/github.com/uptrace/bun/Makefile @@ -0,0 +1,21 @@ +ALL_GO_MOD_DIRS := $(shell find . -type f -name 'go.mod' -exec dirname {} \; | sort) + +test: + set -e; for dir in $(ALL_GO_MOD_DIRS); do \ + echo "go test in $${dir}"; \ + (cd "$${dir}" && \ + go test ./... && \ + go vet); \ + done + +go_mod_tidy: + set -e; for dir in $(ALL_GO_MOD_DIRS); do \ + echo "go mod tidy in $${dir}"; \ + (cd "$${dir}" && \ + go get -d ./... && \ + go mod tidy); \ + done + +fmt: + gofmt -w -s ./ + goimports -w -local github.com/uptrace/bun ./ diff --git a/vendor/github.com/uptrace/bun/README.md b/vendor/github.com/uptrace/bun/README.md new file mode 100644 index 000000000..e7cc77a60 --- /dev/null +++ b/vendor/github.com/uptrace/bun/README.md @@ -0,0 +1,267 @@ +<p align="center"> + <a href="https://uptrace.dev/?utm_source=gh-redis&utm_campaign=gh-redis-banner1"> + <img src="https://raw.githubusercontent.com/uptrace/roadmap/master/banner1.png" alt="All-in-one tool to optimize performance and monitor errors & logs"> + </a> +</p> + +# Simple and performant SQL database client + +[](https://github.com/uptrace/bun/actions) +[](https://pkg.go.dev/github.com/uptrace/bun) +[](https://bun.uptrace.dev/) +[](https://discord.gg/rWtp5Aj) + +Main features are: + +- Works with [PostgreSQL](https://bun.uptrace.dev/guide/drivers.html#postgresql), + [MySQL](https://bun.uptrace.dev/guide/drivers.html#mysql), + [SQLite](https://bun.uptrace.dev/guide/drivers.html#sqlite). +- [Selecting](/example/basic/) into a map, struct, slice of maps/structs/vars. +- [Bulk inserts](https://bun.uptrace.dev/guide/queries.html#insert). +- [Bulk updates](https://bun.uptrace.dev/guide/queries.html#update) using common table expressions. +- [Bulk deletes](https://bun.uptrace.dev/guide/queries.html#delete). +- [Fixtures](https://bun.uptrace.dev/guide/fixtures.html). +- [Migrations](https://bun.uptrace.dev/guide/migrations.html). +- [Soft deletes](https://bun.uptrace.dev/guide/soft-deletes.html). + +Resources: + +- [Examples](https://github.com/uptrace/bun/tree/master/example) +- [Documentation](https://bun.uptrace.dev/) +- [Reference](https://pkg.go.dev/github.com/uptrace/bun) +- [Starter kit](https://github.com/go-bun/bun-starter-kit) +- [RealWorld app](https://github.com/go-bun/bun-realworld-app) + +<details> + <summary>github.com/frederikhors/orm-benchmark results</summary> + +``` + 4000 times - Insert + raw_stmt: 0.38s 94280 ns/op 718 B/op 14 allocs/op + raw: 0.39s 96719 ns/op 718 B/op 13 allocs/op + beego_orm: 0.48s 118994 ns/op 2411 B/op 56 allocs/op + bun: 0.57s 142285 ns/op 918 B/op 12 allocs/op + pg: 0.58s 145496 ns/op 1235 B/op 12 allocs/op + gorm: 0.70s 175294 ns/op 6665 B/op 88 allocs/op + xorm: 0.76s 189533 ns/op 3032 B/op 94 allocs/op + + 4000 times - MultiInsert 100 row + raw: 4.59s 1147385 ns/op 135155 B/op 916 allocs/op + raw_stmt: 4.59s 1148137 ns/op 131076 B/op 916 allocs/op + beego_orm: 5.50s 1375637 ns/op 179962 B/op 2747 allocs/op + bun: 6.18s 1544648 ns/op 4265 B/op 214 allocs/op + pg: 7.01s 1753495 ns/op 5039 B/op 114 allocs/op + gorm: 9.52s 2379219 ns/op 293956 B/op 3729 allocs/op + xorm: 11.66s 2915478 ns/op 286140 B/op 7422 allocs/op + + 4000 times - Update + raw_stmt: 0.26s 65781 ns/op 773 B/op 14 allocs/op + raw: 0.31s 77209 ns/op 757 B/op 13 allocs/op + beego_orm: 0.43s 107064 ns/op 1802 B/op 47 allocs/op + bun: 0.56s 139839 ns/op 589 B/op 4 allocs/op + pg: 0.60s 149608 ns/op 896 B/op 11 allocs/op + gorm: 0.74s 185970 ns/op 6604 B/op 81 allocs/op + xorm: 0.81s 203240 ns/op 2994 B/op 119 allocs/op + + 4000 times - Read + raw: 0.33s 81671 ns/op 2081 B/op 49 allocs/op + raw_stmt: 0.34s 85847 ns/op 2112 B/op 50 allocs/op + beego_orm: 0.38s 94777 ns/op 2106 B/op 75 allocs/op + pg: 0.42s 106148 ns/op 1526 B/op 22 allocs/op + bun: 0.43s 106904 ns/op 1319 B/op 18 allocs/op + gorm: 0.65s 162221 ns/op 5240 B/op 108 allocs/op + xorm: 1.13s 281738 ns/op 8326 B/op 237 allocs/op + + 4000 times - MultiRead limit 100 + raw: 1.52s 380351 ns/op 38356 B/op 1037 allocs/op + raw_stmt: 1.54s 385541 ns/op 38388 B/op 1038 allocs/op + pg: 1.86s 465468 ns/op 24045 B/op 631 allocs/op + bun: 2.58s 645354 ns/op 30009 B/op 1122 allocs/op + beego_orm: 2.93s 732028 ns/op 55280 B/op 3077 allocs/op + gorm: 4.97s 1241831 ns/op 71628 B/op 3877 allocs/op + xorm: doesn't work +``` + +</details> + +## Installation + +```go +go get github.com/uptrace/bun +``` + +You also need to install a database/sql driver and the corresponding Bun +[dialect](https://bun.uptrace.dev/guide/drivers.html). + +## Quickstart + +First you need to create a `sql.DB`. Here we are using the +[sqliteshim](https://pkg.go.dev/github.com/uptrace/bun/driver/sqliteshim) driver which choses +between [modernc.org/sqlite](https://modernc.org/sqlite/) and +[mattn/go-sqlite3](https://github.com/mattn/go-sqlite3) depending on your platform. + +```go +import "github.com/uptrace/bun/driver/sqliteshim" + +sqldb, err := sql.Open(sqliteshim.ShimName, "file::memory:?cache=shared") +if err != nil { + panic(err) +} +``` + +And then create a `bun.DB` on top of it using the corresponding SQLite dialect: + +```go +import ( + "github.com/uptrace/bun" + "github.com/uptrace/bun/dialect/sqlitedialect" +) + +db := bun.NewDB(sqldb, sqlitedialect.New()) +``` + +Now you are ready to issue some queries: + +```go +type User struct { + ID int64 + Name string +} + +user := new(User) +err := db.NewSelect(). + Model(user). + Where("name != ?", ""). + OrderExpr("id ASC"). + Limit(1). + Scan(ctx) +``` + +The code above is equivalent to: + +```go +query := "SELECT id, name FROM users AS user WHERE name != '' ORDER BY id ASC LIMIT 1" + +rows, err := sqldb.QueryContext(ctx, query) +if err != nil { + panic(err) +} + +if !rows.Next() { + panic(sql.ErrNoRows) +} + +user := new(User) +if err := db.ScanRow(ctx, rows, user); err != nil { + panic(err) +} + +if err := rows.Err(); err != nil { + panic(err) +} +``` + +## Basic example + +To provide initial data for our [example](/example/basic/), we will use Bun +[fixtures](https://bun.uptrace.dev/guide/fixtures.html): + +```go +import "github.com/uptrace/bun/dbfixture" + +// Register models for the fixture. +db.RegisterModel((*User)(nil), (*Story)(nil)) + +// WithRecreateTables tells Bun to drop existing tables and create new ones. +fixture := dbfixture.New(db, dbfixture.WithRecreateTables()) + +// Load fixture.yaml which contains data for User and Story models. +if err := fixture.Load(ctx, os.DirFS("."), "fixture.yaml"); err != nil { + panic(err) +} +``` + +The `fixture.yaml` looks like this: + +```yaml +- model: User + rows: + - _id: admin + name: admin + emails: ['admin1@admin', 'admin2@admin'] + - _id: root + name: root + emails: ['root1@root', 'root2@root'] + +- model: Story + rows: + - title: Cool story + author_id: '{{ $.User.admin.ID }}' +``` + +To select all users: + +```go +users := make([]User, 0) +if err := db.NewSelect().Model(&users).OrderExpr("id ASC").Scan(ctx); err != nil { + panic(err) +} +``` + +To select a single user by id: + +```go +user1 := new(User) +if err := db.NewSelect().Model(user1).Where("id = ?", 1).Scan(ctx); err != nil { + panic(err) +} +``` + +To select a story and the associated author in a single query: + +```go +story := new(Story) +if err := db.NewSelect(). + Model(story). + Relation("Author"). + Limit(1). + Scan(ctx); err != nil { + panic(err) +} +``` + +To select a user into a map: + +```go +m := make(map[string]interface{}) +if err := db.NewSelect(). + Model((*User)(nil)). + Limit(1). + Scan(ctx, &m); err != nil { + panic(err) +} +``` + +To select all users scanning each column into a separate slice: + +```go +var ids []int64 +var names []string +if err := db.NewSelect(). + ColumnExpr("id, name"). + Model((*User)(nil)). + OrderExpr("id ASC"). + Scan(ctx, &ids, &names); err != nil { + panic(err) +} +``` + +For more details, please consult [docs](https://bun.uptrace.dev/) and check [examples](example). + +## Contributors + +Thanks to all the people who already contributed! + +<a href="https://github.com/uptrace/bun/graphs/contributors"> + <img src="https://contributors-img.web.app/image?repo=uptrace/bun" /> +</a> diff --git a/vendor/github.com/uptrace/bun/RELEASING.md b/vendor/github.com/uptrace/bun/RELEASING.md new file mode 100644 index 000000000..9e50c1063 --- /dev/null +++ b/vendor/github.com/uptrace/bun/RELEASING.md @@ -0,0 +1,21 @@ +# Releasing + +1. Run `release.sh` script which updates versions in go.mod files and pushes a new branch to GitHub: + +```shell +./scripts/release.sh -t v1.0.0 +``` + +2. Open a pull request and wait for the build to finish. + +3. Merge the pull request and run `tag.sh` to create tags for packages: + +```shell +./scripts/tag.sh -t v1.0.0 +``` + +4. Push the tags: + +```shell +git push origin --tags +``` diff --git a/vendor/github.com/uptrace/bun/bun.go b/vendor/github.com/uptrace/bun/bun.go new file mode 100644 index 000000000..92ebe691a --- /dev/null +++ b/vendor/github.com/uptrace/bun/bun.go @@ -0,0 +1,122 @@ +package bun + +import ( + "context" + "fmt" + "reflect" + + "github.com/uptrace/bun/schema" +) + +type ( + Safe = schema.Safe + Ident = schema.Ident +) + +type NullTime = schema.NullTime + +type BaseModel = schema.BaseModel + +type ( + BeforeScanHook = schema.BeforeScanHook + AfterScanHook = schema.AfterScanHook +) + +type BeforeSelectHook interface { + BeforeSelect(ctx context.Context, query *SelectQuery) error +} + +type AfterSelectHook interface { + AfterSelect(ctx context.Context, query *SelectQuery) error +} + +type BeforeInsertHook interface { + BeforeInsert(ctx context.Context, query *InsertQuery) error +} + +type AfterInsertHook interface { + AfterInsert(ctx context.Context, query *InsertQuery) error +} + +type BeforeUpdateHook interface { + BeforeUpdate(ctx context.Context, query *UpdateQuery) error +} + +type AfterUpdateHook interface { + AfterUpdate(ctx context.Context, query *UpdateQuery) error +} + +type BeforeDeleteHook interface { + BeforeDelete(ctx context.Context, query *DeleteQuery) error +} + +type AfterDeleteHook interface { + AfterDelete(ctx context.Context, query *DeleteQuery) error +} + +type BeforeCreateTableHook interface { + BeforeCreateTable(ctx context.Context, query *CreateTableQuery) error +} + +type AfterCreateTableHook interface { + AfterCreateTable(ctx context.Context, query *CreateTableQuery) error +} + +type BeforeDropTableHook interface { + BeforeDropTable(ctx context.Context, query *DropTableQuery) error +} + +type AfterDropTableHook interface { + AfterDropTable(ctx context.Context, query *DropTableQuery) error +} + +//------------------------------------------------------------------------------ + +type InValues struct { + slice reflect.Value + err error +} + +var _ schema.QueryAppender = InValues{} + +func In(slice interface{}) InValues { + v := reflect.ValueOf(slice) + if v.Kind() != reflect.Slice { + return InValues{ + err: fmt.Errorf("bun: In(non-slice %T)", slice), + } + } + return InValues{ + slice: v, + } +} + +func (in InValues) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) { + if in.err != nil { + return nil, in.err + } + return appendIn(fmter, b, in.slice), nil +} + +func appendIn(fmter schema.Formatter, b []byte, slice reflect.Value) []byte { + sliceLen := slice.Len() + for i := 0; i < sliceLen; i++ { + if i > 0 { + b = append(b, ", "...) + } + + elem := slice.Index(i) + if elem.Kind() == reflect.Interface { + elem = elem.Elem() + } + + if elem.Kind() == reflect.Slice { + b = append(b, '(') + b = appendIn(fmter, b, elem) + b = append(b, ')') + } else { + b = fmter.AppendValue(b, elem) + } + } + return b +} diff --git a/vendor/github.com/uptrace/bun/db.go b/vendor/github.com/uptrace/bun/db.go new file mode 100644 index 000000000..d08adefb5 --- /dev/null +++ b/vendor/github.com/uptrace/bun/db.go @@ -0,0 +1,502 @@ +package bun + +import ( + "context" + "database/sql" + "errors" + "fmt" + "reflect" + "strings" + "sync/atomic" + + "github.com/uptrace/bun/dialect/feature" + "github.com/uptrace/bun/internal" + "github.com/uptrace/bun/schema" +) + +const ( + discardUnknownColumns internal.Flag = 1 << iota +) + +type DBStats struct { + Queries uint64 + Errors uint64 +} + +type DBOption func(db *DB) + +func WithDiscardUnknownColumns() DBOption { + return func(db *DB) { + db.flags = db.flags.Set(discardUnknownColumns) + } +} + +type DB struct { + *sql.DB + dialect schema.Dialect + features feature.Feature + + queryHooks []QueryHook + + fmter schema.Formatter + flags internal.Flag + + stats DBStats +} + +func NewDB(sqldb *sql.DB, dialect schema.Dialect, opts ...DBOption) *DB { + dialect.Init(sqldb) + + db := &DB{ + DB: sqldb, + dialect: dialect, + features: dialect.Features(), + fmter: schema.NewFormatter(dialect), + } + + for _, opt := range opts { + opt(db) + } + + return db +} + +func (db *DB) String() string { + var b strings.Builder + b.WriteString("DB<dialect=") + b.WriteString(db.dialect.Name().String()) + b.WriteString(">") + return b.String() +} + +func (db *DB) DBStats() DBStats { + return DBStats{ + Queries: atomic.LoadUint64(&db.stats.Queries), + Errors: atomic.LoadUint64(&db.stats.Errors), + } +} + +func (db *DB) NewValues(model interface{}) *ValuesQuery { + return NewValuesQuery(db, model) +} + +func (db *DB) NewSelect() *SelectQuery { + return NewSelectQuery(db) +} + +func (db *DB) NewInsert() *InsertQuery { + return NewInsertQuery(db) +} + +func (db *DB) NewUpdate() *UpdateQuery { + return NewUpdateQuery(db) +} + +func (db *DB) NewDelete() *DeleteQuery { + return NewDeleteQuery(db) +} + +func (db *DB) NewCreateTable() *CreateTableQuery { + return NewCreateTableQuery(db) +} + +func (db *DB) NewDropTable() *DropTableQuery { + return NewDropTableQuery(db) +} + +func (db *DB) NewCreateIndex() *CreateIndexQuery { + return NewCreateIndexQuery(db) +} + +func (db *DB) NewDropIndex() *DropIndexQuery { + return NewDropIndexQuery(db) +} + +func (db *DB) NewTruncateTable() *TruncateTableQuery { + return NewTruncateTableQuery(db) +} + +func (db *DB) NewAddColumn() *AddColumnQuery { + return NewAddColumnQuery(db) +} + +func (db *DB) NewDropColumn() *DropColumnQuery { + return NewDropColumnQuery(db) +} + +func (db *DB) ResetModel(ctx context.Context, models ...interface{}) error { + for _, model := range models { + if _, err := db.NewDropTable().Model(model).IfExists().Exec(ctx); err != nil { + return err + } + if _, err := db.NewCreateTable().Model(model).Exec(ctx); err != nil { + return err + } + } + return nil +} + +func (db *DB) Dialect() schema.Dialect { + return db.dialect +} + +func (db *DB) ScanRows(ctx context.Context, rows *sql.Rows, dest ...interface{}) error { + model, err := newModel(db, dest) + if err != nil { + return err + } + + _, err = model.ScanRows(ctx, rows) + return err +} + +func (db *DB) ScanRow(ctx context.Context, rows *sql.Rows, dest ...interface{}) error { + model, err := newModel(db, dest) + if err != nil { + return err + } + + rs, ok := model.(rowScanner) + if !ok { + return fmt.Errorf("bun: %T does not support ScanRow", model) + } + + return rs.ScanRow(ctx, rows) +} + +func (db *DB) AddQueryHook(hook QueryHook) { + db.queryHooks = append(db.queryHooks, hook) +} + +func (db *DB) Table(typ reflect.Type) *schema.Table { + return db.dialect.Tables().Get(typ) +} + +func (db *DB) RegisterModel(models ...interface{}) { + db.dialect.Tables().Register(models...) +} + +func (db *DB) clone() *DB { + clone := *db + + l := len(clone.queryHooks) + clone.queryHooks = clone.queryHooks[:l:l] + + return &clone +} + +func (db *DB) WithNamedArg(name string, value interface{}) *DB { + clone := db.clone() + clone.fmter = clone.fmter.WithNamedArg(name, value) + return clone +} + +func (db *DB) Formatter() schema.Formatter { + return db.fmter +} + +//------------------------------------------------------------------------------ + +func (db *DB) Exec(query string, args ...interface{}) (sql.Result, error) { + return db.ExecContext(context.Background(), query, args...) +} + +func (db *DB) ExecContext( + ctx context.Context, query string, args ...interface{}, +) (sql.Result, error) { + ctx, event := db.beforeQuery(ctx, nil, query, args) + res, err := db.DB.ExecContext(ctx, db.format(query, args)) + db.afterQuery(ctx, event, res, err) + return res, err +} + +func (db *DB) Query(query string, args ...interface{}) (*sql.Rows, error) { + return db.QueryContext(context.Background(), query, args...) +} + +func (db *DB) QueryContext( + ctx context.Context, query string, args ...interface{}, +) (*sql.Rows, error) { + ctx, event := db.beforeQuery(ctx, nil, query, args) + rows, err := db.DB.QueryContext(ctx, db.format(query, args)) + db.afterQuery(ctx, event, nil, err) + return rows, err +} + +func (db *DB) QueryRow(query string, args ...interface{}) *sql.Row { + return db.QueryRowContext(context.Background(), query, args...) +} + +func (db *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { + ctx, event := db.beforeQuery(ctx, nil, query, args) + row := db.DB.QueryRowContext(ctx, db.format(query, args)) + db.afterQuery(ctx, event, nil, row.Err()) + return row +} + +func (db *DB) format(query string, args []interface{}) string { + return db.fmter.FormatQuery(query, args...) +} + +//------------------------------------------------------------------------------ + +type Conn struct { + db *DB + *sql.Conn +} + +func (db *DB) Conn(ctx context.Context) (Conn, error) { + conn, err := db.DB.Conn(ctx) + if err != nil { + return Conn{}, err + } + return Conn{ + db: db, + Conn: conn, + }, nil +} + +func (c Conn) ExecContext( + ctx context.Context, query string, args ...interface{}, +) (sql.Result, error) { + ctx, event := c.db.beforeQuery(ctx, nil, query, args) + res, err := c.Conn.ExecContext(ctx, c.db.format(query, args)) + c.db.afterQuery(ctx, event, res, err) + return res, err +} + +func (c Conn) QueryContext( + ctx context.Context, query string, args ...interface{}, +) (*sql.Rows, error) { + ctx, event := c.db.beforeQuery(ctx, nil, query, args) + rows, err := c.Conn.QueryContext(ctx, c.db.format(query, args)) + c.db.afterQuery(ctx, event, nil, err) + return rows, err +} + +func (c Conn) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { + ctx, event := c.db.beforeQuery(ctx, nil, query, args) + row := c.Conn.QueryRowContext(ctx, c.db.format(query, args)) + c.db.afterQuery(ctx, event, nil, row.Err()) + return row +} + +func (c Conn) NewValues(model interface{}) *ValuesQuery { + return NewValuesQuery(c.db, model).Conn(c) +} + +func (c Conn) NewSelect() *SelectQuery { + return NewSelectQuery(c.db).Conn(c) +} + +func (c Conn) NewInsert() *InsertQuery { + return NewInsertQuery(c.db).Conn(c) +} + +func (c Conn) NewUpdate() *UpdateQuery { + return NewUpdateQuery(c.db).Conn(c) +} + +func (c Conn) NewDelete() *DeleteQuery { + return NewDeleteQuery(c.db).Conn(c) +} + +func (c Conn) NewCreateTable() *CreateTableQuery { + return NewCreateTableQuery(c.db).Conn(c) +} + +func (c Conn) NewDropTable() *DropTableQuery { + return NewDropTableQuery(c.db).Conn(c) +} + +func (c Conn) NewCreateIndex() *CreateIndexQuery { + return NewCreateIndexQuery(c.db).Conn(c) +} + +func (c Conn) NewDropIndex() *DropIndexQuery { + return NewDropIndexQuery(c.db).Conn(c) +} + +func (c Conn) NewTruncateTable() *TruncateTableQuery { + return NewTruncateTableQuery(c.db).Conn(c) +} + +func (c Conn) NewAddColumn() *AddColumnQuery { + return NewAddColumnQuery(c.db).Conn(c) +} + +func (c Conn) NewDropColumn() *DropColumnQuery { + return NewDropColumnQuery(c.db).Conn(c) +} + +//------------------------------------------------------------------------------ + +type Stmt struct { + *sql.Stmt +} + +func (db *DB) Prepare(query string) (Stmt, error) { + return db.PrepareContext(context.Background(), query) +} + +func (db *DB) PrepareContext(ctx context.Context, query string) (Stmt, error) { + stmt, err := db.DB.PrepareContext(ctx, query) + if err != nil { + return Stmt{}, err + } + return Stmt{Stmt: stmt}, nil +} + +//------------------------------------------------------------------------------ + +type Tx struct { + db *DB + *sql.Tx +} + +// RunInTx runs the function in a transaction. If the function returns an error, +// the transaction is rolled back. Otherwise, the transaction is committed. +func (db *DB) RunInTx( + ctx context.Context, opts *sql.TxOptions, fn func(ctx context.Context, tx Tx) error, +) error { + tx, err := db.BeginTx(ctx, opts) + if err != nil { + return err + } + defer tx.Rollback() //nolint:errcheck + + if err := fn(ctx, tx); err != nil { + return err + } + return tx.Commit() +} + +func (db *DB) Begin() (Tx, error) { + return db.BeginTx(context.Background(), nil) +} + +func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error) { + tx, err := db.DB.BeginTx(ctx, opts) + if err != nil { + return Tx{}, err + } + return Tx{ + db: db, + Tx: tx, + }, nil +} + +func (tx Tx) Exec(query string, args ...interface{}) (sql.Result, error) { + return tx.ExecContext(context.TODO(), query, args...) +} + +func (tx Tx) ExecContext( + ctx context.Context, query string, args ...interface{}, +) (sql.Result, error) { + ctx, event := tx.db.beforeQuery(ctx, nil, query, args) + res, err := tx.Tx.ExecContext(ctx, tx.db.format(query, args)) + tx.db.afterQuery(ctx, event, res, err) + return res, err +} + +func (tx Tx) Query(query string, args ...interface{}) (*sql.Rows, error) { + return tx.QueryContext(context.TODO(), query, args...) +} + +func (tx Tx) QueryContext( + ctx context.Context, query string, args ...interface{}, +) (*sql.Rows, error) { + ctx, event := tx.db.beforeQuery(ctx, nil, query, args) + rows, err := tx.Tx.QueryContext(ctx, tx.db.format(query, args)) + tx.db.afterQuery(ctx, event, nil, err) + return rows, err +} + +func (tx Tx) QueryRow(query string, args ...interface{}) *sql.Row { + return tx.QueryRowContext(context.TODO(), query, args...) +} + +func (tx Tx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { + ctx, event := tx.db.beforeQuery(ctx, nil, query, args) + row := tx.Tx.QueryRowContext(ctx, tx.db.format(query, args)) + tx.db.afterQuery(ctx, event, nil, row.Err()) + return row +} + +//------------------------------------------------------------------------------ + +func (tx Tx) NewValues(model interface{}) *ValuesQuery { + return NewValuesQuery(tx.db, model).Conn(tx) +} + +func (tx Tx) NewSelect() *SelectQuery { + return NewSelectQuery(tx.db).Conn(tx) +} + +func (tx Tx) NewInsert() *InsertQuery { + return NewInsertQuery(tx.db).Conn(tx) +} + +func (tx Tx) NewUpdate() *UpdateQuery { + return NewUpdateQuery(tx.db).Conn(tx) +} + +func (tx Tx) NewDelete() *DeleteQuery { + return NewDeleteQuery(tx.db).Conn(tx) +} + +func (tx Tx) NewCreateTable() *CreateTableQuery { + return NewCreateTableQuery(tx.db).Conn(tx) +} + +func (tx Tx) NewDropTable() *DropTableQuery { + return NewDropTableQuery(tx.db).Conn(tx) +} + +func (tx Tx) NewCreateIndex() *CreateIndexQuery { + return NewCreateIndexQuery(tx.db).Conn(tx) +} + +func (tx Tx) NewDropIndex() *DropIndexQuery { + return NewDropIndexQuery(tx.db).Conn(tx) +} + +func (tx Tx) NewTruncateTable() *TruncateTableQuery { + return NewTruncateTableQuery(tx.db).Conn(tx) +} + +func (tx Tx) NewAddColumn() *AddColumnQuery { + return NewAddColumnQuery(tx.db).Conn(tx) +} + +func (tx Tx) NewDropColumn() *DropColumnQuery { + return NewDropColumnQuery(tx.db).Conn(tx) +} + +//------------------------------------------------------------------------------0 + +func (db *DB) makeQueryBytes() []byte { + // TODO: make this configurable? + return make([]byte, 0, 4096) +} + +//------------------------------------------------------------------------------ + +type result struct { + r sql.Result + n int +} + +func (r result) RowsAffected() (int64, error) { + if r.r != nil { + return r.r.RowsAffected() + } + return int64(r.n), nil +} + +func (r result) LastInsertId() (int64, error) { + if r.r != nil { + return r.r.LastInsertId() + } + return 0, errors.New("LastInsertId is not available") +} diff --git a/vendor/github.com/uptrace/bun/dialect/append.go b/vendor/github.com/uptrace/bun/dialect/append.go new file mode 100644 index 000000000..7040c5155 --- /dev/null +++ b/vendor/github.com/uptrace/bun/dialect/append.go @@ -0,0 +1,178 @@ +package dialect + +import ( + "encoding/hex" + "math" + "strconv" + "time" + "unicode/utf8" + + "github.com/uptrace/bun/internal" + "github.com/uptrace/bun/internal/parser" +) + +func AppendError(b []byte, err error) []byte { + b = append(b, "?!("...) + b = append(b, err.Error()...) + b = append(b, ')') + return b +} + +func AppendNull(b []byte) []byte { + return append(b, "NULL"...) +} + +func AppendBool(b []byte, v bool) []byte { + if v { + return append(b, "TRUE"...) + } + return append(b, "FALSE"...) +} + +func AppendFloat32(b []byte, v float32) []byte { + return appendFloat(b, float64(v), 32) +} + +func AppendFloat64(b []byte, v float64) []byte { + return appendFloat(b, v, 64) +} + +func appendFloat(b []byte, v float64, bitSize int) []byte { + switch { + case math.IsNaN(v): + return append(b, "'NaN'"...) + case math.IsInf(v, 1): + return append(b, "'Infinity'"...) + case math.IsInf(v, -1): + return append(b, "'-Infinity'"...) + default: + return strconv.AppendFloat(b, v, 'f', -1, bitSize) + } +} + +func AppendString(b []byte, s string) []byte { + b = append(b, '\'') + for _, r := range s { + if r == '\000' { + continue + } + + if r == '\'' { + b = append(b, '\'', '\'') + continue + } + + if r < utf8.RuneSelf { + b = append(b, byte(r)) + continue + } + + l := len(b) + if cap(b)-l < utf8.UTFMax { + b = append(b, make([]byte, utf8.UTFMax)...) + } + n := utf8.EncodeRune(b[l:l+utf8.UTFMax], r) + b = b[:l+n] + } + b = append(b, '\'') + return b +} + +func AppendBytes(b []byte, bytes []byte) []byte { + if bytes == nil { + return AppendNull(b) + } + + b = append(b, `'\x`...) + + s := len(b) + b = append(b, make([]byte, hex.EncodedLen(len(bytes)))...) + hex.Encode(b[s:], bytes) + + b = append(b, '\'') + + return b +} + +func AppendTime(b []byte, tm time.Time) []byte { + if tm.IsZero() { + return AppendNull(b) + } + b = append(b, '\'') + b = tm.UTC().AppendFormat(b, "2006-01-02 15:04:05.999999-07:00") + b = append(b, '\'') + return b +} + +func AppendJSON(b, jsonb []byte) []byte { + b = append(b, '\'') + + p := parser.New(jsonb) + for p.Valid() { + c := p.Read() + switch c { + case '"': + b = append(b, '"') + case '\'': + b = append(b, "''"...) + case '\000': + continue + case '\\': + if p.SkipBytes([]byte("u0000")) { + b = append(b, `\\u0000`...) + } else { + b = append(b, '\\') + if p.Valid() { + b = append(b, p.Read()) + } + } + default: + b = append(b, c) + } + } + + b = append(b, '\'') + + return b +} + +//------------------------------------------------------------------------------ + +func AppendIdent(b []byte, field string, quote byte) []byte { + return appendIdent(b, internal.Bytes(field), quote) +} + +func appendIdent(b, src []byte, quote byte) []byte { + var quoted bool +loop: + for _, c := range src { + switch c { + case '*': + if !quoted { + b = append(b, '*') + continue loop + } + case '.': + if quoted { + b = append(b, quote) + quoted = false + } + b = append(b, '.') + continue loop + } + + if !quoted { + b = append(b, quote) + quoted = true + } + if c == quote { + b = append(b, quote, quote) + } else { + b = append(b, c) + } + } + if quoted { + b = append(b, quote) + } + return b +} diff --git a/vendor/github.com/uptrace/bun/dialect/dialect.go b/vendor/github.com/uptrace/bun/dialect/dialect.go new file mode 100644 index 000000000..9ff8b2461 --- /dev/null +++ b/vendor/github.com/uptrace/bun/dialect/dialect.go @@ -0,0 +1,26 @@ +package dialect + +type Name int + +func (n Name) String() string { + switch n { + case PG: + return "pg" + case SQLite: + return "sqlite" + case MySQL5: + return "mysql5" + case MySQL8: + return "mysql8" + default: + return "invalid" + } +} + +const ( + Invalid Name = iota + PG + SQLite + MySQL5 + MySQL8 +) diff --git a/vendor/github.com/uptrace/bun/dialect/feature/feature.go b/vendor/github.com/uptrace/bun/dialect/feature/feature.go new file mode 100644 index 000000000..ff8f1d625 --- /dev/null +++ b/vendor/github.com/uptrace/bun/dialect/feature/feature.go @@ -0,0 +1,22 @@ +package feature + +import "github.com/uptrace/bun/internal" + +type Feature = internal.Flag + +const DefaultFeatures = Returning | TableCascade + +const ( + Returning Feature = 1 << iota + DefaultPlaceholder + DoubleColonCast + ValuesRow + UpdateMultiTable + InsertTableAlias + DeleteTableAlias + AutoIncrement + TableCascade + TableIdentity + TableTruncate + OnDuplicateKey +) diff --git a/vendor/github.com/uptrace/bun/dialect/pgdialect/LICENSE b/vendor/github.com/uptrace/bun/dialect/pgdialect/LICENSE new file mode 100644 index 000000000..7ec81810c --- /dev/null +++ b/vendor/github.com/uptrace/bun/dialect/pgdialect/LICENSE @@ -0,0 +1,24 @@ +Copyright (c) 2021 Vladimir Mihailenco. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/github.com/uptrace/bun/dialect/pgdialect/append.go b/vendor/github.com/uptrace/bun/dialect/pgdialect/append.go new file mode 100644 index 000000000..475621197 --- /dev/null +++ b/vendor/github.com/uptrace/bun/dialect/pgdialect/append.go @@ -0,0 +1,303 @@ +package pgdialect + +import ( + "database/sql/driver" + "fmt" + "reflect" + "strconv" + "time" + "unicode/utf8" + + "github.com/uptrace/bun/dialect" + "github.com/uptrace/bun/schema" +) + +var ( + driverValuerType = reflect.TypeOf((*driver.Valuer)(nil)).Elem() + + stringType = reflect.TypeOf((*string)(nil)).Elem() + sliceStringType = reflect.TypeOf([]string(nil)) + + intType = reflect.TypeOf((*int)(nil)).Elem() + sliceIntType = reflect.TypeOf([]int(nil)) + + int64Type = reflect.TypeOf((*int64)(nil)).Elem() + sliceInt64Type = reflect.TypeOf([]int64(nil)) + + float64Type = reflect.TypeOf((*float64)(nil)).Elem() + sliceFloat64Type = reflect.TypeOf([]float64(nil)) +) + +func customAppender(typ reflect.Type) schema.AppenderFunc { + switch typ.Kind() { + case reflect.Uint32: + return appendUint32ValueAsInt + case reflect.Uint, reflect.Uint64: + return appendUint64ValueAsInt + } + return nil +} + +func appendUint32ValueAsInt(fmter schema.Formatter, b []byte, v reflect.Value) []byte { + return strconv.AppendInt(b, int64(int32(v.Uint())), 10) +} + +func appendUint64ValueAsInt(fmter schema.Formatter, b []byte, v reflect.Value) []byte { + return strconv.AppendInt(b, int64(v.Uint()), 10) +} + +//------------------------------------------------------------------------------ + +func arrayAppend(fmter schema.Formatter, b []byte, v interface{}) []byte { + switch v := v.(type) { + case int64: + return strconv.AppendInt(b, v, 10) + case float64: + return dialect.AppendFloat64(b, v) + case bool: + return dialect.AppendBool(b, v) + case []byte: + return dialect.AppendBytes(b, v) + case string: + return arrayAppendString(b, v) + case time.Time: + return dialect.AppendTime(b, v) + default: + err := fmt.Errorf("pgdialect: can't append %T", v) + return dialect.AppendError(b, err) + } +} + +func arrayElemAppender(typ reflect.Type) schema.AppenderFunc { + if typ.Kind() == reflect.String { + return arrayAppendStringValue + } + if typ.Implements(driverValuerType) { + return arrayAppendDriverValue + } + return schema.Appender(typ, customAppender) +} + +func arrayAppendStringValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte { + return arrayAppendString(b, v.String()) +} + +func arrayAppendDriverValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte { + iface, err := v.Interface().(driver.Valuer).Value() + if err != nil { + return dialect.AppendError(b, err) + } + return arrayAppend(fmter, b, iface) +} + +//------------------------------------------------------------------------------ + +func arrayAppender(typ reflect.Type) schema.AppenderFunc { + kind := typ.Kind() + if kind == reflect.Ptr { + typ = typ.Elem() + kind = typ.Kind() + } + + switch kind { + case reflect.Slice, reflect.Array: + // ok: + default: + return nil + } + + elemType := typ.Elem() + + if kind == reflect.Slice { + switch elemType { + case stringType: + return appendStringSliceValue + case intType: + return appendIntSliceValue + case int64Type: + return appendInt64SliceValue + case float64Type: + return appendFloat64SliceValue + } + } + + appendElem := arrayElemAppender(elemType) + if appendElem == nil { + panic(fmt.Errorf("pgdialect: %s is not supported", typ)) + } + + return func(fmter schema.Formatter, b []byte, v reflect.Value) []byte { + kind := v.Kind() + switch kind { + case reflect.Ptr, reflect.Slice: + if v.IsNil() { + return dialect.AppendNull(b) + } + } + + if kind == reflect.Ptr { + v = v.Elem() + } + + b = append(b, '\'') + + b = append(b, '{') + for i := 0; i < v.Len(); i++ { + elem := v.Index(i) + b = appendElem(fmter, b, elem) + b = append(b, ',') + } + if v.Len() > 0 { + b[len(b)-1] = '}' // Replace trailing comma. + } else { + b = append(b, '}') + } + + b = append(b, '\'') + + return b + } +} + +func appendStringSliceValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte { + ss := v.Convert(sliceStringType).Interface().([]string) + return appendStringSlice(b, ss) +} + +func appendStringSlice(b []byte, ss []string) []byte { + if ss == nil { + return dialect.AppendNull(b) + } + + b = append(b, '\'') + + b = append(b, '{') + for _, s := range ss { + b = arrayAppendString(b, s) + b = append(b, ',') + } + if len(ss) > 0 { + b[len(b)-1] = '}' // Replace trailing comma. + } else { + b = append(b, '}') + } + + b = append(b, '\'') + + return b +} + +func appendIntSliceValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte { + ints := v.Convert(sliceIntType).Interface().([]int) + return appendIntSlice(b, ints) +} + +func appendIntSlice(b []byte, ints []int) []byte { + if ints == nil { + return dialect.AppendNull(b) + } + + b = append(b, '\'') + + b = append(b, '{') + for _, n := range ints { + b = strconv.AppendInt(b, int64(n), 10) + b = append(b, ',') + } + if len(ints) > 0 { + b[len(b)-1] = '}' // Replace trailing comma. + } else { + b = append(b, '}') + } + + b = append(b, '\'') + + return b +} + +func appendInt64SliceValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte { + ints := v.Convert(sliceInt64Type).Interface().([]int64) + return appendInt64Slice(b, ints) +} + +func appendInt64Slice(b []byte, ints []int64) []byte { + if ints == nil { + return dialect.AppendNull(b) + } + + b = append(b, '\'') + + b = append(b, '{') + for _, n := range ints { + b = strconv.AppendInt(b, n, 10) + b = append(b, ',') + } + if len(ints) > 0 { + b[len(b)-1] = '}' // Replace trailing comma. + } else { + b = append(b, '}') + } + + b = append(b, '\'') + + return b +} + +func appendFloat64SliceValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte { + floats := v.Convert(sliceFloat64Type).Interface().([]float64) + return appendFloat64Slice(b, floats) +} + +func appendFloat64Slice(b []byte, floats []float64) []byte { + if floats == nil { + return dialect.AppendNull(b) + } + + b = append(b, '\'') + + b = append(b, '{') + for _, n := range floats { + b = dialect.AppendFloat64(b, n) + b = append(b, ',') + } + if len(floats) > 0 { + b[len(b)-1] = '}' // Replace trailing comma. + } else { + b = append(b, '}') + } + + b = append(b, '\'') + + return b +} + +//------------------------------------------------------------------------------ + +func arrayAppendString(b []byte, s string) []byte { + b = append(b, '"') + for _, r := range s { + switch r { + case 0: + // ignore + case '\'': + b = append(b, "'''"...) + case '"': + b = append(b, '\\', '"') + case '\\': + b = append(b, '\\', '\\') + default: + if r < utf8.RuneSelf { + b = append(b, byte(r)) + break + } + l := len(b) + if cap(b)-l < utf8.UTFMax { + b = append(b, make([]byte, utf8.UTFMax)...) + } + n := utf8.EncodeRune(b[l:l+utf8.UTFMax], r) + b = b[:l+n] + } + } + b = append(b, '"') + return b +} diff --git a/vendor/github.com/uptrace/bun/dialect/pgdialect/array.go b/vendor/github.com/uptrace/bun/dialect/pgdialect/array.go new file mode 100644 index 000000000..57f5a4384 --- /dev/null +++ b/vendor/github.com/uptrace/bun/dialect/pgdialect/array.go @@ -0,0 +1,65 @@ +package pgdialect + +import ( + "database/sql" + "fmt" + "reflect" + + "github.com/uptrace/bun/schema" +) + +type ArrayValue struct { + v reflect.Value + + append schema.AppenderFunc + scan schema.ScannerFunc +} + +// Array accepts a slice and returns a wrapper for working with PostgreSQL +// array data type. +// +// For struct fields you can use array tag: +// +// Emails []string `bun:",array"` +func Array(vi interface{}) *ArrayValue { + v := reflect.ValueOf(vi) + if !v.IsValid() { + panic(fmt.Errorf("bun: Array(nil)")) + } + + return &ArrayValue{ + v: v, + + append: arrayAppender(v.Type()), + scan: arrayScanner(v.Type()), + } +} + +var ( + _ schema.QueryAppender = (*ArrayValue)(nil) + _ sql.Scanner = (*ArrayValue)(nil) +) + +func (a *ArrayValue) AppendQuery(fmter schema.Formatter, b []byte) ([]byte, error) { + if a.append == nil { + panic(fmt.Errorf("bun: Array(unsupported %s)", a.v.Type())) + } + return a.append(fmter, b, a.v), nil +} + +func (a *ArrayValue) Scan(src interface{}) error { + if a.scan == nil { + return fmt.Errorf("bun: Array(unsupported %s)", a.v.Type()) + } + if a.v.Kind() != reflect.Ptr { + return fmt.Errorf("bun: Array(non-pointer %s)", a.v.Type()) + } + return a.scan(a.v, src) +} + +func (a *ArrayValue) Value() interface{} { + if a.v.IsValid() { + return a.v.Interface() + } + return nil +} diff --git a/vendor/github.com/uptrace/bun/dialect/pgdialect/array_parser.go b/vendor/github.com/uptrace/bun/dialect/pgdialect/array_parser.go new file mode 100644 index 000000000..1c927fca0 --- /dev/null +++ b/vendor/github.com/uptrace/bun/dialect/pgdialect/array_parser.go @@ -0,0 +1,146 @@ +package pgdialect + +import ( + "bytes" + "fmt" + "io" +) + +type arrayParser struct { + b []byte + i int + + buf []byte + err error +} + +func newArrayParser(b []byte) *arrayParser { + p := &arrayParser{ + b: b, + i: 1, + } + if len(b) < 2 || b[0] != '{' || b[len(b)-1] != '}' { + p.err = fmt.Errorf("bun: can't parse array: %q", b) + } + return p +} + +func (p *arrayParser) NextElem() ([]byte, error) { + if p.err != nil { + return nil, p.err + } + + c, err := p.readByte() + if err != nil { + return nil, err + } + + switch c { + case '}': + return nil, io.EOF + case '"': + b, err := p.readSubstring() + if err != nil { + return nil, err + } + + if p.peek() == ',' { + p.skipNext() + } + + return b, nil + default: + b := p.readSimple() + if bytes.Equal(b, []byte("NULL")) { + b = nil + } + + if p.peek() == ',' { + p.skipNext() + } + + return b, nil + } +} + +func (p *arrayParser) readSimple() []byte { + p.unreadByte() + + if i := bytes.IndexByte(p.b[p.i:], ','); i >= 0 { + b := p.b[p.i : p.i+i] + p.i += i + return b + } + + b := p.b[p.i : len(p.b)-1] + p.i = len(p.b) - 1 + return b +} + +func (p *arrayParser) readSubstring() ([]byte, error) { + c, err := p.readByte() + if err != nil { + return nil, err + } + + p.buf = p.buf[:0] + for { + if c == '"' { + break + } + + next, err := p.readByte() + if err != nil { + return nil, err + } + + if c == '\\' { + switch next { + case '\\', '"': + p.buf = append(p.buf, next) + + c, err = p.readByte() + if err != nil { + return nil, err + } + default: + p.buf = append(p.buf, '\\') + c = next + } + continue + } + + p.buf = append(p.buf, c) + c = next + } + + return p.buf, nil +} + +func (p *arrayParser) valid() bool { + return p.i < len(p.b) +} + +func (p *arrayParser) readByte() (byte, error) { + if p.valid() { + c := p.b[p.i] + p.i++ + return c, nil + } + return 0, io.EOF +} + +func (p *arrayParser) unreadByte() { + p.i-- +} + +func (p *arrayParser) peek() byte { + if p.valid() { + return p.b[p.i] + } + return 0 +} + +func (p *arrayParser) skipNext() { + p.i++ +} diff --git a/vendor/github.com/uptrace/bun/dialect/pgdialect/array_scan.go b/vendor/github.com/uptrace/bun/dialect/pgdialect/array_scan.go new file mode 100644 index 000000000..33d31f325 --- /dev/null +++ b/vendor/github.com/uptrace/bun/dialect/pgdialect/array_scan.go @@ -0,0 +1,302 @@ +package pgdialect + +import ( + "fmt" + "io" + "reflect" + "strconv" + + "github.com/uptrace/bun/internal" + "github.com/uptrace/bun/schema" +) + +func arrayScanner(typ reflect.Type) schema.ScannerFunc { + kind := typ.Kind() + if kind == reflect.Ptr { + typ = typ.Elem() + kind = typ.Kind() + } + + switch kind { + case reflect.Slice, reflect.Array: + // ok: + default: + return nil + } + + elemType := typ.Elem() + + if kind == reflect.Slice { + switch elemType { + case stringType: + return scanStringSliceValue + case intType: + return scanIntSliceValue + case int64Type: + return scanInt64SliceValue + case float64Type: + return scanFloat64SliceValue + } + } + + scanElem := schema.Scanner(elemType) + return func(dest reflect.Value, src interface{}) error { + dest = reflect.Indirect(dest) + if !dest.CanSet() { + return fmt.Errorf("bun: Scan(non-settable %s)", dest.Type()) + } + + kind := dest.Kind() + + if src == nil { + if kind != reflect.Slice || !dest.IsNil() { + dest.Set(reflect.Zero(dest.Type())) + } + return nil + } + + if kind == reflect.Slice { + if dest.IsNil() { + dest.Set(reflect.MakeSlice(dest.Type(), 0, 0)) + } else if dest.Len() > 0 { + dest.Set(dest.Slice(0, 0)) + } + } + + b, err := toBytes(src) + if err != nil { + return err + } + + p := newArrayParser(b) + nextValue := internal.MakeSliceNextElemFunc(dest) + for { + elem, err := p.NextElem() + if err != nil { + if err == io.EOF { + break + } + return err + } + + elemValue := nextValue() + if err := scanElem(elemValue, elem); err != nil { + return err + } + } + + return nil + } +} + +func scanStringSliceValue(dest reflect.Value, src interface{}) error { + dest = reflect.Indirect(dest) + if !dest.CanSet() { + return fmt.Errorf("bun: Scan(non-settable %s)", dest.Type()) + } + + slice, err := decodeStringSlice(src) + if err != nil { + return err + } + + dest.Set(reflect.ValueOf(slice)) + return nil +} + +func decodeStringSlice(src interface{}) ([]string, error) { + if src == nil { + return nil, nil + } + + b, err := toBytes(src) + if err != nil { + return nil, err + } + + slice := make([]string, 0) + + p := newArrayParser(b) + for { + elem, err := p.NextElem() + if err != nil { + if err == io.EOF { + break + } + return nil, err + } + slice = append(slice, string(elem)) + } + + return slice, nil +} + +func scanIntSliceValue(dest reflect.Value, src interface{}) error { + dest = reflect.Indirect(dest) + if !dest.CanSet() { + return fmt.Errorf("bun: Scan(non-settable %s)", dest.Type()) + } + + slice, err := decodeIntSlice(src) + if err != nil { + return err + } + + dest.Set(reflect.ValueOf(slice)) + return nil +} + +func decodeIntSlice(src interface{}) ([]int, error) { + if src == nil { + return nil, nil + } + + b, err := toBytes(src) + if err != nil { + return nil, err + } + + slice := make([]int, 0) + + p := newArrayParser(b) + for { + elem, err := p.NextElem() + if err != nil { + if err == io.EOF { + break + } + return nil, err + } + + if elem == nil { + slice = append(slice, 0) + continue + } + + n, err := strconv.Atoi(bytesToString(elem)) + if err != nil { + return nil, err + } + + slice = append(slice, n) + } + + return slice, nil +} + +func scanInt64SliceValue(dest reflect.Value, src interface{}) error { + dest = reflect.Indirect(dest) + if !dest.CanSet() { + return fmt.Errorf("bun: Scan(non-settable %s)", dest.Type()) + } + + slice, err := decodeInt64Slice(src) + if err != nil { + return err + } + + dest.Set(reflect.ValueOf(slice)) + return nil +} + +func decodeInt64Slice(src interface{}) ([]int64, error) { + if src == nil { + return nil, nil + } + + b, err := toBytes(src) + if err != nil { + return nil, err + } + + slice := make([]int64, 0) + + p := newArrayParser(b) + for { + elem, err := p.NextElem() + if err != nil { + if err == io.EOF { + break + } + return nil, err + } + + if elem == nil { + slice = append(slice, 0) + continue + } + + n, err := strconv.ParseInt(bytesToString(elem), 10, 64) + if err != nil { + return nil, err + } + + slice = append(slice, n) + } + + return slice, nil +} + +func scanFloat64SliceValue(dest reflect.Value, src interface{}) error { + dest = reflect.Indirect(dest) + if !dest.CanSet() { + return fmt.Errorf("bun: Scan(non-settable %s)", dest.Type()) + } + + slice, err := scanFloat64Slice(src) + if err != nil { + return err + } + + dest.Set(reflect.ValueOf(slice)) + return nil +} + +func scanFloat64Slice(src interface{}) ([]float64, error) { + if src == -1 { + return nil, nil + } + + b, err := toBytes(src) + if err != nil { + return nil, err + } + + slice := make([]float64, 0) + + p := newArrayParser(b) + for { + elem, err := p.NextElem() + if err != nil { + if err == io.EOF { + break + } + return nil, err + } + + if elem == nil { + slice = append(slice, 0) + continue + } + + n, err := strconv.ParseFloat(bytesToString(elem), 64) + if err != nil { + return nil, err + } + + slice = append(slice, n) + } + + return slice, nil +} + +func toBytes(src interface{}) ([]byte, error) { + switch src := src.(type) { + case string: + return stringToBytes(src), nil + case []byte: + return src, nil + default: + return nil, fmt.Errorf("bun: got %T, wanted []byte or string", src) + } +} diff --git a/vendor/github.com/uptrace/bun/dialect/pgdialect/dialect.go b/vendor/github.com/uptrace/bun/dialect/pgdialect/dialect.go new file mode 100644 index 000000000..fb210751b --- /dev/null +++ b/vendor/github.com/uptrace/bun/dialect/pgdialect/dialect.go @@ -0,0 +1,150 @@ +package pgdialect + +import ( + "database/sql" + "reflect" + "strconv" + "sync" + "time" + + "github.com/uptrace/bun/dialect" + "github.com/uptrace/bun/dialect/feature" + "github.com/uptrace/bun/dialect/sqltype" + "github.com/uptrace/bun/schema" +) + +type Dialect struct { + tables *schema.Tables + features feature.Feature + + appenderMap sync.Map + scannerMap sync.Map +} + +func New() *Dialect { + d := new(Dialect) + d.tables = schema.NewTables(d) + d.features = feature.Returning | + feature.DefaultPlaceholder | + feature.DoubleColonCast | + feature.InsertTableAlias | + feature.DeleteTableAlias | + feature.TableCascade | + feature.TableIdentity | + feature.TableTruncate + return d +} + +func (d *Dialect) Init(*sql.DB) {} + +func (d *Dialect) Name() dialect.Name { + return dialect.PG +} + +func (d *Dialect) Features() feature.Feature { + return d.features +} + +func (d *Dialect) Tables() *schema.Tables { + return d.tables +} + +func (d *Dialect) OnTable(table *schema.Table) { + for _, field := range table.FieldMap { + d.onField(field) + } +} + +func (d *Dialect) onField(field *schema.Field) { + field.DiscoveredSQLType = fieldSQLType(field) + + if field.AutoIncrement { + switch field.DiscoveredSQLType { + case sqltype.SmallInt: + field.CreateTableSQLType = pgTypeSmallSerial + case sqltype.Integer: + field.CreateTableSQLType = pgTypeSerial + case sqltype.BigInt: + field.CreateTableSQLType = pgTypeBigSerial + } + } + + if field.Tag.HasOption("array") { + field.Append = arrayAppender(field.IndirectType) + field.Scan = arrayScanner(field.IndirectType) + } +} + +func (d *Dialect) IdentQuote() byte { + return '"' +} + +func (d *Dialect) Append(fmter schema.Formatter, b []byte, v interface{}) []byte { + switch v := v.(type) { + case nil: + return dialect.AppendNull(b) + case bool: + return dialect.AppendBool(b, v) + case int: + return strconv.AppendInt(b, int64(v), 10) + case int32: + return strconv.AppendInt(b, int64(v), 10) + case int64: + return strconv.AppendInt(b, v, 10) + case uint: + return strconv.AppendInt(b, int64(v), 10) + case uint32: + return strconv.AppendInt(b, int64(v), 10) + case uint64: + return strconv.AppendInt(b, int64(v), 10) + case float32: + return dialect.AppendFloat32(b, v) + case float64: + return dialect.AppendFloat64(b, v) + case string: + return dialect.AppendString(b, v) + case time.Time: + return dialect.AppendTime(b, v) + case []byte: + return dialect.AppendBytes(b, v) + case schema.QueryAppender: + return schema.AppendQueryAppender(fmter, b, v) + default: + vv := reflect.ValueOf(v) + if vv.Kind() == reflect.Ptr && vv.IsNil() { + return dialect.AppendNull(b) + } + appender := d.Appender(vv.Type()) + return appender(fmter, b, vv) + } +} + +func (d *Dialect) Appender(typ reflect.Type) schema.AppenderFunc { + if v, ok := d.appenderMap.Load(typ); ok { + return v.(schema.AppenderFunc) + } + + fn := schema.Appender(typ, customAppender) + + if v, ok := d.appenderMap.LoadOrStore(typ, fn); ok { + return v.(schema.AppenderFunc) + } + return fn +} + +func (d *Dialect) FieldAppender(field *schema.Field) schema.AppenderFunc { + return schema.FieldAppender(d, field) +} + +func (d *Dialect) Scanner(typ reflect.Type) schema.ScannerFunc { + if v, ok := d.scannerMap.Load(typ); ok { + return v.(schema.ScannerFunc) + } + + fn := scanner(typ) + + if v, ok := d.scannerMap.LoadOrStore(typ, fn); ok { + return v.(schema.ScannerFunc) + } + return fn +} diff --git a/vendor/github.com/uptrace/bun/dialect/pgdialect/go.mod b/vendor/github.com/uptrace/bun/dialect/pgdialect/go.mod new file mode 100644 index 000000000..0cad1ce5b --- /dev/null +++ b/vendor/github.com/uptrace/bun/dialect/pgdialect/go.mod @@ -0,0 +1,7 @@ +module github.com/uptrace/bun/dialect/pgdialect + +go 1.16 + +replace github.com/uptrace/bun => ../.. + +require github.com/uptrace/bun v0.4.3 diff --git a/vendor/github.com/uptrace/bun/dialect/pgdialect/go.sum b/vendor/github.com/uptrace/bun/dialect/pgdialect/go.sum new file mode 100644 index 000000000..4d0f1c1bb --- /dev/null +++ b/vendor/github.com/uptrace/bun/dialect/pgdialect/go.sum @@ -0,0 +1,22 @@ +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc h1:9lRDQMhESg+zvGYmW5DyG0UqvY96Bu5QYsTLvCHdrgo= +github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc/go.mod h1:bciPuU6GHm1iF1pBvUfxfsH0Wmnc2VbpgvbI9ZWuIRs= +github.com/vmihailenco/msgpack/v5 v5.3.4 h1:qMKAwOV+meBw2Y8k9cVwAy7qErtYCwBzZ2ellBfvnqc= +github.com/vmihailenco/msgpack/v5 v5.3.4/go.mod h1:7xyJ9e+0+9SaZT0Wt1RGleJXzli6Q/V5KbhBonMG9jc= +github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g= +github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds= +golang.org/x/sys v0.0.0-20210616094352-59db8d763f22 h1:RqytpXGR1iVNX7psjB3ff8y7sNFinVFvkx1c8SjBkio= +golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/vendor/github.com/uptrace/bun/dialect/pgdialect/safe.go b/vendor/github.com/uptrace/bun/dialect/pgdialect/safe.go new file mode 100644 index 000000000..dff30b9c5 --- /dev/null +++ b/vendor/github.com/uptrace/bun/dialect/pgdialect/safe.go @@ -0,0 +1,11 @@ +// +build appengine + +package pgdialect + +func bytesToString(b []byte) string { + return string(b) +} + +func stringToBytes(s string) []byte { + return []byte(s) +} diff --git a/vendor/github.com/uptrace/bun/dialect/pgdialect/scan.go b/vendor/github.com/uptrace/bun/dialect/pgdialect/scan.go new file mode 100644 index 000000000..9e22282f5 --- /dev/null +++ b/vendor/github.com/uptrace/bun/dialect/pgdialect/scan.go @@ -0,0 +1,28 @@ +package pgdialect + +import ( + "fmt" + "reflect" + + "github.com/uptrace/bun/schema" +) + +func scanner(typ reflect.Type) schema.ScannerFunc { + if typ.Kind() == reflect.Interface { + return scanInterface + } + return schema.Scanner(typ) +} + +func scanInterface(dest reflect.Value, src interface{}) error { + if dest.IsNil() { + dest.Set(reflect.ValueOf(src)) + return nil + } + + dest = dest.Elem() + if fn := scanner(dest.Type()); fn != nil { + return fn(dest, src) + } + return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type()) +} diff --git a/vendor/github.com/uptrace/bun/dialect/pgdialect/sqltype.go b/vendor/github.com/uptrace/bun/dialect/pgdialect/sqltype.go new file mode 100644 index 000000000..4c2d8075d --- /dev/null +++ b/vendor/github.com/uptrace/bun/dialect/pgdialect/sqltype.go @@ -0,0 +1,104 @@ +package pgdialect + +import ( + "encoding/json" + "net" + "reflect" + "time" + + "github.com/uptrace/bun/dialect/sqltype" + "github.com/uptrace/bun/schema" +) + +const ( + // Date / Time + pgTypeTimestampTz = "TIMESTAMPTZ" // Timestamp with a time zone + pgTypeDate = "DATE" // Date + pgTypeTime = "TIME" // Time without a time zone + pgTypeTimeTz = "TIME WITH TIME ZONE" // Time with a time zone + pgTypeInterval = "INTERVAL" // Time Interval + + // Network Addresses + pgTypeInet = "INET" // IPv4 or IPv6 hosts and networks + pgTypeCidr = "CIDR" // IPv4 or IPv6 networks + pgTypeMacaddr = "MACADDR" // MAC addresses + + // Serial Types + pgTypeSmallSerial = "SMALLSERIAL" // 2 byte autoincrementing integer + pgTypeSerial = "SERIAL" // 4 byte autoincrementing integer + pgTypeBigSerial = "BIGSERIAL" // 8 byte autoincrementing integer + + // Character Types + pgTypeChar = "CHAR" // fixed length string (blank padded) + pgTypeText = "TEXT" // variable length string without limit + + // JSON Types + pgTypeJSON = "JSON" // text representation of json data + pgTypeJSONB = "JSONB" // binary representation of json data + + // Binary Data Types + pgTypeBytea = "BYTEA" // binary string +) + +var ( + timeType = reflect.TypeOf((*time.Time)(nil)).Elem() + ipType = reflect.TypeOf((*net.IP)(nil)).Elem() + ipNetType = reflect.TypeOf((*net.IPNet)(nil)).Elem() + jsonRawMessageType = reflect.TypeOf((*json.RawMessage)(nil)).Elem() +) + +func fieldSQLType(field *schema.Field) string { + if field.UserSQLType != "" { + return field.UserSQLType + } + + if v, ok := field.Tag.Options["composite"]; ok { + return v + } + + if _, ok := field.Tag.Options["hstore"]; ok { + return "hstore" + } + + if _, ok := field.Tag.Options["array"]; ok { + switch field.IndirectType.Kind() { + case reflect.Slice, reflect.Array: + sqlType := sqlType(field.IndirectType.Elem()) + return sqlType + "[]" + } + } + + return sqlType(field.IndirectType) +} + +func sqlType(typ reflect.Type) string { + switch typ { + case ipType: + return pgTypeInet + case ipNetType: + return pgTypeCidr + case jsonRawMessageType: + return pgTypeJSONB + } + + sqlType := schema.DiscoverSQLType(typ) + switch sqlType { + case sqltype.Timestamp: + sqlType = pgTypeTimestampTz + } + + switch typ.Kind() { + case reflect.Map, reflect.Struct: + if sqlType == sqltype.VarChar { + return pgTypeJSONB + } + return sqlType + case reflect.Array, reflect.Slice: + if typ.Elem().Kind() == reflect.Uint8 { + return pgTypeBytea + } + return pgTypeJSONB + } + + return sqlType +} diff --git a/vendor/github.com/uptrace/bun/dialect/pgdialect/unsafe.go b/vendor/github.com/uptrace/bun/dialect/pgdialect/unsafe.go new file mode 100644 index 000000000..2a02a20b1 --- /dev/null +++ b/vendor/github.com/uptrace/bun/dialect/pgdialect/unsafe.go @@ -0,0 +1,18 @@ +// +build !appengine + +package pgdialect + +import "unsafe" + +func bytesToString(b []byte) string { + return *(*string)(unsafe.Pointer(&b)) +} + +func stringToBytes(s string) []byte { + return *(*[]byte)(unsafe.Pointer( + &struct { + string + Cap int + }{s, len(s)}, + )) +} diff --git a/vendor/github.com/uptrace/bun/dialect/sqltype/sqltype.go b/vendor/github.com/uptrace/bun/dialect/sqltype/sqltype.go new file mode 100644 index 000000000..84a51d26d --- /dev/null +++ b/vendor/github.com/uptrace/bun/dialect/sqltype/sqltype.go @@ -0,0 +1,14 @@ +package sqltype + +const ( + Boolean = "BOOLEAN" + SmallInt = "SMALLINT" + Integer = "INTEGER" + BigInt = "BIGINT" + Real = "REAL" + DoublePrecision = "DOUBLE PRECISION" + VarChar = "VARCHAR" + Timestamp = "TIMESTAMP" + JSON = "JSON" + JSONB = "JSONB" +) diff --git a/vendor/github.com/uptrace/bun/extra/bunjson/json.go b/vendor/github.com/uptrace/bun/extra/bunjson/json.go new file mode 100644 index 000000000..eff9d3f0e --- /dev/null +++ b/vendor/github.com/uptrace/bun/extra/bunjson/json.go @@ -0,0 +1,26 @@ +package bunjson + +import ( + "encoding/json" + "io" +) + +var _ Provider = (*StdProvider)(nil) + +type StdProvider struct{} + +func (StdProvider) Marshal(v interface{}) ([]byte, error) { + return json.Marshal(v) +} + +func (StdProvider) Unmarshal(data []byte, v interface{}) error { + return json.Unmarshal(data, v) +} + +func (StdProvider) NewEncoder(w io.Writer) Encoder { + return json.NewEncoder(w) +} + +func (StdProvider) NewDecoder(r io.Reader) Decoder { + return json.NewDecoder(r) +} diff --git a/vendor/github.com/uptrace/bun/extra/bunjson/provider.go b/vendor/github.com/uptrace/bun/extra/bunjson/provider.go new file mode 100644 index 000000000..7f810e122 --- /dev/null +++ b/vendor/github.com/uptrace/bun/extra/bunjson/provider.go @@ -0,0 +1,43 @@ +package bunjson + +import ( + "io" +) + +var provider Provider = StdProvider{} + +func SetProvider(p Provider) { + provider = p +} + +type Provider interface { + Marshal(v interface{}) ([]byte, error) + Unmarshal(data []byte, v interface{}) error + NewEncoder(w io.Writer) Encoder + NewDecoder(r io.Reader) Decoder +} + +type Decoder interface { + Decode(v interface{}) error + UseNumber() +} + +type Encoder interface { + Encode(v interface{}) error +} + +func Marshal(v interface{}) ([]byte, error) { + return provider.Marshal(v) +} + +func Unmarshal(data []byte, v interface{}) error { + return provider.Unmarshal(data, v) +} + +func NewEncoder(w io.Writer) Encoder { + return provider.NewEncoder(w) +} + +func NewDecoder(r io.Reader) Decoder { + return provider.NewDecoder(r) +} diff --git a/vendor/github.com/uptrace/bun/go.mod b/vendor/github.com/uptrace/bun/go.mod new file mode 100644 index 000000000..92def2a3d --- /dev/null +++ b/vendor/github.com/uptrace/bun/go.mod @@ -0,0 +1,12 @@ +module github.com/uptrace/bun + +go 1.16 + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/jinzhu/inflection v1.0.0 + github.com/stretchr/testify v1.7.0 + github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc + github.com/vmihailenco/msgpack/v5 v5.3.4 + golang.org/x/sys v0.0.0-20210616094352-59db8d763f22 // indirect +) diff --git a/vendor/github.com/uptrace/bun/go.sum b/vendor/github.com/uptrace/bun/go.sum new file mode 100644 index 000000000..3bf0a4a3f --- /dev/null +++ b/vendor/github.com/uptrace/bun/go.sum @@ -0,0 +1,23 @@ +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc h1:9lRDQMhESg+zvGYmW5DyG0UqvY96Bu5QYsTLvCHdrgo= +github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc/go.mod h1:bciPuU6GHm1iF1pBvUfxfsH0Wmnc2VbpgvbI9ZWuIRs= +github.com/vmihailenco/msgpack/v5 v5.3.4 h1:qMKAwOV+meBw2Y8k9cVwAy7qErtYCwBzZ2ellBfvnqc= +github.com/vmihailenco/msgpack/v5 v5.3.4/go.mod h1:7xyJ9e+0+9SaZT0Wt1RGleJXzli6Q/V5KbhBonMG9jc= +github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g= +github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds= +golang.org/x/sys v0.0.0-20210616094352-59db8d763f22 h1:RqytpXGR1iVNX7psjB3ff8y7sNFinVFvkx1c8SjBkio= +golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/vendor/github.com/uptrace/bun/hook.go b/vendor/github.com/uptrace/bun/hook.go new file mode 100644 index 000000000..4cfa68fa6 --- /dev/null +++ b/vendor/github.com/uptrace/bun/hook.go @@ -0,0 +1,98 @@ +package bun + +import ( + "context" + "database/sql" + "reflect" + "sync/atomic" + "time" + + "github.com/uptrace/bun/schema" +) + +type QueryEvent struct { + DB *DB + + QueryAppender schema.QueryAppender + Query string + QueryArgs []interface{} + + StartTime time.Time + Result sql.Result + Err error + + Stash map[interface{}]interface{} +} + +type QueryHook interface { + BeforeQuery(context.Context, *QueryEvent) context.Context + AfterQuery(context.Context, *QueryEvent) +} + +func (db *DB) beforeQuery( + ctx context.Context, + queryApp schema.QueryAppender, + query string, + queryArgs []interface{}, +) (context.Context, *QueryEvent) { + atomic.AddUint64(&db.stats.Queries, 1) + + if len(db.queryHooks) == 0 { + return ctx, nil + } + + event := &QueryEvent{ + DB: db, + + QueryAppender: queryApp, + Query: query, + QueryArgs: queryArgs, + + StartTime: time.Now(), + } + + for _, hook := range db.queryHooks { + ctx = hook.BeforeQuery(ctx, event) + } + + return ctx, event +} + +func (db *DB) afterQuery( + ctx context.Context, + event *QueryEvent, + res sql.Result, + err error, +) { + switch err { + case nil, sql.ErrNoRows: + // nothing + default: + atomic.AddUint64(&db.stats.Errors, 1) + } + + if event == nil { + return + } + + event.Result = res + event.Err = err + + db.afterQueryFromIndex(ctx, event, len(db.queryHooks)-1) +} + +func (db *DB) afterQueryFromIndex(ctx context.Context, event *QueryEvent, hookIndex int) { + for ; hookIndex >= 0; hookIndex-- { + db.queryHooks[hookIndex].AfterQuery(ctx, event) + } +} + +//------------------------------------------------------------------------------ + +func callBeforeScanHook(ctx context.Context, v reflect.Value) error { + return v.Interface().(schema.BeforeScanHook).BeforeScan(ctx) +} + +func callAfterScanHook(ctx context.Context, v reflect.Value) error { + return v.Interface().(schema.AfterScanHook).AfterScan(ctx) +} diff --git a/vendor/github.com/uptrace/bun/internal/flag.go b/vendor/github.com/uptrace/bun/internal/flag.go new file mode 100644 index 000000000..b42f59df7 --- /dev/null +++ b/vendor/github.com/uptrace/bun/internal/flag.go @@ -0,0 +1,16 @@ +package internal + +type Flag uint64 + +func (flag Flag) Has(other Flag) bool { + return flag&other == other +} + +func (flag Flag) Set(other Flag) Flag { + return flag | other +} + +func (flag Flag) Remove(other Flag) Flag { + flag &= ^other + return flag +} diff --git a/vendor/github.com/uptrace/bun/internal/hex.go b/vendor/github.com/uptrace/bun/internal/hex.go new file mode 100644 index 000000000..6fae2bb78 --- /dev/null +++ b/vendor/github.com/uptrace/bun/internal/hex.go @@ -0,0 +1,43 @@ +package internal + +import ( + fasthex "github.com/tmthrgd/go-hex" +) + +type HexEncoder struct { + b []byte + written bool +} + +func NewHexEncoder(b []byte) *HexEncoder { + return &HexEncoder{ + b: b, + } +} + +func (enc *HexEncoder) Bytes() []byte { + return enc.b +} + +func (enc *HexEncoder) Write(b []byte) (int, error) { + if !enc.written { + enc.b = append(enc.b, '\'') + enc.b = append(enc.b, `\x`...) + enc.written = true + } + + i := len(enc.b) + enc.b = append(enc.b, make([]byte, fasthex.EncodedLen(len(b)))...) + fasthex.Encode(enc.b[i:], b) + + return len(b), nil +} + +func (enc *HexEncoder) Close() error { + if enc.written { + enc.b = append(enc.b, '\'') + } else { + enc.b = append(enc.b, "NULL"...) + } + return nil +} diff --git a/vendor/github.com/uptrace/bun/internal/logger.go b/vendor/github.com/uptrace/bun/internal/logger.go new file mode 100644 index 000000000..2e22a0893 --- /dev/null +++ b/vendor/github.com/uptrace/bun/internal/logger.go @@ -0,0 +1,27 @@ +package internal + +import ( + "fmt" + "log" + "os" +) + +var Warn = log.New(os.Stderr, "WARN: bun: ", log.LstdFlags) + +var Deprecated = log.New(os.Stderr, "DEPRECATED: bun: ", log.LstdFlags) + +type Logging interface { + Printf(format string, v ...interface{}) +} + +type logger struct { + log *log.Logger +} + +func (l *logger) Printf(format string, v ...interface{}) { + _ = l.log.Output(2, fmt.Sprintf(format, v...)) +} + +var Logger Logging = &logger{ + log: log.New(os.Stderr, "bun: ", log.LstdFlags|log.Lshortfile), +} diff --git a/vendor/github.com/uptrace/bun/internal/map_key.go b/vendor/github.com/uptrace/bun/internal/map_key.go new file mode 100644 index 000000000..bb5fcca8c --- /dev/null +++ b/vendor/github.com/uptrace/bun/internal/map_key.go @@ -0,0 +1,67 @@ +package internal + +import "reflect" + +var ifaceType = reflect.TypeOf((*interface{})(nil)).Elem() + +type MapKey struct { + iface interface{} +} + +func NewMapKey(is []interface{}) MapKey { + return MapKey{ + iface: newMapKey(is), + } +} + +func newMapKey(is []interface{}) interface{} { + switch len(is) { + case 1: + ptr := new([1]interface{}) + copy((*ptr)[:], is) + return *ptr + case 2: + ptr := new([2]interface{}) + copy((*ptr)[:], is) + return *ptr + case 3: + ptr := new([3]interface{}) + copy((*ptr)[:], is) + return *ptr + case 4: + ptr := new([4]interface{}) + copy((*ptr)[:], is) + return *ptr + case 5: + ptr := new([5]interface{}) + copy((*ptr)[:], is) + return *ptr + case 6: + ptr := new([6]interface{}) + copy((*ptr)[:], is) + return *ptr + case 7: + ptr := new([7]interface{}) + copy((*ptr)[:], is) + return *ptr + case 8: + ptr := new([8]interface{}) + copy((*ptr)[:], is) + return *ptr + case 9: + ptr := new([9]interface{}) + copy((*ptr)[:], is) + return *ptr + case 10: + ptr := new([10]interface{}) + copy((*ptr)[:], is) + return *ptr + default: + } + + at := reflect.New(reflect.ArrayOf(len(is), ifaceType)).Elem() + for i, v := range is { + *(at.Index(i).Addr().Interface().(*interface{})) = v + } + return at.Interface() +} diff --git a/vendor/github.com/uptrace/bun/internal/parser/parser.go b/vendor/github.com/uptrace/bun/internal/parser/parser.go new file mode 100644 index 000000000..cdfc0be16 --- /dev/null +++ b/vendor/github.com/uptrace/bun/internal/parser/parser.go @@ -0,0 +1,141 @@ +package parser + +import ( + "bytes" + "strconv" + + "github.com/uptrace/bun/internal" +) + +type Parser struct { + b []byte + i int +} + +func New(b []byte) *Parser { + return &Parser{ + b: b, + } +} + +func NewString(s string) *Parser { + return New(internal.Bytes(s)) +} + +func (p *Parser) Valid() bool { + return p.i < len(p.b) +} + +func (p *Parser) Bytes() []byte { + return p.b[p.i:] +} + +func (p *Parser) Read() byte { + if p.Valid() { + c := p.b[p.i] + p.Advance() + return c + } + return 0 +} + +func (p *Parser) Peek() byte { + if p.Valid() { + return p.b[p.i] + } + return 0 +} + +func (p *Parser) Advance() { + p.i++ +} + +func (p *Parser) Skip(skip byte) bool { + if p.Peek() == skip { + p.Advance() + return true + } + return false +} + +func (p *Parser) SkipBytes(skip []byte) bool { + if len(skip) > len(p.b[p.i:]) { + return false + } + if !bytes.Equal(p.b[p.i:p.i+len(skip)], skip) { + return false + } + p.i += len(skip) + return true +} + +func (p *Parser) ReadSep(sep byte) ([]byte, bool) { + ind := bytes.IndexByte(p.b[p.i:], sep) + if ind == -1 { + b := p.b[p.i:] + p.i = len(p.b) + return b, false + } + + b := p.b[p.i : p.i+ind] + p.i += ind + 1 + return b, true +} + +func (p *Parser) ReadIdentifier() (string, bool) { + if p.i < len(p.b) && p.b[p.i] == '(' { + s := p.i + 1 + if ind := bytes.IndexByte(p.b[s:], ')'); ind != -1 { + b := p.b[s : s+ind] + p.i = s + ind + 1 + return internal.String(b), false + } + } + + ind := len(p.b) - p.i + var alpha bool + for i, c := range p.b[p.i:] { + if isNum(c) { + continue + } + if isAlpha(c) || (i > 0 && alpha && c == '_') { + alpha = true + continue + } + ind = i + break + } + if ind == 0 { + return "", false + } + b := p.b[p.i : p.i+ind] + p.i += ind + return internal.String(b), !alpha +} + +func (p *Parser) ReadNumber() int { + ind := len(p.b) - p.i + for i, c := range p.b[p.i:] { + if !isNum(c) { + ind = i + break + } + } + if ind == 0 { + return 0 + } + n, err := strconv.Atoi(string(p.b[p.i : p.i+ind])) + if err != nil { + panic(err) + } + p.i += ind + return n +} + +func isNum(c byte) bool { + return c >= '0' && c <= '9' +} + +func isAlpha(c byte) bool { + return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') +} diff --git a/vendor/github.com/uptrace/bun/internal/safe.go b/vendor/github.com/uptrace/bun/internal/safe.go new file mode 100644 index 000000000..862ff0eb3 --- /dev/null +++ b/vendor/github.com/uptrace/bun/internal/safe.go @@ -0,0 +1,11 @@ +// +build appengine + +package internal + +func String(b []byte) string { + return string(b) +} + +func Bytes(s string) []byte { + return []byte(s) +} diff --git a/vendor/github.com/uptrace/bun/internal/tagparser/parser.go b/vendor/github.com/uptrace/bun/internal/tagparser/parser.go new file mode 100644 index 000000000..8ef89248c --- /dev/null +++ b/vendor/github.com/uptrace/bun/internal/tagparser/parser.go @@ -0,0 +1,147 @@ +package tagparser + +import ( + "strings" +) + +type Tag struct { + Name string + Options map[string]string +} + +func (t Tag) HasOption(name string) bool { + _, ok := t.Options[name] + return ok +} + +func Parse(s string) Tag { + p := parser{ + s: s, + } + p.parse() + return p.tag +} + +type parser struct { + s string + i int + + tag Tag + seenName bool // for empty names +} + +func (p *parser) setName(name string) { + if p.seenName { + p.addOption(name, "") + } else { + p.seenName = true + p.tag.Name = name + } +} + +func (p *parser) addOption(key, value string) { + p.seenName = true + if key == "" { + return + } + if p.tag.Options == nil { + p.tag.Options = make(map[string]string) + } + p.tag.Options[key] = value +} + +func (p *parser) parse() { + for p.valid() { + p.parseKeyValue() + if p.peek() == ',' { + p.i++ + } + } +} + +func (p *parser) parseKeyValue() { + start := p.i + + for p.valid() { + switch c := p.read(); c { + case ',': + key := p.s[start : p.i-1] + p.setName(key) + return + case ':': + key := p.s[start : p.i-1] + value := p.parseValue() + p.addOption(key, value) + return + case '"': + key := p.parseQuotedValue() + p.setName(key) + return + } + } + + key := p.s[start:p.i] + p.setName(key) +} + +func (p *parser) parseValue() string { + start := p.i + + for p.valid() { + switch c := p.read(); c { + case '"': + return p.parseQuotedValue() + case ',': + return p.s[start : p.i-1] + } + } + + if p.i == start { + return "" + } + return p.s[start:p.i] +} + +func (p *parser) parseQuotedValue() string { + if i := strings.IndexByte(p.s[p.i:], '"'); i >= 0 && p.s[p.i+i-1] != '\\' { + s := p.s[p.i : p.i+i] + p.i += i + 1 + return s + } + + b := make([]byte, 0, 16) + + for p.valid() { + switch c := p.read(); c { + case '\\': + b = append(b, p.read()) + case '"': + return string(b) + default: + b = append(b, c) + } + } + + return "" +} + +func (p *parser) valid() bool { + return p.i < len(p.s) +} + +func (p *parser) read() byte { + if !p.valid() { + return 0 + } + c := p.s[p.i] + p.i++ + return c +} + +func (p *parser) peek() byte { + if !p.valid() { + return 0 + } + c := p.s[p.i] + return c +} diff --git a/vendor/github.com/uptrace/bun/internal/time.go b/vendor/github.com/uptrace/bun/internal/time.go new file mode 100644 index 000000000..e4e0804b0 --- /dev/null +++ b/vendor/github.com/uptrace/bun/internal/time.go @@ -0,0 +1,41 @@ +package internal + +import ( + "fmt" + "time" +) + +const ( + dateFormat = "2006-01-02" + timeFormat = "15:04:05.999999999" + timestampFormat = "2006-01-02 15:04:05.999999999" + timestamptzFormat = "2006-01-02 15:04:05.999999999-07:00:00" + timestamptzFormat2 = "2006-01-02 15:04:05.999999999-07:00" + timestamptzFormat3 = "2006-01-02 15:04:05.999999999-07" +) + +func ParseTime(s string) (time.Time, error) { + switch l := len(s); { + case l < len("15:04:05"): + return time.Time{}, fmt.Errorf("bun: can't parse time=%q", s) + case l <= len(timeFormat): + if s[2] == ':' { + return time.ParseInLocation(timeFormat, s, time.UTC) + } + return time.ParseInLocation(dateFormat, s, time.UTC) + default: + if s[10] == 'T' { + return time.Parse(time.RFC3339Nano, s) + } + if c := s[l-9]; c == '+' || c == '-' { + return time.Parse(timestamptzFormat, s) + } + if c := s[l-6]; c == '+' || c == '-' { + return time.Parse(timestamptzFormat2, s) + } + if c := s[l-3]; c == '+' || c == '-' { + return time.Parse(timestamptzFormat3, s) + } + return time.ParseInLocation(timestampFormat, s, time.UTC) + } +} diff --git a/vendor/github.com/uptrace/bun/internal/underscore.go b/vendor/github.com/uptrace/bun/internal/underscore.go new file mode 100644 index 000000000..9de52fb7b --- /dev/null +++ b/vendor/github.com/uptrace/bun/internal/underscore.go @@ -0,0 +1,67 @@ +package internal + +func IsUpper(c byte) bool { + return c >= 'A' && c <= 'Z' +} + +func IsLower(c byte) bool { + return c >= 'a' && c <= 'z' +} + +func ToUpper(c byte) byte { + return c - 32 +} + +func ToLower(c byte) byte { + return c + 32 +} + +// Underscore converts "CamelCasedString" to "camel_cased_string". +func Underscore(s string) string { + r := make([]byte, 0, len(s)+5) + for i := 0; i < len(s); i++ { + c := s[i] + if IsUpper(c) { + if i > 0 && i+1 < len(s) && (IsLower(s[i-1]) || IsLower(s[i+1])) { + r = append(r, '_', ToLower(c)) + } else { + r = append(r, ToLower(c)) + } + } else { + r = append(r, c) + } + } + return string(r) +} + +func CamelCased(s string) string { + r := make([]byte, 0, len(s)) + upperNext := true + for i := 0; i < len(s); i++ { + c := s[i] + if c == '_' { + upperNext = true + continue + } + if upperNext { + if IsLower(c) { + c = ToUpper(c) + } + upperNext = false + } + r = append(r, c) + } + return string(r) +} + +func ToExported(s string) string { + if len(s) == 0 { + return s + } + if c := s[0]; IsLower(c) { + b := []byte(s) + b[0] = ToUpper(c) + return string(b) + } + return s +} diff --git a/vendor/github.com/uptrace/bun/internal/unsafe.go b/vendor/github.com/uptrace/bun/internal/unsafe.go new file mode 100644 index 000000000..4bc79701f --- /dev/null +++ b/vendor/github.com/uptrace/bun/internal/unsafe.go @@ -0,0 +1,20 @@ +// +build !appengine + +package internal + +import "unsafe" + +// String converts byte slice to string. +func String(b []byte) string { + return *(*string)(unsafe.Pointer(&b)) +} + +// Bytes converts string to byte slice. +func Bytes(s string) []byte { + return *(*[]byte)(unsafe.Pointer( + &struct { + string + Cap int + }{s, len(s)}, + )) +} diff --git a/vendor/github.com/uptrace/bun/internal/util.go b/vendor/github.com/uptrace/bun/internal/util.go new file mode 100644 index 000000000..c831dc659 --- /dev/null +++ b/vendor/github.com/uptrace/bun/internal/util.go @@ -0,0 +1,57 @@ +package internal + +import ( + "reflect" +) + +func MakeSliceNextElemFunc(v reflect.Value) func() reflect.Value { + if v.Kind() == reflect.Array { + var pos int + return func() reflect.Value { + v := v.Index(pos) + pos++ + return v + } + } + + elemType := v.Type().Elem() + + if elemType.Kind() == reflect.Ptr { + elemType = elemType.Elem() + return func() reflect.Value { + if v.Len() < v.Cap() { + v.Set(v.Slice(0, v.Len()+1)) + elem := v.Index(v.Len() - 1) + if elem.IsNil() { + elem.Set(reflect.New(elemType)) + } + return elem.Elem() + } + + elem := reflect.New(elemType) + v.Set(reflect.Append(v, elem)) + return elem.Elem() + } + } + + zero := reflect.Zero(elemType) + return func() reflect.Value { + if v.Len() < v.Cap() { + v.Set(v.Slice(0, v.Len()+1)) + return v.Index(v.Len() - 1) + } + + v.Set(reflect.Append(v, zero)) + return v.Index(v.Len() - 1) + } +} + +func Unwrap(err error) error { + u, ok := err.(interface { + Unwrap() error + }) + if !ok { + return nil + } + return u.Unwrap() +} diff --git a/vendor/github.com/uptrace/bun/join.go b/vendor/github.com/uptrace/bun/join.go new file mode 100644 index 000000000..4557f5bc0 --- /dev/null +++ b/vendor/github.com/uptrace/bun/join.go @@ -0,0 +1,308 @@ +package bun + +import ( + "context" + "reflect" + + "github.com/uptrace/bun/internal" + "github.com/uptrace/bun/schema" +) + +type join struct { + Parent *join + BaseModel tableModel + JoinModel tableModel + Relation *schema.Relation + + ApplyQueryFunc func(*SelectQuery) *SelectQuery + columns []schema.QueryWithArgs +} + +func (j *join) applyQuery(q *SelectQuery) { + if j.ApplyQueryFunc == nil { + return + } + + var table *schema.Table + var columns []schema.QueryWithArgs + + // Save state. + table, q.table = q.table, j.JoinModel.Table() + columns, q.columns = q.columns, nil + + q = j.ApplyQueryFunc(q) + + // Restore state. + q.table = table + j.columns, q.columns = q.columns, columns +} + +func (j *join) Select(ctx context.Context, q *SelectQuery) error { + switch j.Relation.Type { + case schema.HasManyRelation: + return j.selectMany(ctx, q) + case schema.ManyToManyRelation: + return j.selectM2M(ctx, q) + } + panic("not reached") +} + +func (j *join) selectMany(ctx context.Context, q *SelectQuery) error { + q = j.manyQuery(q) + if q == nil { + return nil + } + return q.Scan(ctx) +} + +func (j *join) manyQuery(q *SelectQuery) *SelectQuery { + hasManyModel := newHasManyModel(j) + if hasManyModel == nil { + return nil + } + + q = q.Model(hasManyModel) + + var where []byte + if len(j.Relation.JoinFields) > 1 { + where = append(where, '(') + } + where = appendColumns(where, j.JoinModel.Table().SQLAlias, j.Relation.JoinFields) + if len(j.Relation.JoinFields) > 1 { + where = append(where, ')') + } + where = append(where, " IN ("...) + where = appendChildValues( + q.db.Formatter(), + where, + j.JoinModel.Root(), + j.JoinModel.ParentIndex(), + j.Relation.BaseFields, + ) + where = append(where, ")"...) + q = q.Where(internal.String(where)) + + if j.Relation.PolymorphicField != nil { + q = q.Where("? = ?", j.Relation.PolymorphicField.SQLName, j.Relation.PolymorphicValue) + } + + j.applyQuery(q) + q = q.Apply(j.hasManyColumns) + + return q +} + +func (j *join) hasManyColumns(q *SelectQuery) *SelectQuery { + if j.Relation.M2MTable != nil { + q = q.ColumnExpr(string(j.Relation.M2MTable.SQLAlias) + ".*") + } + + b := make([]byte, 0, 32) + + if len(j.columns) > 0 { + for i, col := range j.columns { + if i > 0 { + b = append(b, ", "...) + } + + var err error + b, err = col.AppendQuery(q.db.fmter, b) + if err != nil { + q.err = err + return q + } + } + } else { + joinTable := j.JoinModel.Table() + b = appendColumns(b, joinTable.SQLAlias, joinTable.Fields) + } + + q = q.ColumnExpr(internal.String(b)) + + return q +} + +func (j *join) selectM2M(ctx context.Context, q *SelectQuery) error { + q = j.m2mQuery(q) + if q == nil { + return nil + } + return q.Scan(ctx) +} + +func (j *join) m2mQuery(q *SelectQuery) *SelectQuery { + fmter := q.db.fmter + + m2mModel := newM2MModel(j) + if m2mModel == nil { + return nil + } + q = q.Model(m2mModel) + + index := j.JoinModel.ParentIndex() + baseTable := j.BaseModel.Table() + + //nolint + var join []byte + join = append(join, "JOIN "...) + join = fmter.AppendQuery(join, string(j.Relation.M2MTable.Name)) + join = append(join, " AS "...) + join = append(join, j.Relation.M2MTable.SQLAlias...) + join = append(join, " ON ("...) + for i, col := range j.Relation.M2MBaseFields { + if i > 0 { + join = append(join, ", "...) + } + join = append(join, j.Relation.M2MTable.SQLAlias...) + join = append(join, '.') + join = append(join, col.SQLName...) + } + join = append(join, ") IN ("...) + join = appendChildValues(fmter, join, j.BaseModel.Root(), index, baseTable.PKs) + join = append(join, ")"...) + q = q.Join(internal.String(join)) + + joinTable := j.JoinModel.Table() + for i, m2mJoinField := range j.Relation.M2MJoinFields { + joinField := j.Relation.JoinFields[i] + q = q.Where("?.? = ?.?", + joinTable.SQLAlias, joinField.SQLName, + j.Relation.M2MTable.SQLAlias, m2mJoinField.SQLName) + } + + j.applyQuery(q) + q = q.Apply(j.hasManyColumns) + + return q +} + +func (j *join) hasParent() bool { + if j.Parent != nil { + switch j.Parent.Relation.Type { + case schema.HasOneRelation, schema.BelongsToRelation: + return true + } + } + return false +} + +func (j *join) appendAlias(fmter schema.Formatter, b []byte) []byte { + quote := fmter.IdentQuote() + + b = append(b, quote) + b = appendAlias(b, j) + b = append(b, quote) + return b +} + +func (j *join) appendAliasColumn(fmter schema.Formatter, b []byte, column string) []byte { + quote := fmter.IdentQuote() + + b = append(b, quote) + b = appendAlias(b, j) + b = append(b, "__"...) + b = append(b, column...) + b = append(b, quote) + return b +} + +func (j *join) appendBaseAlias(fmter schema.Formatter, b []byte) []byte { + quote := fmter.IdentQuote() + + if j.hasParent() { + b = append(b, quote) + b = appendAlias(b, j.Parent) + b = append(b, quote) + return b + } + return append(b, j.BaseModel.Table().SQLAlias...) +} + +func (j *join) appendSoftDelete(b []byte, flags internal.Flag) []byte { + b = append(b, '.') + b = append(b, j.JoinModel.Table().SoftDeleteField.SQLName...) + if flags.Has(deletedFlag) { + b = append(b, " IS NOT NULL"...) + } else { + b = append(b, " IS NULL"...) + } + return b +} + +func appendAlias(b []byte, j *join) []byte { + if j.hasParent() { + b = appendAlias(b, j.Parent) + b = append(b, "__"...) + } + b = append(b, j.Relation.Field.Name...) + return b +} + +func (j *join) appendHasOneJoin( + fmter schema.Formatter, b []byte, q *SelectQuery, +) (_ []byte, err error) { + isSoftDelete := j.JoinModel.Table().SoftDeleteField != nil && !q.flags.Has(allWithDeletedFlag) + + b = append(b, "LEFT JOIN "...) + b = fmter.AppendQuery(b, string(j.JoinModel.Table().SQLNameForSelects)) + b = append(b, " AS "...) + b = j.appendAlias(fmter, b) + + b = append(b, " ON "...) + + b = append(b, '(') + for i, baseField := range j.Relation.BaseFields { + if i > 0 { + b = append(b, " AND "...) + } + b = j.appendAlias(fmter, b) + b = append(b, '.') + b = append(b, j.Relation.JoinFields[i].SQLName...) + b = append(b, " = "...) + b = j.appendBaseAlias(fmter, b) + b = append(b, '.') + b = append(b, baseField.SQLName...) + } + b = append(b, ')') + + if isSoftDelete { + b = append(b, " AND "...) + b = j.appendAlias(fmter, b) + b = j.appendSoftDelete(b, q.flags) + } + + return b, nil +} + +func appendChildValues( + fmter schema.Formatter, b []byte, v reflect.Value, index []int, fields []*schema.Field, +) []byte { + seen := make(map[string]struct{}) + walk(v, index, func(v reflect.Value) { + start := len(b) + + if len(fields) > 1 { + b = append(b, '(') + } + for i, f := range fields { + if i > 0 { + b = append(b, ", "...) + } + b = f.AppendValue(fmter, b, v) + } + if len(fields) > 1 { + b = append(b, ')') + } + b = append(b, ", "...) + + if _, ok := seen[string(b[start:])]; ok { + b = b[:start] + } else { + seen[string(b[start:])] = struct{}{} + } + }) + if len(seen) > 0 { + b = b[:len(b)-2] // trim ", " + } + return b +} diff --git a/vendor/github.com/uptrace/bun/model.go b/vendor/github.com/uptrace/bun/model.go new file mode 100644 index 000000000..c9f0f3583 --- /dev/null +++ b/vendor/github.com/uptrace/bun/model.go @@ -0,0 +1,207 @@ +package bun + +import ( + "context" + "database/sql" + "errors" + "fmt" + "reflect" + "time" + + "github.com/uptrace/bun/schema" +) + +var errNilModel = errors.New("bun: Model(nil)") + +var timeType = reflect.TypeOf((*time.Time)(nil)).Elem() + +type Model interface { + ScanRows(ctx context.Context, rows *sql.Rows) (int, error) + Value() interface{} +} + +type rowScanner interface { + ScanRow(ctx context.Context, rows *sql.Rows) error +} + +type model interface { + Model +} + +type tableModel interface { + model + + schema.BeforeScanHook + schema.AfterScanHook + ScanColumn(column string, src interface{}) error + + Table() *schema.Table + Relation() *schema.Relation + + Join(string, func(*SelectQuery) *SelectQuery) *join + GetJoin(string) *join + GetJoins() []join + AddJoin(join) *join + + Root() reflect.Value + ParentIndex() []int + Mount(reflect.Value) + + updateSoftDeleteField() error +} + +func newModel(db *DB, dest []interface{}) (model, error) { + if len(dest) == 1 { + return _newModel(db, dest[0], true) + } + + values := make([]reflect.Value, len(dest)) + + for i, el := range dest { + v := reflect.ValueOf(el) + if v.Kind() != reflect.Ptr { + return nil, fmt.Errorf("bun: Scan(non-pointer %T)", dest) + } + + v = v.Elem() + if v.Kind() != reflect.Slice { + return newScanModel(db, dest), nil + } + + values[i] = v + } + + return newSliceModel(db, dest, values), nil +} + +func newSingleModel(db *DB, dest interface{}) (model, error) { + return _newModel(db, dest, false) +} + +func _newModel(db *DB, dest interface{}, scan bool) (model, error) { + switch dest := dest.(type) { + case nil: + return nil, errNilModel + case Model: + return dest, nil + case sql.Scanner: + if !scan { + return nil, fmt.Errorf("bun: Model(unsupported %T)", dest) + } + return newScanModel(db, []interface{}{dest}), nil + } + + v := reflect.ValueOf(dest) + if !v.IsValid() { + return nil, errNilModel + } + if v.Kind() != reflect.Ptr { + return nil, fmt.Errorf("bun: Model(non-pointer %T)", dest) + } + + if v.IsNil() { + typ := v.Type().Elem() + if typ.Kind() == reflect.Struct { + return newStructTableModel(db, dest, db.Table(typ)), nil + } + return nil, fmt.Errorf("bun: Model(nil %T)", dest) + } + + v = v.Elem() + + switch v.Kind() { + case reflect.Map: + typ := v.Type() + if err := validMap(typ); err != nil { + return nil, err + } + mapPtr := v.Addr().Interface().(*map[string]interface{}) + return newMapModel(db, mapPtr), nil + case reflect.Struct: + if v.Type() != timeType { + return newStructTableModelValue(db, dest, v), nil + } + case reflect.Slice: + switch elemType := sliceElemType(v); elemType.Kind() { + case reflect.Struct: + if elemType != timeType { + return newSliceTableModel(db, dest, v, elemType), nil + } + case reflect.Map: + if err := validMap(elemType); err != nil { + return nil, err + } + slicePtr := v.Addr().Interface().(*[]map[string]interface{}) + return newMapSliceModel(db, slicePtr), nil + } + return newSliceModel(db, []interface{}{dest}, []reflect.Value{v}), nil + } + + if scan { + return newScanModel(db, []interface{}{dest}), nil + } + + return nil, fmt.Errorf("bun: Model(unsupported %T)", dest) +} + +func newTableModelIndex( + db *DB, + table *schema.Table, + root reflect.Value, + index []int, + rel *schema.Relation, +) (tableModel, error) { + typ := typeByIndex(table.Type, index) + + if typ.Kind() == reflect.Struct { + return &structTableModel{ + db: db, + table: table.Dialect().Tables().Get(typ), + rel: rel, + + root: root, + index: index, + }, nil + } + + if typ.Kind() == reflect.Slice { + structType := indirectType(typ.Elem()) + if structType.Kind() == reflect.Struct { + m := sliceTableModel{ + structTableModel: structTableModel{ + db: db, + table: table.Dialect().Tables().Get(structType), + rel: rel, + + root: root, + index: index, + }, + } + m.init(typ) + return &m, nil + } + } + + return nil, fmt.Errorf("bun: NewModel(%s)", typ) +} + +func validMap(typ reflect.Type) error { + if typ.Key().Kind() != reflect.String || typ.Elem().Kind() != reflect.Interface { + return fmt.Errorf("bun: Model(unsupported %s) (expected *map[string]interface{})", + typ) + } + return nil +} + +//------------------------------------------------------------------------------ + +func isSingleRowModel(m model) bool { + switch m.(type) { + case *mapModel, + *structTableModel, + *scanModel: + return true + default: + return false + } +} diff --git a/vendor/github.com/uptrace/bun/model_map.go b/vendor/github.com/uptrace/bun/model_map.go new file mode 100644 index 000000000..81c1a4a3b --- /dev/null +++ b/vendor/github.com/uptrace/bun/model_map.go @@ -0,0 +1,183 @@ +package bun + +import ( + "context" + "database/sql" + "reflect" + "sort" + + "github.com/uptrace/bun/schema" +) + +type mapModel struct { + db *DB + + dest *map[string]interface{} + m map[string]interface{} + + rows *sql.Rows + columns []string + _columnTypes []*sql.ColumnType + scanIndex int +} + +var _ model = (*mapModel)(nil) + +func newMapModel(db *DB, dest *map[string]interface{}) *mapModel { + m := &mapModel{ + db: db, + dest: dest, + } + if dest != nil { + m.m = *dest + } + return m +} + +func (m *mapModel) Value() interface{} { + return m.dest +} + +func (m *mapModel) ScanRows(ctx context.Context, rows *sql.Rows) (int, error) { + if !rows.Next() { + return 0, rows.Err() + } + + columns, err := rows.Columns() + if err != nil { + return 0, err + } + + m.rows = rows + m.columns = columns + dest := makeDest(m, len(columns)) + + if m.m == nil { + m.m = make(map[string]interface{}, len(m.columns)) + } + + m.scanIndex = 0 + if err := rows.Scan(dest...); err != nil { + return 0, err + } + + *m.dest = m.m + + return 1, nil +} + +func (m *mapModel) Scan(src interface{}) error { + if _, ok := src.([]byte); !ok { + return m.scanRaw(src) + } + + columnTypes, err := m.columnTypes() + if err != nil { + return err + } + + scanType := columnTypes[m.scanIndex].ScanType() + switch scanType.Kind() { + case reflect.Interface: + return m.scanRaw(src) + case reflect.Slice: + if scanType.Elem().Kind() == reflect.Uint8 { + return m.scanRaw(src) + } + } + + dest := reflect.New(scanType).Elem() + if err := schema.Scanner(scanType)(dest, src); err != nil { + return err + } + + return m.scanRaw(dest.Interface()) +} + +func (m *mapModel) columnTypes() ([]*sql.ColumnType, error) { + if m._columnTypes == nil { + columnTypes, err := m.rows.ColumnTypes() + if err != nil { + return nil, err + } + m._columnTypes = columnTypes + } + return m._columnTypes, nil +} + +func (m *mapModel) scanRaw(src interface{}) error { + columnName := m.columns[m.scanIndex] + m.scanIndex++ + m.m[columnName] = src + return nil +} + +func (m *mapModel) appendColumnsValues(fmter schema.Formatter, b []byte) []byte { + keys := make([]string, 0, len(m.m)) + + for k := range m.m { + keys = append(keys, k) + } + sort.Strings(keys) + + b = append(b, " ("...) + + for i, k := range keys { + if i > 0 { + b = append(b, ", "...) + } + b = fmter.AppendIdent(b, k) + } + + b = append(b, ") VALUES ("...) + + isTemplate := fmter.IsNop() + for i, k := range keys { + if i > 0 { + b = append(b, ", "...) + } + if isTemplate { + b = append(b, '?') + } else { + b = fmter.Dialect().Append(fmter, b, m.m[k]) + } + } + + b = append(b, ")"...) + + return b +} + +func (m *mapModel) appendSet(fmter schema.Formatter, b []byte) []byte { + keys := make([]string, 0, len(m.m)) + + for k := range m.m { + keys = append(keys, k) + } + sort.Strings(keys) + + isTemplate := fmter.IsNop() + for i, k := range keys { + if i > 0 { + b = append(b, ", "...) + } + + b = fmter.AppendIdent(b, k) + b = append(b, " = "...) + if isTemplate { + b = append(b, '?') + } else { + b = fmter.Dialect().Append(fmter, b, m.m[k]) + } + } + + return b +} + +func makeDest(v interface{}, n int) []interface{} { + dest := make([]interface{}, n) + for i := range dest { + dest[i] = v + } + return dest +} diff --git a/vendor/github.com/uptrace/bun/model_map_slice.go b/vendor/github.com/uptrace/bun/model_map_slice.go new file mode 100644 index 000000000..5c6f48e44 --- /dev/null +++ b/vendor/github.com/uptrace/bun/model_map_slice.go @@ -0,0 +1,162 @@ +package bun + +import ( + "context" + "database/sql" + "errors" + "sort" + + "github.com/uptrace/bun/dialect/feature" + "github.com/uptrace/bun/schema" +) + +type mapSliceModel struct { + mapModel + dest *[]map[string]interface{} + + keys []string +} + +var _ model = (*mapSliceModel)(nil) + +func newMapSliceModel(db *DB, dest *[]map[string]interface{}) *mapSliceModel { + return &mapSliceModel{ + mapModel: mapModel{ + db: db, + }, + dest: dest, + } +} + +func (m *mapSliceModel) Value() interface{} { + return m.dest +} + +func (m *mapSliceModel) SetCap(cap int) { + if cap > 100 { + cap = 100 + } + if slice := *m.dest; len(slice) < cap { + *m.dest = make([]map[string]interface{}, 0, cap) + } +} + +func (m *mapSliceModel) ScanRows(ctx context.Context, rows *sql.Rows) (int, error) { + columns, err := rows.Columns() + if err != nil { + return 0, err + } + + m.rows = rows + m.columns = columns + dest := makeDest(m, len(columns)) + + slice := *m.dest + if len(slice) > 0 { + slice = slice[:0] + } + + var n int + + for rows.Next() { + m.m = make(map[string]interface{}, len(m.columns)) + + m.scanIndex = 0 + if err := rows.Scan(dest...); err != nil { + return 0, err + } + + slice = append(slice, m.m) + n++ + } + if err := rows.Err(); err != nil { + return 0, err + } + + *m.dest = slice + return n, nil +} + +func (m *mapSliceModel) appendColumns(fmter schema.Formatter, b []byte) (_ []byte, err error) { + if err := m.initKeys(); err != nil { + return nil, err + } + + for i, k := range m.keys { + if i > 0 { + b = append(b, ", "...) + } + b = fmter.AppendIdent(b, k) + } + + return b, nil +} + +func (m *mapSliceModel) appendValues(fmter schema.Formatter, b []byte) (_ []byte, err error) { + if err := m.initKeys(); err != nil { + return nil, err + } + slice := *m.dest + + b = append(b, "VALUES "...) + if m.db.features.Has(feature.ValuesRow) { + b = append(b, "ROW("...) + } else { + b = append(b, '(') + } + + if fmter.IsNop() { + for i := range m.keys { + if i > 0 { + b = append(b, ", "...) + } + b = append(b, '?') + } + return b, nil + } + + for i, el := range slice { + if i > 0 { + b = append(b, "), "...) + if m.db.features.Has(feature.ValuesRow) { + b = append(b, "ROW("...) + } else { + b = append(b, '(') + } + } + + for j, key := range m.keys { + if j > 0 { + b = append(b, ", "...) + } + b = fmter.Dialect().Append(fmter, b, el[key]) + } + } + + b = append(b, ')') + + return b, nil +} + +func (m *mapSliceModel) initKeys() error { + if m.keys != nil { + return nil + } + + slice := *m.dest + if len(slice) == 0 { + return errors.New("bun: map slice is empty") + } + + first := slice[0] + keys := make([]string, 0, len(first)) + + for k := range first { + keys = append(keys, k) + } + + sort.Strings(keys) + m.keys = keys + + return nil +} diff --git a/vendor/github.com/uptrace/bun/model_scan.go b/vendor/github.com/uptrace/bun/model_scan.go new file mode 100644 index 000000000..6dd061fb2 --- /dev/null +++ b/vendor/github.com/uptrace/bun/model_scan.go @@ -0,0 +1,54 @@ +package bun + +import ( + "context" + "database/sql" + "reflect" +) + +type scanModel struct { + db *DB + + dest []interface{} + scanIndex int +} + +var _ model = (*scanModel)(nil) + +func newScanModel(db *DB, dest []interface{}) *scanModel { + return &scanModel{ + db: db, + dest: dest, + } +} + +func (m *scanModel) Value() interface{} { + return m.dest +} + +func (m *scanModel) ScanRows(ctx context.Context, rows *sql.Rows) (int, error) { + if !rows.Next() { + return 0, rows.Err() + } + + dest := makeDest(m, len(m.dest)) + + m.scanIndex = 0 + if err := rows.Scan(dest...); err != nil { + return 0, err + } + + return 1, nil +} + +func (m *scanModel) ScanRow(ctx context.Context, rows *sql.Rows) error { + return rows.Scan(m.dest...) +} + +func (m *scanModel) Scan(src interface{}) error { + dest := reflect.ValueOf(m.dest[m.scanIndex]) + m.scanIndex++ + + scanner := m.db.dialect.Scanner(dest.Type()) + return scanner(dest, src) +} diff --git a/vendor/github.com/uptrace/bun/model_slice.go b/vendor/github.com/uptrace/bun/model_slice.go new file mode 100644 index 000000000..afe804382 --- /dev/null +++ b/vendor/github.com/uptrace/bun/model_slice.go @@ -0,0 +1,82 @@ +package bun + +import ( + "context" + "database/sql" + "reflect" + + "github.com/uptrace/bun/internal" + "github.com/uptrace/bun/schema" +) + +type sliceInfo struct { + nextElem func() reflect.Value + scan schema.ScannerFunc +} + +type sliceModel struct { + dest []interface{} + values []reflect.Value + scanIndex int + info []sliceInfo +} + +var _ model = (*sliceModel)(nil) + +func newSliceModel(db *DB, dest []interface{}, values []reflect.Value) *sliceModel { + return &sliceModel{ + dest: dest, + values: values, + } +} + +func (m *sliceModel) Value() interface{} { + return m.dest +} + +func (m *sliceModel) ScanRows(ctx context.Context, rows *sql.Rows) (int, error) { + columns, err := rows.Columns() + if err != nil { + return 0, err + } + + m.info = make([]sliceInfo, len(m.values)) + for i, v := range m.values { + if v.IsValid() && v.Len() > 0 { + v.Set(v.Slice(0, 0)) + } + + m.info[i] = sliceInfo{ + nextElem: internal.MakeSliceNextElemFunc(v), + scan: schema.Scanner(v.Type().Elem()), + } + } + + if len(columns) == 0 { + return 0, nil + } + dest := makeDest(m, len(columns)) + + var n int + + for rows.Next() { + m.scanIndex = 0 + if err := rows.Scan(dest...); err != nil { + return 0, err + } + n++ + } + if err := rows.Err(); err != nil { + return 0, err + } + + return n, nil +} + +func (m *sliceModel) Scan(src interface{}) error { + info := m.info[m.scanIndex] + m.scanIndex++ + + dest := info.nextElem() + return info.scan(dest, src) +} diff --git a/vendor/github.com/uptrace/bun/model_table_has_many.go b/vendor/github.com/uptrace/bun/model_table_has_many.go new file mode 100644 index 000000000..e64b7088d --- /dev/null +++ b/vendor/github.com/uptrace/bun/model_table_has_many.go @@ -0,0 +1,149 @@ +package bun + +import ( + "context" + "database/sql" + "fmt" + "reflect" + + "github.com/uptrace/bun/internal" + "github.com/uptrace/bun/schema" +) + +type hasManyModel struct { + *sliceTableModel + baseTable *schema.Table + rel *schema.Relation + + baseValues map[internal.MapKey][]reflect.Value + structKey []interface{} +} + +var _ tableModel = (*hasManyModel)(nil) + +func newHasManyModel(j *join) *hasManyModel { + baseTable := j.BaseModel.Table() + joinModel := j.JoinModel.(*sliceTableModel) + baseValues := baseValues(joinModel, j.Relation.BaseFields) + if len(baseValues) == 0 { + return nil + } + m := hasManyModel{ + sliceTableModel: joinModel, + baseTable: baseTable, + rel: j.Relation, + + baseValues: baseValues, + } + if !m.sliceOfPtr { + m.strct = reflect.New(m.table.Type).Elem() + } + return &m +} + +func (m *hasManyModel) ScanRows(ctx context.Context, rows *sql.Rows) (int, error) { + columns, err := rows.Columns() + if err != nil { + return 0, err + } + + m.columns = columns + dest := makeDest(m, len(columns)) + + var n int + + for rows.Next() { + if m.sliceOfPtr { + m.strct = reflect.New(m.table.Type).Elem() + } else { + m.strct.Set(m.table.ZeroValue) + } + m.structInited = false + + m.scanIndex = 0 + m.structKey = m.structKey[:0] + if err := rows.Scan(dest...); err != nil { + return 0, err + } + + if err := m.parkStruct(); err != nil { + return 0, err + } + + n++ + } + if err := rows.Err(); err != nil { + return 0, err + } + + return n, nil +} + +func (m *hasManyModel) Scan(src interface{}) error { + column := m.columns[m.scanIndex] + m.scanIndex++ + + field, err := m.table.Field(column) + if err != nil { + return err + } + + if err := field.ScanValue(m.strct, src); err != nil { + return err + } + + for _, f := range m.rel.JoinFields { + if f.Name == field.Name { + m.structKey = append(m.structKey, field.Value(m.strct).Interface()) + break + } + } + + return nil +} + +func (m *hasManyModel) parkStruct() error { + baseValues, ok := m.baseValues[internal.NewMapKey(m.structKey)] + if !ok { + return fmt.Errorf( + "bun: has-many relation=%s does not have base %s with id=%q (check join conditions)", + m.rel.Field.GoName, m.baseTable, m.structKey) + } + + for i, v := range baseValues { + if !m.sliceOfPtr { + v.Set(reflect.Append(v, m.strct)) + continue + } + + if i == 0 { + v.Set(reflect.Append(v, m.strct.Addr())) + continue + } + + clone := reflect.New(m.strct.Type()).Elem() + clone.Set(m.strct) + v.Set(reflect.Append(v, clone.Addr())) + } + + return nil +} + +func baseValues(model tableModel, fields []*schema.Field) map[internal.MapKey][]reflect.Value { + fieldIndex := model.Relation().Field.Index + m := make(map[internal.MapKey][]reflect.Value) + key := make([]interface{}, 0, len(fields)) + walk(model.Root(), model.ParentIndex(), func(v reflect.Value) { + key = modelKey(key[:0], v, fields) + mapKey := internal.NewMapKey(key) + m[mapKey] = append(m[mapKey], v.FieldByIndex(fieldIndex)) + }) + return m +} + +func modelKey(key []interface{}, strct reflect.Value, fields []*schema.Field) []interface{} { + for _, f := range fields { + key = append(key, f.Value(strct).Interface()) + } + return key +} diff --git a/vendor/github.com/uptrace/bun/model_table_m2m.go b/vendor/github.com/uptrace/bun/model_table_m2m.go new file mode 100644 index 000000000..4357e3a8e --- /dev/null +++ b/vendor/github.com/uptrace/bun/model_table_m2m.go @@ -0,0 +1,138 @@ +package bun + +import ( + "context" + "database/sql" + "fmt" + "reflect" + + "github.com/uptrace/bun/internal" + "github.com/uptrace/bun/schema" +) + +type m2mModel struct { + *sliceTableModel + baseTable *schema.Table + rel *schema.Relation + + baseValues map[internal.MapKey][]reflect.Value + structKey []interface{} +} + +var _ tableModel = (*m2mModel)(nil) + +func newM2MModel(j *join) *m2mModel { + baseTable := j.BaseModel.Table() + joinModel := j.JoinModel.(*sliceTableModel) + baseValues := baseValues(joinModel, baseTable.PKs) + if len(baseValues) == 0 { + return nil + } + m := &m2mModel{ + sliceTableModel: joinModel, + baseTable: baseTable, + rel: j.Relation, + + baseValues: baseValues, + } + if !m.sliceOfPtr { + m.strct = reflect.New(m.table.Type).Elem() + } + return m +} + +func (m *m2mModel) ScanRows(ctx context.Context, rows *sql.Rows) (int, error) { + columns, err := rows.Columns() + if err != nil { + return 0, err + } + + m.columns = columns + dest := makeDest(m, len(columns)) + + var n int + + for rows.Next() { + if m.sliceOfPtr { + m.strct = reflect.New(m.table.Type).Elem() + } else { + m.strct.Set(m.table.ZeroValue) + } + m.structInited = false + + m.scanIndex = 0 + m.structKey = m.structKey[:0] + if err := rows.Scan(dest...); err != nil { + return 0, err + } + + if err := m.parkStruct(); err != nil { + return 0, err + } + + n++ + } + if err := rows.Err(); err != nil { + return 0, err + } + + return n, nil +} + +func (m *m2mModel) Scan(src interface{}) error { + column := m.columns[m.scanIndex] + m.scanIndex++ + + field, ok := m.table.FieldMap[column] + if !ok { + return m.scanM2MColumn(column, src) + } + + if err := field.ScanValue(m.strct, src); err != nil { + return err + } + + for _, fk := range m.rel.M2MBaseFields { + if fk.Name == field.Name { + m.structKey = append(m.structKey, field.Value(m.strct).Interface()) + break + } + } + + return nil +} + +func (m *m2mModel) scanM2MColumn(column string, src interface{}) error { + for _, field := range m.rel.M2MBaseFields { + if field.Name == column { + dest := reflect.New(field.IndirectType).Elem() + if err := field.Scan(dest, src); err != nil { + return err + } + m.structKey = append(m.structKey, dest.Interface()) + break + } + } + + _, err := m.scanColumn(column, src) + return err +} + +func (m *m2mModel) parkStruct() error { + baseValues, ok := m.baseValues[internal.NewMapKey(m.structKey)] + if !ok { + return fmt.Errorf( + "bun: m2m relation=%s does not have base %s with key=%q (check join conditions)", + m.rel.Field.GoName, m.baseTable, m.structKey) + } + + for _, v := range baseValues { + if m.sliceOfPtr { + v.Set(reflect.Append(v, m.strct.Addr())) + } else { + v.Set(reflect.Append(v, m.strct)) + } + } + + return nil +} diff --git a/vendor/github.com/uptrace/bun/model_table_slice.go b/vendor/github.com/uptrace/bun/model_table_slice.go new file mode 100644 index 000000000..67e7c71e7 --- /dev/null +++ b/vendor/github.com/uptrace/bun/model_table_slice.go @@ -0,0 +1,113 @@ +package bun + +import ( + "context" + "database/sql" + "reflect" + + "github.com/uptrace/bun/schema" +) + +type sliceTableModel struct { + structTableModel + + slice reflect.Value + sliceLen int + sliceOfPtr bool + nextElem func() reflect.Value +} + +var _ tableModel = (*sliceTableModel)(nil) + +func newSliceTableModel( + db *DB, dest interface{}, slice reflect.Value, elemType reflect.Type, +) *sliceTableModel { + m := &sliceTableModel{ + structTableModel: structTableModel{ + db: db, + table: db.Table(elemType), + dest: dest, + root: slice, + }, + + slice: slice, + sliceLen: slice.Len(), + nextElem: makeSliceNextElemFunc(slice), + } + m.init(slice.Type()) + return m +} + +func (m *sliceTableModel) init(sliceType reflect.Type) { + switch sliceType.Elem().Kind() { + case reflect.Ptr, reflect.Interface: + m.sliceOfPtr = true + } +} + +func (m *sliceTableModel) Join(name string, apply func(*SelectQuery) *SelectQuery) *join { + return m.join(m.slice, name, apply) +} + +func (m *sliceTableModel) Bind(bind reflect.Value) { + m.slice = bind.Field(m.index[len(m.index)-1]) +} + +func (m *sliceTableModel) SetCap(cap int) { + if cap > 100 { + cap = 100 + } + if m.slice.Cap() < cap { + m.slice.Set(reflect.MakeSlice(m.slice.Type(), 0, cap)) + } +} + +func (m *sliceTableModel) ScanRows(ctx context.Context, rows *sql.Rows) (int, error) { + columns, err := rows.Columns() + if err != nil { + return 0, err + } + + m.columns = columns + dest := makeDest(m, len(columns)) + + if m.slice.IsValid() && m.slice.Len() > 0 { + m.slice.Set(m.slice.Slice(0, 0)) + } + + var n int + + for rows.Next() { + m.strct = m.nextElem() + m.structInited = false + + if err := m.scanRow(ctx, rows, dest); err != nil { + return 0, err + } + + n++ + } + if err := rows.Err(); err != nil { + return 0, err + } + + return n, nil +} + +// Inherit these hooks from structTableModel. +var ( + _ schema.BeforeScanHook = (*sliceTableModel)(nil) + _ schema.AfterScanHook = (*sliceTableModel)(nil) +) + +func (m *sliceTableModel) updateSoftDeleteField() error { + sliceLen := m.slice.Len() + for i := 0; i < sliceLen; i++ { + strct := indirect(m.slice.Index(i)) + fv := m.table.SoftDeleteField.Value(strct) + if err := m.table.UpdateSoftDeleteField(fv); err != nil { + return err + } + } + return nil +} diff --git a/vendor/github.com/uptrace/bun/model_table_struct.go b/vendor/github.com/uptrace/bun/model_table_struct.go new file mode 100644 index 000000000..3bb0c14dd --- /dev/null +++ b/vendor/github.com/uptrace/bun/model_table_struct.go @@ -0,0 +1,345 @@ +package bun + +import ( + "context" + "database/sql" + "fmt" + "reflect" + "strings" + + "github.com/uptrace/bun/dialect" + "github.com/uptrace/bun/schema" +) + +type structTableModel struct { + db *DB + table *schema.Table + + rel *schema.Relation + joins []join + + dest interface{} + root reflect.Value + index []int + + strct reflect.Value + structInited bool + structInitErr error + + columns []string + scanIndex int +} + +var _ tableModel = (*structTableModel)(nil) + +func newStructTableModel(db *DB, dest interface{}, table *schema.Table) *structTableModel { + return &structTableModel{ + db: db, + table: table, + dest: dest, + } +} + +func newStructTableModelValue(db *DB, dest interface{}, v reflect.Value) *structTableModel { + return &structTableModel{ + db: db, + table: db.Table(v.Type()), + dest: dest, + root: v, + strct: v, + } +} + +func (m *structTableModel) Value() interface{} { + return m.dest +} + +func (m *structTableModel) Table() *schema.Table { + return m.table +} + +func (m *structTableModel) Relation() *schema.Relation { + return m.rel +} + +func (m *structTableModel) Root() reflect.Value { + return m.root +} + +func (m *structTableModel) Index() []int { + return m.index +} + +func (m *structTableModel) ParentIndex() []int { + return m.index[:len(m.index)-len(m.rel.Field.Index)] +} + +func (m *structTableModel) Mount(host reflect.Value) { + m.strct = host.FieldByIndex(m.rel.Field.Index) + m.structInited = false +} + +func (m *structTableModel) initStruct() error { + if m.structInited { + return m.structInitErr + } + m.structInited = true + + switch m.strct.Kind() { + case reflect.Invalid: + m.structInitErr = errNilModel + return m.structInitErr + case reflect.Interface: + m.strct = m.strct.Elem() + } + + if m.strct.Kind() == reflect.Ptr { + if m.strct.IsNil() { + m.strct.Set(reflect.New(m.strct.Type().Elem())) + m.strct = m.strct.Elem() + } else { + m.strct = m.strct.Elem() + } + } + + m.mountJoins() + + return nil +} + +func (m *structTableModel) mountJoins() { + for i := range m.joins { + j := &m.joins[i] + switch j.Relation.Type { + case schema.HasOneRelation, schema.BelongsToRelation: + j.JoinModel.Mount(m.strct) + } + } +} + +var _ schema.BeforeScanHook = (*structTableModel)(nil) + +func (m *structTableModel) BeforeScan(ctx context.Context) error { + if !m.table.HasBeforeScanHook() { + return nil + } + return callBeforeScanHook(ctx, m.strct.Addr()) +} + +var _ schema.AfterScanHook = (*structTableModel)(nil) + +func (m *structTableModel) AfterScan(ctx context.Context) error { + if !m.table.HasAfterScanHook() || !m.structInited { + return nil + } + + var firstErr error + + if err := callAfterScanHook(ctx, m.strct.Addr()); err != nil && firstErr == nil { + firstErr = err + } + + for _, j := range m.joins { + switch j.Relation.Type { + case schema.HasOneRelation, schema.BelongsToRelation: + if err := j.JoinModel.AfterScan(ctx); err != nil && firstErr == nil { + firstErr = err + } + } + } + + return firstErr +} + +func (m *structTableModel) GetJoin(name string) *join { + for i := range m.joins { + j := &m.joins[i] + if j.Relation.Field.Name == name || j.Relation.Field.GoName == name { + return j + } + } + return nil +} + +func (m *structTableModel) GetJoins() []join { + return m.joins +} + +func (m *structTableModel) AddJoin(j join) *join { + m.joins = append(m.joins, j) + return &m.joins[len(m.joins)-1] +} + +func (m *structTableModel) Join(name string, apply func(*SelectQuery) *SelectQuery) *join { + return m.join(m.strct, name, apply) +} + +func (m *structTableModel) join( + bind reflect.Value, name string, apply func(*SelectQuery) *SelectQuery, +) *join { + path := strings.Split(name, ".") + index := make([]int, 0, len(path)) + + currJoin := join{ + BaseModel: m, + JoinModel: m, + } + var lastJoin *join + + for _, name := range path { + relation, ok := currJoin.JoinModel.Table().Relations[name] + if !ok { + return nil + } + + currJoin.Relation = relation + index = append(index, relation.Field.Index...) + + if j := currJoin.JoinModel.GetJoin(name); j != nil { + currJoin.BaseModel = j.BaseModel + currJoin.JoinModel = j.JoinModel + + lastJoin = j + } else { + model, err := newTableModelIndex(m.db, m.table, bind, index, relation) + if err != nil { + return nil + } + + currJoin.Parent = lastJoin + currJoin.BaseModel = currJoin.JoinModel + currJoin.JoinModel = model + + lastJoin = currJoin.BaseModel.AddJoin(currJoin) + } + } + + // No joins with such name. + if lastJoin == nil { + return nil + } + if apply != nil { + lastJoin.ApplyQueryFunc = apply + } + + return lastJoin +} + +func (m *structTableModel) updateSoftDeleteField() error { + fv := m.table.SoftDeleteField.Value(m.strct) + return m.table.UpdateSoftDeleteField(fv) +} + +func (m *structTableModel) ScanRows(ctx context.Context, rows *sql.Rows) (int, error) { + if !rows.Next() { + return 0, rows.Err() + } + + if err := m.ScanRow(ctx, rows); err != nil { + return 0, err + } + + // For inserts, SQLite3 can return a row like it was inserted sucessfully and then + // an actual error for the next row. See issues/100. + if m.db.dialect.Name() == dialect.SQLite { + _ = rows.Next() + if err := rows.Err(); err != nil { + return 0, err + } + } + + return 1, nil +} + +func (m *structTableModel) ScanRow(ctx context.Context, rows *sql.Rows) error { + columns, err := rows.Columns() + if err != nil { + return err + } + + m.columns = columns + dest := makeDest(m, len(columns)) + + return m.scanRow(ctx, rows, dest) +} + +func (m *structTableModel) scanRow(ctx context.Context, rows *sql.Rows, dest []interface{}) error { + if err := m.BeforeScan(ctx); err != nil { + return err + } + + m.scanIndex = 0 + if err := rows.Scan(dest...); err != nil { + return err + } + + if err := m.AfterScan(ctx); err != nil { + return err + } + + return nil +} + +func (m *structTableModel) Scan(src interface{}) error { + column := m.columns[m.scanIndex] + m.scanIndex++ + + return m.ScanColumn(unquote(column), src) +} + +func (m *structTableModel) ScanColumn(column string, src interface{}) error { + if ok, err := m.scanColumn(column, src); ok { + return err + } + if column == "" || column[0] == '_' || m.db.flags.Has(discardUnknownColumns) { + return nil + } + return fmt.Errorf("bun: %s does not have column %q", m.table.TypeName, column) +} + +func (m *structTableModel) scanColumn(column string, src interface{}) (bool, error) { + if src != nil { + if err := m.initStruct(); err != nil { + return true, err + } + } + + if field, ok := m.table.FieldMap[column]; ok { + return true, field.ScanValue(m.strct, src) + } + + if joinName, column := splitColumn(column); joinName != "" { + if join := m.GetJoin(joinName); join != nil { + return true, join.JoinModel.ScanColumn(column, src) + } + if m.table.ModelName == joinName { + return true, m.ScanColumn(column, src) + } + } + + return false, nil +} + +func (m *structTableModel) AppendNamedArg( + fmter schema.Formatter, b []byte, name string, +) ([]byte, bool) { + return m.table.AppendNamedArg(fmter, b, name, m.strct) +} + +// sqlite3 sometimes does not unquote columns. +func unquote(s string) string { + if s == "" { + return s + } + if s[0] == '"' && s[len(s)-1] == '"' { + return s[1 : len(s)-1] + } + return s +} + +func splitColumn(s string) (string, string) { + if i := strings.Index(s, "__"); i >= 0 { + return s[:i], s[i+2:] + } + return "", s +} diff --git a/vendor/github.com/uptrace/bun/query_base.go b/vendor/github.com/uptrace/bun/query_base.go new file mode 100644 index 000000000..1a7c32720 --- /dev/null +++ b/vendor/github.com/uptrace/bun/query_base.go @@ -0,0 +1,874 @@ +package bun + +import ( + "context" + "database/sql" + "errors" + "fmt" + + "github.com/uptrace/bun/dialect/feature" + "github.com/uptrace/bun/internal" + "github.com/uptrace/bun/schema" +) + +const ( + wherePKFlag internal.Flag = 1 << iota + forceDeleteFlag + deletedFlag + allWithDeletedFlag +) + +type withQuery struct { + name string + query schema.QueryAppender +} + +// IConn is a common interface for *sql.DB, *sql.Conn, and *sql.Tx. +type IConn interface { + QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) + ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) + QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row +} + +var ( + _ IConn = (*sql.DB)(nil) + _ IConn = (*sql.Conn)(nil) + _ IConn = (*sql.Tx)(nil) + _ IConn = (*DB)(nil) + _ IConn = (*Conn)(nil) + _ IConn = (*Tx)(nil) +) + +// IDB is a common interface for *bun.DB, bun.Conn, and bun.Tx. +type IDB interface { + IConn + + NewValues(model interface{}) *ValuesQuery + NewSelect() *SelectQuery + NewInsert() *InsertQuery + NewUpdate() *UpdateQuery + NewDelete() *DeleteQuery + NewCreateTable() *CreateTableQuery + NewDropTable() *DropTableQuery + NewCreateIndex() *CreateIndexQuery + NewDropIndex() *DropIndexQuery + NewTruncateTable() *TruncateTableQuery + NewAddColumn() *AddColumnQuery + NewDropColumn() *DropColumnQuery +} + +var ( + _ IConn = (*DB)(nil) + _ IConn = (*Conn)(nil) + _ IConn = (*Tx)(nil) +) + +type baseQuery struct { + db *DB + conn IConn + + model model + err error + + tableModel tableModel + table *schema.Table + + with []withQuery + modelTable schema.QueryWithArgs + tables []schema.QueryWithArgs + columns []schema.QueryWithArgs + + flags internal.Flag +} + +func (q *baseQuery) DB() *DB { + return q.db +} + +func (q *baseQuery) GetModel() Model { + return q.model +} + +func (q *baseQuery) setConn(db IConn) { + // Unwrap Bun wrappers to not call query hooks twice. + switch db := db.(type) { + case *DB: + q.conn = db.DB + case Conn: + q.conn = db.Conn + case Tx: + q.conn = db.Tx + default: + q.conn = db + } +} + +// TODO: rename to setModel +func (q *baseQuery) setTableModel(modeli interface{}) { + model, err := newSingleModel(q.db, modeli) + if err != nil { + q.setErr(err) + return + } + + q.model = model + if tm, ok := model.(tableModel); ok { + q.tableModel = tm + q.table = tm.Table() + } +} + +func (q *baseQuery) setErr(err error) { + if q.err == nil { + q.err = err + } +} + +func (q *baseQuery) getModel(dest []interface{}) (model, error) { + if len(dest) == 0 { + if q.model != nil { + return q.model, nil + } + return nil, errNilModel + } + return newModel(q.db, dest) +} + +//------------------------------------------------------------------------------ + +func (q *baseQuery) checkSoftDelete() error { + if q.table == nil { + return errors.New("bun: can't use soft deletes without a table") + } + if q.table.SoftDeleteField == nil { + return fmt.Errorf("%s does not have a soft delete field", q.table) + } + if q.tableModel == nil { + return errors.New("bun: can't use soft deletes without a table model") + } + return nil +} + +// Deleted adds `WHERE deleted_at IS NOT NULL` clause for soft deleted models. +func (q *baseQuery) whereDeleted() { + if err := q.checkSoftDelete(); err != nil { + q.setErr(err) + return + } + q.flags = q.flags.Set(deletedFlag) + q.flags = q.flags.Remove(allWithDeletedFlag) +} + +// AllWithDeleted changes query to return all rows including soft deleted ones. +func (q *baseQuery) whereAllWithDeleted() { + if err := q.checkSoftDelete(); err != nil { + q.setErr(err) + return + } + q.flags = q.flags.Set(allWithDeletedFlag) + q.flags = q.flags.Remove(deletedFlag) +} + +func (q *baseQuery) isSoftDelete() bool { + if q.table != nil { + return q.table.SoftDeleteField != nil && !q.flags.Has(allWithDeletedFlag) + } + return false +} + +//------------------------------------------------------------------------------ + +func (q *baseQuery) addWith(name string, query schema.QueryAppender) { + q.with = append(q.with, withQuery{ + name: name, + query: query, + }) +} + +func (q *baseQuery) appendWith(fmter schema.Formatter, b []byte) (_ []byte, err error) { + if len(q.with) == 0 { + return b, nil + } + + b = append(b, "WITH "...) + for i, with := range q.with { + if i > 0 { + b = append(b, ", "...) + } + + b = fmter.AppendIdent(b, with.name) + if q, ok := with.query.(schema.ColumnsAppender); ok { + b = append(b, " ("...) + b, err = q.AppendColumns(fmter, b) + if err != nil { + return nil, err + } + b = append(b, ")"...) + } + + b = append(b, " AS ("...) + + b, err = with.query.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + + b = append(b, ')') + } + b = append(b, ' ') + return b, nil +} + +//------------------------------------------------------------------------------ + +func (q *baseQuery) addTable(table schema.QueryWithArgs) { + q.tables = append(q.tables, table) +} + +func (q *baseQuery) addColumn(column schema.QueryWithArgs) { + q.columns = append(q.columns, column) +} + +func (q *baseQuery) excludeColumn(columns []string) { + if q.columns == nil { + for _, f := range q.table.Fields { + q.columns = append(q.columns, schema.UnsafeIdent(f.Name)) + } + } + + if len(columns) == 1 && columns[0] == "*" { + q.columns = make([]schema.QueryWithArgs, 0) + return + } + + for _, column := range columns { + if !q._excludeColumn(column) { + q.setErr(fmt.Errorf("bun: can't find column=%q", column)) + return + } + } +} + +func (q *baseQuery) _excludeColumn(column string) bool { + for i, col := range q.columns { + if col.Args == nil && col.Query == column { + q.columns = append(q.columns[:i], q.columns[i+1:]...) + return true + } + } + return false +} + +//------------------------------------------------------------------------------ + +func (q *baseQuery) modelHasTableName() bool { + return !q.modelTable.IsZero() || q.table != nil +} + +func (q *baseQuery) hasTables() bool { + return q.modelHasTableName() || len(q.tables) > 0 +} + +func (q *baseQuery) appendTables( + fmter schema.Formatter, b []byte, +) (_ []byte, err error) { + return q._appendTables(fmter, b, false) +} + +func (q *baseQuery) appendTablesWithAlias( + fmter schema.Formatter, b []byte, +) (_ []byte, err error) { + return q._appendTables(fmter, b, true) +} + +func (q *baseQuery) _appendTables( + fmter schema.Formatter, b []byte, withAlias bool, +) (_ []byte, err error) { + startLen := len(b) + + if q.modelHasTableName() { + if !q.modelTable.IsZero() { + b, err = q.modelTable.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } else { + b = fmter.AppendQuery(b, string(q.table.SQLNameForSelects)) + if withAlias && q.table.SQLAlias != q.table.SQLNameForSelects { + b = append(b, " AS "...) + b = append(b, q.table.SQLAlias...) + } + } + } + + for _, table := range q.tables { + if len(b) > startLen { + b = append(b, ", "...) + } + b, err = table.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + + return b, nil +} + +func (q *baseQuery) appendFirstTable(fmter schema.Formatter, b []byte) ([]byte, error) { + return q._appendFirstTable(fmter, b, false) +} + +func (q *baseQuery) appendFirstTableWithAlias( + fmter schema.Formatter, b []byte, +) ([]byte, error) { + return q._appendFirstTable(fmter, b, true) +} + +func (q *baseQuery) _appendFirstTable( + fmter schema.Formatter, b []byte, withAlias bool, +) ([]byte, error) { + if !q.modelTable.IsZero() { + return q.modelTable.AppendQuery(fmter, b) + } + + if q.table != nil { + b = fmter.AppendQuery(b, string(q.table.SQLName)) + if withAlias { + b = append(b, " AS "...) + b = append(b, q.table.SQLAlias...) + } + return b, nil + } + + if len(q.tables) > 0 { + return q.tables[0].AppendQuery(fmter, b) + } + + return nil, errors.New("bun: query does not have a table") +} + +func (q *baseQuery) hasMultiTables() bool { + if q.modelHasTableName() { + return len(q.tables) >= 1 + } + return len(q.tables) >= 2 +} + +func (q *baseQuery) appendOtherTables(fmter schema.Formatter, b []byte) (_ []byte, err error) { + tables := q.tables + if !q.modelHasTableName() { + tables = tables[1:] + } + for i, table := range tables { + if i > 0 { + b = append(b, ", "...) + } + b, err = table.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + return b, nil +} + +//------------------------------------------------------------------------------ + +func (q *baseQuery) appendColumns(fmter schema.Formatter, b []byte) (_ []byte, err error) { + for i, f := range q.columns { + if i > 0 { + b = append(b, ", "...) + } + b, err = f.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + return b, nil +} + +func (q *baseQuery) getFields() ([]*schema.Field, error) { + table := q.tableModel.Table() + + if len(q.columns) == 0 { + return table.Fields, nil + } + + fields, err := q._getFields(false) + if err != nil { + return nil, err + } + + return fields, nil +} + +func (q *baseQuery) getDataFields() ([]*schema.Field, error) { + if len(q.columns) == 0 { + return q.table.DataFields, nil + } + return q._getFields(true) +} + +func (q *baseQuery) _getFields(omitPK bool) ([]*schema.Field, error) { + fields := make([]*schema.Field, 0, len(q.columns)) + for _, col := range q.columns { + if col.Args != nil { + continue + } + + field, err := q.table.Field(col.Query) + if err != nil { + return nil, err + } + + if omitPK && field.IsPK { + continue + } + + fields = append(fields, field) + } + return fields, nil +} + +func (q *baseQuery) scan( + ctx context.Context, + queryApp schema.QueryAppender, + query string, + model model, + hasDest bool, +) (res result, _ error) { + ctx, event := q.db.beforeQuery(ctx, queryApp, query, nil) + + rows, err := q.conn.QueryContext(ctx, query) + if err != nil { + q.db.afterQuery(ctx, event, nil, err) + return res, err + } + defer rows.Close() + + n, err := model.ScanRows(ctx, rows) + if err != nil { + q.db.afterQuery(ctx, event, nil, err) + return res, err + } + + res.n = n + if n == 0 && hasDest && isSingleRowModel(model) { + err = sql.ErrNoRows + } + + q.db.afterQuery(ctx, event, nil, err) + + return res, err +} + +func (q *baseQuery) exec( + ctx context.Context, + queryApp schema.QueryAppender, + query string, +) (res result, _ error) { + ctx, event := q.db.beforeQuery(ctx, queryApp, query, nil) + + r, err := q.conn.ExecContext(ctx, query) + if err != nil { + q.db.afterQuery(ctx, event, nil, err) + return res, err + } + + res.r = r + + q.db.afterQuery(ctx, event, nil, err) + return res, nil +} + +//------------------------------------------------------------------------------ + +func (q *baseQuery) AppendNamedArg(fmter schema.Formatter, b []byte, name string) ([]byte, bool) { + if q.table == nil { + return b, false + } + + if m, ok := q.tableModel.(*structTableModel); ok { + if b, ok := m.AppendNamedArg(fmter, b, name); ok { + return b, ok + } + } + + switch name { + case "TableName": + b = fmter.AppendQuery(b, string(q.table.SQLName)) + return b, true + case "TableAlias": + b = fmter.AppendQuery(b, string(q.table.SQLAlias)) + return b, true + case "PKs": + b = appendColumns(b, "", q.table.PKs) + return b, true + case "TablePKs": + b = appendColumns(b, q.table.SQLAlias, q.table.PKs) + return b, true + case "Columns": + b = appendColumns(b, "", q.table.Fields) + return b, true + case "TableColumns": + b = appendColumns(b, q.table.SQLAlias, q.table.Fields) + return b, true + } + + return b, false +} + +func appendColumns(b []byte, table schema.Safe, fields []*schema.Field) []byte { + for i, f := range fields { + if i > 0 { + b = append(b, ", "...) + } + + if len(table) > 0 { + b = append(b, table...) + b = append(b, '.') + } + b = append(b, f.SQLName...) + } + return b +} + +func formatterWithModel(fmter schema.Formatter, model schema.NamedArgAppender) schema.Formatter { + if fmter.IsNop() { + return fmter + } + return fmter.WithArg(model) +} + +//------------------------------------------------------------------------------ + +type whereBaseQuery struct { + baseQuery + + where []schema.QueryWithSep +} + +func (q *whereBaseQuery) addWhere(where schema.QueryWithSep) { + q.where = append(q.where, where) +} + +func (q *whereBaseQuery) addWhereGroup(sep string, where []schema.QueryWithSep) { + if len(where) == 0 { + return + } + + where[0].Sep = "" + + q.addWhere(schema.SafeQueryWithSep("", nil, sep+"(")) + q.where = append(q.where, where...) + q.addWhere(schema.SafeQueryWithSep("", nil, ")")) +} + +func (q *whereBaseQuery) mustAppendWhere( + fmter schema.Formatter, b []byte, withAlias bool, +) ([]byte, error) { + if len(q.where) == 0 && !q.flags.Has(wherePKFlag) { + err := errors.New("bun: Update and Delete queries require at least one Where") + return nil, err + } + return q.appendWhere(fmter, b, withAlias) +} + +func (q *whereBaseQuery) appendWhere( + fmter schema.Formatter, b []byte, withAlias bool, +) (_ []byte, err error) { + if len(q.where) == 0 && !q.isSoftDelete() && !q.flags.Has(wherePKFlag) { + return b, nil + } + + b = append(b, " WHERE "...) + startLen := len(b) + + if len(q.where) > 0 { + b, err = appendWhere(fmter, b, q.where) + if err != nil { + return nil, err + } + } + + if q.isSoftDelete() { + if len(b) > startLen { + b = append(b, " AND "...) + } + if withAlias { + b = append(b, q.tableModel.Table().SQLAlias...) + b = append(b, '.') + } + b = append(b, q.tableModel.Table().SoftDeleteField.SQLName...) + if q.flags.Has(deletedFlag) { + b = append(b, " IS NOT NULL"...) + } else { + b = append(b, " IS NULL"...) + } + } + + if q.flags.Has(wherePKFlag) { + if len(b) > startLen { + b = append(b, " AND "...) + } + b, err = q.appendWherePK(fmter, b, withAlias) + if err != nil { + return nil, err + } + } + + return b, nil +} + +func appendWhere( + fmter schema.Formatter, b []byte, where []schema.QueryWithSep, +) (_ []byte, err error) { + for i, where := range where { + if i > 0 || where.Sep == "(" { + b = append(b, where.Sep...) + } + + if where.Query == "" && where.Args == nil { + continue + } + + b = append(b, '(') + b, err = where.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + b = append(b, ')') + } + return b, nil +} + +func (q *whereBaseQuery) appendWherePK( + fmter schema.Formatter, b []byte, withAlias bool, +) (_ []byte, err error) { + if q.table == nil { + err := fmt.Errorf("bun: got %T, but WherePK requires a struct or slice-based model", q.model) + return nil, err + } + if err := q.table.CheckPKs(); err != nil { + return nil, err + } + + switch model := q.tableModel.(type) { + case *structTableModel: + return q.appendWherePKStruct(fmter, b, model, withAlias) + case *sliceTableModel: + return q.appendWherePKSlice(fmter, b, model, withAlias) + } + + return nil, fmt.Errorf("bun: WherePK does not support %T", q.tableModel) +} + +func (q *whereBaseQuery) appendWherePKStruct( + fmter schema.Formatter, b []byte, model *structTableModel, withAlias bool, +) (_ []byte, err error) { + if !model.strct.IsValid() { + return nil, errNilModel + } + + isTemplate := fmter.IsNop() + b = append(b, '(') + for i, f := range q.table.PKs { + if i > 0 { + b = append(b, " AND "...) + } + if withAlias { + b = append(b, q.table.SQLAlias...) + b = append(b, '.') + } + b = append(b, f.SQLName...) + b = append(b, " = "...) + if isTemplate { + b = append(b, '?') + } else { + b = f.AppendValue(fmter, b, model.strct) + } + } + b = append(b, ')') + return b, nil +} + +func (q *whereBaseQuery) appendWherePKSlice( + fmter schema.Formatter, b []byte, model *sliceTableModel, withAlias bool, +) (_ []byte, err error) { + if len(q.table.PKs) > 1 { + b = append(b, '(') + } + if withAlias { + b = appendColumns(b, q.table.SQLAlias, q.table.PKs) + } else { + b = appendColumns(b, "", q.table.PKs) + } + if len(q.table.PKs) > 1 { + b = append(b, ')') + } + + b = append(b, " IN ("...) + + isTemplate := fmter.IsNop() + slice := model.slice + sliceLen := slice.Len() + for i := 0; i < sliceLen; i++ { + if i > 0 { + if isTemplate { + break + } + b = append(b, ", "...) + } + + el := indirect(slice.Index(i)) + + if len(q.table.PKs) > 1 { + b = append(b, '(') + } + for i, f := range q.table.PKs { + if i > 0 { + b = append(b, ", "...) + } + if isTemplate { + b = append(b, '?') + } else { + b = f.AppendValue(fmter, b, el) + } + } + if len(q.table.PKs) > 1 { + b = append(b, ')') + } + } + + b = append(b, ')') + + return b, nil +} + +//------------------------------------------------------------------------------ + +type returningQuery struct { + returning []schema.QueryWithArgs + returningFields []*schema.Field +} + +func (q *returningQuery) addReturning(ret schema.QueryWithArgs) { + q.returning = append(q.returning, ret) +} + +func (q *returningQuery) addReturningField(field *schema.Field) { + if len(q.returning) > 0 { + return + } + for _, f := range q.returningFields { + if f == field { + return + } + } + q.returningFields = append(q.returningFields, field) +} + +func (q *returningQuery) hasReturning() bool { + if len(q.returning) == 1 { + switch q.returning[0].Query { + case "null", "NULL": + return false + } + } + return len(q.returning) > 0 || len(q.returningFields) > 0 +} + +func (q *returningQuery) appendReturning( + fmter schema.Formatter, b []byte, +) (_ []byte, err error) { + if !q.hasReturning() { + return b, nil + } + + b = append(b, " RETURNING "...) + + for i, f := range q.returning { + if i > 0 { + b = append(b, ", "...) + } + b, err = f.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + + if len(q.returning) > 0 { + return b, nil + } + + b = appendColumns(b, "", q.returningFields) + return b, nil +} + +//------------------------------------------------------------------------------ + +type columnValue struct { + column string + value schema.QueryWithArgs +} + +type customValueQuery struct { + modelValues map[string]schema.QueryWithArgs + extraValues []columnValue +} + +func (q *customValueQuery) addValue( + table *schema.Table, column string, value string, args []interface{}, +) { + if _, ok := table.FieldMap[column]; ok { + if q.modelValues == nil { + q.modelValues = make(map[string]schema.QueryWithArgs) + } + q.modelValues[column] = schema.SafeQuery(value, args) + } else { + q.extraValues = append(q.extraValues, columnValue{ + column: column, + value: schema.SafeQuery(value, args), + }) + } +} + +//------------------------------------------------------------------------------ + +type setQuery struct { + set []schema.QueryWithArgs +} + +func (q *setQuery) addSet(set schema.QueryWithArgs) { + q.set = append(q.set, set) +} + +func (q setQuery) appendSet(fmter schema.Formatter, b []byte) (_ []byte, err error) { + for i, f := range q.set { + if i > 0 { + b = append(b, ", "...) + } + b, err = f.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + return b, nil +} + +//------------------------------------------------------------------------------ + +type cascadeQuery struct { + restrict bool +} + +func (q cascadeQuery) appendCascade(fmter schema.Formatter, b []byte) []byte { + if !fmter.HasFeature(feature.TableCascade) { + return b + } + if q.restrict { + b = append(b, " RESTRICT"...) + } else { + b = append(b, " CASCADE"...) + } + return b +} diff --git a/vendor/github.com/uptrace/bun/query_column_add.go b/vendor/github.com/uptrace/bun/query_column_add.go new file mode 100644 index 000000000..ce2f60bf0 --- /dev/null +++ b/vendor/github.com/uptrace/bun/query_column_add.go @@ -0,0 +1,105 @@ +package bun + +import ( + "context" + "database/sql" + "fmt" + + "github.com/uptrace/bun/internal" + "github.com/uptrace/bun/schema" +) + +type AddColumnQuery struct { + baseQuery +} + +func NewAddColumnQuery(db *DB) *AddColumnQuery { + q := &AddColumnQuery{ + baseQuery: baseQuery{ + db: db, + conn: db.DB, + }, + } + return q +} + +func (q *AddColumnQuery) Conn(db IConn) *AddColumnQuery { + q.setConn(db) + return q +} + +func (q *AddColumnQuery) Model(model interface{}) *AddColumnQuery { + q.setTableModel(model) + return q +} + +//------------------------------------------------------------------------------ + +func (q *AddColumnQuery) Table(tables ...string) *AddColumnQuery { + for _, table := range tables { + q.addTable(schema.UnsafeIdent(table)) + } + return q +} + +func (q *AddColumnQuery) TableExpr(query string, args ...interface{}) *AddColumnQuery { + q.addTable(schema.SafeQuery(query, args)) + return q +} + +func (q *AddColumnQuery) ModelTableExpr(query string, args ...interface{}) *AddColumnQuery { + q.modelTable = schema.SafeQuery(query, args) + return q +} + +//------------------------------------------------------------------------------ + +func (q *AddColumnQuery) ColumnExpr(query string, args ...interface{}) *AddColumnQuery { + q.addColumn(schema.SafeQuery(query, args)) + return q +} + +//------------------------------------------------------------------------------ + +func (q *AddColumnQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) { + if q.err != nil { + return nil, q.err + } + if len(q.columns) != 1 { + return nil, fmt.Errorf("bun: AddColumnQuery requires exactly one column") + } + + b = append(b, "ALTER TABLE "...) + + b, err = q.appendFirstTable(fmter, b) + if err != nil { + return nil, err + } + + b = append(b, " ADD "...) + + b, err = q.columns[0].AppendQuery(fmter, b) + if err != nil { + return nil, err + } + + return b, nil +} + +//------------------------------------------------------------------------------ + +func (q *AddColumnQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) { + queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes()) + if err != nil { + return nil, err + } + + query := internal.String(queryBytes) + + res, err := q.exec(ctx, q, query) + if err != nil { + return nil, err + } + + return res, nil +} diff --git a/vendor/github.com/uptrace/bun/query_column_drop.go b/vendor/github.com/uptrace/bun/query_column_drop.go new file mode 100644 index 000000000..5684beeb3 --- /dev/null +++ b/vendor/github.com/uptrace/bun/query_column_drop.go @@ -0,0 +1,112 @@ +package bun + +import ( + "context" + "database/sql" + "fmt" + + "github.com/uptrace/bun/internal" + "github.com/uptrace/bun/schema" +) + +type DropColumnQuery struct { + baseQuery +} + +func NewDropColumnQuery(db *DB) *DropColumnQuery { + q := &DropColumnQuery{ + baseQuery: baseQuery{ + db: db, + conn: db.DB, + }, + } + return q +} + +func (q *DropColumnQuery) Conn(db IConn) *DropColumnQuery { + q.setConn(db) + return q +} + +func (q *DropColumnQuery) Model(model interface{}) *DropColumnQuery { + q.setTableModel(model) + return q +} + +//------------------------------------------------------------------------------ + +func (q *DropColumnQuery) Table(tables ...string) *DropColumnQuery { + for _, table := range tables { + q.addTable(schema.UnsafeIdent(table)) + } + return q +} + +func (q *DropColumnQuery) TableExpr(query string, args ...interface{}) *DropColumnQuery { + q.addTable(schema.SafeQuery(query, args)) + return q +} + +func (q *DropColumnQuery) ModelTableExpr(query string, args ...interface{}) *DropColumnQuery { + q.modelTable = schema.SafeQuery(query, args) + return q +} + +//------------------------------------------------------------------------------ + +func (q *DropColumnQuery) Column(columns ...string) *DropColumnQuery { + for _, column := range columns { + q.addColumn(schema.UnsafeIdent(column)) + } + return q +} + +func (q *DropColumnQuery) ColumnExpr(query string, args ...interface{}) *DropColumnQuery { + q.addColumn(schema.SafeQuery(query, args)) + return q +} + +//------------------------------------------------------------------------------ + +func (q *DropColumnQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) { + if q.err != nil { + return nil, q.err + } + if len(q.columns) != 1 { + return nil, fmt.Errorf("bun: DropColumnQuery requires exactly one column") + } + + b = append(b, "ALTER TABLE "...) + + b, err = q.appendFirstTable(fmter, b) + if err != nil { + return nil, err + } + + b = append(b, " DROP COLUMN "...) + + b, err = q.columns[0].AppendQuery(fmter, b) + if err != nil { + return nil, err + } + + return b, nil +} + +//------------------------------------------------------------------------------ + +func (q *DropColumnQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) { + queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes()) + if err != nil { + return nil, err + } + + query := internal.String(queryBytes) + + res, err := q.exec(ctx, q, query) + if err != nil { + return nil, err + } + + return res, nil +} diff --git a/vendor/github.com/uptrace/bun/query_delete.go b/vendor/github.com/uptrace/bun/query_delete.go new file mode 100644 index 000000000..c0c5039c7 --- /dev/null +++ b/vendor/github.com/uptrace/bun/query_delete.go @@ -0,0 +1,256 @@ +package bun + +import ( + "context" + "database/sql" + + "github.com/uptrace/bun/dialect/feature" + "github.com/uptrace/bun/internal" + "github.com/uptrace/bun/schema" +) + +type DeleteQuery struct { + whereBaseQuery + returningQuery +} + +func NewDeleteQuery(db *DB) *DeleteQuery { + q := &DeleteQuery{ + whereBaseQuery: whereBaseQuery{ + baseQuery: baseQuery{ + db: db, + conn: db.DB, + }, + }, + } + return q +} + +func (q *DeleteQuery) Conn(db IConn) *DeleteQuery { + q.setConn(db) + return q +} + +func (q *DeleteQuery) Model(model interface{}) *DeleteQuery { + q.setTableModel(model) + return q +} + +// Apply calls the fn passing the DeleteQuery as an argument. +func (q *DeleteQuery) Apply(fn func(*DeleteQuery) *DeleteQuery) *DeleteQuery { + return fn(q) +} + +func (q *DeleteQuery) With(name string, query schema.QueryAppender) *DeleteQuery { + q.addWith(name, query) + return q +} + +func (q *DeleteQuery) Table(tables ...string) *DeleteQuery { + for _, table := range tables { + q.addTable(schema.UnsafeIdent(table)) + } + return q +} + +func (q *DeleteQuery) TableExpr(query string, args ...interface{}) *DeleteQuery { + q.addTable(schema.SafeQuery(query, args)) + return q +} + +func (q *DeleteQuery) ModelTableExpr(query string, args ...interface{}) *DeleteQuery { + q.modelTable = schema.SafeQuery(query, args) + return q +} + +//------------------------------------------------------------------------------ + +func (q *DeleteQuery) WherePK() *DeleteQuery { + q.flags = q.flags.Set(wherePKFlag) + return q +} + +func (q *DeleteQuery) Where(query string, args ...interface{}) *DeleteQuery { + q.addWhere(schema.SafeQueryWithSep(query, args, " AND ")) + return q +} + +func (q *DeleteQuery) WhereOr(query string, args ...interface{}) *DeleteQuery { + q.addWhere(schema.SafeQueryWithSep(query, args, " OR ")) + return q +} + +func (q *DeleteQuery) WhereGroup(sep string, fn func(*DeleteQuery) *DeleteQuery) *DeleteQuery { + saved := q.where + q.where = nil + + q = fn(q) + + where := q.where + q.where = saved + + q.addWhereGroup(sep, where) + + return q +} + +func (q *DeleteQuery) WhereDeleted() *DeleteQuery { + q.whereDeleted() + return q +} + +func (q *DeleteQuery) WhereAllWithDeleted() *DeleteQuery { + q.whereAllWithDeleted() + return q +} + +func (q *DeleteQuery) ForceDelete() *DeleteQuery { + q.flags = q.flags.Set(forceDeleteFlag) + return q +} + +//------------------------------------------------------------------------------ + +// Returning adds a RETURNING clause to the query. +// +// To suppress the auto-generated RETURNING clause, use `Returning("NULL")`. +func (q *DeleteQuery) Returning(query string, args ...interface{}) *DeleteQuery { + q.addReturning(schema.SafeQuery(query, args)) + return q +} + +func (q *DeleteQuery) hasReturning() bool { + if !q.db.features.Has(feature.Returning) { + return false + } + return q.returningQuery.hasReturning() +} + +//------------------------------------------------------------------------------ + +func (q *DeleteQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) { + if q.err != nil { + return nil, q.err + } + fmter = formatterWithModel(fmter, q) + + if q.isSoftDelete() { + if err := q.tableModel.updateSoftDeleteField(); err != nil { + return nil, err + } + + upd := UpdateQuery{ + whereBaseQuery: q.whereBaseQuery, + returningQuery: q.returningQuery, + } + upd.Column(q.table.SoftDeleteField.Name) + return upd.AppendQuery(fmter, b) + } + + q = q.WhereAllWithDeleted() + withAlias := q.db.features.Has(feature.DeleteTableAlias) + + b, err = q.appendWith(fmter, b) + if err != nil { + return nil, err + } + + b = append(b, "DELETE FROM "...) + + if withAlias { + b, err = q.appendFirstTableWithAlias(fmter, b) + } else { + b, err = q.appendFirstTable(fmter, b) + } + if err != nil { + return nil, err + } + + if q.hasMultiTables() { + b = append(b, " USING "...) + b, err = q.appendOtherTables(fmter, b) + if err != nil { + return nil, err + } + } + + b, err = q.mustAppendWhere(fmter, b, withAlias) + if err != nil { + return nil, err + } + + if len(q.returning) > 0 { + b, err = q.appendReturning(fmter, b) + if err != nil { + return nil, err + } + } + + return b, nil +} + +func (q *DeleteQuery) isSoftDelete() bool { + return q.tableModel != nil && q.table.SoftDeleteField != nil && !q.flags.Has(forceDeleteFlag) +} + +//------------------------------------------------------------------------------ + +func (q *DeleteQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) { + if q.table != nil { + if err := q.beforeDeleteHook(ctx); err != nil { + return nil, err + } + } + + queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes()) + if err != nil { + return nil, err + } + + query := internal.String(queryBytes) + + var res sql.Result + + if hasDest := len(dest) > 0; hasDest || q.hasReturning() { + model, err := q.getModel(dest) + if err != nil { + return nil, err + } + + res, err = q.scan(ctx, q, query, model, hasDest) + if err != nil { + return nil, err + } + } else { + res, err = q.exec(ctx, q, query) + if err != nil { + return nil, err + } + } + + if q.table != nil { + if err := q.afterDeleteHook(ctx); err != nil { + return nil, err + } + } + + return res, nil +} + +func (q *DeleteQuery) beforeDeleteHook(ctx context.Context) error { + if hook, ok := q.table.ZeroIface.(BeforeDeleteHook); ok { + if err := hook.BeforeDelete(ctx, q); err != nil { + return err + } + } + return nil +} + +func (q *DeleteQuery) afterDeleteHook(ctx context.Context) error { + if hook, ok := q.table.ZeroIface.(AfterDeleteHook); ok { + if err := hook.AfterDelete(ctx, q); err != nil { + return err + } + } + return nil +} diff --git a/vendor/github.com/uptrace/bun/query_index_create.go b/vendor/github.com/uptrace/bun/query_index_create.go new file mode 100644 index 000000000..de7eb7aa0 --- /dev/null +++ b/vendor/github.com/uptrace/bun/query_index_create.go @@ -0,0 +1,242 @@ +package bun + +import ( + "context" + "database/sql" + + "github.com/uptrace/bun/internal" + "github.com/uptrace/bun/schema" +) + +type CreateIndexQuery struct { + whereBaseQuery + + unique bool + fulltext bool + spatial bool + concurrently bool + ifNotExists bool + + index schema.QueryWithArgs + using schema.QueryWithArgs + include []schema.QueryWithArgs +} + +func NewCreateIndexQuery(db *DB) *CreateIndexQuery { + q := &CreateIndexQuery{ + whereBaseQuery: whereBaseQuery{ + baseQuery: baseQuery{ + db: db, + conn: db.DB, + }, + }, + } + return q +} + +func (q *CreateIndexQuery) Conn(db IConn) *CreateIndexQuery { + q.setConn(db) + return q +} + +func (q *CreateIndexQuery) Model(model interface{}) *CreateIndexQuery { + q.setTableModel(model) + return q +} + +func (q *CreateIndexQuery) Unique() *CreateIndexQuery { + q.unique = true + return q +} + +func (q *CreateIndexQuery) Concurrently() *CreateIndexQuery { + q.concurrently = true + return q +} + +func (q *CreateIndexQuery) IfNotExists() *CreateIndexQuery { + q.ifNotExists = true + return q +} + +//------------------------------------------------------------------------------ + +func (q *CreateIndexQuery) Index(query string) *CreateIndexQuery { + q.index = schema.UnsafeIdent(query) + return q +} + +func (q *CreateIndexQuery) IndexExpr(query string, args ...interface{}) *CreateIndexQuery { + q.index = schema.SafeQuery(query, args) + return q +} + +//------------------------------------------------------------------------------ + +func (q *CreateIndexQuery) Table(tables ...string) *CreateIndexQuery { + for _, table := range tables { + q.addTable(schema.UnsafeIdent(table)) + } + return q +} + +func (q *CreateIndexQuery) TableExpr(query string, args ...interface{}) *CreateIndexQuery { + q.addTable(schema.SafeQuery(query, args)) + return q +} + +func (q *CreateIndexQuery) ModelTableExpr(query string, args ...interface{}) *CreateIndexQuery { + q.modelTable = schema.SafeQuery(query, args) + return q +} + +func (q *CreateIndexQuery) Using(query string, args ...interface{}) *CreateIndexQuery { + q.using = schema.SafeQuery(query, args) + return q +} + +//------------------------------------------------------------------------------ + +func (q *CreateIndexQuery) Column(columns ...string) *CreateIndexQuery { + for _, column := range columns { + q.addColumn(schema.UnsafeIdent(column)) + } + return q +} + +func (q *CreateIndexQuery) ColumnExpr(query string, args ...interface{}) *CreateIndexQuery { + q.addColumn(schema.SafeQuery(query, args)) + return q +} + +func (q *CreateIndexQuery) ExcludeColumn(columns ...string) *CreateIndexQuery { + q.excludeColumn(columns) + return q +} + +//------------------------------------------------------------------------------ + +func (q *CreateIndexQuery) Include(columns ...string) *CreateIndexQuery { + for _, column := range columns { + q.include = append(q.include, schema.UnsafeIdent(column)) + } + return q +} + +func (q *CreateIndexQuery) IncludeExpr(query string, args ...interface{}) *CreateIndexQuery { + q.include = append(q.include, schema.SafeQuery(query, args)) + return q +} + +//------------------------------------------------------------------------------ + +func (q *CreateIndexQuery) Where(query string, args ...interface{}) *CreateIndexQuery { + q.addWhere(schema.SafeQueryWithSep(query, args, " AND ")) + return q +} + +func (q *CreateIndexQuery) WhereOr(query string, args ...interface{}) *CreateIndexQuery { + q.addWhere(schema.SafeQueryWithSep(query, args, " OR ")) + return q +} + +//------------------------------------------------------------------------------ + +func (q *CreateIndexQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) { + if q.err != nil { + return nil, q.err + } + + b = append(b, "CREATE "...) + + if q.unique { + b = append(b, "UNIQUE "...) + } + if q.fulltext { + b = append(b, "FULLTEXT "...) + } + if q.spatial { + b = append(b, "SPATIAL "...) + } + + b = append(b, "INDEX "...) + + if q.concurrently { + b = append(b, "CONCURRENTLY "...) + } + if q.ifNotExists { + b = append(b, "IF NOT EXISTS "...) + } + + b, err = q.index.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + + b = append(b, " ON "...) + b, err = q.appendFirstTable(fmter, b) + if err != nil { + return nil, err + } + + if !q.using.IsZero() { + b = append(b, " USING "...) + b, err = q.using.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + + b = append(b, " ("...) + for i, col := range q.columns { + if i > 0 { + b = append(b, ", "...) + } + b, err = col.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + b = append(b, ')') + + if len(q.include) > 0 { + b = append(b, " INCLUDE ("...) + for i, col := range q.include { + if i > 0 { + b = append(b, ", "...) + } + b, err = col.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + b = append(b, ')') + } + + if len(q.where) > 0 { + b, err = appendWhere(fmter, b, q.where) + if err != nil { + return nil, err + } + } + + return b, nil +} + +//------------------------------------------------------------------------------ + +func (q *CreateIndexQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) { + queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes()) + if err != nil { + return nil, err + } + + query := internal.String(queryBytes) + + res, err := q.exec(ctx, q, query) + if err != nil { + return nil, err + } + + return res, nil +} diff --git a/vendor/github.com/uptrace/bun/query_index_drop.go b/vendor/github.com/uptrace/bun/query_index_drop.go new file mode 100644 index 000000000..c922ff04f --- /dev/null +++ b/vendor/github.com/uptrace/bun/query_index_drop.go @@ -0,0 +1,105 @@ +package bun + +import ( + "context" + "database/sql" + + "github.com/uptrace/bun/internal" + "github.com/uptrace/bun/schema" +) + +type DropIndexQuery struct { + baseQuery + cascadeQuery + + concurrently bool + ifExists bool + + index schema.QueryWithArgs +} + +func NewDropIndexQuery(db *DB) *DropIndexQuery { + q := &DropIndexQuery{ + baseQuery: baseQuery{ + db: db, + conn: db.DB, + }, + } + return q +} + +func (q *DropIndexQuery) Conn(db IConn) *DropIndexQuery { + q.setConn(db) + return q +} + +func (q *DropIndexQuery) Model(model interface{}) *DropIndexQuery { + q.setTableModel(model) + return q +} + +//------------------------------------------------------------------------------ + +func (q *DropIndexQuery) Concurrently() *DropIndexQuery { + q.concurrently = true + return q +} + +func (q *DropIndexQuery) IfExists() *DropIndexQuery { + q.ifExists = true + return q +} + +func (q *DropIndexQuery) Restrict() *DropIndexQuery { + q.restrict = true + return q +} + +func (q *DropIndexQuery) Index(query string, args ...interface{}) *DropIndexQuery { + q.index = schema.SafeQuery(query, args) + return q +} + +//------------------------------------------------------------------------------ + +func (q *DropIndexQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) { + if q.err != nil { + return nil, q.err + } + + b = append(b, "DROP INDEX "...) + + if q.concurrently { + b = append(b, "CONCURRENTLY "...) + } + if q.ifExists { + b = append(b, "IF EXISTS "...) + } + + b, err = q.index.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + + b = q.appendCascade(fmter, b) + + return b, nil +} + +//------------------------------------------------------------------------------ + +func (q *DropIndexQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) { + queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes()) + if err != nil { + return nil, err + } + + query := internal.String(queryBytes) + + res, err := q.exec(ctx, q, query) + if err != nil { + return nil, err + } + + return res, nil +} diff --git a/vendor/github.com/uptrace/bun/query_insert.go b/vendor/github.com/uptrace/bun/query_insert.go new file mode 100644 index 000000000..efddee407 --- /dev/null +++ b/vendor/github.com/uptrace/bun/query_insert.go @@ -0,0 +1,551 @@ +package bun + +import ( + "context" + "database/sql" + "fmt" + "reflect" + + "github.com/uptrace/bun/dialect/feature" + "github.com/uptrace/bun/internal" + "github.com/uptrace/bun/schema" +) + +type InsertQuery struct { + whereBaseQuery + returningQuery + customValueQuery + + onConflict schema.QueryWithArgs + setQuery + + ignore bool + replace bool +} + +func NewInsertQuery(db *DB) *InsertQuery { + q := &InsertQuery{ + whereBaseQuery: whereBaseQuery{ + baseQuery: baseQuery{ + db: db, + conn: db.DB, + }, + }, + } + return q +} + +func (q *InsertQuery) Conn(db IConn) *InsertQuery { + q.setConn(db) + return q +} + +func (q *InsertQuery) Model(model interface{}) *InsertQuery { + q.setTableModel(model) + return q +} + +// Apply calls the fn passing the SelectQuery as an argument. +func (q *InsertQuery) Apply(fn func(*InsertQuery) *InsertQuery) *InsertQuery { + return fn(q) +} + +func (q *InsertQuery) With(name string, query schema.QueryAppender) *InsertQuery { + q.addWith(name, query) + return q +} + +//------------------------------------------------------------------------------ + +func (q *InsertQuery) Table(tables ...string) *InsertQuery { + for _, table := range tables { + q.addTable(schema.UnsafeIdent(table)) + } + return q +} + +func (q *InsertQuery) TableExpr(query string, args ...interface{}) *InsertQuery { + q.addTable(schema.SafeQuery(query, args)) + return q +} + +func (q *InsertQuery) ModelTableExpr(query string, args ...interface{}) *InsertQuery { + q.modelTable = schema.SafeQuery(query, args) + return q +} + +//------------------------------------------------------------------------------ + +func (q *InsertQuery) Column(columns ...string) *InsertQuery { + for _, column := range columns { + q.addColumn(schema.UnsafeIdent(column)) + } + return q +} + +func (q *InsertQuery) ExcludeColumn(columns ...string) *InsertQuery { + q.excludeColumn(columns) + return q +} + +// Value overwrites model value for the column in INSERT and UPDATE queries. +func (q *InsertQuery) Value(column string, value string, args ...interface{}) *InsertQuery { + if q.table == nil { + q.err = errNilModel + return q + } + q.addValue(q.table, column, value, args) + return q +} + +func (q *InsertQuery) Where(query string, args ...interface{}) *InsertQuery { + q.addWhere(schema.SafeQueryWithSep(query, args, " AND ")) + return q +} + +func (q *InsertQuery) WhereOr(query string, args ...interface{}) *InsertQuery { + q.addWhere(schema.SafeQueryWithSep(query, args, " OR ")) + return q +} + +//------------------------------------------------------------------------------ + +// Returning adds a RETURNING clause to the query. +// +// To suppress the auto-generated RETURNING clause, use `Returning("NULL")`. +func (q *InsertQuery) Returning(query string, args ...interface{}) *InsertQuery { + q.addReturning(schema.SafeQuery(query, args)) + return q +} + +func (q *InsertQuery) hasReturning() bool { + if !q.db.features.Has(feature.Returning) { + return false + } + return q.returningQuery.hasReturning() +} + +//------------------------------------------------------------------------------ + +// Ignore generates an `INSERT IGNORE INTO` query (MySQL). +func (q *InsertQuery) Ignore() *InsertQuery { + q.ignore = true + return q +} + +// Replaces generates a `REPLACE INTO` query (MySQL). +func (q *InsertQuery) Replace() *InsertQuery { + q.replace = true + return q +} + +//------------------------------------------------------------------------------ + +func (q *InsertQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) { + if q.err != nil { + return nil, q.err + } + fmter = formatterWithModel(fmter, q) + + b, err = q.appendWith(fmter, b) + if err != nil { + return nil, err + } + + if q.replace { + b = append(b, "REPLACE "...) + } else { + b = append(b, "INSERT "...) + if q.ignore { + b = append(b, "IGNORE "...) + } + } + b = append(b, "INTO "...) + + if q.db.features.Has(feature.InsertTableAlias) && !q.onConflict.IsZero() { + b, err = q.appendFirstTableWithAlias(fmter, b) + } else { + b, err = q.appendFirstTable(fmter, b) + } + if err != nil { + return nil, err + } + + b, err = q.appendColumnsValues(fmter, b) + if err != nil { + return nil, err + } + + b, err = q.appendOn(fmter, b) + if err != nil { + return nil, err + } + + if q.hasReturning() { + b, err = q.appendReturning(fmter, b) + if err != nil { + return nil, err + } + } + + return b, nil +} + +func (q *InsertQuery) appendColumnsValues( + fmter schema.Formatter, b []byte, +) (_ []byte, err error) { + if q.hasMultiTables() { + if q.columns != nil { + b = append(b, " ("...) + b, err = q.appendColumns(fmter, b) + if err != nil { + return nil, err + } + b = append(b, ")"...) + } + + b = append(b, " SELECT * FROM "...) + b, err = q.appendOtherTables(fmter, b) + if err != nil { + return nil, err + } + + return b, nil + } + + if m, ok := q.model.(*mapModel); ok { + return m.appendColumnsValues(fmter, b), nil + } + if _, ok := q.model.(*mapSliceModel); ok { + return nil, fmt.Errorf("Insert(*[]map[string]interface{}) is not supported") + } + + if q.model == nil { + return nil, errNilModel + } + + fields, err := q.getFields() + if err != nil { + return nil, err + } + + b = append(b, " ("...) + b = q.appendFields(fmter, b, fields) + b = append(b, ") VALUES ("...) + + switch model := q.tableModel.(type) { + case *structTableModel: + b, err = q.appendStructValues(fmter, b, fields, model.strct) + if err != nil { + return nil, err + } + case *sliceTableModel: + b, err = q.appendSliceValues(fmter, b, fields, model.slice) + if err != nil { + return nil, err + } + default: + return nil, fmt.Errorf("bun: Insert does not support %T", q.tableModel) + } + + b = append(b, ')') + + return b, nil +} + +func (q *InsertQuery) appendStructValues( + fmter schema.Formatter, b []byte, fields []*schema.Field, strct reflect.Value, +) (_ []byte, err error) { + isTemplate := fmter.IsNop() + for i, f := range fields { + if i > 0 { + b = append(b, ", "...) + } + + app, ok := q.modelValues[f.Name] + if ok { + b, err = app.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + q.addReturningField(f) + continue + } + + switch { + case isTemplate: + b = append(b, '?') + case f.NullZero && f.HasZeroValue(strct): + if q.db.features.Has(feature.DefaultPlaceholder) { + b = append(b, "DEFAULT"...) + } else if f.SQLDefault != "" { + b = append(b, f.SQLDefault...) + } else { + b = append(b, "NULL"...) + } + q.addReturningField(f) + default: + b = f.AppendValue(fmter, b, strct) + } + } + + for i, v := range q.extraValues { + if i > 0 || len(fields) > 0 { + b = append(b, ", "...) + } + + b, err = v.value.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + + return b, nil +} + +func (q *InsertQuery) appendSliceValues( + fmter schema.Formatter, b []byte, fields []*schema.Field, slice reflect.Value, +) (_ []byte, err error) { + if fmter.IsNop() { + return q.appendStructValues(fmter, b, fields, reflect.Value{}) + } + + sliceLen := slice.Len() + for i := 0; i < sliceLen; i++ { + if i > 0 { + b = append(b, "), ("...) + } + el := indirect(slice.Index(i)) + b, err = q.appendStructValues(fmter, b, fields, el) + if err != nil { + return nil, err + } + } + + for i, v := range q.extraValues { + if i > 0 || len(fields) > 0 { + b = append(b, ", "...) + } + + b, err = v.value.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + + return b, nil +} + +func (q *InsertQuery) getFields() ([]*schema.Field, error) { + if q.db.features.Has(feature.DefaultPlaceholder) || len(q.columns) > 0 { + return q.baseQuery.getFields() + } + + var strct reflect.Value + + switch model := q.tableModel.(type) { + case *structTableModel: + strct = model.strct + case *sliceTableModel: + if model.sliceLen == 0 { + return nil, fmt.Errorf("bun: Insert(empty %T)", model.slice.Type()) + } + strct = indirect(model.slice.Index(0)) + } + + fields := make([]*schema.Field, 0, len(q.table.Fields)) + + for _, f := range q.table.Fields { + if f.NotNull && f.NullZero && f.SQLDefault == "" && f.HasZeroValue(strct) { + q.addReturningField(f) + continue + } + fields = append(fields, f) + } + + return fields, nil +} + +func (q *InsertQuery) appendFields( + fmter schema.Formatter, b []byte, fields []*schema.Field, +) []byte { + b = appendColumns(b, "", fields) + for i, v := range q.extraValues { + if i > 0 || len(fields) > 0 { + b = append(b, ", "...) + } + b = fmter.AppendIdent(b, v.column) + } + return b +} + +//------------------------------------------------------------------------------ + +func (q *InsertQuery) On(s string, args ...interface{}) *InsertQuery { + q.onConflict = schema.SafeQuery(s, args) + return q +} + +func (q *InsertQuery) Set(query string, args ...interface{}) *InsertQuery { + q.addSet(schema.SafeQuery(query, args)) + return q +} + +func (q *InsertQuery) appendOn(fmter schema.Formatter, b []byte) (_ []byte, err error) { + if q.onConflict.IsZero() { + return b, nil + } + + b = append(b, " ON "...) + b, err = q.onConflict.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + + if len(q.set) > 0 { + if fmter.HasFeature(feature.OnDuplicateKey) { + b = append(b, ' ') + } else { + b = append(b, " SET "...) + } + + b, err = q.appendSet(fmter, b) + if err != nil { + return nil, err + } + } else if len(q.columns) > 0 { + fields, err := q.getDataFields() + if err != nil { + return nil, err + } + + if len(fields) == 0 { + fields = q.tableModel.Table().DataFields + } + + b = q.appendSetExcluded(b, fields) + } + + b, err = q.appendWhere(fmter, b, true) + if err != nil { + return nil, err + } + + return b, nil +} + +func (q *InsertQuery) appendSetExcluded(b []byte, fields []*schema.Field) []byte { + b = append(b, " SET "...) + for i, f := range fields { + if i > 0 { + b = append(b, ", "...) + } + b = append(b, f.SQLName...) + b = append(b, " = EXCLUDED."...) + b = append(b, f.SQLName...) + } + return b +} + +//------------------------------------------------------------------------------ + +func (q *InsertQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) { + if q.table != nil { + if err := q.beforeInsertHook(ctx); err != nil { + return nil, err + } + } + + queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes()) + if err != nil { + return nil, err + } + + query := internal.String(queryBytes) + var res sql.Result + + if hasDest := len(dest) > 0; hasDest || q.hasReturning() { + model, err := q.getModel(dest) + if err != nil { + return nil, err + } + + res, err = q.scan(ctx, q, query, model, hasDest) + if err != nil { + return nil, err + } + } else { + res, err = q.exec(ctx, q, query) + if err != nil { + return nil, err + } + + if err := q.tryLastInsertID(res, dest); err != nil { + return nil, err + } + } + + if q.table != nil { + if err := q.afterInsertHook(ctx); err != nil { + return nil, err + } + } + + return res, nil +} + +func (q *InsertQuery) beforeInsertHook(ctx context.Context) error { + if hook, ok := q.table.ZeroIface.(BeforeInsertHook); ok { + if err := hook.BeforeInsert(ctx, q); err != nil { + return err + } + } + return nil +} + +func (q *InsertQuery) afterInsertHook(ctx context.Context) error { + if hook, ok := q.table.ZeroIface.(AfterInsertHook); ok { + if err := hook.AfterInsert(ctx, q); err != nil { + return err + } + } + return nil +} + +func (q *InsertQuery) tryLastInsertID(res sql.Result, dest []interface{}) error { + if q.db.features.Has(feature.Returning) || q.table == nil || len(q.table.PKs) != 1 { + return nil + } + + id, err := res.LastInsertId() + if err != nil { + return err + } + if id == 0 { + return nil + } + + model, err := q.getModel(dest) + if err != nil { + return err + } + + pk := q.table.PKs[0] + switch model := model.(type) { + case *structTableModel: + if err := pk.ScanValue(model.strct, id); err != nil { + return err + } + case *sliceTableModel: + sliceLen := model.slice.Len() + for i := 0; i < sliceLen; i++ { + strct := indirect(model.slice.Index(i)) + if err := pk.ScanValue(strct, id); err != nil { + return err + } + id++ + } + } + + return nil +} diff --git a/vendor/github.com/uptrace/bun/query_select.go b/vendor/github.com/uptrace/bun/query_select.go new file mode 100644 index 000000000..1f63686ad --- /dev/null +++ b/vendor/github.com/uptrace/bun/query_select.go @@ -0,0 +1,830 @@ +package bun + +import ( + "bytes" + "context" + "database/sql" + "errors" + "fmt" + "strconv" + "strings" + "sync" + + "github.com/uptrace/bun/dialect" + "github.com/uptrace/bun/internal" + "github.com/uptrace/bun/schema" +) + +type union struct { + expr string + query *SelectQuery +} + +type SelectQuery struct { + whereBaseQuery + + distinctOn []schema.QueryWithArgs + joins []joinQuery + group []schema.QueryWithArgs + having []schema.QueryWithArgs + order []schema.QueryWithArgs + limit int32 + offset int32 + selFor schema.QueryWithArgs + + union []union +} + +func NewSelectQuery(db *DB) *SelectQuery { + return &SelectQuery{ + whereBaseQuery: whereBaseQuery{ + baseQuery: baseQuery{ + db: db, + conn: db.DB, + }, + }, + } +} + +func (q *SelectQuery) Conn(db IConn) *SelectQuery { + q.setConn(db) + return q +} + +func (q *SelectQuery) Model(model interface{}) *SelectQuery { + q.setTableModel(model) + return q +} + +// Apply calls the fn passing the SelectQuery as an argument. +func (q *SelectQuery) Apply(fn func(*SelectQuery) *SelectQuery) *SelectQuery { + return fn(q) +} + +func (q *SelectQuery) With(name string, query schema.QueryAppender) *SelectQuery { + q.addWith(name, query) + return q +} + +func (q *SelectQuery) Distinct() *SelectQuery { + q.distinctOn = make([]schema.QueryWithArgs, 0) + return q +} + +func (q *SelectQuery) DistinctOn(query string, args ...interface{}) *SelectQuery { + q.distinctOn = append(q.distinctOn, schema.SafeQuery(query, args)) + return q +} + +//------------------------------------------------------------------------------ + +func (q *SelectQuery) Table(tables ...string) *SelectQuery { + for _, table := range tables { + q.addTable(schema.UnsafeIdent(table)) + } + return q +} + +func (q *SelectQuery) TableExpr(query string, args ...interface{}) *SelectQuery { + q.addTable(schema.SafeQuery(query, args)) + return q +} + +func (q *SelectQuery) ModelTableExpr(query string, args ...interface{}) *SelectQuery { + q.modelTable = schema.SafeQuery(query, args) + return q +} + +//------------------------------------------------------------------------------ + +func (q *SelectQuery) Column(columns ...string) *SelectQuery { + for _, column := range columns { + q.addColumn(schema.UnsafeIdent(column)) + } + return q +} + +func (q *SelectQuery) ColumnExpr(query string, args ...interface{}) *SelectQuery { + q.addColumn(schema.SafeQuery(query, args)) + return q +} + +func (q *SelectQuery) ExcludeColumn(columns ...string) *SelectQuery { + q.excludeColumn(columns) + return q +} + +//------------------------------------------------------------------------------ + +func (q *SelectQuery) WherePK() *SelectQuery { + q.flags = q.flags.Set(wherePKFlag) + return q +} + +func (q *SelectQuery) Where(query string, args ...interface{}) *SelectQuery { + q.addWhere(schema.SafeQueryWithSep(query, args, " AND ")) + return q +} + +func (q *SelectQuery) WhereOr(query string, args ...interface{}) *SelectQuery { + q.addWhere(schema.SafeQueryWithSep(query, args, " OR ")) + return q +} + +func (q *SelectQuery) WhereGroup(sep string, fn func(*SelectQuery) *SelectQuery) *SelectQuery { + saved := q.where + q.where = nil + + q = fn(q) + + where := q.where + q.where = saved + + q.addWhereGroup(sep, where) + + return q +} + +func (q *SelectQuery) WhereDeleted() *SelectQuery { + q.whereDeleted() + return q +} + +func (q *SelectQuery) WhereAllWithDeleted() *SelectQuery { + q.whereAllWithDeleted() + return q +} + +//------------------------------------------------------------------------------ + +func (q *SelectQuery) Group(columns ...string) *SelectQuery { + for _, column := range columns { + q.group = append(q.group, schema.UnsafeIdent(column)) + } + return q +} + +func (q *SelectQuery) GroupExpr(group string, args ...interface{}) *SelectQuery { + q.group = append(q.group, schema.SafeQuery(group, args)) + return q +} + +func (q *SelectQuery) Having(having string, args ...interface{}) *SelectQuery { + q.having = append(q.having, schema.SafeQuery(having, args)) + return q +} + +func (q *SelectQuery) Order(orders ...string) *SelectQuery { + for _, order := range orders { + if order == "" { + continue + } + + index := strings.IndexByte(order, ' ') + if index == -1 { + q.order = append(q.order, schema.UnsafeIdent(order)) + continue + } + + field := order[:index] + sort := order[index+1:] + + switch strings.ToUpper(sort) { + case "ASC", "DESC", "ASC NULLS FIRST", "DESC NULLS FIRST", + "ASC NULLS LAST", "DESC NULLS LAST": + q.order = append(q.order, schema.SafeQuery("? ?", []interface{}{ + Ident(field), + Safe(sort), + })) + default: + q.order = append(q.order, schema.UnsafeIdent(order)) + } + } + return q +} + +func (q *SelectQuery) OrderExpr(query string, args ...interface{}) *SelectQuery { + q.order = append(q.order, schema.SafeQuery(query, args)) + return q +} + +func (q *SelectQuery) Limit(n int) *SelectQuery { + q.limit = int32(n) + return q +} + +func (q *SelectQuery) Offset(n int) *SelectQuery { + q.offset = int32(n) + return q +} + +func (q *SelectQuery) For(s string, args ...interface{}) *SelectQuery { + q.selFor = schema.SafeQuery(s, args) + return q +} + +//------------------------------------------------------------------------------ + +func (q *SelectQuery) Union(other *SelectQuery) *SelectQuery { + return q.addUnion(" UNION ", other) +} + +func (q *SelectQuery) UnionAll(other *SelectQuery) *SelectQuery { + return q.addUnion(" UNION ALL ", other) +} + +func (q *SelectQuery) Intersect(other *SelectQuery) *SelectQuery { + return q.addUnion(" INTERSECT ", other) +} + +func (q *SelectQuery) IntersectAll(other *SelectQuery) *SelectQuery { + return q.addUnion(" INTERSECT ALL ", other) +} + +func (q *SelectQuery) Except(other *SelectQuery) *SelectQuery { + return q.addUnion(" EXCEPT ", other) +} + +func (q *SelectQuery) ExceptAll(other *SelectQuery) *SelectQuery { + return q.addUnion(" EXCEPT ALL ", other) +} + +func (q *SelectQuery) addUnion(expr string, other *SelectQuery) *SelectQuery { + q.union = append(q.union, union{ + expr: expr, + query: other, + }) + return q +} + +//------------------------------------------------------------------------------ + +func (q *SelectQuery) Join(join string, args ...interface{}) *SelectQuery { + q.joins = append(q.joins, joinQuery{ + join: schema.SafeQuery(join, args), + }) + return q +} + +func (q *SelectQuery) JoinOn(cond string, args ...interface{}) *SelectQuery { + return q.joinOn(cond, args, " AND ") +} + +func (q *SelectQuery) JoinOnOr(cond string, args ...interface{}) *SelectQuery { + return q.joinOn(cond, args, " OR ") +} + +func (q *SelectQuery) joinOn(cond string, args []interface{}, sep string) *SelectQuery { + if len(q.joins) == 0 { + q.err = errors.New("bun: query has no joins") + return q + } + j := &q.joins[len(q.joins)-1] + j.on = append(j.on, schema.SafeQueryWithSep(cond, args, sep)) + return q +} + +//------------------------------------------------------------------------------ + +// Relation adds a relation to the query. Relation name can be: +// - RelationName to select all columns, +// - RelationName.column_name, +// - RelationName._ to join relation without selecting relation columns. +func (q *SelectQuery) Relation(name string, apply ...func(*SelectQuery) *SelectQuery) *SelectQuery { + if q.tableModel == nil { + q.setErr(errNilModel) + return q + } + + var fn func(*SelectQuery) *SelectQuery + + if len(apply) == 1 { + fn = apply[0] + } else if len(apply) > 1 { + panic("only one apply function is supported") + } + + join := q.tableModel.Join(name, fn) + if join == nil { + q.setErr(fmt.Errorf("%s does not have relation=%q", q.table, name)) + return q + } + + return q +} + +func (q *SelectQuery) forEachHasOneJoin(fn func(*join) error) error { + if q.tableModel == nil { + return nil + } + return q._forEachHasOneJoin(fn, q.tableModel.GetJoins()) +} + +func (q *SelectQuery) _forEachHasOneJoin(fn func(*join) error, joins []join) error { + for i := range joins { + j := &joins[i] + switch j.Relation.Type { + case schema.HasOneRelation, schema.BelongsToRelation: + if err := fn(j); err != nil { + return err + } + if err := q._forEachHasOneJoin(fn, j.JoinModel.GetJoins()); err != nil { + return err + } + } + } + return nil +} + +func (q *SelectQuery) selectJoins(ctx context.Context, joins []join) error { + var err error + for i := range joins { + j := &joins[i] + switch j.Relation.Type { + case schema.HasOneRelation, schema.BelongsToRelation: + err = q.selectJoins(ctx, j.JoinModel.GetJoins()) + default: + err = j.Select(ctx, q.db.NewSelect()) + } + if err != nil { + return err + } + } + return nil +} + +//------------------------------------------------------------------------------ + +func (q *SelectQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) { + return q.appendQuery(fmter, b, false) +} + +func (q *SelectQuery) appendQuery( + fmter schema.Formatter, b []byte, count bool, +) (_ []byte, err error) { + if q.err != nil { + return nil, q.err + } + fmter = formatterWithModel(fmter, q) + + cteCount := count && (len(q.group) > 0 || q.distinctOn != nil) + if cteCount { + b = append(b, "WITH _count_wrapper AS ("...) + } + + if len(q.union) > 0 { + b = append(b, '(') + } + + b, err = q.appendWith(fmter, b) + if err != nil { + return nil, err + } + + b = append(b, "SELECT "...) + + if len(q.distinctOn) > 0 { + b = append(b, "DISTINCT ON ("...) + for i, app := range q.distinctOn { + if i > 0 { + b = append(b, ", "...) + } + b, err = app.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + b = append(b, ") "...) + } else if q.distinctOn != nil { + b = append(b, "DISTINCT "...) + } + + if count && !cteCount { + b = append(b, "count(*)"...) + } else { + b, err = q.appendColumns(fmter, b) + if err != nil { + return nil, err + } + } + + if q.hasTables() { + b, err = q.appendTables(fmter, b) + if err != nil { + return nil, err + } + } + + if err := q.forEachHasOneJoin(func(j *join) error { + b = append(b, ' ') + b, err = j.appendHasOneJoin(fmter, b, q) + return err + }); err != nil { + return nil, err + } + + for _, j := range q.joins { + b, err = j.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + + b, err = q.appendWhere(fmter, b, true) + if err != nil { + return nil, err + } + + if len(q.group) > 0 { + b = append(b, " GROUP BY "...) + for i, f := range q.group { + if i > 0 { + b = append(b, ", "...) + } + b, err = f.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + } + + if len(q.having) > 0 { + b = append(b, " HAVING "...) + for i, f := range q.having { + if i > 0 { + b = append(b, " AND "...) + } + b = append(b, '(') + b, err = f.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + b = append(b, ')') + } + } + + if !count { + b, err = q.appendOrder(fmter, b) + if err != nil { + return nil, err + } + + if q.limit != 0 { + b = append(b, " LIMIT "...) + b = strconv.AppendInt(b, int64(q.limit), 10) + } + + if q.offset != 0 { + b = append(b, " OFFSET "...) + b = strconv.AppendInt(b, int64(q.offset), 10) + } + + if !q.selFor.IsZero() { + b = append(b, " FOR "...) + b, err = q.selFor.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + } + + if len(q.union) > 0 { + b = append(b, ')') + + for _, u := range q.union { + b = append(b, u.expr...) + b = append(b, '(') + b, err = u.query.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + b = append(b, ')') + } + } + + if cteCount { + b = append(b, ") SELECT count(*) FROM _count_wrapper"...) + } + + return b, nil +} + +func (q *SelectQuery) appendColumns(fmter schema.Formatter, b []byte) (_ []byte, err error) { + start := len(b) + + switch { + case q.columns != nil: + for i, col := range q.columns { + if i > 0 { + b = append(b, ", "...) + } + + if col.Args == nil { + if field, ok := q.table.FieldMap[col.Query]; ok { + b = append(b, q.table.SQLAlias...) + b = append(b, '.') + b = append(b, field.SQLName...) + continue + } + } + + b, err = col.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + case q.table != nil: + if len(q.table.Fields) > 10 && fmter.IsNop() { + b = append(b, q.table.SQLAlias...) + b = append(b, '.') + b = dialect.AppendString(b, fmt.Sprintf("%d columns", len(q.table.Fields))) + } else { + b = appendColumns(b, q.table.SQLAlias, q.table.Fields) + } + default: + b = append(b, '*') + } + + if err := q.forEachHasOneJoin(func(j *join) error { + if len(b) != start { + b = append(b, ", "...) + start = len(b) + } + + b, err = q.appendHasOneColumns(fmter, b, j) + if err != nil { + return err + } + + return nil + }); err != nil { + return nil, err + } + + b = bytes.TrimSuffix(b, []byte(", ")) + + return b, nil +} + +func (q *SelectQuery) appendHasOneColumns( + fmter schema.Formatter, b []byte, join *join, +) (_ []byte, err error) { + join.applyQuery(q) + + if join.columns != nil { + for i, col := range join.columns { + if i > 0 { + b = append(b, ", "...) + } + + if col.Args == nil { + if field, ok := q.table.FieldMap[col.Query]; ok { + b = join.appendAlias(fmter, b) + b = append(b, '.') + b = append(b, field.SQLName...) + b = append(b, " AS "...) + b = join.appendAliasColumn(fmter, b, field.Name) + continue + } + } + + b, err = col.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + return b, nil + } + + for i, field := range join.JoinModel.Table().Fields { + if i > 0 { + b = append(b, ", "...) + } + b = join.appendAlias(fmter, b) + b = append(b, '.') + b = append(b, field.SQLName...) + b = append(b, " AS "...) + b = join.appendAliasColumn(fmter, b, field.Name) + } + return b, nil +} + +func (q *SelectQuery) appendTables(fmter schema.Formatter, b []byte) (_ []byte, err error) { + b = append(b, " FROM "...) + return q.appendTablesWithAlias(fmter, b) +} + +func (q *SelectQuery) appendOrder(fmter schema.Formatter, b []byte) (_ []byte, err error) { + if len(q.order) > 0 { + b = append(b, " ORDER BY "...) + + for i, f := range q.order { + if i > 0 { + b = append(b, ", "...) + } + b, err = f.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + + return b, nil + } + return b, nil +} + +//------------------------------------------------------------------------------ + +func (q *SelectQuery) Rows(ctx context.Context) (*sql.Rows, error) { + queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes()) + if err != nil { + return nil, err + } + + query := internal.String(queryBytes) + return q.conn.QueryContext(ctx, query) +} + +func (q *SelectQuery) Exec(ctx context.Context) (res sql.Result, err error) { + queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes()) + if err != nil { + return nil, err + } + + query := internal.String(queryBytes) + + res, err = q.exec(ctx, q, query) + if err != nil { + return nil, err + } + + return res, nil +} + +func (q *SelectQuery) Scan(ctx context.Context, dest ...interface{}) error { + model, err := q.getModel(dest) + if err != nil { + return err + } + + if q.limit > 1 { + if model, ok := model.(interface{ SetCap(int) }); ok { + model.SetCap(int(q.limit)) + } + } + + if q.table != nil { + if err := q.beforeSelectHook(ctx); err != nil { + return err + } + } + + queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes()) + if err != nil { + return err + } + + query := internal.String(queryBytes) + + res, err := q.scan(ctx, q, query, model, true) + if err != nil { + return err + } + + if res.n > 0 { + if tableModel, ok := model.(tableModel); ok { + if err := q.selectJoins(ctx, tableModel.GetJoins()); err != nil { + return err + } + } + } + + if q.table != nil { + if err := q.afterSelectHook(ctx); err != nil { + return err + } + } + + return nil +} + +func (q *SelectQuery) beforeSelectHook(ctx context.Context) error { + if hook, ok := q.table.ZeroIface.(BeforeSelectHook); ok { + if err := hook.BeforeSelect(ctx, q); err != nil { + return err + } + } + return nil +} + +func (q *SelectQuery) afterSelectHook(ctx context.Context) error { + if hook, ok := q.table.ZeroIface.(AfterSelectHook); ok { + if err := hook.AfterSelect(ctx, q); err != nil { + return err + } + } + return nil +} + +func (q *SelectQuery) Count(ctx context.Context) (int, error) { + qq := countQuery{q} + + queryBytes, err := qq.appendQuery(q.db.fmter, nil, true) + if err != nil { + return 0, err + } + + query := internal.String(queryBytes) + ctx, event := q.db.beforeQuery(ctx, qq, query, nil) + + var num int + err = q.conn.QueryRowContext(ctx, query).Scan(&num) + + q.db.afterQuery(ctx, event, nil, err) + + return num, err +} + +func (q *SelectQuery) ScanAndCount(ctx context.Context, dest ...interface{}) (int, error) { + var count int + var wg sync.WaitGroup + var mu sync.Mutex + var firstErr error + + if q.limit >= 0 { + wg.Add(1) + go func() { + defer wg.Done() + + if err := q.Scan(ctx, dest...); err != nil { + mu.Lock() + if firstErr == nil { + firstErr = err + } + mu.Unlock() + } + }() + } + + wg.Add(1) + go func() { + defer wg.Done() + + var err error + count, err = q.Count(ctx) + if err != nil { + mu.Lock() + if firstErr == nil { + firstErr = err + } + mu.Unlock() + } + }() + + wg.Wait() + return count, firstErr +} + +//------------------------------------------------------------------------------ + +type joinQuery struct { + join schema.QueryWithArgs + on []schema.QueryWithSep +} + +func (j *joinQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) { + b = append(b, ' ') + + b, err = j.join.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + + if len(j.on) > 0 { + b = append(b, " ON "...) + for i, on := range j.on { + if i > 0 { + b = append(b, on.Sep...) + } + + b = append(b, '(') + b, err = on.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + b = append(b, ')') + } + } + + return b, nil +} + +//------------------------------------------------------------------------------ + +type countQuery struct { + *SelectQuery +} + +func (q countQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) { + return q.appendQuery(fmter, b, true) +} diff --git a/vendor/github.com/uptrace/bun/query_table_create.go b/vendor/github.com/uptrace/bun/query_table_create.go new file mode 100644 index 000000000..0a4b3567c --- /dev/null +++ b/vendor/github.com/uptrace/bun/query_table_create.go @@ -0,0 +1,275 @@ +package bun + +import ( + "context" + "database/sql" + "sort" + "strconv" + + "github.com/uptrace/bun/dialect/feature" + "github.com/uptrace/bun/dialect/sqltype" + "github.com/uptrace/bun/internal" + "github.com/uptrace/bun/schema" +) + +type CreateTableQuery struct { + baseQuery + + temp bool + ifNotExists bool + varchar int + + fks []schema.QueryWithArgs + partitionBy schema.QueryWithArgs + tablespace schema.QueryWithArgs +} + +func NewCreateTableQuery(db *DB) *CreateTableQuery { + q := &CreateTableQuery{ + baseQuery: baseQuery{ + db: db, + conn: db.DB, + }, + } + return q +} + +func (q *CreateTableQuery) Conn(db IConn) *CreateTableQuery { + q.setConn(db) + return q +} + +func (q *CreateTableQuery) Model(model interface{}) *CreateTableQuery { + q.setTableModel(model) + return q +} + +//------------------------------------------------------------------------------ + +func (q *CreateTableQuery) Table(tables ...string) *CreateTableQuery { + for _, table := range tables { + q.addTable(schema.UnsafeIdent(table)) + } + return q +} + +func (q *CreateTableQuery) TableExpr(query string, args ...interface{}) *CreateTableQuery { + q.addTable(schema.SafeQuery(query, args)) + return q +} + +func (q *CreateTableQuery) ModelTableExpr(query string, args ...interface{}) *CreateTableQuery { + q.modelTable = schema.SafeQuery(query, args) + return q +} + +//------------------------------------------------------------------------------ + +func (q *CreateTableQuery) Temp() *CreateTableQuery { + q.temp = true + return q +} + +func (q *CreateTableQuery) IfNotExists() *CreateTableQuery { + q.ifNotExists = true + return q +} + +func (q *CreateTableQuery) Varchar(n int) *CreateTableQuery { + q.varchar = n + return q +} + +func (q *CreateTableQuery) ForeignKey(query string, args ...interface{}) *CreateTableQuery { + q.fks = append(q.fks, schema.SafeQuery(query, args)) + return q +} + +func (q *CreateTableQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) { + if q.err != nil { + return nil, q.err + } + if q.table == nil { + return nil, errNilModel + } + + b = append(b, "CREATE "...) + if q.temp { + b = append(b, "TEMP "...) + } + b = append(b, "TABLE "...) + if q.ifNotExists { + b = append(b, "IF NOT EXISTS "...) + } + b, err = q.appendFirstTable(fmter, b) + if err != nil { + return nil, err + } + + b = append(b, " ("...) + + for i, field := range q.table.Fields { + if i > 0 { + b = append(b, ", "...) + } + + b = append(b, field.SQLName...) + b = append(b, " "...) + b = q.appendSQLType(b, field) + if field.NotNull { + b = append(b, " NOT NULL"...) + } + if q.db.features.Has(feature.AutoIncrement) && field.AutoIncrement { + b = append(b, " AUTO_INCREMENT"...) + } + if field.SQLDefault != "" { + b = append(b, " DEFAULT "...) + b = append(b, field.SQLDefault...) + } + } + + b = q.appendPKConstraint(b, q.table.PKs) + b = q.appendUniqueConstraints(fmter, b) + b, err = q.appenFKConstraints(fmter, b) + if err != nil { + return nil, err + } + + b = append(b, ")"...) + + if !q.partitionBy.IsZero() { + b = append(b, " PARTITION BY "...) + b, err = q.partitionBy.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + + if !q.tablespace.IsZero() { + b = append(b, " TABLESPACE "...) + b, err = q.tablespace.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + + return b, nil +} + +func (q *CreateTableQuery) appendSQLType(b []byte, field *schema.Field) []byte { + if field.CreateTableSQLType != field.DiscoveredSQLType { + return append(b, field.CreateTableSQLType...) + } + + if q.varchar > 0 && + field.CreateTableSQLType == sqltype.VarChar { + b = append(b, "varchar("...) + b = strconv.AppendInt(b, int64(q.varchar), 10) + b = append(b, ")"...) + return b + } + + return append(b, field.CreateTableSQLType...) +} + +func (q *CreateTableQuery) appendUniqueConstraints(fmter schema.Formatter, b []byte) []byte { + unique := q.table.Unique + + keys := make([]string, 0, len(unique)) + for key := range unique { + keys = append(keys, key) + } + sort.Strings(keys) + + for _, key := range keys { + b = q.appendUniqueConstraint(fmter, b, key, unique[key]) + } + + return b +} + +func (q *CreateTableQuery) appendUniqueConstraint( + fmter schema.Formatter, b []byte, name string, fields []*schema.Field, +) []byte { + if name != "" { + b = append(b, ", CONSTRAINT "...) + b = fmter.AppendIdent(b, name) + } else { + b = append(b, ","...) + } + b = append(b, " UNIQUE ("...) + b = appendColumns(b, "", fields) + b = append(b, ")"...) + + return b +} + +func (q *CreateTableQuery) appenFKConstraints( + fmter schema.Formatter, b []byte, +) (_ []byte, err error) { + for _, fk := range q.fks { + b = append(b, ", FOREIGN KEY "...) + b, err = fk.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + return b, nil +} + +func (q *CreateTableQuery) appendPKConstraint(b []byte, pks []*schema.Field) []byte { + if len(pks) == 0 { + return b + } + + b = append(b, ", PRIMARY KEY ("...) + b = appendColumns(b, "", pks) + b = append(b, ")"...) + return b +} + +//------------------------------------------------------------------------------ + +func (q *CreateTableQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) { + if err := q.beforeCreateTableHook(ctx); err != nil { + return nil, err + } + + queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes()) + if err != nil { + return nil, err + } + + query := internal.String(queryBytes) + + res, err := q.exec(ctx, q, query) + if err != nil { + return nil, err + } + + if q.table != nil { + if err := q.afterCreateTableHook(ctx); err != nil { + return nil, err + } + } + + return res, nil +} + +func (q *CreateTableQuery) beforeCreateTableHook(ctx context.Context) error { + if hook, ok := q.table.ZeroIface.(BeforeCreateTableHook); ok { + if err := hook.BeforeCreateTable(ctx, q); err != nil { + return err + } + } + return nil +} + +func (q *CreateTableQuery) afterCreateTableHook(ctx context.Context) error { + if hook, ok := q.table.ZeroIface.(AfterCreateTableHook); ok { + if err := hook.AfterCreateTable(ctx, q); err != nil { + return err + } + } + return nil +} diff --git a/vendor/github.com/uptrace/bun/query_table_drop.go b/vendor/github.com/uptrace/bun/query_table_drop.go new file mode 100644 index 000000000..2c30171c1 --- /dev/null +++ b/vendor/github.com/uptrace/bun/query_table_drop.go @@ -0,0 +1,137 @@ +package bun + +import ( + "context" + "database/sql" + + "github.com/uptrace/bun/internal" + "github.com/uptrace/bun/schema" +) + +type DropTableQuery struct { + baseQuery + cascadeQuery + + ifExists bool +} + +func NewDropTableQuery(db *DB) *DropTableQuery { + q := &DropTableQuery{ + baseQuery: baseQuery{ + db: db, + conn: db.DB, + }, + } + return q +} + +func (q *DropTableQuery) Conn(db IConn) *DropTableQuery { + q.setConn(db) + return q +} + +func (q *DropTableQuery) Model(model interface{}) *DropTableQuery { + q.setTableModel(model) + return q +} + +//------------------------------------------------------------------------------ + +func (q *DropTableQuery) Table(tables ...string) *DropTableQuery { + for _, table := range tables { + q.addTable(schema.UnsafeIdent(table)) + } + return q +} + +func (q *DropTableQuery) TableExpr(query string, args ...interface{}) *DropTableQuery { + q.addTable(schema.SafeQuery(query, args)) + return q +} + +func (q *DropTableQuery) ModelTableExpr(query string, args ...interface{}) *DropTableQuery { + q.modelTable = schema.SafeQuery(query, args) + return q +} + +//------------------------------------------------------------------------------ + +func (q *DropTableQuery) IfExists() *DropTableQuery { + q.ifExists = true + return q +} + +func (q *DropTableQuery) Restrict() *DropTableQuery { + q.restrict = true + return q +} + +//------------------------------------------------------------------------------ + +func (q *DropTableQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) { + if q.err != nil { + return nil, q.err + } + + b = append(b, "DROP TABLE "...) + if q.ifExists { + b = append(b, "IF EXISTS "...) + } + + b, err = q.appendTables(fmter, b) + if err != nil { + return nil, err + } + + b = q.appendCascade(fmter, b) + + return b, nil +} + +//------------------------------------------------------------------------------ + +func (q *DropTableQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) { + if q.table != nil { + if err := q.beforeDropTableHook(ctx); err != nil { + return nil, err + } + } + + queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes()) + if err != nil { + return nil, err + } + + query := internal.String(queryBytes) + + res, err := q.exec(ctx, q, query) + if err != nil { + return nil, err + } + + if q.table != nil { + if err := q.afterDropTableHook(ctx); err != nil { + return nil, err + } + } + + return res, nil +} + +func (q *DropTableQuery) beforeDropTableHook(ctx context.Context) error { + if hook, ok := q.table.ZeroIface.(BeforeDropTableHook); ok { + if err := hook.BeforeDropTable(ctx, q); err != nil { + return err + } + } + return nil +} + +func (q *DropTableQuery) afterDropTableHook(ctx context.Context) error { + if hook, ok := q.table.ZeroIface.(AfterDropTableHook); ok { + if err := hook.AfterDropTable(ctx, q); err != nil { + return err + } + } + return nil +} diff --git a/vendor/github.com/uptrace/bun/query_table_truncate.go b/vendor/github.com/uptrace/bun/query_table_truncate.go new file mode 100644 index 000000000..1e4bef7f6 --- /dev/null +++ b/vendor/github.com/uptrace/bun/query_table_truncate.go @@ -0,0 +1,121 @@ +package bun + +import ( + "context" + "database/sql" + + "github.com/uptrace/bun/dialect/feature" + "github.com/uptrace/bun/internal" + "github.com/uptrace/bun/schema" +) + +type TruncateTableQuery struct { + baseQuery + cascadeQuery + + continueIdentity bool +} + +func NewTruncateTableQuery(db *DB) *TruncateTableQuery { + q := &TruncateTableQuery{ + baseQuery: baseQuery{ + db: db, + conn: db.DB, + }, + } + return q +} + +func (q *TruncateTableQuery) Conn(db IConn) *TruncateTableQuery { + q.setConn(db) + return q +} + +func (q *TruncateTableQuery) Model(model interface{}) *TruncateTableQuery { + q.setTableModel(model) + return q +} + +//------------------------------------------------------------------------------ + +func (q *TruncateTableQuery) Table(tables ...string) *TruncateTableQuery { + for _, table := range tables { + q.addTable(schema.UnsafeIdent(table)) + } + return q +} + +func (q *TruncateTableQuery) TableExpr(query string, args ...interface{}) *TruncateTableQuery { + q.addTable(schema.SafeQuery(query, args)) + return q +} + +//------------------------------------------------------------------------------ + +func (q *TruncateTableQuery) ContinueIdentity() *TruncateTableQuery { + q.continueIdentity = true + return q +} + +func (q *TruncateTableQuery) Restrict() *TruncateTableQuery { + q.restrict = true + return q +} + +//------------------------------------------------------------------------------ + +func (q *TruncateTableQuery) AppendQuery( + fmter schema.Formatter, b []byte, +) (_ []byte, err error) { + if q.err != nil { + return nil, q.err + } + + if !fmter.HasFeature(feature.TableTruncate) { + b = append(b, "DELETE FROM "...) + + b, err = q.appendTables(fmter, b) + if err != nil { + return nil, err + } + + return b, nil + } + + b = append(b, "TRUNCATE TABLE "...) + + b, err = q.appendTables(fmter, b) + if err != nil { + return nil, err + } + + if q.db.features.Has(feature.TableIdentity) { + if q.continueIdentity { + b = append(b, " CONTINUE IDENTITY"...) + } else { + b = append(b, " RESTART IDENTITY"...) + } + } + + b = q.appendCascade(fmter, b) + + return b, nil +} + +//------------------------------------------------------------------------------ + +func (q *TruncateTableQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) { + queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes()) + if err != nil { + return nil, err + } + + query := internal.String(queryBytes) + + res, err := q.exec(ctx, q, query) + if err != nil { + return nil, err + } + + return res, nil +} diff --git a/vendor/github.com/uptrace/bun/query_update.go b/vendor/github.com/uptrace/bun/query_update.go new file mode 100644 index 000000000..ea74e1419 --- /dev/null +++ b/vendor/github.com/uptrace/bun/query_update.go @@ -0,0 +1,432 @@ +package bun + +import ( + "context" + "database/sql" + "errors" + "fmt" + + "github.com/uptrace/bun/dialect/feature" + "github.com/uptrace/bun/internal" + "github.com/uptrace/bun/schema" +) + +type UpdateQuery struct { + whereBaseQuery + returningQuery + customValueQuery + setQuery + + omitZero bool +} + +func NewUpdateQuery(db *DB) *UpdateQuery { + q := &UpdateQuery{ + whereBaseQuery: whereBaseQuery{ + baseQuery: baseQuery{ + db: db, + conn: db.DB, + }, + }, + } + return q +} + +func (q *UpdateQuery) Conn(db IConn) *UpdateQuery { + q.setConn(db) + return q +} + +func (q *UpdateQuery) Model(model interface{}) *UpdateQuery { + q.setTableModel(model) + return q +} + +// Apply calls the fn passing the SelectQuery as an argument. +func (q *UpdateQuery) Apply(fn func(*UpdateQuery) *UpdateQuery) *UpdateQuery { + return fn(q) +} + +func (q *UpdateQuery) With(name string, query schema.QueryAppender) *UpdateQuery { + q.addWith(name, query) + return q +} + +//------------------------------------------------------------------------------ + +func (q *UpdateQuery) Table(tables ...string) *UpdateQuery { + for _, table := range tables { + q.addTable(schema.UnsafeIdent(table)) + } + return q +} + +func (q *UpdateQuery) TableExpr(query string, args ...interface{}) *UpdateQuery { + q.addTable(schema.SafeQuery(query, args)) + return q +} + +func (q *UpdateQuery) ModelTableExpr(query string, args ...interface{}) *UpdateQuery { + q.modelTable = schema.SafeQuery(query, args) + return q +} + +//------------------------------------------------------------------------------ + +func (q *UpdateQuery) Column(columns ...string) *UpdateQuery { + for _, column := range columns { + q.addColumn(schema.UnsafeIdent(column)) + } + return q +} + +func (q *UpdateQuery) ExcludeColumn(columns ...string) *UpdateQuery { + q.excludeColumn(columns) + return q +} + +func (q *UpdateQuery) Set(query string, args ...interface{}) *UpdateQuery { + q.addSet(schema.SafeQuery(query, args)) + return q +} + +// Value overwrites model value for the column in INSERT and UPDATE queries. +func (q *UpdateQuery) Value(column string, value string, args ...interface{}) *UpdateQuery { + if q.table == nil { + q.err = errNilModel + return q + } + q.addValue(q.table, column, value, args) + return q +} + +//------------------------------------------------------------------------------ + +func (q *UpdateQuery) WherePK() *UpdateQuery { + q.flags = q.flags.Set(wherePKFlag) + return q +} + +func (q *UpdateQuery) Where(query string, args ...interface{}) *UpdateQuery { + q.addWhere(schema.SafeQueryWithSep(query, args, " AND ")) + return q +} + +func (q *UpdateQuery) WhereOr(query string, args ...interface{}) *UpdateQuery { + q.addWhere(schema.SafeQueryWithSep(query, args, " OR ")) + return q +} + +func (q *UpdateQuery) WhereGroup(sep string, fn func(*UpdateQuery) *UpdateQuery) *UpdateQuery { + saved := q.where + q.where = nil + + q = fn(q) + + where := q.where + q.where = saved + + q.addWhereGroup(sep, where) + + return q +} + +func (q *UpdateQuery) WhereDeleted() *UpdateQuery { + q.whereDeleted() + return q +} + +func (q *UpdateQuery) WhereAllWithDeleted() *UpdateQuery { + q.whereAllWithDeleted() + return q +} + +//------------------------------------------------------------------------------ + +// Returning adds a RETURNING clause to the query. +// +// To suppress the auto-generated RETURNING clause, use `Returning("NULL")`. +func (q *UpdateQuery) Returning(query string, args ...interface{}) *UpdateQuery { + q.addReturning(schema.SafeQuery(query, args)) + return q +} + +func (q *UpdateQuery) hasReturning() bool { + if !q.db.features.Has(feature.Returning) { + return false + } + return q.returningQuery.hasReturning() +} + +//------------------------------------------------------------------------------ + +func (q *UpdateQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) { + if q.err != nil { + return nil, q.err + } + fmter = formatterWithModel(fmter, q) + + withAlias := fmter.HasFeature(feature.UpdateMultiTable) + + b, err = q.appendWith(fmter, b) + if err != nil { + return nil, err + } + + b = append(b, "UPDATE "...) + + if withAlias { + b, err = q.appendTablesWithAlias(fmter, b) + } else { + b, err = q.appendFirstTableWithAlias(fmter, b) + } + if err != nil { + return nil, err + } + + b, err = q.mustAppendSet(fmter, b) + if err != nil { + return nil, err + } + + if !fmter.HasFeature(feature.UpdateMultiTable) { + b, err = q.appendOtherTables(fmter, b) + if err != nil { + return nil, err + } + } + + b, err = q.mustAppendWhere(fmter, b, withAlias) + if err != nil { + return nil, err + } + + if len(q.returning) > 0 { + b, err = q.appendReturning(fmter, b) + if err != nil { + return nil, err + } + } + + return b, nil +} + +func (q *UpdateQuery) mustAppendSet(fmter schema.Formatter, b []byte) (_ []byte, err error) { + b = append(b, " SET "...) + + if len(q.set) > 0 { + return q.appendSet(fmter, b) + } + + if m, ok := q.model.(*mapModel); ok { + return m.appendSet(fmter, b), nil + } + + if q.tableModel == nil { + return nil, errNilModel + } + + switch model := q.tableModel.(type) { + case *structTableModel: + b, err = q.appendSetStruct(fmter, b, model) + if err != nil { + return nil, err + } + case *sliceTableModel: + return nil, errors.New("bun: to bulk Update, use CTE and VALUES") + default: + return nil, fmt.Errorf("bun: Update does not support %T", q.tableModel) + } + + return b, nil +} + +func (q *UpdateQuery) appendSetStruct( + fmter schema.Formatter, b []byte, model *structTableModel, +) ([]byte, error) { + fields, err := q.getDataFields() + if err != nil { + return nil, err + } + + isTemplate := fmter.IsNop() + pos := len(b) + for _, f := range fields { + if q.omitZero && f.NullZero && f.HasZeroValue(model.strct) { + continue + } + + if len(b) != pos { + b = append(b, ", "...) + pos = len(b) + } + + b = append(b, f.SQLName...) + b = append(b, " = "...) + + if isTemplate { + b = append(b, '?') + continue + } + + app, ok := q.modelValues[f.Name] + if ok { + b, err = app.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } else { + b = f.AppendValue(fmter, b, model.strct) + } + } + + for i, v := range q.extraValues { + if i > 0 || len(fields) > 0 { + b = append(b, ", "...) + } + + b = append(b, v.column...) + b = append(b, " = "...) + + b, err = v.value.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + + return b, nil +} + +func (q *UpdateQuery) appendOtherTables(fmter schema.Formatter, b []byte) (_ []byte, err error) { + if !q.hasMultiTables() { + return b, nil + } + + b = append(b, " FROM "...) + + b, err = q.whereBaseQuery.appendOtherTables(fmter, b) + if err != nil { + return nil, err + } + + return b, nil +} + +//------------------------------------------------------------------------------ + +func (q *UpdateQuery) Bulk() *UpdateQuery { + model, ok := q.model.(*sliceTableModel) + if !ok { + q.setErr(fmt.Errorf("bun: Bulk requires a slice, got %T", q.model)) + return q + } + + return q.With("_data", q.db.NewValues(model)). + Model(model). + TableExpr("_data"). + Set(q.updateSliceSet(model)). + Where(q.updateSliceWhere(model)) +} + +func (q *UpdateQuery) updateSliceSet(model *sliceTableModel) string { + var b []byte + for i, field := range model.table.DataFields { + if i > 0 { + b = append(b, ", "...) + } + if q.db.fmter.HasFeature(feature.UpdateMultiTable) { + b = append(b, model.table.SQLAlias...) + b = append(b, '.') + } + b = append(b, field.SQLName...) + b = append(b, " = _data."...) + b = append(b, field.SQLName...) + } + return internal.String(b) +} + +func (db *UpdateQuery) updateSliceWhere(model *sliceTableModel) string { + var b []byte + for i, pk := range model.table.PKs { + if i > 0 { + b = append(b, " AND "...) + } + b = append(b, model.table.SQLAlias...) + b = append(b, '.') + b = append(b, pk.SQLName...) + b = append(b, " = _data."...) + b = append(b, pk.SQLName...) + } + return internal.String(b) +} + +//------------------------------------------------------------------------------ + +func (q *UpdateQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) { + if q.table != nil { + if err := q.beforeUpdateHook(ctx); err != nil { + return nil, err + } + } + + queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes()) + if err != nil { + return nil, err + } + + query := internal.String(queryBytes) + + var res sql.Result + + if hasDest := len(dest) > 0; hasDest || q.hasReturning() { + model, err := q.getModel(dest) + if err != nil { + return nil, err + } + + res, err = q.scan(ctx, q, query, model, hasDest) + if err != nil { + return nil, err + } + } else { + res, err = q.exec(ctx, q, query) + if err != nil { + return nil, err + } + } + + if q.table != nil { + if err := q.afterUpdateHook(ctx); err != nil { + return nil, err + } + } + + return res, nil +} + +func (q *UpdateQuery) beforeUpdateHook(ctx context.Context) error { + if hook, ok := q.table.ZeroIface.(BeforeUpdateHook); ok { + if err := hook.BeforeUpdate(ctx, q); err != nil { + return err + } + } + return nil +} + +func (q *UpdateQuery) afterUpdateHook(ctx context.Context) error { + if hook, ok := q.table.ZeroIface.(AfterUpdateHook); ok { + if err := hook.AfterUpdate(ctx, q); err != nil { + return err + } + } + return nil +} + +// FQN returns a fully qualified column name. For MySQL, it returns the column name with +// the table alias. For other RDBMS, it returns just the column name. +func (q *UpdateQuery) FQN(name string) Ident { + if q.db.fmter.HasFeature(feature.UpdateMultiTable) { + return Ident(q.table.Alias + "." + name) + } + return Ident(name) +} diff --git a/vendor/github.com/uptrace/bun/query_values.go b/vendor/github.com/uptrace/bun/query_values.go new file mode 100644 index 000000000..323ac68ef --- /dev/null +++ b/vendor/github.com/uptrace/bun/query_values.go @@ -0,0 +1,198 @@ +package bun + +import ( + "fmt" + "reflect" + "strconv" + + "github.com/uptrace/bun/dialect/feature" + "github.com/uptrace/bun/schema" +) + +type ValuesQuery struct { + baseQuery + customValueQuery + + withOrder bool +} + +var _ schema.NamedArgAppender = (*ValuesQuery)(nil) + +func NewValuesQuery(db *DB, model interface{}) *ValuesQuery { + q := &ValuesQuery{ + baseQuery: baseQuery{ + db: db, + conn: db.DB, + }, + } + q.setTableModel(model) + return q +} + +func (q *ValuesQuery) Conn(db IConn) *ValuesQuery { + q.setConn(db) + return q +} + +func (q *ValuesQuery) WithOrder() *ValuesQuery { + q.withOrder = true + return q +} + +func (q *ValuesQuery) AppendNamedArg(fmter schema.Formatter, b []byte, name string) ([]byte, bool) { + switch name { + case "Columns": + bb, err := q.AppendColumns(fmter, b) + if err != nil { + q.setErr(err) + return b, true + } + return bb, true + } + return b, false +} + +// AppendColumns appends the table columns. It is used by CTE. +func (q *ValuesQuery) AppendColumns(fmter schema.Formatter, b []byte) (_ []byte, err error) { + if q.err != nil { + return nil, q.err + } + if q.model == nil { + return nil, errNilModel + } + + if q.tableModel != nil { + fields, err := q.getFields() + if err != nil { + return nil, err + } + + b = appendColumns(b, "", fields) + + if q.withOrder { + b = append(b, ", _order"...) + } + + return b, nil + } + + switch model := q.model.(type) { + case *mapSliceModel: + return model.appendColumns(fmter, b) + } + + return nil, fmt.Errorf("bun: Values does not support %T", q.model) +} + +func (q *ValuesQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) { + if q.err != nil { + return nil, q.err + } + if q.model == nil { + return nil, errNilModel + } + + fmter = formatterWithModel(fmter, q) + + if q.tableModel != nil { + fields, err := q.getFields() + if err != nil { + return nil, err + } + return q.appendQuery(fmter, b, fields) + } + + switch model := q.model.(type) { + case *mapSliceModel: + return model.appendValues(fmter, b) + } + + return nil, fmt.Errorf("bun: Values does not support %T", q.model) +} + +func (q *ValuesQuery) appendQuery( + fmter schema.Formatter, + b []byte, + fields []*schema.Field, +) (_ []byte, err error) { + b = append(b, "VALUES "...) + if q.db.features.Has(feature.ValuesRow) { + b = append(b, "ROW("...) + } else { + b = append(b, '(') + } + + switch model := q.tableModel.(type) { + case *structTableModel: + b, err = q.appendValues(fmter, b, fields, model.strct) + if err != nil { + return nil, err + } + + if q.withOrder { + b = append(b, ", "...) + b = strconv.AppendInt(b, 0, 10) + } + case *sliceTableModel: + slice := model.slice + sliceLen := slice.Len() + for i := 0; i < sliceLen; i++ { + if i > 0 { + b = append(b, "), "...) + if q.db.features.Has(feature.ValuesRow) { + b = append(b, "ROW("...) + } else { + b = append(b, '(') + } + } + + b, err = q.appendValues(fmter, b, fields, slice.Index(i)) + if err != nil { + return nil, err + } + + if q.withOrder { + b = append(b, ", "...) + b = strconv.AppendInt(b, int64(i), 10) + } + } + default: + return nil, fmt.Errorf("bun: Values does not support %T", q.model) + } + + b = append(b, ')') + + return b, nil +} + +func (q *ValuesQuery) appendValues( + fmter schema.Formatter, b []byte, fields []*schema.Field, strct reflect.Value, +) (_ []byte, err error) { + isTemplate := fmter.IsNop() + for i, f := range fields { + if i > 0 { + b = append(b, ", "...) + } + + app, ok := q.modelValues[f.Name] + if ok { + b, err = app.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + continue + } + + if isTemplate { + b = append(b, '?') + } else { + b = f.AppendValue(fmter, b, indirect(strct)) + } + + if fmter.HasFeature(feature.DoubleColonCast) { + b = append(b, "::"...) + b = append(b, f.UserSQLType...) + } + } + return b, nil +} diff --git a/vendor/github.com/uptrace/bun/schema/append.go b/vendor/github.com/uptrace/bun/schema/append.go new file mode 100644 index 000000000..68f7071c8 --- /dev/null +++ b/vendor/github.com/uptrace/bun/schema/append.go @@ -0,0 +1,93 @@ +package schema + +import ( + "reflect" + "strconv" + "strings" + "time" + + "github.com/vmihailenco/msgpack/v5" + + "github.com/uptrace/bun/dialect" + "github.com/uptrace/bun/dialect/sqltype" + "github.com/uptrace/bun/internal" +) + +func FieldAppender(dialect Dialect, field *Field) AppenderFunc { + if field.Tag.HasOption("msgpack") { + return appendMsgpack + } + + switch strings.ToUpper(field.UserSQLType) { + case sqltype.JSON, sqltype.JSONB: + return AppendJSONValue + } + + return dialect.Appender(field.StructField.Type) +} + +func Append(fmter Formatter, b []byte, v interface{}, custom CustomAppender) []byte { + switch v := v.(type) { + case nil: + return dialect.AppendNull(b) + case bool: + return dialect.AppendBool(b, v) + case int: + return strconv.AppendInt(b, int64(v), 10) + case int32: + return strconv.AppendInt(b, int64(v), 10) + case int64: + return strconv.AppendInt(b, v, 10) + case uint: + return strconv.AppendUint(b, uint64(v), 10) + case uint32: + return strconv.AppendUint(b, uint64(v), 10) + case uint64: + return strconv.AppendUint(b, v, 10) + case float32: + return dialect.AppendFloat32(b, v) + case float64: + return dialect.AppendFloat64(b, v) + case string: + return dialect.AppendString(b, v) + case time.Time: + return dialect.AppendTime(b, v) + case []byte: + return dialect.AppendBytes(b, v) + case QueryAppender: + return AppendQueryAppender(fmter, b, v) + default: + vv := reflect.ValueOf(v) + if vv.Kind() == reflect.Ptr && vv.IsNil() { + return dialect.AppendNull(b) + } + appender := Appender(vv.Type(), custom) + return appender(fmter, b, vv) + } +} + +func appendMsgpack(fmter Formatter, b []byte, v reflect.Value) []byte { + hexEnc := internal.NewHexEncoder(b) + + enc := msgpack.GetEncoder() + defer msgpack.PutEncoder(enc) + + enc.Reset(hexEnc) + if err := enc.EncodeValue(v); err != nil { + return dialect.AppendError(b, err) + } + + if err := hexEnc.Close(); err != nil { + return dialect.AppendError(b, err) + } + + return hexEnc.Bytes() +} + +func AppendQueryAppender(fmter Formatter, b []byte, app QueryAppender) []byte { + bb, err := app.AppendQuery(fmter, b) + if err != nil { + return dialect.AppendError(b, err) + } + return bb +} diff --git a/vendor/github.com/uptrace/bun/schema/append_value.go b/vendor/github.com/uptrace/bun/schema/append_value.go new file mode 100644 index 000000000..0c4677069 --- /dev/null +++ b/vendor/github.com/uptrace/bun/schema/append_value.go @@ -0,0 +1,237 @@ +package schema + +import ( + "database/sql/driver" + "encoding/json" + "fmt" + "net" + "reflect" + "strconv" + "time" + + "github.com/uptrace/bun/dialect" + "github.com/uptrace/bun/extra/bunjson" + "github.com/uptrace/bun/internal" +) + +var ( + timeType = reflect.TypeOf((*time.Time)(nil)).Elem() + ipType = reflect.TypeOf((*net.IP)(nil)).Elem() + ipNetType = reflect.TypeOf((*net.IPNet)(nil)).Elem() + jsonRawMessageType = reflect.TypeOf((*json.RawMessage)(nil)).Elem() + + driverValuerType = reflect.TypeOf((*driver.Valuer)(nil)).Elem() + queryAppenderType = reflect.TypeOf((*QueryAppender)(nil)).Elem() +) + +type ( + AppenderFunc func(fmter Formatter, b []byte, v reflect.Value) []byte + CustomAppender func(typ reflect.Type) AppenderFunc +) + +var appenders = []AppenderFunc{ + reflect.Bool: AppendBoolValue, + reflect.Int: AppendIntValue, + reflect.Int8: AppendIntValue, + reflect.Int16: AppendIntValue, + reflect.Int32: AppendIntValue, + reflect.Int64: AppendIntValue, + reflect.Uint: AppendUintValue, + reflect.Uint8: AppendUintValue, + reflect.Uint16: AppendUintValue, + reflect.Uint32: AppendUintValue, + reflect.Uint64: AppendUintValue, + reflect.Uintptr: nil, + reflect.Float32: AppendFloat32Value, + reflect.Float64: AppendFloat64Value, + reflect.Complex64: nil, + reflect.Complex128: nil, + reflect.Array: AppendJSONValue, + reflect.Chan: nil, + reflect.Func: nil, + reflect.Interface: nil, + reflect.Map: AppendJSONValue, + reflect.Ptr: nil, + reflect.Slice: AppendJSONValue, + reflect.String: AppendStringValue, + reflect.Struct: AppendJSONValue, + reflect.UnsafePointer: nil, +} + +func Appender(typ reflect.Type, custom CustomAppender) AppenderFunc { + switch typ { + case timeType: + return appendTimeValue + case ipType: + return appendIPValue + case ipNetType: + return appendIPNetValue + case jsonRawMessageType: + return appendJSONRawMessageValue + } + + if typ.Implements(queryAppenderType) { + return appendQueryAppenderValue + } + if typ.Implements(driverValuerType) { + return driverValueAppender(custom) + } + + kind := typ.Kind() + + if kind != reflect.Ptr { + ptr := reflect.PtrTo(typ) + if ptr.Implements(queryAppenderType) { + return addrAppender(appendQueryAppenderValue, custom) + } + if ptr.Implements(driverValuerType) { + return addrAppender(driverValueAppender(custom), custom) + } + } + + switch kind { + case reflect.Interface: + return ifaceAppenderFunc(typ, custom) + case reflect.Ptr: + return ptrAppenderFunc(typ, custom) + case reflect.Slice: + if typ.Elem().Kind() == reflect.Uint8 { + return appendBytesValue + } + case reflect.Array: + if typ.Elem().Kind() == reflect.Uint8 { + return appendArrayBytesValue + } + } + + if custom != nil { + if fn := custom(typ); fn != nil { + return fn + } + } + return appenders[typ.Kind()] +} + +func ifaceAppenderFunc(typ reflect.Type, custom func(reflect.Type) AppenderFunc) AppenderFunc { + return func(fmter Formatter, b []byte, v reflect.Value) []byte { + if v.IsNil() { + return dialect.AppendNull(b) + } + elem := v.Elem() + appender := Appender(elem.Type(), custom) + return appender(fmter, b, elem) + } +} + +func ptrAppenderFunc(typ reflect.Type, custom func(reflect.Type) AppenderFunc) AppenderFunc { + appender := Appender(typ.Elem(), custom) + return func(fmter Formatter, b []byte, v reflect.Value) []byte { + if v.IsNil() { + return dialect.AppendNull(b) + } + return appender(fmter, b, v.Elem()) + } +} + +func AppendBoolValue(fmter Formatter, b []byte, v reflect.Value) []byte { + return dialect.AppendBool(b, v.Bool()) +} + +func AppendIntValue(fmter Formatter, b []byte, v reflect.Value) []byte { + return strconv.AppendInt(b, v.Int(), 10) +} + +func AppendUintValue(fmter Formatter, b []byte, v reflect.Value) []byte { + return strconv.AppendUint(b, v.Uint(), 10) +} + +func AppendFloat32Value(fmter Formatter, b []byte, v reflect.Value) []byte { + return dialect.AppendFloat32(b, float32(v.Float())) +} + +func AppendFloat64Value(fmter Formatter, b []byte, v reflect.Value) []byte { + return dialect.AppendFloat64(b, float64(v.Float())) +} + +func appendBytesValue(fmter Formatter, b []byte, v reflect.Value) []byte { + return dialect.AppendBytes(b, v.Bytes()) +} + +func appendArrayBytesValue(fmter Formatter, b []byte, v reflect.Value) []byte { + if v.CanAddr() { + return dialect.AppendBytes(b, v.Slice(0, v.Len()).Bytes()) + } + + tmp := make([]byte, v.Len()) + reflect.Copy(reflect.ValueOf(tmp), v) + b = dialect.AppendBytes(b, tmp) + return b +} + +func AppendStringValue(fmter Formatter, b []byte, v reflect.Value) []byte { + return dialect.AppendString(b, v.String()) +} + +func AppendJSONValue(fmter Formatter, b []byte, v reflect.Value) []byte { + bb, err := bunjson.Marshal(v.Interface()) + if err != nil { + return dialect.AppendError(b, err) + } + + if len(bb) > 0 && bb[len(bb)-1] == '\n' { + bb = bb[:len(bb)-1] + } + + return dialect.AppendJSON(b, bb) +} + +func appendTimeValue(fmter Formatter, b []byte, v reflect.Value) []byte { + tm := v.Interface().(time.Time) + return dialect.AppendTime(b, tm) +} + +func appendIPValue(fmter Formatter, b []byte, v reflect.Value) []byte { + ip := v.Interface().(net.IP) + return dialect.AppendString(b, ip.String()) +} + +func appendIPNetValue(fmter Formatter, b []byte, v reflect.Value) []byte { + ipnet := v.Interface().(net.IPNet) + return dialect.AppendString(b, ipnet.String()) +} + +func appendJSONRawMessageValue(fmter Formatter, b []byte, v reflect.Value) []byte { + bytes := v.Bytes() + if bytes == nil { + return dialect.AppendNull(b) + } + return dialect.AppendString(b, internal.String(bytes)) +} + +func appendQueryAppenderValue(fmter Formatter, b []byte, v reflect.Value) []byte { + return AppendQueryAppender(fmter, b, v.Interface().(QueryAppender)) +} + +func driverValueAppender(custom CustomAppender) AppenderFunc { + return func(fmter Formatter, b []byte, v reflect.Value) []byte { + return appendDriverValue(fmter, b, v.Interface().(driver.Valuer), custom) + } +} + +func appendDriverValue(fmter Formatter, b []byte, v driver.Valuer, custom CustomAppender) []byte { + value, err := v.Value() + if err != nil { + return dialect.AppendError(b, err) + } + return Append(fmter, b, value, custom) +} + +func addrAppender(fn AppenderFunc, custom CustomAppender) AppenderFunc { + return func(fmter Formatter, b []byte, v reflect.Value) []byte { + if !v.CanAddr() { + err := fmt.Errorf("bun: Append(nonaddressable %T)", v.Interface()) + return dialect.AppendError(b, err) + } + return fn(fmter, b, v.Addr()) + } +} diff --git a/vendor/github.com/uptrace/bun/schema/dialect.go b/vendor/github.com/uptrace/bun/schema/dialect.go new file mode 100644 index 000000000..c50de715a --- /dev/null +++ b/vendor/github.com/uptrace/bun/schema/dialect.go @@ -0,0 +1,99 @@ +package schema + +import ( + "database/sql" + "reflect" + "sync" + + "github.com/uptrace/bun/dialect" + "github.com/uptrace/bun/dialect/feature" +) + +type Dialect interface { + Init(db *sql.DB) + + Name() dialect.Name + Features() feature.Feature + + Tables() *Tables + OnTable(table *Table) + + IdentQuote() byte + Append(fmter Formatter, b []byte, v interface{}) []byte + Appender(typ reflect.Type) AppenderFunc + FieldAppender(field *Field) AppenderFunc + Scanner(typ reflect.Type) ScannerFunc +} + +//------------------------------------------------------------------------------ + +type nopDialect struct { + tables *Tables + features feature.Feature + + appenderMap sync.Map + scannerMap sync.Map +} + +func newNopDialect() *nopDialect { + d := new(nopDialect) + d.tables = NewTables(d) + d.features = feature.Returning + return d +} + +func (d *nopDialect) Init(*sql.DB) {} + +func (d *nopDialect) Name() dialect.Name { + return dialect.Invalid +} + +func (d *nopDialect) Features() feature.Feature { + return d.features +} + +func (d *nopDialect) Tables() *Tables { + return d.tables +} + +func (d *nopDialect) OnField(field *Field) {} + +func (d *nopDialect) OnTable(table *Table) {} + +func (d *nopDialect) IdentQuote() byte { + return '"' +} + +func (d *nopDialect) Append(fmter Formatter, b []byte, v interface{}) []byte { + return Append(fmter, b, v, nil) +} + +func (d *nopDialect) Appender(typ reflect.Type) AppenderFunc { + if v, ok := d.appenderMap.Load(typ); ok { + return v.(AppenderFunc) + } + + fn := Appender(typ, nil) + + if v, ok := d.appenderMap.LoadOrStore(typ, fn); ok { + return v.(AppenderFunc) + } + return fn +} + +func (d *nopDialect) FieldAppender(field *Field) AppenderFunc { + return FieldAppender(d, field) +} + +func (d *nopDialect) Scanner(typ reflect.Type) ScannerFunc { + if v, ok := d.scannerMap.Load(typ); ok { + return v.(ScannerFunc) + } + + fn := Scanner(typ) + + if v, ok := d.scannerMap.LoadOrStore(typ, fn); ok { + return v.(ScannerFunc) + } + return fn +} diff --git a/vendor/github.com/uptrace/bun/schema/field.go b/vendor/github.com/uptrace/bun/schema/field.go new file mode 100644 index 000000000..1e069b82f --- /dev/null +++ b/vendor/github.com/uptrace/bun/schema/field.go @@ -0,0 +1,117 @@ +package schema + +import ( + "fmt" + "reflect" + + "github.com/uptrace/bun/dialect" + "github.com/uptrace/bun/internal/tagparser" +) + +type Field struct { + StructField reflect.StructField + + Tag tagparser.Tag + IndirectType reflect.Type + Index []int + + Name string // SQL name, .e.g. id + SQLName Safe // escaped SQL name, e.g. "id" + GoName string // struct field name, e.g. Id + + DiscoveredSQLType string + UserSQLType string + CreateTableSQLType string + SQLDefault string + + OnDelete string + OnUpdate string + + IsPK bool + NotNull bool + NullZero bool + AutoIncrement bool + + Append AppenderFunc + Scan ScannerFunc + IsZero IsZeroerFunc +} + +func (f *Field) String() string { + return f.Name +} + +func (f *Field) Clone() *Field { + cp := *f + cp.Index = cp.Index[:len(f.Index):len(f.Index)] + return &cp +} + +func (f *Field) Value(strct reflect.Value) reflect.Value { + return fieldByIndexAlloc(strct, f.Index) +} + +func (f *Field) HasZeroValue(v reflect.Value) bool { + for _, idx := range f.Index { + if v.Kind() == reflect.Ptr { + if v.IsNil() { + return true + } + v = v.Elem() + } + v = v.Field(idx) + } + return f.IsZero(v) +} + +func (f *Field) AppendValue(fmter Formatter, b []byte, strct reflect.Value) []byte { + fv, ok := fieldByIndex(strct, f.Index) + if !ok { + return dialect.AppendNull(b) + } + + if f.NullZero && f.IsZero(fv) { + return dialect.AppendNull(b) + } + if f.Append == nil { + panic(fmt.Errorf("bun: AppendValue(unsupported %s)", fv.Type())) + } + return f.Append(fmter, b, fv) +} + +func (f *Field) ScanWithCheck(fv reflect.Value, src interface{}) error { + if f.Scan == nil { + return fmt.Errorf("bun: Scan(unsupported %s)", f.IndirectType) + } + return f.Scan(fv, src) +} + +func (f *Field) ScanValue(strct reflect.Value, src interface{}) error { + if src == nil { + if fv, ok := fieldByIndex(strct, f.Index); ok { + return f.ScanWithCheck(fv, src) + } + return nil + } + + fv := fieldByIndexAlloc(strct, f.Index) + return f.ScanWithCheck(fv, src) +} + +func (f *Field) markAsPK() { + f.IsPK = true + f.NotNull = true + f.NullZero = true +} + +func indexEqual(ind1, ind2 []int) bool { + if len(ind1) != len(ind2) { + return false + } + for i, ind := range ind1 { + if ind != ind2[i] { + return false + } + } + return true +} diff --git a/vendor/github.com/uptrace/bun/schema/formatter.go b/vendor/github.com/uptrace/bun/schema/formatter.go new file mode 100644 index 000000000..7b26fbaca --- /dev/null +++ b/vendor/github.com/uptrace/bun/schema/formatter.go @@ -0,0 +1,248 @@ +package schema + +import ( + "reflect" + "strconv" + "strings" + + "github.com/uptrace/bun/dialect" + "github.com/uptrace/bun/dialect/feature" + "github.com/uptrace/bun/internal" + "github.com/uptrace/bun/internal/parser" +) + +var nopFormatter = Formatter{ + dialect: newNopDialect(), +} + +type Formatter struct { + dialect Dialect + args *namedArgList +} + +func NewFormatter(dialect Dialect) Formatter { + return Formatter{ + dialect: dialect, + } +} + +func NewNopFormatter() Formatter { + return nopFormatter +} + +func (f Formatter) IsNop() bool { + return f.dialect.Name() == dialect.Invalid +} + +func (f Formatter) Dialect() Dialect { + return f.dialect +} + +func (f Formatter) IdentQuote() byte { + return f.dialect.IdentQuote() +} + +func (f Formatter) AppendIdent(b []byte, ident string) []byte { + return dialect.AppendIdent(b, ident, f.IdentQuote()) +} + +func (f Formatter) AppendValue(b []byte, v reflect.Value) []byte { + if v.Kind() == reflect.Ptr && v.IsNil() { + return dialect.AppendNull(b) + } + appender := f.dialect.Appender(v.Type()) + return appender(f, b, v) +} + +func (f Formatter) HasFeature(feature feature.Feature) bool { + return f.dialect.Features().Has(feature) +} + +func (f Formatter) WithArg(arg NamedArgAppender) Formatter { + return Formatter{ + dialect: f.dialect, + args: f.args.WithArg(arg), + } +} + +func (f Formatter) WithNamedArg(name string, value interface{}) Formatter { + return Formatter{ + dialect: f.dialect, + args: f.args.WithArg(&namedArg{name: name, value: value}), + } +} + +func (f Formatter) FormatQuery(query string, args ...interface{}) string { + if f.IsNop() || (args == nil && f.args == nil) || strings.IndexByte(query, '?') == -1 { + return query + } + return internal.String(f.AppendQuery(nil, query, args...)) +} + +func (f Formatter) AppendQuery(dst []byte, query string, args ...interface{}) []byte { + if f.IsNop() || (args == nil && f.args == nil) || strings.IndexByte(query, '?') == -1 { + return append(dst, query...) + } + return f.append(dst, parser.NewString(query), args) +} + +func (f Formatter) append(dst []byte, p *parser.Parser, args []interface{}) []byte { + var namedArgs NamedArgAppender + if len(args) == 1 { + var ok bool + namedArgs, ok = args[0].(NamedArgAppender) + if !ok { + namedArgs, _ = newStructArgs(f, args[0]) + } + } + + var argIndex int + for p.Valid() { + b, ok := p.ReadSep('?') + if !ok { + dst = append(dst, b...) + continue + } + if len(b) > 0 && b[len(b)-1] == '\\' { + dst = append(dst, b[:len(b)-1]...) + dst = append(dst, '?') + continue + } + dst = append(dst, b...) + + name, numeric := p.ReadIdentifier() + if name != "" { + if numeric { + idx, err := strconv.Atoi(name) + if err != nil { + goto restore_arg + } + + if idx >= len(args) { + goto restore_arg + } + + dst = f.appendArg(dst, args[idx]) + continue + } + + if namedArgs != nil { + dst, ok = namedArgs.AppendNamedArg(f, dst, name) + if ok { + continue + } + } + + dst, ok = f.args.AppendNamedArg(f, dst, name) + if ok { + continue + } + + restore_arg: + dst = append(dst, '?') + dst = append(dst, name...) + continue + } + + if argIndex >= len(args) { + dst = append(dst, '?') + continue + } + + arg := args[argIndex] + argIndex++ + + dst = f.appendArg(dst, arg) + } + + return dst +} + +func (f Formatter) appendArg(b []byte, arg interface{}) []byte { + switch arg := arg.(type) { + case QueryAppender: + bb, err := arg.AppendQuery(f, b) + if err != nil { + return dialect.AppendError(b, err) + } + return bb + default: + return f.dialect.Append(f, b, arg) + } +} + +//------------------------------------------------------------------------------ + +type NamedArgAppender interface { + AppendNamedArg(fmter Formatter, b []byte, name string) ([]byte, bool) +} + +//------------------------------------------------------------------------------ + +type namedArgList struct { + arg NamedArgAppender + next *namedArgList +} + +func (l *namedArgList) WithArg(arg NamedArgAppender) *namedArgList { + return &namedArgList{ + arg: arg, + next: l, + } +} + +func (l *namedArgList) AppendNamedArg(fmter Formatter, b []byte, name string) ([]byte, bool) { + for l != nil && l.arg != nil { + if b, ok := l.arg.AppendNamedArg(fmter, b, name); ok { + return b, true + } + l = l.next + } + return b, false +} + +//------------------------------------------------------------------------------ + +type namedArg struct { + name string + value interface{} +} + +var _ NamedArgAppender = (*namedArg)(nil) + +func (a *namedArg) AppendNamedArg(fmter Formatter, b []byte, name string) ([]byte, bool) { + if a.name == name { + return fmter.appendArg(b, a.value), true + } + return b, false +} + +//------------------------------------------------------------------------------ + +var _ NamedArgAppender = (*structArgs)(nil) + +type structArgs struct { + table *Table + strct reflect.Value +} + +func newStructArgs(fmter Formatter, strct interface{}) (*structArgs, bool) { + v := reflect.ValueOf(strct) + if !v.IsValid() { + return nil, false + } + + v = reflect.Indirect(v) + if v.Kind() != reflect.Struct { + return nil, false + } + + return &structArgs{ + table: fmter.Dialect().Tables().Get(v.Type()), + strct: v, + }, true +} + +func (m *structArgs) AppendNamedArg(fmter Formatter, b []byte, name string) ([]byte, bool) { + return m.table.AppendNamedArg(fmter, b, name, m.strct) +} diff --git a/vendor/github.com/uptrace/bun/schema/hook.go b/vendor/github.com/uptrace/bun/schema/hook.go new file mode 100644 index 000000000..5391981d5 --- /dev/null +++ b/vendor/github.com/uptrace/bun/schema/hook.go @@ -0,0 +1,20 @@ +package schema + +import ( + "context" + "reflect" +) + +type BeforeScanHook interface { + BeforeScan(context.Context) error +} + +var beforeScanHookType = reflect.TypeOf((*BeforeScanHook)(nil)).Elem() + +//------------------------------------------------------------------------------ + +type AfterScanHook interface { + AfterScan(context.Context) error +} + +var afterScanHookType = reflect.TypeOf((*AfterScanHook)(nil)).Elem() diff --git a/vendor/github.com/uptrace/bun/schema/relation.go b/vendor/github.com/uptrace/bun/schema/relation.go new file mode 100644 index 000000000..8d1baeb3f --- /dev/null +++ b/vendor/github.com/uptrace/bun/schema/relation.go @@ -0,0 +1,32 @@ +package schema + +import ( + "fmt" +) + +const ( + InvalidRelation = iota + HasOneRelation + BelongsToRelation + HasManyRelation + ManyToManyRelation +) + +type Relation struct { + Type int + Field *Field + JoinTable *Table + BaseFields []*Field + JoinFields []*Field + + PolymorphicField *Field + PolymorphicValue string + + M2MTable *Table + M2MBaseFields []*Field + M2MJoinFields []*Field +} + +func (r *Relation) String() string { + return fmt.Sprintf("relation=%s", r.Field.GoName) +} diff --git a/vendor/github.com/uptrace/bun/schema/scan.go b/vendor/github.com/uptrace/bun/schema/scan.go new file mode 100644 index 000000000..0e66a860f --- /dev/null +++ b/vendor/github.com/uptrace/bun/schema/scan.go @@ -0,0 +1,392 @@ +package schema + +import ( + "bytes" + "database/sql" + "fmt" + "net" + "reflect" + "strconv" + "time" + + "github.com/vmihailenco/msgpack/v5" + + "github.com/uptrace/bun/extra/bunjson" + "github.com/uptrace/bun/internal" +) + +var scannerType = reflect.TypeOf((*sql.Scanner)(nil)).Elem() + +type ScannerFunc func(dest reflect.Value, src interface{}) error + +var scanners = []ScannerFunc{ + reflect.Bool: scanBool, + reflect.Int: scanInt64, + reflect.Int8: scanInt64, + reflect.Int16: scanInt64, + reflect.Int32: scanInt64, + reflect.Int64: scanInt64, + reflect.Uint: scanUint64, + reflect.Uint8: scanUint64, + reflect.Uint16: scanUint64, + reflect.Uint32: scanUint64, + reflect.Uint64: scanUint64, + reflect.Uintptr: scanUint64, + reflect.Float32: scanFloat64, + reflect.Float64: scanFloat64, + reflect.Complex64: nil, + reflect.Complex128: nil, + reflect.Array: nil, + reflect.Chan: nil, + reflect.Func: nil, + reflect.Map: scanJSON, + reflect.Ptr: nil, + reflect.Slice: scanJSON, + reflect.String: scanString, + reflect.Struct: scanJSON, + reflect.UnsafePointer: nil, +} + +func FieldScanner(dialect Dialect, field *Field) ScannerFunc { + if field.Tag.HasOption("msgpack") { + return scanMsgpack + } + if field.Tag.HasOption("json_use_number") { + return scanJSONUseNumber + } + return dialect.Scanner(field.StructField.Type) +} + +func Scanner(typ reflect.Type) ScannerFunc { + kind := typ.Kind() + + if kind == reflect.Ptr { + if fn := Scanner(typ.Elem()); fn != nil { + return ptrScanner(fn) + } + } + + if typ.Implements(scannerType) { + return scanScanner + } + + if kind != reflect.Ptr { + ptr := reflect.PtrTo(typ) + if ptr.Implements(scannerType) { + return addrScanner(scanScanner) + } + } + + switch typ { + case timeType: + return scanTime + case ipType: + return scanIP + case ipNetType: + return scanIPNet + case jsonRawMessageType: + return scanJSONRawMessage + } + + return scanners[kind] +} + +func scanBool(dest reflect.Value, src interface{}) error { + switch src := src.(type) { + case nil: + dest.SetBool(false) + return nil + case bool: + dest.SetBool(src) + return nil + case int64: + dest.SetBool(src != 0) + return nil + case []byte: + if len(src) == 1 { + dest.SetBool(src[0] != '0') + return nil + } + } + return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type()) +} + +func scanInt64(dest reflect.Value, src interface{}) error { + switch src := src.(type) { + case nil: + dest.SetInt(0) + return nil + case int64: + dest.SetInt(src) + return nil + case uint64: + dest.SetInt(int64(src)) + return nil + case []byte: + n, err := strconv.ParseInt(internal.String(src), 10, 64) + if err != nil { + return err + } + dest.SetInt(n) + return nil + case string: + n, err := strconv.ParseInt(src, 10, 64) + if err != nil { + return err + } + dest.SetInt(n) + return nil + } + return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type()) +} + +func scanUint64(dest reflect.Value, src interface{}) error { + switch src := src.(type) { + case nil: + dest.SetUint(0) + return nil + case uint64: + dest.SetUint(src) + return nil + case int64: + dest.SetUint(uint64(src)) + return nil + case []byte: + n, err := strconv.ParseUint(internal.String(src), 10, 64) + if err != nil { + return err + } + dest.SetUint(n) + return nil + } + return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type()) +} + +func scanFloat64(dest reflect.Value, src interface{}) error { + switch src := src.(type) { + case nil: + dest.SetFloat(0) + return nil + case float64: + dest.SetFloat(src) + return nil + case []byte: + f, err := strconv.ParseFloat(internal.String(src), 64) + if err != nil { + return err + } + dest.SetFloat(f) + return nil + } + return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type()) +} + +func scanString(dest reflect.Value, src interface{}) error { + switch src := src.(type) { + case nil: + dest.SetString("") + return nil + case string: + dest.SetString(src) + return nil + case []byte: + dest.SetString(string(src)) + return nil + } + return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type()) +} + +func scanTime(dest reflect.Value, src interface{}) error { + switch src := src.(type) { + case nil: + destTime := dest.Addr().Interface().(*time.Time) + *destTime = time.Time{} + return nil + case time.Time: + destTime := dest.Addr().Interface().(*time.Time) + *destTime = src + return nil + case string: + srcTime, err := internal.ParseTime(src) + if err != nil { + return err + } + destTime := dest.Addr().Interface().(*time.Time) + *destTime = srcTime + return nil + case []byte: + srcTime, err := internal.ParseTime(internal.String(src)) + if err != nil { + return err + } + destTime := dest.Addr().Interface().(*time.Time) + *destTime = srcTime + return nil + } + return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type()) +} + +func scanScanner(dest reflect.Value, src interface{}) error { + return dest.Interface().(sql.Scanner).Scan(src) +} + +func scanMsgpack(dest reflect.Value, src interface{}) error { + if src == nil { + return scanNull(dest) + } + + b, err := toBytes(src) + if err != nil { + return err + } + + dec := msgpack.GetDecoder() + defer msgpack.PutDecoder(dec) + + dec.Reset(bytes.NewReader(b)) + return dec.DecodeValue(dest) +} + +func scanJSON(dest reflect.Value, src interface{}) error { + if src == nil { + return scanNull(dest) + } + + b, err := toBytes(src) + if err != nil { + return err + } + + return bunjson.Unmarshal(b, dest.Addr().Interface()) +} + +func scanJSONUseNumber(dest reflect.Value, src interface{}) error { + if src == nil { + return scanNull(dest) + } + + b, err := toBytes(src) + if err != nil { + return err + } + + dec := bunjson.NewDecoder(bytes.NewReader(b)) + dec.UseNumber() + return dec.Decode(dest.Addr().Interface()) +} + +func scanIP(dest reflect.Value, src interface{}) error { + if src == nil { + return scanNull(dest) + } + + b, err := toBytes(src) + if err != nil { + return err + } + + ip := net.ParseIP(internal.String(b)) + if ip == nil { + return fmt.Errorf("bun: invalid ip: %q", b) + } + + ptr := dest.Addr().Interface().(*net.IP) + *ptr = ip + + return nil +} + +func scanIPNet(dest reflect.Value, src interface{}) error { + if src == nil { + return scanNull(dest) + } + + b, err := toBytes(src) + if err != nil { + return err + } + + _, ipnet, err := net.ParseCIDR(internal.String(b)) + if err != nil { + return err + } + + ptr := dest.Addr().Interface().(*net.IPNet) + *ptr = *ipnet + + return nil +} + +func scanJSONRawMessage(dest reflect.Value, src interface{}) error { + if src == nil { + dest.SetBytes(nil) + return nil + } + + b, err := toBytes(src) + if err != nil { + return err + } + + dest.SetBytes(b) + return nil +} + +func addrScanner(fn ScannerFunc) ScannerFunc { + return func(dest reflect.Value, src interface{}) error { + if !dest.CanAddr() { + return fmt.Errorf("bun: Scan(nonaddressable %T)", dest.Interface()) + } + return fn(dest.Addr(), src) + } +} + +func toBytes(src interface{}) ([]byte, error) { + switch src := src.(type) { + case string: + return internal.Bytes(src), nil + case []byte: + return src, nil + default: + return nil, fmt.Errorf("bun: got %T, wanted []byte or string", src) + } +} + +func ptrScanner(fn ScannerFunc) ScannerFunc { + return func(dest reflect.Value, src interface{}) error { + if src == nil { + if !dest.CanAddr() { + if dest.IsNil() { + return nil + } + return fn(dest.Elem(), src) + } + + if !dest.IsNil() { + dest.Set(reflect.New(dest.Type().Elem())) + } + return nil + } + + if dest.IsNil() { + dest.Set(reflect.New(dest.Type().Elem())) + } + return fn(dest.Elem(), src) + } +} + +func scanNull(dest reflect.Value) error { + if nilable(dest.Kind()) && dest.IsNil() { + return nil + } + dest.Set(reflect.New(dest.Type()).Elem()) + return nil +} + +func nilable(kind reflect.Kind) bool { + switch kind { + case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice: + return true + } + return false +} diff --git a/vendor/github.com/uptrace/bun/schema/sqlfmt.go b/vendor/github.com/uptrace/bun/schema/sqlfmt.go new file mode 100644 index 000000000..7b538cd0c --- /dev/null +++ b/vendor/github.com/uptrace/bun/schema/sqlfmt.go @@ -0,0 +1,76 @@ +package schema + +type QueryAppender interface { + AppendQuery(fmter Formatter, b []byte) ([]byte, error) +} + +type ColumnsAppender interface { + AppendColumns(fmter Formatter, b []byte) ([]byte, error) +} + +//------------------------------------------------------------------------------ + +// Safe represents a safe SQL query. +type Safe string + +var _ QueryAppender = (*Safe)(nil) + +func (s Safe) AppendQuery(fmter Formatter, b []byte) ([]byte, error) { + return append(b, s...), nil +} + +//------------------------------------------------------------------------------ + +// Ident represents a SQL identifier, for example, table or column name. +type Ident string + +var _ QueryAppender = (*Ident)(nil) + +func (s Ident) AppendQuery(fmter Formatter, b []byte) ([]byte, error) { + return fmter.AppendIdent(b, string(s)), nil +} + +//------------------------------------------------------------------------------ + +type QueryWithArgs struct { + Query string + Args []interface{} +} + +var _ QueryAppender = QueryWithArgs{} + +func SafeQuery(query string, args []interface{}) QueryWithArgs { + if query != "" && args == nil { + args = make([]interface{}, 0) + } + return QueryWithArgs{Query: query, Args: args} +} + +func UnsafeIdent(ident string) QueryWithArgs { + return QueryWithArgs{Query: ident} +} + +func (q QueryWithArgs) IsZero() bool { + return q.Query == "" && q.Args == nil +} + +func (q QueryWithArgs) AppendQuery(fmter Formatter, b []byte) ([]byte, error) { + if q.Args == nil { + return fmter.AppendIdent(b, q.Query), nil + } + return fmter.AppendQuery(b, q.Query, q.Args...), nil +} + +//------------------------------------------------------------------------------ + +type QueryWithSep struct { + QueryWithArgs + Sep string +} + +func SafeQueryWithSep(query string, args []interface{}, sep string) QueryWithSep { + return QueryWithSep{ + QueryWithArgs: SafeQuery(query, args), + Sep: sep, + } +} diff --git a/vendor/github.com/uptrace/bun/schema/sqltype.go b/vendor/github.com/uptrace/bun/schema/sqltype.go new file mode 100644 index 000000000..560f695c2 --- /dev/null +++ b/vendor/github.com/uptrace/bun/schema/sqltype.go @@ -0,0 +1,129 @@ +package schema + +import ( + "bytes" + "database/sql" + "encoding/json" + "fmt" + "reflect" + "time" + + "github.com/uptrace/bun/dialect" + "github.com/uptrace/bun/dialect/sqltype" + "github.com/uptrace/bun/internal" +) + +var ( + bunNullTimeType = reflect.TypeOf((*NullTime)(nil)).Elem() + nullTimeType = reflect.TypeOf((*sql.NullTime)(nil)).Elem() + nullBoolType = reflect.TypeOf((*sql.NullBool)(nil)).Elem() + nullFloatType = reflect.TypeOf((*sql.NullFloat64)(nil)).Elem() + nullIntType = reflect.TypeOf((*sql.NullInt64)(nil)).Elem() + nullStringType = reflect.TypeOf((*sql.NullString)(nil)).Elem() +) + +var sqlTypes = []string{ + reflect.Bool: sqltype.Boolean, + reflect.Int: sqltype.BigInt, + reflect.Int8: sqltype.SmallInt, + reflect.Int16: sqltype.SmallInt, + reflect.Int32: sqltype.Integer, + reflect.Int64: sqltype.BigInt, + reflect.Uint: sqltype.BigInt, + reflect.Uint8: sqltype.SmallInt, + reflect.Uint16: sqltype.SmallInt, + reflect.Uint32: sqltype.Integer, + reflect.Uint64: sqltype.BigInt, + reflect.Uintptr: sqltype.BigInt, + reflect.Float32: sqltype.Real, + reflect.Float64: sqltype.DoublePrecision, + reflect.Complex64: "", + reflect.Complex128: "", + reflect.Array: "", + reflect.Chan: "", + reflect.Func: "", + reflect.Interface: "", + reflect.Map: sqltype.VarChar, + reflect.Ptr: "", + reflect.Slice: sqltype.VarChar, + reflect.String: sqltype.VarChar, + reflect.Struct: sqltype.VarChar, + reflect.UnsafePointer: "", +} + +func DiscoverSQLType(typ reflect.Type) string { + switch typ { + case timeType, nullTimeType, bunNullTimeType: + return sqltype.Timestamp + case nullBoolType: + return sqltype.Boolean + case nullFloatType: + return sqltype.DoublePrecision + case nullIntType: + return sqltype.BigInt + case nullStringType: + return sqltype.VarChar + } + return sqlTypes[typ.Kind()] +} + +//------------------------------------------------------------------------------ + +var jsonNull = []byte("null") + +// NullTime is a time.Time wrapper that marshals zero time as JSON null and SQL NULL. +type NullTime struct { + time.Time +} + +var ( + _ json.Marshaler = (*NullTime)(nil) + _ json.Unmarshaler = (*NullTime)(nil) + _ sql.Scanner = (*NullTime)(nil) + _ QueryAppender = (*NullTime)(nil) +) + +func (tm NullTime) MarshalJSON() ([]byte, error) { + if tm.IsZero() { + return jsonNull, nil + } + return tm.Time.MarshalJSON() +} + +func (tm *NullTime) UnmarshalJSON(b []byte) error { + if bytes.Equal(b, jsonNull) { + tm.Time = time.Time{} + return nil + } + return tm.Time.UnmarshalJSON(b) +} + +func (tm NullTime) AppendQuery(fmter Formatter, b []byte) ([]byte, error) { + if tm.IsZero() { + return dialect.AppendNull(b), nil + } + return dialect.AppendTime(b, tm.Time), nil +} + +func (tm *NullTime) Scan(src interface{}) error { + if src == nil { + tm.Time = time.Time{} + return nil + } + + switch src := src.(type) { + case []byte: + newtm, err := internal.ParseTime(internal.String(src)) + if err != nil { + return err + } + + tm.Time = newtm + return nil + case time.Time: + tm.Time = src + return nil + default: + return fmt.Errorf("bun: can't scan %#v into NullTime", src) + } +} diff --git a/vendor/github.com/uptrace/bun/schema/table.go b/vendor/github.com/uptrace/bun/schema/table.go new file mode 100644 index 000000000..eca18b781 --- /dev/null +++ b/vendor/github.com/uptrace/bun/schema/table.go @@ -0,0 +1,948 @@ +package schema + +import ( + "database/sql" + "fmt" + "reflect" + "strings" + "sync" + "time" + + "github.com/jinzhu/inflection" + + "github.com/uptrace/bun/internal" + "github.com/uptrace/bun/internal/tagparser" +) + +const ( + beforeScanHookFlag internal.Flag = 1 << iota + afterScanHookFlag +) + +var ( + baseModelType = reflect.TypeOf((*BaseModel)(nil)).Elem() + tableNameInflector = inflection.Plural +) + +type BaseModel struct{} + +// SetTableNameInflector overrides the default func that pluralizes +// model name to get table name, e.g. my_article becomes my_articles. +func SetTableNameInflector(fn func(string) string) { + tableNameInflector = fn +} + +// Table represents a SQL table created from Go struct. +type Table struct { + dialect Dialect + + Type reflect.Type + ZeroValue reflect.Value // reflect.Struct + ZeroIface interface{} // struct pointer + + TypeName string + ModelName string + + Name string + SQLName Safe + SQLNameForSelects Safe + Alias string + SQLAlias Safe + + Fields []*Field // PKs + DataFields + PKs []*Field + DataFields []*Field + + fieldsMapMu sync.RWMutex + FieldMap map[string]*Field + + Relations map[string]*Relation + Unique map[string][]*Field + + SoftDeleteField *Field + UpdateSoftDeleteField func(fv reflect.Value) error + + allFields []*Field // read only + skippedFields []*Field + + flags internal.Flag +} + +func newTable(dialect Dialect, typ reflect.Type) *Table { + t := new(Table) + t.dialect = dialect + t.Type = typ + t.ZeroValue = reflect.New(t.Type).Elem() + t.ZeroIface = reflect.New(t.Type).Interface() + t.TypeName = internal.ToExported(t.Type.Name()) + t.ModelName = internal.Underscore(t.Type.Name()) + tableName := tableNameInflector(t.ModelName) + t.setName(tableName) + t.Alias = t.ModelName + t.SQLAlias = t.quoteIdent(t.ModelName) + + hooks := []struct { + typ reflect.Type + flag internal.Flag + }{ + {beforeScanHookType, beforeScanHookFlag}, + {afterScanHookType, afterScanHookFlag}, + } + + typ = reflect.PtrTo(t.Type) + for _, hook := range hooks { + if typ.Implements(hook.typ) { + t.flags = t.flags.Set(hook.flag) + } + } + + return t +} + +func (t *Table) init1() { + t.initFields() +} + +func (t *Table) init2() { + t.initInlines() + t.initRelations() + t.skippedFields = nil +} + +func (t *Table) setName(name string) { + t.Name = name + t.SQLName = t.quoteIdent(name) + t.SQLNameForSelects = t.quoteIdent(name) + if t.SQLAlias == "" { + t.Alias = name + t.SQLAlias = t.quoteIdent(name) + } +} + +func (t *Table) String() string { + return "model=" + t.TypeName +} + +func (t *Table) CheckPKs() error { + if len(t.PKs) == 0 { + return fmt.Errorf("bun: %s does not have primary keys", t) + } + return nil +} + +func (t *Table) addField(field *Field) { + t.Fields = append(t.Fields, field) + if field.IsPK { + t.PKs = append(t.PKs, field) + } else { + t.DataFields = append(t.DataFields, field) + } + t.FieldMap[field.Name] = field +} + +func (t *Table) removeField(field *Field) { + t.Fields = removeField(t.Fields, field) + if field.IsPK { + t.PKs = removeField(t.PKs, field) + } else { + t.DataFields = removeField(t.DataFields, field) + } + delete(t.FieldMap, field.Name) +} + +func (t *Table) fieldWithLock(name string) *Field { + t.fieldsMapMu.RLock() + field := t.FieldMap[name] + t.fieldsMapMu.RUnlock() + return field +} + +func (t *Table) HasField(name string) bool { + _, ok := t.FieldMap[name] + return ok +} + +func (t *Table) Field(name string) (*Field, error) { + field, ok := t.FieldMap[name] + if !ok { + return nil, fmt.Errorf("bun: %s does not have column=%s", t, name) + } + return field, nil +} + +func (t *Table) fieldByGoName(name string) *Field { + for _, f := range t.allFields { + if f.GoName == name { + return f + } + } + return nil +} + +func (t *Table) initFields() { + t.Fields = make([]*Field, 0, t.Type.NumField()) + t.FieldMap = make(map[string]*Field, t.Type.NumField()) + t.addFields(t.Type, nil) + + if len(t.PKs) > 0 { + return + } + for _, name := range []string{"id", "uuid", "pk_" + t.ModelName} { + if field, ok := t.FieldMap[name]; ok { + field.markAsPK() + t.PKs = []*Field{field} + t.DataFields = removeField(t.DataFields, field) + break + } + } + if len(t.PKs) == 1 { + switch t.PKs[0].IndirectType.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + t.PKs[0].AutoIncrement = true + } + } +} + +func (t *Table) addFields(typ reflect.Type, baseIndex []int) { + for i := 0; i < typ.NumField(); i++ { + f := typ.Field(i) + + // Make a copy so slice is not shared between fields. + index := make([]int, len(baseIndex)) + copy(index, baseIndex) + + if f.Anonymous { + if f.Tag.Get("bun") == "-" { + continue + } + if f.Name == "BaseModel" && f.Type == baseModelType { + if len(index) == 0 { + t.processBaseModelField(f) + } + continue + } + + fieldType := indirectType(f.Type) + if fieldType.Kind() != reflect.Struct { + continue + } + t.addFields(fieldType, append(index, f.Index...)) + + tag := tagparser.Parse(f.Tag.Get("bun")) + if _, inherit := tag.Options["inherit"]; inherit { + embeddedTable := t.dialect.Tables().Ref(fieldType) + t.TypeName = embeddedTable.TypeName + t.SQLName = embeddedTable.SQLName + t.SQLNameForSelects = embeddedTable.SQLNameForSelects + t.Alias = embeddedTable.Alias + t.SQLAlias = embeddedTable.SQLAlias + t.ModelName = embeddedTable.ModelName + } + + continue + } + + field := t.newField(f, index) + if field != nil { + t.addField(field) + } + } +} + +func (t *Table) processBaseModelField(f reflect.StructField) { + tag := tagparser.Parse(f.Tag.Get("bun")) + + if isKnownTableOption(tag.Name) { + internal.Warn.Printf( + "%s.%s tag name %q is also an option name; is it a mistake?", + t.TypeName, f.Name, tag.Name, + ) + } + + for name := range tag.Options { + if !isKnownTableOption(name) { + internal.Warn.Printf("%s.%s has unknown tag option: %q", t.TypeName, f.Name, name) + } + } + + if tag.Name != "" { + t.setName(tag.Name) + } + + if s, ok := tag.Options["select"]; ok { + t.SQLNameForSelects = t.quoteTableName(s) + } + + if s, ok := tag.Options["alias"]; ok { + t.Alias = s + t.SQLAlias = t.quoteIdent(s) + } +} + +//nolint +func (t *Table) newField(f reflect.StructField, index []int) *Field { + tag := tagparser.Parse(f.Tag.Get("bun")) + + if f.PkgPath != "" { + return nil + } + + sqlName := internal.Underscore(f.Name) + + if tag.Name != sqlName && isKnownFieldOption(tag.Name) { + internal.Warn.Printf( + "%s.%s tag name %q is also an option name; is it a mistake?", + t.TypeName, f.Name, tag.Name, + ) + } + + for name := range tag.Options { + if !isKnownFieldOption(name) { + internal.Warn.Printf("%s.%s has unknown tag option: %q", t.TypeName, f.Name, name) + } + } + + skip := tag.Name == "-" + if !skip && tag.Name != "" { + sqlName = tag.Name + } + + index = append(index, f.Index...) + if field := t.fieldWithLock(sqlName); field != nil { + if indexEqual(field.Index, index) { + return field + } + t.removeField(field) + } + + field := &Field{ + StructField: f, + + Tag: tag, + IndirectType: indirectType(f.Type), + Index: index, + + Name: sqlName, + GoName: f.Name, + SQLName: t.quoteIdent(sqlName), + } + + field.NotNull = tag.HasOption("notnull") + field.NullZero = tag.HasOption("nullzero") + field.AutoIncrement = tag.HasOption("autoincrement") + if tag.HasOption("pk") { + field.markAsPK() + } + if tag.HasOption("allowzero") { + if tag.HasOption("nullzero") { + internal.Warn.Printf( + "%s.%s: nullzero and allowzero options are mutually exclusive", + t.TypeName, f.Name, + ) + } + field.NullZero = false + } + + if v, ok := tag.Options["unique"]; ok { + // Split the value by comma, this will allow multiple names to be specified. + // We can use this to create multiple named unique constraints where a single column + // might be included in multiple constraints. + for _, uniqueName := range strings.Split(v, ",") { + if t.Unique == nil { + t.Unique = make(map[string][]*Field) + } + t.Unique[uniqueName] = append(t.Unique[uniqueName], field) + } + } + if s, ok := tag.Options["default"]; ok { + field.SQLDefault = s + } + if s, ok := field.Tag.Options["type"]; ok { + field.UserSQLType = s + } + field.DiscoveredSQLType = DiscoverSQLType(field.IndirectType) + field.Append = t.dialect.FieldAppender(field) + field.Scan = FieldScanner(t.dialect, field) + field.IsZero = FieldZeroChecker(field) + + if v, ok := tag.Options["alt"]; ok { + t.FieldMap[v] = field + } + + t.allFields = append(t.allFields, field) + if skip { + t.skippedFields = append(t.skippedFields, field) + t.FieldMap[field.Name] = field + return nil + } + + if _, ok := tag.Options["soft_delete"]; ok { + field.NullZero = true + t.SoftDeleteField = field + t.UpdateSoftDeleteField = softDeleteFieldUpdater(field) + } + + return field +} + +func (t *Table) initInlines() { + for _, f := range t.skippedFields { + if f.IndirectType.Kind() == reflect.Struct { + t.inlineFields(f, nil) + } + } +} + +//--------------------------------------------------------------------------------------- + +func (t *Table) initRelations() { + for i := 0; i < len(t.Fields); { + f := t.Fields[i] + if t.tryRelation(f) { + t.Fields = removeField(t.Fields, f) + t.DataFields = removeField(t.DataFields, f) + } else { + i++ + } + + if f.IndirectType.Kind() == reflect.Struct { + t.inlineFields(f, nil) + } + } +} + +func (t *Table) tryRelation(field *Field) bool { + if rel, ok := field.Tag.Options["rel"]; ok { + t.initRelation(field, rel) + return true + } + if field.Tag.HasOption("m2m") { + t.addRelation(t.m2mRelation(field)) + return true + } + + if field.Tag.HasOption("join") { + internal.Warn.Printf( + `%s.%s option "join" requires a relation type`, + t.TypeName, field.GoName, + ) + } + + return false +} + +func (t *Table) initRelation(field *Field, rel string) { + switch rel { + case "belongs-to": + t.addRelation(t.belongsToRelation(field)) + case "has-one": + t.addRelation(t.hasOneRelation(field)) + case "has-many": + t.addRelation(t.hasManyRelation(field)) + default: + panic(fmt.Errorf("bun: unknown relation=%s on field=%s", rel, field.GoName)) + } +} + +func (t *Table) addRelation(rel *Relation) { + if t.Relations == nil { + t.Relations = make(map[string]*Relation) + } + _, ok := t.Relations[rel.Field.GoName] + if ok { + panic(fmt.Errorf("%s already has %s", t, rel)) + } + t.Relations[rel.Field.GoName] = rel +} + +func (t *Table) belongsToRelation(field *Field) *Relation { + joinTable := t.dialect.Tables().Ref(field.IndirectType) + if err := joinTable.CheckPKs(); err != nil { + panic(err) + } + + rel := &Relation{ + Type: HasOneRelation, + Field: field, + JoinTable: joinTable, + } + + if join, ok := field.Tag.Options["join"]; ok { + baseColumns, joinColumns := parseRelationJoin(join) + for i, baseColumn := range baseColumns { + joinColumn := joinColumns[i] + + if f := t.fieldWithLock(baseColumn); f != nil { + rel.BaseFields = append(rel.BaseFields, f) + } else { + panic(fmt.Errorf( + "bun: %s belongs-to %s: %s must have column %s", + t.TypeName, field.GoName, t.TypeName, baseColumn, + )) + } + + if f := joinTable.fieldWithLock(joinColumn); f != nil { + rel.JoinFields = append(rel.JoinFields, f) + } else { + panic(fmt.Errorf( + "bun: %s belongs-to %s: %s must have column %s", + t.TypeName, field.GoName, t.TypeName, baseColumn, + )) + } + } + return rel + } + + rel.JoinFields = joinTable.PKs + fkPrefix := internal.Underscore(field.GoName) + "_" + for _, joinPK := range joinTable.PKs { + fkName := fkPrefix + joinPK.Name + if fk := t.fieldWithLock(fkName); fk != nil { + rel.BaseFields = append(rel.BaseFields, fk) + continue + } + + if fk := t.fieldWithLock(joinPK.Name); fk != nil { + rel.BaseFields = append(rel.BaseFields, fk) + continue + } + + panic(fmt.Errorf( + "bun: %s belongs-to %s: %s must have column %s "+ + "(to override, use join:base_column=join_column tag on %s field)", + t.TypeName, field.GoName, t.TypeName, fkName, field.GoName, + )) + } + return rel +} + +func (t *Table) hasOneRelation(field *Field) *Relation { + if err := t.CheckPKs(); err != nil { + panic(err) + } + + joinTable := t.dialect.Tables().Ref(field.IndirectType) + rel := &Relation{ + Type: BelongsToRelation, + Field: field, + JoinTable: joinTable, + } + + if join, ok := field.Tag.Options["join"]; ok { + baseColumns, joinColumns := parseRelationJoin(join) + for i, baseColumn := range baseColumns { + if f := t.fieldWithLock(baseColumn); f != nil { + rel.BaseFields = append(rel.BaseFields, f) + } else { + panic(fmt.Errorf( + "bun: %s has-one %s: %s must have column %s", + field.GoName, t.TypeName, joinTable.TypeName, baseColumn, + )) + } + + joinColumn := joinColumns[i] + if f := joinTable.fieldWithLock(joinColumn); f != nil { + rel.JoinFields = append(rel.JoinFields, f) + } else { + panic(fmt.Errorf( + "bun: %s has-one %s: %s must have column %s", + field.GoName, t.TypeName, joinTable.TypeName, baseColumn, + )) + } + } + return rel + } + + rel.BaseFields = t.PKs + fkPrefix := internal.Underscore(t.ModelName) + "_" + for _, pk := range t.PKs { + fkName := fkPrefix + pk.Name + if f := joinTable.fieldWithLock(fkName); f != nil { + rel.JoinFields = append(rel.JoinFields, f) + continue + } + + if f := joinTable.fieldWithLock(pk.Name); f != nil { + rel.JoinFields = append(rel.JoinFields, f) + continue + } + + panic(fmt.Errorf( + "bun: %s has-one %s: %s must have column %s "+ + "(to override, use join:base_column=join_column tag on %s field)", + field.GoName, t.TypeName, joinTable.TypeName, fkName, field.GoName, + )) + } + return rel +} + +func (t *Table) hasManyRelation(field *Field) *Relation { + if err := t.CheckPKs(); err != nil { + panic(err) + } + if field.IndirectType.Kind() != reflect.Slice { + panic(fmt.Errorf( + "bun: %s.%s has-many relation requires slice, got %q", + t.TypeName, field.GoName, field.IndirectType.Kind(), + )) + } + + joinTable := t.dialect.Tables().Ref(indirectType(field.IndirectType.Elem())) + polymorphicValue, isPolymorphic := field.Tag.Options["polymorphic"] + rel := &Relation{ + Type: HasManyRelation, + Field: field, + JoinTable: joinTable, + } + var polymorphicColumn string + + if join, ok := field.Tag.Options["join"]; ok { + baseColumns, joinColumns := parseRelationJoin(join) + for i, baseColumn := range baseColumns { + joinColumn := joinColumns[i] + + if isPolymorphic && baseColumn == "type" { + polymorphicColumn = joinColumn + continue + } + + if f := t.fieldWithLock(baseColumn); f != nil { + rel.BaseFields = append(rel.BaseFields, f) + } else { + panic(fmt.Errorf( + "bun: %s has-one %s: %s must have column %s", + t.TypeName, field.GoName, t.TypeName, baseColumn, + )) + } + + if f := joinTable.fieldWithLock(joinColumn); f != nil { + rel.JoinFields = append(rel.JoinFields, f) + } else { + panic(fmt.Errorf( + "bun: %s has-one %s: %s must have column %s", + t.TypeName, field.GoName, t.TypeName, baseColumn, + )) + } + } + } else { + rel.BaseFields = t.PKs + fkPrefix := internal.Underscore(t.ModelName) + "_" + if isPolymorphic { + polymorphicColumn = fkPrefix + "type" + } + + for _, pk := range t.PKs { + joinColumn := fkPrefix + pk.Name + if fk := joinTable.fieldWithLock(joinColumn); fk != nil { + rel.JoinFields = append(rel.JoinFields, fk) + continue + } + + if fk := joinTable.fieldWithLock(pk.Name); fk != nil { + rel.JoinFields = append(rel.JoinFields, fk) + continue + } + + panic(fmt.Errorf( + "bun: %s has-many %s: %s must have column %s "+ + "(to override, use join:base_column=join_column tag on the field %s)", + t.TypeName, field.GoName, joinTable.TypeName, joinColumn, field.GoName, + )) + } + } + + if isPolymorphic { + rel.PolymorphicField = joinTable.fieldWithLock(polymorphicColumn) + if rel.PolymorphicField == nil { + panic(fmt.Errorf( + "bun: %s has-many %s: %s must have polymorphic column %s", + t.TypeName, field.GoName, joinTable.TypeName, polymorphicColumn, + )) + } + + if polymorphicValue == "" { + polymorphicValue = t.ModelName + } + rel.PolymorphicValue = polymorphicValue + } + + return rel +} + +func (t *Table) m2mRelation(field *Field) *Relation { + if field.IndirectType.Kind() != reflect.Slice { + panic(fmt.Errorf( + "bun: %s.%s m2m relation requires slice, got %q", + t.TypeName, field.GoName, field.IndirectType.Kind(), + )) + } + joinTable := t.dialect.Tables().Ref(indirectType(field.IndirectType.Elem())) + + if err := t.CheckPKs(); err != nil { + panic(err) + } + if err := joinTable.CheckPKs(); err != nil { + panic(err) + } + + m2mTableName, ok := field.Tag.Options["m2m"] + if !ok { + panic(fmt.Errorf("bun: %s must have m2m tag option", field.GoName)) + } + + m2mTable := t.dialect.Tables().ByName(m2mTableName) + if m2mTable == nil { + panic(fmt.Errorf( + "bun: can't find m2m %s table (use db.RegisterModel)", + m2mTableName, + )) + } + + rel := &Relation{ + Type: ManyToManyRelation, + Field: field, + JoinTable: joinTable, + M2MTable: m2mTable, + } + var leftColumn, rightColumn string + + if join, ok := field.Tag.Options["join"]; ok { + left, right := parseRelationJoin(join) + leftColumn = left[0] + rightColumn = right[0] + } else { + leftColumn = t.TypeName + rightColumn = joinTable.TypeName + } + + leftField := m2mTable.fieldByGoName(leftColumn) + if leftField == nil { + panic(fmt.Errorf( + "bun: %s many-to-many %s: %s must have field %s "+ + "(to override, use tag join:LeftField=RightField on field %s.%s", + t.TypeName, field.GoName, m2mTable.TypeName, leftColumn, t.TypeName, field.GoName, + )) + } + + rightField := m2mTable.fieldByGoName(rightColumn) + if rightField == nil { + panic(fmt.Errorf( + "bun: %s many-to-many %s: %s must have field %s "+ + "(to override, use tag join:LeftField=RightField on field %s.%s", + t.TypeName, field.GoName, m2mTable.TypeName, rightColumn, t.TypeName, field.GoName, + )) + } + + leftRel := m2mTable.belongsToRelation(leftField) + rel.BaseFields = leftRel.JoinFields + rel.M2MBaseFields = leftRel.BaseFields + + rightRel := m2mTable.belongsToRelation(rightField) + rel.JoinFields = rightRel.JoinFields + rel.M2MJoinFields = rightRel.BaseFields + + return rel +} + +func (t *Table) inlineFields(field *Field, path map[reflect.Type]struct{}) { + if path == nil { + path = map[reflect.Type]struct{}{ + t.Type: {}, + } + } + + if _, ok := path[field.IndirectType]; ok { + return + } + path[field.IndirectType] = struct{}{} + + joinTable := t.dialect.Tables().Ref(field.IndirectType) + for _, f := range joinTable.allFields { + f = f.Clone() + f.GoName = field.GoName + "_" + f.GoName + f.Name = field.Name + "__" + f.Name + f.SQLName = t.quoteIdent(f.Name) + f.Index = appendNew(field.Index, f.Index...) + + t.fieldsMapMu.Lock() + if _, ok := t.FieldMap[f.Name]; !ok { + t.FieldMap[f.Name] = f + } + t.fieldsMapMu.Unlock() + + if f.IndirectType.Kind() != reflect.Struct { + continue + } + + if _, ok := path[f.IndirectType]; !ok { + t.inlineFields(f, path) + } + } +} + +//------------------------------------------------------------------------------ + +func (t *Table) Dialect() Dialect { return t.dialect } + +//------------------------------------------------------------------------------ + +func (t *Table) HasBeforeScanHook() bool { return t.flags.Has(beforeScanHookFlag) } +func (t *Table) HasAfterScanHook() bool { return t.flags.Has(afterScanHookFlag) } + +//------------------------------------------------------------------------------ + +func (t *Table) AppendNamedArg( + fmter Formatter, b []byte, name string, strct reflect.Value, +) ([]byte, bool) { + if field, ok := t.FieldMap[name]; ok { + return fmter.appendArg(b, field.Value(strct).Interface()), true + } + return b, false +} + +func (t *Table) quoteTableName(s string) Safe { + // Don't quote if table name contains placeholder (?) or parentheses. + if strings.IndexByte(s, '?') >= 0 || + strings.IndexByte(s, '(') >= 0 || + strings.IndexByte(s, ')') >= 0 { + return Safe(s) + } + return t.quoteIdent(s) +} + +func (t *Table) quoteIdent(s string) Safe { + return Safe(NewFormatter(t.dialect).AppendIdent(nil, s)) +} + +func appendNew(dst []int, src ...int) []int { + cp := make([]int, len(dst)+len(src)) + copy(cp, dst) + copy(cp[len(dst):], src) + return cp +} + +func isKnownTableOption(name string) bool { + switch name { + case "alias", "select": + return true + } + return false +} + +func isKnownFieldOption(name string) bool { + switch name { + case "alias", + "type", + "array", + "hstore", + "composite", + "json_use_number", + "msgpack", + "notnull", + "nullzero", + "allowzero", + "default", + "unique", + "soft_delete", + + "pk", + "autoincrement", + "rel", + "join", + "m2m", + "polymorphic": + return true + } + return false +} + +func removeField(fields []*Field, field *Field) []*Field { + for i, f := range fields { + if f == field { + return append(fields[:i], fields[i+1:]...) + } + } + return fields +} + +func parseRelationJoin(join string) ([]string, []string) { + ss := strings.Split(join, ",") + baseColumns := make([]string, len(ss)) + joinColumns := make([]string, len(ss)) + for i, s := range ss { + ss := strings.Split(strings.TrimSpace(s), "=") + if len(ss) != 2 { + panic(fmt.Errorf("can't parse relation join: %q", join)) + } + baseColumns[i] = ss[0] + joinColumns[i] = ss[1] + } + return baseColumns, joinColumns +} + +//------------------------------------------------------------------------------ + +func softDeleteFieldUpdater(field *Field) func(fv reflect.Value) error { + typ := field.StructField.Type + + switch typ { + case timeType: + return func(fv reflect.Value) error { + ptr := fv.Addr().Interface().(*time.Time) + *ptr = time.Now() + return nil + } + case nullTimeType: + return func(fv reflect.Value) error { + ptr := fv.Addr().Interface().(*sql.NullTime) + *ptr = sql.NullTime{Time: time.Now()} + return nil + } + case nullIntType: + return func(fv reflect.Value) error { + ptr := fv.Addr().Interface().(*sql.NullInt64) + *ptr = sql.NullInt64{Int64: time.Now().UnixNano()} + return nil + } + } + + switch field.IndirectType.Kind() { + case reflect.Int64: + return func(fv reflect.Value) error { + ptr := fv.Addr().Interface().(*int64) + *ptr = time.Now().UnixNano() + return nil + } + case reflect.Ptr: + typ = typ.Elem() + default: + return softDeleteFieldUpdaterFallback(field) + } + + switch typ { //nolint:gocritic + case timeType: + return func(fv reflect.Value) error { + now := time.Now() + fv.Set(reflect.ValueOf(&now)) + return nil + } + } + + switch typ.Kind() { //nolint:gocritic + case reflect.Int64: + return func(fv reflect.Value) error { + utime := time.Now().UnixNano() + fv.Set(reflect.ValueOf(&utime)) + return nil + } + } + + return softDeleteFieldUpdaterFallback(field) +} + +func softDeleteFieldUpdaterFallback(field *Field) func(fv reflect.Value) error { + return func(fv reflect.Value) error { + return field.ScanWithCheck(fv, time.Now()) + } +} diff --git a/vendor/github.com/uptrace/bun/schema/tables.go b/vendor/github.com/uptrace/bun/schema/tables.go new file mode 100644 index 000000000..d82d08f59 --- /dev/null +++ b/vendor/github.com/uptrace/bun/schema/tables.go @@ -0,0 +1,148 @@ +package schema + +import ( + "fmt" + "reflect" + "sync" +) + +type tableInProgress struct { + table *Table + + init1Once sync.Once + init2Once sync.Once +} + +func newTableInProgress(table *Table) *tableInProgress { + return &tableInProgress{ + table: table, + } +} + +func (inp *tableInProgress) init1() bool { + var inited bool + inp.init1Once.Do(func() { + inp.table.init1() + inited = true + }) + return inited +} + +func (inp *tableInProgress) init2() bool { + var inited bool + inp.init2Once.Do(func() { + inp.table.init2() + inited = true + }) + return inited +} + +type Tables struct { + dialect Dialect + tables sync.Map + + mu sync.RWMutex + inProgress map[reflect.Type]*tableInProgress +} + +func NewTables(dialect Dialect) *Tables { + return &Tables{ + dialect: dialect, + inProgress: make(map[reflect.Type]*tableInProgress), + } +} + +func (t *Tables) Register(models ...interface{}) { + for _, model := range models { + _ = t.Get(reflect.TypeOf(model).Elem()) + } +} + +func (t *Tables) Get(typ reflect.Type) *Table { + return t.table(typ, false) +} + +func (t *Tables) Ref(typ reflect.Type) *Table { + return t.table(typ, true) +} + +func (t *Tables) table(typ reflect.Type, allowInProgress bool) *Table { + if typ.Kind() != reflect.Struct { + panic(fmt.Errorf("got %s, wanted %s", typ.Kind(), reflect.Struct)) + } + + if v, ok := t.tables.Load(typ); ok { + return v.(*Table) + } + + t.mu.Lock() + + if v, ok := t.tables.Load(typ); ok { + t.mu.Unlock() + return v.(*Table) + } + + var table *Table + + inProgress := t.inProgress[typ] + if inProgress == nil { + table = newTable(t.dialect, typ) + inProgress = newTableInProgress(table) + t.inProgress[typ] = inProgress + } else { + table = inProgress.table + } + + t.mu.Unlock() + + inProgress.init1() + if allowInProgress { + return table + } + + if inProgress.init2() { + t.mu.Lock() + delete(t.inProgress, typ) + t.tables.Store(typ, table) + t.mu.Unlock() + } + + t.dialect.OnTable(table) + + for _, field := range table.FieldMap { + if field.UserSQLType == "" { + field.UserSQLType = field.DiscoveredSQLType + } + if field.CreateTableSQLType == "" { + field.CreateTableSQLType = field.UserSQLType + } + } + + return table +} + +func (t *Tables) ByModel(name string) *Table { + var found *Table + t.tables.Range(func(key, value interface{}) bool { + t := value.(*Table) + if t.TypeName == name { + found = t + return false + } + return true + }) + return found +} + +func (t *Tables) ByName(name string) *Table { + var found *Table + t.tables.Range(func(key, value interface{}) bool { + t := value.(*Table) + if t.Name == name { + found = t + return false + } + return true + }) + return found +} diff --git a/vendor/github.com/uptrace/bun/schema/util.go b/vendor/github.com/uptrace/bun/schema/util.go new file mode 100644 index 000000000..6d474e4cc --- /dev/null +++ b/vendor/github.com/uptrace/bun/schema/util.go @@ -0,0 +1,53 @@ +package schema + +import "reflect" + +func indirectType(t reflect.Type) reflect.Type { + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + return t +} + +func fieldByIndex(v reflect.Value, index []int) (_ reflect.Value, ok bool) { + if len(index) == 1 { + return v.Field(index[0]), true + } + + for i, idx := range index { + if i > 0 { + if v.Kind() == reflect.Ptr { + if v.IsNil() { + return v, false + } + v = v.Elem() + } + } + v = v.Field(idx) + } + return v, true +} + +func fieldByIndexAlloc(v reflect.Value, index []int) reflect.Value { + if len(index) == 1 { + return v.Field(index[0]) + } + + for i, idx := range index { + if i > 0 { + v = indirectNil(v) + } + v = v.Field(idx) + } + return v +} + +func indirectNil(v reflect.Value) reflect.Value { + if v.Kind() == reflect.Ptr { + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } + v = v.Elem() + } + return v +} diff --git a/vendor/github.com/uptrace/bun/schema/zerochecker.go b/vendor/github.com/uptrace/bun/schema/zerochecker.go new file mode 100644 index 000000000..95efeee6b --- /dev/null +++ b/vendor/github.com/uptrace/bun/schema/zerochecker.go @@ -0,0 +1,126 @@ +package schema + +import ( + "database/sql/driver" + "reflect" +) + +var isZeroerType = reflect.TypeOf((*isZeroer)(nil)).Elem() + +type isZeroer interface { + IsZero() bool +} + +type IsZeroerFunc func(reflect.Value) bool + +func FieldZeroChecker(field *Field) IsZeroerFunc { + return zeroChecker(field.IndirectType) +} + +func zeroChecker(typ reflect.Type) IsZeroerFunc { + if typ.Implements(isZeroerType) { + return isZeroInterface + } + + kind := typ.Kind() + + if kind != reflect.Ptr { + ptr := reflect.PtrTo(typ) + if ptr.Implements(isZeroerType) { + return addrChecker(isZeroInterface) + } + } + + switch kind { + case reflect.Array: + if typ.Elem().Kind() == reflect.Uint8 { + return isZeroBytes + } + return isZeroLen + case reflect.String: + return isZeroLen + case reflect.Bool: + return isZeroBool + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return isZeroInt + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return isZeroUint + case reflect.Float32, reflect.Float64: + return isZeroFloat + case reflect.Interface, reflect.Ptr, reflect.Slice, reflect.Map: + return isNil + } + + if typ.Implements(driverValuerType) { + return isZeroDriverValue + } + + return notZero +} + +func addrChecker(fn IsZeroerFunc) IsZeroerFunc { + return func(v reflect.Value) bool { + if !v.CanAddr() { + return false + } + return fn(v.Addr()) + } +} + +func isZeroInterface(v reflect.Value) bool { + if v.Kind() == reflect.Ptr && v.IsNil() { + return true + } + return v.Interface().(isZeroer).IsZero() +} + +func isZeroDriverValue(v reflect.Value) bool { + if v.Kind() == reflect.Ptr { + return v.IsNil() + } + + valuer := v.Interface().(driver.Valuer) + value, err := valuer.Value() + if err != nil { + return false + } + return value == nil +} + +func isZeroLen(v reflect.Value) bool { + return v.Len() == 0 +} + +func isNil(v reflect.Value) bool { + return v.IsNil() +} + +func isZeroBool(v reflect.Value) bool { + return !v.Bool() +} + +func isZeroInt(v reflect.Value) bool { + return v.Int() == 0 +} + +func isZeroUint(v reflect.Value) bool { + return v.Uint() == 0 +} + +func isZeroFloat(v reflect.Value) bool { + return v.Float() == 0 +} + +func isZeroBytes(v reflect.Value) bool { + b := v.Slice(0, v.Len()).Bytes() + for _, c := range b { + if c != 0 { + return false + } + } + return true +} + +func notZero(v reflect.Value) bool { + return false +} diff --git a/vendor/github.com/uptrace/bun/util.go b/vendor/github.com/uptrace/bun/util.go new file mode 100644 index 000000000..ce56be805 --- /dev/null +++ b/vendor/github.com/uptrace/bun/util.go @@ -0,0 +1,114 @@ +package bun + +import "reflect" + +func indirect(v reflect.Value) reflect.Value { + switch v.Kind() { + case reflect.Interface: + return indirect(v.Elem()) + case reflect.Ptr: + return v.Elem() + default: + return v + } +} + +func walk(v reflect.Value, index []int, fn func(reflect.Value)) { + v = reflect.Indirect(v) + switch v.Kind() { + case reflect.Slice: + sliceLen := v.Len() + for i := 0; i < sliceLen; i++ { + visitField(v.Index(i), index, fn) + } + default: + visitField(v, index, fn) + } +} + +func visitField(v reflect.Value, index []int, fn func(reflect.Value)) { + v = reflect.Indirect(v) + if len(index) > 0 { + v = v.Field(index[0]) + if v.Kind() == reflect.Ptr && v.IsNil() { + return + } + walk(v, index[1:], fn) + } else { + fn(v) + } +} + +func typeByIndex(t reflect.Type, index []int) reflect.Type { + for _, x := range index { + switch t.Kind() { + case reflect.Ptr: + t = t.Elem() + case reflect.Slice: + t = indirectType(t.Elem()) + } + t = t.Field(x).Type + } + return indirectType(t) +} + +func indirectType(t reflect.Type) reflect.Type { + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + return t +} + +func sliceElemType(v reflect.Value) reflect.Type { + elemType := v.Type().Elem() + if elemType.Kind() == reflect.Interface && v.Len() > 0 { + return indirect(v.Index(0).Elem()).Type() + } + return indirectType(elemType) +} + +func makeSliceNextElemFunc(v reflect.Value) func() reflect.Value { + if v.Kind() == reflect.Array { + var pos int + return func() reflect.Value { + v := v.Index(pos) + pos++ + return v + } + } + + sliceType := v.Type() + elemType := sliceType.Elem() + + if elemType.Kind() == reflect.Ptr { + elemType = elemType.Elem() + return func() reflect.Value { + if v.Len() < v.Cap() { + v.Set(v.Slice(0, v.Len()+1)) + elem := v.Index(v.Len() - 1) + if elem.IsNil() { + elem.Set(reflect.New(elemType)) + } + return elem.Elem() + } + + elem := reflect.New(elemType) + v.Set(reflect.Append(v, elem)) + return elem.Elem() + } + } + + zero := reflect.Zero(elemType) + return func() reflect.Value { + l := v.Len() + c := v.Cap() + + if l < c { + v.Set(v.Slice(0, l+1)) + return v.Index(l) + } + + v.Set(reflect.Append(v, zero)) + return v.Index(l) + } +} diff --git a/vendor/github.com/uptrace/bun/version.go b/vendor/github.com/uptrace/bun/version.go new file mode 100644 index 000000000..1baf9a39c --- /dev/null +++ b/vendor/github.com/uptrace/bun/version.go @@ -0,0 +1,6 @@ +package bun + +// Version is the current release version. +func Version() string { + return "0.4.3" +} |