diff options
Diffstat (limited to 'vendor/github.com/go-pg/pg/v10')
101 files changed, 17892 insertions, 0 deletions
diff --git a/vendor/github.com/go-pg/pg/v10/.golangci.yml b/vendor/github.com/go-pg/pg/v10/.golangci.yml new file mode 100644 index 000000000..e2b5ce924 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/.golangci.yml @@ -0,0 +1,18 @@ +run: + concurrency: 8 + deadline: 5m + tests: false +linters: + enable-all: true + disable: + - gochecknoglobals + - gocognit + - gomnd + - wsl + - funlen + - godox + - goerr113 + - exhaustive + - nestif + - gofumpt + - goconst diff --git a/vendor/github.com/go-pg/pg/v10/.prettierrc b/vendor/github.com/go-pg/pg/v10/.prettierrc new file mode 100644 index 000000000..8b7f044ad --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/.prettierrc @@ -0,0 +1,4 @@ +semi: false +singleQuote: true +proseWrap: always +printWidth: 100 diff --git a/vendor/github.com/go-pg/pg/v10/.travis.yml b/vendor/github.com/go-pg/pg/v10/.travis.yml new file mode 100644 index 000000000..6db22a449 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/.travis.yml @@ -0,0 +1,21 @@ +dist: xenial +language: go + +addons: + postgresql: '9.6' + +go: + - 1.14.x + - 1.15.x + - tip + +matrix: + allow_failures: + - go: tip + +go_import_path: github.com/go-pg/pg + +before_install: + - psql -U postgres -c "CREATE EXTENSION hstore" + - curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- + -b $(go env GOPATH)/bin v1.28.3 diff --git a/vendor/github.com/go-pg/pg/v10/CHANGELOG.md b/vendor/github.com/go-pg/pg/v10/CHANGELOG.md new file mode 100644 index 000000000..6a8288033 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/CHANGELOG.md @@ -0,0 +1,204 @@ +# Changelog + +> :heart: +> [**Uptrace.dev** - All-in-one tool to optimize performance and monitor errors & logs](https://uptrace.dev) + +**Important**. Please check [Bun](https://bun.uptrace.dev/guide/pg-migration.html) - the next +iteration of go-pg built on top of `sql.DB`. + +## v10.10 + +- Removed extra OpenTelemetry spans from go-pg core. Now go-pg instrumentation only adds a single + span with a SQL query (instead of 4 spans). There are multiple reasons behind this decision: + + - Traces become smaller and less noisy. + - [Bun](https://github.com/uptrace/bun) can't support the same level of instrumentation and it is + nice to keep the projects synced. + - It may be costly to process those 3 extra spans for each query. + + Eventually we hope to replace the information that we no longer collect with OpenTelemetry + Metrics. + +## v10.9 + +- To make updating easier, extra modules now have the same version as go-pg does. That means that + you need to update your imports: + +``` +github.com/go-pg/pg/extra/pgdebug -> github.com/go-pg/pg/extra/pgdebug/v10 +github.com/go-pg/pg/extra/pgotel -> github.com/go-pg/pg/extra/pgotel/v10 +github.com/go-pg/pg/extra/pgsegment -> github.com/go-pg/pg/extra/pgsegment/v10 +``` + +- Exported `pg.Query` which should be used instead of `orm.Query`. +- Added `pg.DBI` which is a DB interface implemented by `pg.DB` and `pg.Tx`. + +## v10 + +### Resources + +- Docs at https://pg.uptrace.dev/ powered by [mkdocs](https://github.com/squidfunk/mkdocs-material). +- [RealWorld example application](https://github.com/uptrace/go-realworld-example-app). +- [Discord](https://discord.gg/rWtp5Aj). + +### Features + +- `Select`, `Insert`, and `Update` support `map[string]interface{}`. `Select` also supports + `[]map[string]interface{}`. + +```go +var mm []map[string]interface{} +err := db.Model((*User)(nil)).Limit(10).Select(&mm) +``` + +- Columns that start with `_` are ignored if there is no destination field. +- Optional [faster json encoding](https://github.com/go-pg/pgext). +- Added [pgext.OpenTelemetryHook](https://github.com/go-pg/pgext) that adds + [OpenTelemetry instrumentation](https://pg.uptrace.dev/tracing/). +- Added [pgext.DebugHook](https://github.com/go-pg/pgext) that logs failed queries. +- Added `db.Ping` to check if database is healthy. + +### Changes + +- ORM relations are reworked and now require `rel` tag option (but existing code will continue + working until v11). Supported options: + - `pg:"rel:has-one"` - has one relation. + - `pg:"rel:belongs-to"` - belongs to relation. + - `pg:"rel:has-many"` - has many relation. + - `pg:"many2many:book_genres"` - many to many relation. +- Changed `pg.QueryHook` to return temp byte slice to reduce memory usage. +- `,msgpack` struct tag marshals data in MessagePack format using + https://github.com/vmihailenco/msgpack +- Empty slices and maps are no longer marshaled as `NULL`. Nil slices and maps are still marshaled + as `NULL`. +- Changed `UpdateNotZero` to include zero fields with `pg:",use_zero"` tag. Consider using + `Model(*map[string]interface{})` for inserts and updates. +- `joinFK` is deprecated in favor of `join_fk`. +- `partitionBy` is deprecated in favor of `partition_by`. +- ORM shortcuts are removed: + - `db.Select(model)` becomes `db.Model(model).WherePK().Select()`. + - `db.Insert(model)` becomes `db.Model(model).Insert()`. + - `db.Update(model)` becomes `db.Model(model).WherePK().Update()`. + - `db.Delete(model)` becomes `db.Model(model).WherePK().Delete()`. +- Deprecated types and funcs are removed. +- `WhereStruct` is removed. + +## v9 + +- `pg:",notnull"` is reworked. Now it means SQL `NOT NULL` constraint and nothing more. +- Added `pg:",use_zero"` to prevent go-pg from converting Go zero values to SQL `NULL`. +- UpdateNotNull is renamed to UpdateNotZero. As previously it omits zero Go values, but it does not + take in account if field is nullable or not. +- ORM supports DistinctOn. +- Hooks accept and return context. +- Client respects Context.Deadline when setting net.Conn deadline. +- Client listens on Context.Done while waiting for a connection from the pool and returns an error + when context is cancelled. +- `Query.Column` does not accept relation name any more. Use `Query.Relation` instead which returns + an error if relation does not exist. +- urlvalues package is removed in favor of https://github.com/go-pg/urlstruct. You can also use + struct based filters via `Query.WhereStruct`. +- `NewModel` and `AddModel` methods of `HooklessModel` interface were renamed to `NextColumnScanner` + and `AddColumnScanner` respectively. +- `types.F` and `pg.F` are deprecated in favor of `pg.Ident`. +- `types.Q` is deprecated in favor of `pg.Safe`. +- `pg.Q` is deprecated in favor of `pg.SafeQuery`. +- `TableName` field is deprecated in favor of `tableName`. +- Always use `pg:"..."` struct field tag instead of `sql:"..."`. +- `pg:",override"` is deprecated in favor of `pg:",inherit"`. + +## v8 + +- Added `QueryContext`, `ExecContext`, and `ModelContext` which accept `context.Context`. Queries + are cancelled when context is cancelled. +- Model hooks are changed to accept `context.Context` as first argument. +- Fixed array and hstore parsers to handle multiple single quotes (#1235). + +## v7 + +- DB.OnQueryProcessed is replaced with DB.AddQueryHook. +- Added WhereStruct. +- orm.Pager is moved to urlvalues.Pager. Pager.FromURLValues returns an error if page or limit + params can't be parsed. + +## v6.16 + +- Read buffer is re-worked. Default read buffer is increased to 65kb. + +## v6.15 + +- Added Options.MinIdleConns. +- Options.MaxAge renamed to Options.MaxConnAge. +- PoolStats.FreeConns is renamed to PoolStats.IdleConns. +- New hook BeforeSelectQuery. +- `,override` is renamed to `,inherit`. +- Dialer.KeepAlive is set to 5 minutes by default. +- Added support "scram-sha-256" authentication. + +## v6.14 + +- Fields ignored with `sql:"-"` tag are no longer considered by ORM relation detector. + +## v6.12 + +- `Insert`, `Update`, and `Delete` can return `pg.ErrNoRows` and `pg.ErrMultiRows` when `Returning` + is used and model expects single row. + +## v6.11 + +- `db.Model(&strct).Update()` and `db.Model(&strct).Delete()` no longer adds WHERE condition based + on primary key when there are no conditions. Instead you should use `db.Update(&strct)` or + `db.Model(&strct).WherePK().Update()`. + +## v6.10 + +- `?Columns` is renamed to `?TableColumns`. `?Columns` is changed to produce column names without + table alias. + +## v6.9 + +- `pg:"fk"` tag now accepts SQL names instead of Go names, e.g. `pg:"fk:ParentId"` becomes + `pg:"fk:parent_id"`. Old code should continue working in most cases, but it is strongly advised to + start using new convention. +- uint and uint64 SQL type is changed from decimal to bigint according to the lesser of two evils + principle. Use `sql:"type:decimal"` to get old behavior. + +## v6.8 + +- `CreateTable` no longer adds ON DELETE hook by default. To get old behavior users should add + `sql:"on_delete:CASCADE"` tag on foreign key field. + +## v6 + +- `types.Result` is renamed to `orm.Result`. +- Added `OnQueryProcessed` event that can be used to log / report queries timing. Query logger is + removed. +- `orm.URLValues` is renamed to `orm.URLFilters`. It no longer adds ORDER clause. +- `orm.Pager` is renamed to `orm.Pagination`. +- Support for net.IP and net.IPNet. +- Support for context.Context. +- Bulk/multi updates. +- Query.WhereGroup for enclosing conditions in parentheses. + +## v5 + +- All fields are nullable by default. `,null` tag is replaced with `,notnull`. +- `Result.Affected` renamed to `Result.RowsAffected`. +- Added `Result.RowsReturned`. +- `Create` renamed to `Insert`, `BeforeCreate` to `BeforeInsert`, `AfterCreate` to `AfterInsert`. +- Indexed placeholders support, e.g. `db.Exec("SELECT ?0 + ?0", 1)`. +- Named placeholders are evaluated when query is executed. +- Added Update and Delete hooks. +- Order reworked to quote column names. OrderExpr added to bypass Order quoting restrictions. +- Group reworked to quote column names. GroupExpr added to bypass Group quoting restrictions. + +## v4 + +- `Options.Host` and `Options.Port` merged into `Options.Addr`. +- Added `Options.MaxRetries`. Now queries are not retried by default. +- `LoadInto` renamed to `Scan`, `ColumnLoader` renamed to `ColumnScanner`, LoadColumn renamed to + ScanColumn, `NewRecord() interface{}` changed to `NewModel() ColumnScanner`, + `AppendQuery(dst []byte) []byte` changed to `AppendValue(dst []byte, quote bool) ([]byte, error)`. +- Structs, maps and slices are marshalled to JSON by default. +- Added support for scanning slices, .e.g. scanning `[]int`. +- Added object relational mapping. diff --git a/vendor/github.com/go-pg/pg/v10/LICENSE b/vendor/github.com/go-pg/pg/v10/LICENSE new file mode 100644 index 000000000..7751509b8 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/LICENSE @@ -0,0 +1,24 @@ +Copyright (c) 2013 github.com/go-pg/pg Authors. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/github.com/go-pg/pg/v10/Makefile b/vendor/github.com/go-pg/pg/v10/Makefile new file mode 100644 index 000000000..bacdbadae --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/Makefile @@ -0,0 +1,27 @@ +all: + TZ= go test ./... + TZ= go test ./... -short -race + TZ= go test ./... -run=NONE -bench=. -benchmem + env GOOS=linux GOARCH=386 go test ./... + go vet + golangci-lint run + +.PHONY: cleanTest +cleanTest: + docker rm -fv pg || true + +.PHONY: pre-test +pre-test: cleanTest + docker run -d --name pg -p 5432:5432 -e POSTGRES_HOST_AUTH_METHOD=trust postgres:9.6 + sleep 10 + docker exec pg psql -U postgres -c "CREATE EXTENSION hstore" + +.PHONY: test +test: pre-test + TZ= PGSSLMODE=disable go test ./... -v + +tag: + git tag $(VERSION) + git tag extra/pgdebug/$(VERSION) + git tag extra/pgotel/$(VERSION) + git tag extra/pgsegment/$(VERSION) diff --git a/vendor/github.com/go-pg/pg/v10/README.md b/vendor/github.com/go-pg/pg/v10/README.md new file mode 100644 index 000000000..a624e0b8e --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/README.md @@ -0,0 +1,240 @@ +<p align="center"> + <a href="https://uptrace.dev/?utm_source=gh-pg&utm_campaign=gh-pg-banner1"> + <img src="https://raw.githubusercontent.com/uptrace/roadmap/master/banner1.png"> + </a> +</p> + +# PostgreSQL client and ORM for Golang + +[](https://travis-ci.org/go-pg/pg) +[](https://pkg.go.dev/github.com/go-pg/pg/v10) +[](https://pg.uptrace.dev/) +[](https://discord.gg/rWtp5Aj) + +**Important**. Please check [Bun](https://bun.uptrace.dev/guide/pg-migration.html) - the next +iteration of go-pg built on top of `sql.DB`. + +- Join [Discord](https://discord.gg/rWtp5Aj) to ask questions. +- [Documentation](https://pg.uptrace.dev) +- [Reference](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc) +- [Examples](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#pkg-examples) +- Example projects: + - [treemux](https://github.com/uptrace/go-treemux-realworld-example-app) + - [gin](https://github.com/gogjango/gjango) + - [go-kit](https://github.com/Tsovak/rest-api-demo) + - [aah framework](https://github.com/kieusonlam/golamapi) +- [GraphQL Tutorial on YouTube](https://www.youtube.com/playlist?list=PLzQWIQOqeUSNwXcneWYJHUREAIucJ5UZn). + +## Ecosystem + +- Migrations by [vmihailenco](https://github.com/go-pg/migrations) and + [robinjoseph08](https://github.com/robinjoseph08/go-pg-migrations). +- [Genna - cli tool for generating go-pg models](https://github.com/dizzyfool/genna). +- [bigint](https://github.com/d-fal/bigint) - big.Int type for go-pg. +- [urlstruct](https://github.com/go-pg/urlstruct) to decode `url.Values` into structs. +- [Sharding](https://github.com/go-pg/sharding). +- [go-pg-monitor](https://github.com/hypnoglow/go-pg-monitor) - Prometheus metrics based on go-pg + client stats. + +## Features + +- Basic types: integers, floats, string, bool, time.Time, net.IP, net.IPNet. +- sql.NullBool, sql.NullString, sql.NullInt64, sql.NullFloat64 and + [pg.NullTime](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#NullTime). +- [sql.Scanner](http://golang.org/pkg/database/sql/#Scanner) and + [sql/driver.Valuer](http://golang.org/pkg/database/sql/driver/#Valuer) interfaces. +- Structs, maps and arrays are marshalled as JSON by default. +- PostgreSQL multidimensional Arrays using + [array tag](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Model-PostgresArrayStructTag) + and [Array wrapper](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-Array). +- Hstore using + [hstore tag](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Model-HstoreStructTag) + and [Hstore wrapper](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-Hstore). +- [Composite types](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Model-CompositeType). +- All struct fields are nullable by default and zero values (empty string, 0, zero time, empty map + or slice, nil ptr) are marshalled as SQL `NULL`. `pg:",notnull"` is used to add SQL `NOT NULL` + constraint and `pg:",use_zero"` to allow Go zero values. +- [Transactions](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Begin). +- [Prepared statements](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Prepare). +- [Notifications](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-Listener) using + `LISTEN` and `NOTIFY`. +- [Copying data](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-CopyFrom) using + `COPY FROM` and `COPY TO`. +- [Timeouts](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#Options) and canceling queries using + context.Context. +- Automatic connection pooling with + [circuit breaker](https://en.wikipedia.org/wiki/Circuit_breaker_design_pattern) support. +- Queries retry on network errors. +- Working with models using + [ORM](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model) and + [SQL](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Query). +- Scanning variables using + [ORM](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-SelectSomeColumnsIntoVars) + and [SQL](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-Scan). +- [SelectOrInsert](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-InsertSelectOrInsert) + using on-conflict. +- [INSERT ... ON CONFLICT DO UPDATE](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-InsertOnConflictDoUpdate) + using ORM. +- Bulk/batch + [inserts](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-BulkInsert), + [updates](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-BulkUpdate), and + [deletes](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-BulkDelete). +- Common table expressions using + [WITH](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-SelectWith) and + [WrapWith](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-SelectWrapWith). +- [CountEstimate](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-CountEstimate) + using `EXPLAIN` to get + [estimated number of matching rows](https://wiki.postgresql.org/wiki/Count_estimate). +- ORM supports + [has one](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-HasOne), + [belongs to](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-BelongsTo), + [has many](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-HasMany), and + [many to many](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-ManyToMany) + with composite/multi-column primary keys. +- [Soft deletes](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-SoftDelete). +- [Creating tables from structs](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-CreateTable). +- [ForEach](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-ForEach) that calls + a function for each row returned by the query without loading all rows into the memory. + +## Installation + +go-pg supports 2 last Go versions and requires a Go version with +[modules](https://github.com/golang/go/wiki/Modules) support. So make sure to initialize a Go +module: + +```shell +go mod init github.com/my/repo +``` + +And then install go-pg (note _v10_ in the import; omitting it is a popular mistake): + +```shell +go get github.com/go-pg/pg/v10 +``` + +## Quickstart + +```go +package pg_test + +import ( + "fmt" + + "github.com/go-pg/pg/v10" + "github.com/go-pg/pg/v10/orm" +) + +type User struct { + Id int64 + Name string + Emails []string +} + +func (u User) String() string { + return fmt.Sprintf("User<%d %s %v>", u.Id, u.Name, u.Emails) +} + +type Story struct { + Id int64 + Title string + AuthorId int64 + Author *User `pg:"rel:has-one"` +} + +func (s Story) String() string { + return fmt.Sprintf("Story<%d %s %s>", s.Id, s.Title, s.Author) +} + +func ExampleDB_Model() { + db := pg.Connect(&pg.Options{ + User: "postgres", + }) + defer db.Close() + + err := createSchema(db) + if err != nil { + panic(err) + } + + user1 := &User{ + Name: "admin", + Emails: []string{"admin1@admin", "admin2@admin"}, + } + _, err = db.Model(user1).Insert() + if err != nil { + panic(err) + } + + _, err = db.Model(&User{ + Name: "root", + Emails: []string{"root1@root", "root2@root"}, + }).Insert() + if err != nil { + panic(err) + } + + story1 := &Story{ + Title: "Cool story", + AuthorId: user1.Id, + } + _, err = db.Model(story1).Insert() + if err != nil { + panic(err) + } + + // Select user by primary key. + user := &User{Id: user1.Id} + err = db.Model(user).WherePK().Select() + if err != nil { + panic(err) + } + + // Select all users. + var users []User + err = db.Model(&users).Select() + if err != nil { + panic(err) + } + + // Select story and associated author in one query. + story := new(Story) + err = db.Model(story). + Relation("Author"). + Where("story.id = ?", story1.Id). + Select() + if err != nil { + panic(err) + } + + fmt.Println(user) + fmt.Println(users) + fmt.Println(story) + // Output: User<1 admin [admin1@admin admin2@admin]> + // [User<1 admin [admin1@admin admin2@admin]> User<2 root [root1@root root2@root]>] + // Story<1 Cool story User<1 admin [admin1@admin admin2@admin]>> +} + +// createSchema creates database schema for User and Story models. +func createSchema(db *pg.DB) error { + models := []interface{}{ + (*User)(nil), + (*Story)(nil), + } + + for _, model := range models { + err := db.Model(model).CreateTable(&orm.CreateTableOptions{ + Temp: true, + }) + if err != nil { + return err + } + } + return nil +} +``` + +## See also + +- [Fast and flexible HTTP router](https://github.com/vmihailenco/treemux) +- [Golang msgpack](https://github.com/vmihailenco/msgpack) +- [Golang message task queue](https://github.com/vmihailenco/taskq) diff --git a/vendor/github.com/go-pg/pg/v10/base.go b/vendor/github.com/go-pg/pg/v10/base.go new file mode 100644 index 000000000..d13997464 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/base.go @@ -0,0 +1,618 @@ +package pg + +import ( + "context" + "io" + "time" + + "github.com/go-pg/pg/v10/internal" + "github.com/go-pg/pg/v10/internal/pool" + "github.com/go-pg/pg/v10/orm" + "github.com/go-pg/pg/v10/types" +) + +type baseDB struct { + db orm.DB + opt *Options + pool pool.Pooler + + fmter *orm.Formatter + queryHooks []QueryHook +} + +// PoolStats contains the stats of a connection pool. +type PoolStats pool.Stats + +// PoolStats returns connection pool stats. +func (db *baseDB) PoolStats() *PoolStats { + stats := db.pool.Stats() + return (*PoolStats)(stats) +} + +func (db *baseDB) clone() *baseDB { + return &baseDB{ + db: db.db, + opt: db.opt, + pool: db.pool, + + fmter: db.fmter, + queryHooks: copyQueryHooks(db.queryHooks), + } +} + +func (db *baseDB) withPool(p pool.Pooler) *baseDB { + cp := db.clone() + cp.pool = p + return cp +} + +func (db *baseDB) WithTimeout(d time.Duration) *baseDB { + newopt := *db.opt + newopt.ReadTimeout = d + newopt.WriteTimeout = d + + cp := db.clone() + cp.opt = &newopt + return cp +} + +func (db *baseDB) WithParam(param string, value interface{}) *baseDB { + cp := db.clone() + cp.fmter = db.fmter.WithParam(param, value) + return cp +} + +// Param returns value for the param. +func (db *baseDB) Param(param string) interface{} { + return db.fmter.Param(param) +} + +func (db *baseDB) retryBackoff(retry int) time.Duration { + return internal.RetryBackoff(retry, db.opt.MinRetryBackoff, db.opt.MaxRetryBackoff) +} + +func (db *baseDB) getConn(ctx context.Context) (*pool.Conn, error) { + cn, err := db.pool.Get(ctx) + if err != nil { + return nil, err + } + + if cn.Inited { + return cn, nil + } + + if err := db.initConn(ctx, cn); err != nil { + db.pool.Remove(ctx, cn, err) + // It is safe to reset StickyConnPool if conn can't be initialized. + if p, ok := db.pool.(*pool.StickyConnPool); ok { + _ = p.Reset(ctx) + } + if err := internal.Unwrap(err); err != nil { + return nil, err + } + return nil, err + } + + return cn, nil +} + +func (db *baseDB) initConn(ctx context.Context, cn *pool.Conn) error { + if cn.Inited { + return nil + } + cn.Inited = true + + if db.opt.TLSConfig != nil { + err := db.enableSSL(ctx, cn, db.opt.TLSConfig) + if err != nil { + return err + } + } + + err := db.startup(ctx, cn, db.opt.User, db.opt.Password, db.opt.Database, db.opt.ApplicationName) + if err != nil { + return err + } + + if db.opt.OnConnect != nil { + p := pool.NewSingleConnPool(db.pool, cn) + return db.opt.OnConnect(ctx, newConn(ctx, db.withPool(p))) + } + + return nil +} + +func (db *baseDB) releaseConn(ctx context.Context, cn *pool.Conn, err error) { + if isBadConn(err, false) { + db.pool.Remove(ctx, cn, err) + } else { + db.pool.Put(ctx, cn) + } +} + +func (db *baseDB) withConn( + ctx context.Context, fn func(context.Context, *pool.Conn) error, +) error { + cn, err := db.getConn(ctx) + if err != nil { + return err + } + + var fnDone chan struct{} + if ctx != nil && ctx.Done() != nil { + fnDone = make(chan struct{}) + go func() { + select { + case <-fnDone: // fn has finished, skip cancel + case <-ctx.Done(): + err := db.cancelRequest(cn.ProcessID, cn.SecretKey) + if err != nil { + internal.Logger.Printf(ctx, "cancelRequest failed: %s", err) + } + // Signal end of conn use. + fnDone <- struct{}{} + } + }() + } + + defer func() { + if fnDone == nil { + db.releaseConn(ctx, cn, err) + return + } + + select { + case <-fnDone: // wait for cancel to finish request + // Looks like the canceled connection must be always removed from the pool. + db.pool.Remove(ctx, cn, err) + case fnDone <- struct{}{}: // signal fn finish, skip cancel goroutine + db.releaseConn(ctx, cn, err) + } + }() + + err = fn(ctx, cn) + return err +} + +func (db *baseDB) shouldRetry(err error) bool { + switch err { + case io.EOF, io.ErrUnexpectedEOF: + return true + case nil, context.Canceled, context.DeadlineExceeded: + return false + } + + if pgerr, ok := err.(Error); ok { + switch pgerr.Field('C') { + case "40001", // serialization_failure + "53300", // too_many_connections + "55000": // attempted to delete invisible tuple + return true + case "57014": // statement_timeout + return db.opt.RetryStatementTimeout + default: + return false + } + } + + if _, ok := err.(timeoutError); ok { + return true + } + + return false +} + +// Close closes the database client, releasing any open resources. +// +// It is rare to Close a DB, as the DB handle is meant to be +// long-lived and shared between many goroutines. +func (db *baseDB) Close() error { + return db.pool.Close() +} + +// Exec executes a query ignoring returned rows. The params are for any +// placeholders in the query. +func (db *baseDB) Exec(query interface{}, params ...interface{}) (res Result, err error) { + return db.exec(db.db.Context(), query, params...) +} + +func (db *baseDB) ExecContext(c context.Context, query interface{}, params ...interface{}) (Result, error) { + return db.exec(c, query, params...) +} + +func (db *baseDB) exec(ctx context.Context, query interface{}, params ...interface{}) (Result, error) { + wb := pool.GetWriteBuffer() + defer pool.PutWriteBuffer(wb) + + if err := writeQueryMsg(wb, db.fmter, query, params...); err != nil { + return nil, err + } + + ctx, evt, err := db.beforeQuery(ctx, db.db, nil, query, params, wb.Query()) + if err != nil { + return nil, err + } + + var res Result + var lastErr error + for attempt := 0; attempt <= db.opt.MaxRetries; attempt++ { + if attempt > 0 { + if err := internal.Sleep(ctx, db.retryBackoff(attempt-1)); err != nil { + return nil, err + } + } + + lastErr = db.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error { + res, err = db.simpleQuery(ctx, cn, wb) + return err + }) + if !db.shouldRetry(lastErr) { + break + } + } + + if err := db.afterQuery(ctx, evt, res, lastErr); err != nil { + return nil, err + } + return res, lastErr +} + +// ExecOne acts like Exec, but query must affect only one row. It +// returns ErrNoRows error when query returns zero rows or +// ErrMultiRows when query returns multiple rows. +func (db *baseDB) ExecOne(query interface{}, params ...interface{}) (Result, error) { + return db.execOne(db.db.Context(), query, params...) +} + +func (db *baseDB) ExecOneContext(ctx context.Context, query interface{}, params ...interface{}) (Result, error) { + return db.execOne(ctx, query, params...) +} + +func (db *baseDB) execOne(c context.Context, query interface{}, params ...interface{}) (Result, error) { + res, err := db.ExecContext(c, query, params...) + if err != nil { + return nil, err + } + + if err := internal.AssertOneRow(res.RowsAffected()); err != nil { + return nil, err + } + return res, nil +} + +// Query executes a query that returns rows, typically a SELECT. +// The params are for any placeholders in the query. +func (db *baseDB) Query(model, query interface{}, params ...interface{}) (res Result, err error) { + return db.query(db.db.Context(), model, query, params...) +} + +func (db *baseDB) QueryContext(c context.Context, model, query interface{}, params ...interface{}) (Result, error) { + return db.query(c, model, query, params...) +} + +func (db *baseDB) query(ctx context.Context, model, query interface{}, params ...interface{}) (Result, error) { + wb := pool.GetWriteBuffer() + defer pool.PutWriteBuffer(wb) + + if err := writeQueryMsg(wb, db.fmter, query, params...); err != nil { + return nil, err + } + + ctx, evt, err := db.beforeQuery(ctx, db.db, model, query, params, wb.Query()) + if err != nil { + return nil, err + } + + var res Result + var lastErr error + for attempt := 0; attempt <= db.opt.MaxRetries; attempt++ { + if attempt > 0 { + if err := internal.Sleep(ctx, db.retryBackoff(attempt-1)); err != nil { + return nil, err + } + } + + lastErr = db.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error { + res, err = db.simpleQueryData(ctx, cn, model, wb) + return err + }) + if !db.shouldRetry(lastErr) { + break + } + } + + if err := db.afterQuery(ctx, evt, res, lastErr); err != nil { + return nil, err + } + return res, lastErr +} + +// QueryOne acts like Query, but query must return only one row. It +// returns ErrNoRows error when query returns zero rows or +// ErrMultiRows when query returns multiple rows. +func (db *baseDB) QueryOne(model, query interface{}, params ...interface{}) (Result, error) { + return db.queryOne(db.db.Context(), model, query, params...) +} + +func (db *baseDB) QueryOneContext( + ctx context.Context, model, query interface{}, params ...interface{}, +) (Result, error) { + return db.queryOne(ctx, model, query, params...) +} + +func (db *baseDB) queryOne(ctx context.Context, model, query interface{}, params ...interface{}) (Result, error) { + res, err := db.QueryContext(ctx, model, query, params...) + if err != nil { + return nil, err + } + + if err := internal.AssertOneRow(res.RowsAffected()); err != nil { + return nil, err + } + return res, nil +} + +// CopyFrom copies data from reader to a table. +func (db *baseDB) CopyFrom(r io.Reader, query interface{}, params ...interface{}) (res Result, err error) { + c := db.db.Context() + err = db.withConn(c, func(c context.Context, cn *pool.Conn) error { + res, err = db.copyFrom(c, cn, r, query, params...) + return err + }) + return res, err +} + +// TODO: don't get/put conn in the pool. +func (db *baseDB) copyFrom( + ctx context.Context, cn *pool.Conn, r io.Reader, query interface{}, params ...interface{}, +) (res Result, err error) { + var evt *QueryEvent + + wb := pool.GetWriteBuffer() + defer pool.PutWriteBuffer(wb) + + if err := writeQueryMsg(wb, db.fmter, query, params...); err != nil { + return nil, err + } + + var model interface{} + if len(params) > 0 { + model, _ = params[len(params)-1].(orm.TableModel) + } + + ctx, evt, err = db.beforeQuery(ctx, db.db, model, query, params, wb.Query()) + if err != nil { + return nil, err + } + + // Note that afterQuery uses the err. + defer func() { + if afterQueryErr := db.afterQuery(ctx, evt, res, err); afterQueryErr != nil { + err = afterQueryErr + } + }() + + err = cn.WithWriter(ctx, db.opt.WriteTimeout, func(wb *pool.WriteBuffer) error { + return writeQueryMsg(wb, db.fmter, query, params...) + }) + if err != nil { + return nil, err + } + + err = cn.WithReader(ctx, db.opt.ReadTimeout, readCopyInResponse) + if err != nil { + return nil, err + } + + for { + err = cn.WithWriter(ctx, db.opt.WriteTimeout, func(wb *pool.WriteBuffer) error { + return writeCopyData(wb, r) + }) + if err != nil { + if err == io.EOF { + break + } + return nil, err + } + } + + err = cn.WithWriter(ctx, db.opt.WriteTimeout, func(wb *pool.WriteBuffer) error { + writeCopyDone(wb) + return nil + }) + if err != nil { + return nil, err + } + + err = cn.WithReader(ctx, db.opt.ReadTimeout, func(rd *pool.ReaderContext) error { + res, err = readReadyForQuery(rd) + return err + }) + if err != nil { + return nil, err + } + + return res, nil +} + +// CopyTo copies data from a table to writer. +func (db *baseDB) CopyTo(w io.Writer, query interface{}, params ...interface{}) (res Result, err error) { + c := db.db.Context() + err = db.withConn(c, func(c context.Context, cn *pool.Conn) error { + res, err = db.copyTo(c, cn, w, query, params...) + return err + }) + return res, err +} + +func (db *baseDB) copyTo( + ctx context.Context, cn *pool.Conn, w io.Writer, query interface{}, params ...interface{}, +) (res Result, err error) { + var evt *QueryEvent + + wb := pool.GetWriteBuffer() + defer pool.PutWriteBuffer(wb) + + if err := writeQueryMsg(wb, db.fmter, query, params...); err != nil { + return nil, err + } + + var model interface{} + if len(params) > 0 { + model, _ = params[len(params)-1].(orm.TableModel) + } + + ctx, evt, err = db.beforeQuery(ctx, db.db, model, query, params, wb.Query()) + if err != nil { + return nil, err + } + + // Note that afterQuery uses the err. + defer func() { + if afterQueryErr := db.afterQuery(ctx, evt, res, err); afterQueryErr != nil { + err = afterQueryErr + } + }() + + err = cn.WithWriter(ctx, db.opt.WriteTimeout, func(wb *pool.WriteBuffer) error { + return writeQueryMsg(wb, db.fmter, query, params...) + }) + if err != nil { + return nil, err + } + + err = cn.WithReader(ctx, db.opt.ReadTimeout, func(rd *pool.ReaderContext) error { + err := readCopyOutResponse(rd) + if err != nil { + return err + } + + res, err = readCopyData(rd, w) + return err + }) + if err != nil { + return nil, err + } + + return res, nil +} + +// Ping verifies a connection to the database is still alive, +// establishing a connection if necessary. +func (db *baseDB) Ping(ctx context.Context) error { + _, err := db.ExecContext(ctx, "SELECT 1") + return err +} + +// Model returns new query for the model. +func (db *baseDB) Model(model ...interface{}) *Query { + return orm.NewQuery(db.db, model...) +} + +func (db *baseDB) ModelContext(c context.Context, model ...interface{}) *Query { + return orm.NewQueryContext(c, db.db, model...) +} + +func (db *baseDB) Formatter() orm.QueryFormatter { + return db.fmter +} + +func (db *baseDB) cancelRequest(processID, secretKey int32) error { + c := context.TODO() + + cn, err := db.pool.NewConn(c) + if err != nil { + return err + } + defer func() { + _ = db.pool.CloseConn(cn) + }() + + return cn.WithWriter(c, db.opt.WriteTimeout, func(wb *pool.WriteBuffer) error { + writeCancelRequestMsg(wb, processID, secretKey) + return nil + }) +} + +func (db *baseDB) simpleQuery( + c context.Context, cn *pool.Conn, wb *pool.WriteBuffer, +) (*result, error) { + if err := cn.WriteBuffer(c, db.opt.WriteTimeout, wb); err != nil { + return nil, err + } + + var res *result + if err := cn.WithReader(c, db.opt.ReadTimeout, func(rd *pool.ReaderContext) error { + var err error + res, err = readSimpleQuery(rd) + return err + }); err != nil { + return nil, err + } + + return res, nil +} + +func (db *baseDB) simpleQueryData( + c context.Context, cn *pool.Conn, model interface{}, wb *pool.WriteBuffer, +) (*result, error) { + if err := cn.WriteBuffer(c, db.opt.WriteTimeout, wb); err != nil { + return nil, err + } + + var res *result + if err := cn.WithReader(c, db.opt.ReadTimeout, func(rd *pool.ReaderContext) error { + var err error + res, err = readSimpleQueryData(c, rd, model) + return err + }); err != nil { + return nil, err + } + + return res, nil +} + +// Prepare creates a prepared statement for later queries or +// executions. Multiple queries or executions may be run concurrently +// from the returned statement. +func (db *baseDB) Prepare(q string) (*Stmt, error) { + return prepareStmt(db.withPool(pool.NewStickyConnPool(db.pool)), q) +} + +func (db *baseDB) prepare( + c context.Context, cn *pool.Conn, q string, +) (string, []types.ColumnInfo, error) { + name := cn.NextID() + err := cn.WithWriter(c, db.opt.WriteTimeout, func(wb *pool.WriteBuffer) error { + writeParseDescribeSyncMsg(wb, name, q) + return nil + }) + if err != nil { + return "", nil, err + } + + var columns []types.ColumnInfo + err = cn.WithReader(c, db.opt.ReadTimeout, func(rd *pool.ReaderContext) error { + columns, err = readParseDescribeSync(rd) + return err + }) + if err != nil { + return "", nil, err + } + + return name, columns, nil +} + +func (db *baseDB) closeStmt(c context.Context, cn *pool.Conn, name string) error { + err := cn.WithWriter(c, db.opt.WriteTimeout, func(wb *pool.WriteBuffer) error { + writeCloseMsg(wb, name) + writeFlushMsg(wb) + return nil + }) + if err != nil { + return err + } + + err = cn.WithReader(c, db.opt.ReadTimeout, readCloseCompleteMsg) + return err +} diff --git a/vendor/github.com/go-pg/pg/v10/db.go b/vendor/github.com/go-pg/pg/v10/db.go new file mode 100644 index 000000000..27664783b --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/db.go @@ -0,0 +1,142 @@ +package pg + +import ( + "context" + "fmt" + "time" + + "github.com/go-pg/pg/v10/internal/pool" + "github.com/go-pg/pg/v10/orm" +) + +// Connect connects to a database using provided options. +// +// The returned DB is safe for concurrent use by multiple goroutines +// and maintains its own connection pool. +func Connect(opt *Options) *DB { + opt.init() + return newDB( + context.Background(), + &baseDB{ + opt: opt, + pool: newConnPool(opt), + fmter: orm.NewFormatter(), + }, + ) +} + +func newDB(ctx context.Context, baseDB *baseDB) *DB { + db := &DB{ + baseDB: baseDB.clone(), + ctx: ctx, + } + db.baseDB.db = db + return db +} + +// DB is a database handle representing a pool of zero or more +// underlying connections. It's safe for concurrent use by multiple +// goroutines. +type DB struct { + *baseDB + ctx context.Context +} + +var _ orm.DB = (*DB)(nil) + +func (db *DB) String() string { + return fmt.Sprintf("DB<Addr=%q%s>", db.opt.Addr, db.fmter) +} + +// Options returns read-only Options that were used to connect to the DB. +func (db *DB) Options() *Options { + return db.opt +} + +// Context returns DB context. +func (db *DB) Context() context.Context { + return db.ctx +} + +// WithContext returns a copy of the DB that uses the ctx. +func (db *DB) WithContext(ctx context.Context) *DB { + return newDB(ctx, db.baseDB) +} + +// WithTimeout returns a copy of the DB that uses d as the read/write timeout. +func (db *DB) WithTimeout(d time.Duration) *DB { + return newDB(db.ctx, db.baseDB.WithTimeout(d)) +} + +// WithParam returns a copy of the DB that replaces the param with the value +// in queries. +func (db *DB) WithParam(param string, value interface{}) *DB { + return newDB(db.ctx, db.baseDB.WithParam(param, value)) +} + +// Listen listens for notifications sent with NOTIFY command. +func (db *DB) Listen(ctx context.Context, channels ...string) *Listener { + ln := &Listener{ + db: db, + } + ln.init() + _ = ln.Listen(ctx, channels...) + return ln +} + +// Conn represents a single database connection rather than a pool of database +// connections. Prefer running queries from DB unless there is a specific +// need for a continuous single database connection. +// +// A Conn must call Close to return the connection to the database pool +// and may do so concurrently with a running query. +// +// After a call to Close, all operations on the connection fail. +type Conn struct { + *baseDB + ctx context.Context +} + +var _ orm.DB = (*Conn)(nil) + +// Conn returns a single connection from the connection pool. +// Queries run on the same Conn will be run in the same database session. +// +// Every Conn must be returned to the database pool after use by +// calling Conn.Close. +func (db *DB) Conn() *Conn { + return newConn(db.ctx, db.baseDB.withPool(pool.NewStickyConnPool(db.pool))) +} + +func newConn(ctx context.Context, baseDB *baseDB) *Conn { + conn := &Conn{ + baseDB: baseDB, + ctx: ctx, + } + conn.baseDB.db = conn + return conn +} + +// Context returns DB context. +func (db *Conn) Context() context.Context { + if db.ctx != nil { + return db.ctx + } + return context.Background() +} + +// WithContext returns a copy of the DB that uses the ctx. +func (db *Conn) WithContext(ctx context.Context) *Conn { + return newConn(ctx, db.baseDB) +} + +// WithTimeout returns a copy of the DB that uses d as the read/write timeout. +func (db *Conn) WithTimeout(d time.Duration) *Conn { + return newConn(db.ctx, db.baseDB.WithTimeout(d)) +} + +// WithParam returns a copy of the DB that replaces the param with the value +// in queries. +func (db *Conn) WithParam(param string, value interface{}) *Conn { + return newConn(db.ctx, db.baseDB.WithParam(param, value)) +} diff --git a/vendor/github.com/go-pg/pg/v10/doc.go b/vendor/github.com/go-pg/pg/v10/doc.go new file mode 100644 index 000000000..9a077a8c1 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/doc.go @@ -0,0 +1,4 @@ +/* +pg provides PostgreSQL client. +*/ +package pg diff --git a/vendor/github.com/go-pg/pg/v10/error.go b/vendor/github.com/go-pg/pg/v10/error.go new file mode 100644 index 000000000..d8113a010 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/error.go @@ -0,0 +1,69 @@ +package pg + +import ( + "net" + + "github.com/go-pg/pg/v10/internal" +) + +// ErrNoRows is returned by QueryOne and ExecOne when query returned zero rows +// but at least one row is expected. +var ErrNoRows = internal.ErrNoRows + +// ErrMultiRows is returned by QueryOne and ExecOne when query returned +// multiple rows but exactly one row is expected. +var ErrMultiRows = internal.ErrMultiRows + +// Error represents an error returned by PostgreSQL server +// using PostgreSQL ErrorResponse protocol. +// +// https://www.postgresql.org/docs/10/static/protocol-message-formats.html +type Error interface { + error + + // Field returns a string value associated with an error field. + // + // https://www.postgresql.org/docs/10/static/protocol-error-fields.html + Field(field byte) string + + // IntegrityViolation reports whether an error is a part of + // Integrity Constraint Violation class of errors. + // + // https://www.postgresql.org/docs/10/static/errcodes-appendix.html + IntegrityViolation() bool +} + +var _ Error = (*internal.PGError)(nil) + +func isBadConn(err error, allowTimeout bool) bool { + if err == nil { + return false + } + if _, ok := err.(internal.Error); ok { + return false + } + if pgErr, ok := err.(Error); ok { + switch pgErr.Field('V') { + case "FATAL", "PANIC": + return true + } + switch pgErr.Field('C') { + case "25P02", // current transaction is aborted + "57014": // canceling statement due to user request + return true + } + return false + } + if allowTimeout { + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + return !netErr.Temporary() + } + } + return true +} + +//------------------------------------------------------------------------------ + +type timeoutError interface { + Timeout() bool +} diff --git a/vendor/github.com/go-pg/pg/v10/go.mod b/vendor/github.com/go-pg/pg/v10/go.mod new file mode 100644 index 000000000..aa867f309 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/go.mod @@ -0,0 +1,24 @@ +module github.com/go-pg/pg/v10 + +go 1.11 + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/go-pg/zerochecker v0.2.0 + github.com/golang/protobuf v1.4.3 // indirect + github.com/google/go-cmp v0.5.5 // indirect + github.com/jinzhu/inflection v1.0.0 + github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e // indirect + github.com/onsi/ginkgo v1.14.2 + github.com/onsi/gomega v1.10.3 + github.com/stretchr/testify v1.7.0 + github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc + github.com/vmihailenco/bufpool v0.1.11 + github.com/vmihailenco/msgpack/v5 v5.3.1 + github.com/vmihailenco/tagparser v0.1.2 + golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b // indirect + golang.org/x/sys v0.0.0-20210423185535-09eb48e85fd7 // indirect + google.golang.org/protobuf v1.25.0 // indirect + gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f + mellium.im/sasl v0.2.1 +) diff --git a/vendor/github.com/go-pg/pg/v10/go.sum b/vendor/github.com/go-pg/pg/v10/go.sum new file mode 100644 index 000000000..7d2d87c0b --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/go.sum @@ -0,0 +1,154 @@ +cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= +github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= +github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= +github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= +github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= +github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= +github.com/go-pg/zerochecker v0.2.0 h1:pp7f72c3DobMWOb2ErtZsnrPaSvHd2W4o9//8HtF4mU= +github.com/go-pg/zerochecker v0.2.0/go.mod h1:NJZ4wKL0NmTtz0GKCoJ8kym6Xn/EQzXRl2OnAe7MmDo= +github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= +github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= +github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= +github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= +github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= +github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= +github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= +github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/golang/protobuf v1.4.3 h1:JjCZWpVbqXDqFVmTfYWEVTMIYrL/NPdPSCHPJ0T/raM= +github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= +github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs= +github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= +github.com/nxadm/tail v1.4.4 h1:DQuhQpB1tVlglWS2hLQ5OV6B5r8aGxSrPc5Qo6uTN78= +github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= +github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk= +github.com/onsi/ginkgo v1.14.2 h1:8mVmC9kjFFmA8H4pKMUhcblgifdkOIXPvbhN1T36q1M= +github.com/onsi/ginkgo v1.14.2/go.mod h1:iSB4RoI2tjJc9BBv4NKIKWKya62Rps+oPG/Lv9klQyY= +github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY= +github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo= +github.com/onsi/gomega v1.10.3 h1:gph6h/qe9GSUw1NhH1gp+qb+h8rXD8Cy60Z32Qw3ELA= +github.com/onsi/gomega v1.10.3/go.mod h1:V9xEwhxec5O8UDM77eCW8vLymOMltsqPVYWrpDsH8xc= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc h1:9lRDQMhESg+zvGYmW5DyG0UqvY96Bu5QYsTLvCHdrgo= +github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc/go.mod h1:bciPuU6GHm1iF1pBvUfxfsH0Wmnc2VbpgvbI9ZWuIRs= +github.com/vmihailenco/bufpool v0.1.11 h1:gOq2WmBrq0i2yW5QJ16ykccQ4wH9UyEsgLm6czKAd94= +github.com/vmihailenco/bufpool v0.1.11/go.mod h1:AFf/MOy3l2CFTKbxwt0mp2MwnqjNEs5H/UxrkA5jxTQ= +github.com/vmihailenco/msgpack/v5 v5.3.1 h1:0i85a4dsZh8mC//wmyyTEzidDLPQfQAxZIOLtafGbFY= +github.com/vmihailenco/msgpack/v5 v5.3.1/go.mod h1:7xyJ9e+0+9SaZT0Wt1RGleJXzli6Q/V5KbhBonMG9jc= +github.com/vmihailenco/tagparser v0.1.2 h1:gnjoVuB/kljJ5wICEEOpx98oXMWPLj22G67Vbd1qPqc= +github.com/vmihailenco/tagparser v0.1.2/go.mod h1:OeAg3pn3UbLjkWt+rN9oFYB6u/cQgqMEUPoW2WPyhdI= +github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g= +github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds= +golang.org/x/crypto v0.0.0-20180910181607-0e37d006457b/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b h1:7mWr3k41Qtv8XlltBkDkl8LoP3mpSgBW8BUoxtEdbXg= +golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= +golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= +golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= +golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= +golang.org/x/net v0.0.0-20201006153459-a7d1128ccaa0/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110 h1:qWPm9rbaAMKs8Bq/9LRpbMqxWRVUAQwMI9fVrssnTfw= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= +golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190904154756-749cb33beabd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200519105757-fe76b779f299/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423185535-09eb48e85fd7 h1:iGu644GcxtEcrInvDsQRCwJjtCIOlT2V7IRt6ah2Whw= +golang.org/x/sys v0.0.0-20210423185535-09eb48e85fd7/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= +golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= +google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= +google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= +google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= +google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= +google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= +google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= +google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= +google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= +google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= +google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= +google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= +google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.25.0 h1:Ejskq+SyPohKW+1uil0JJMtmHCgJPJ/qWTxr8qp+R4c= +google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU= +gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.3.0 h1:clyUAQHOM3G0M3f5vQj7LuJrETvjVot3Z5el9nffUtU= +gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +mellium.im/sasl v0.2.1 h1:nspKSRg7/SyO0cRGY71OkfHab8tf9kCts6a6oTDut0w= +mellium.im/sasl v0.2.1/go.mod h1:ROaEDLQNuf9vjKqE1SrAfnsobm2YKXT1gnN1uDp1PjQ= diff --git a/vendor/github.com/go-pg/pg/v10/hook.go b/vendor/github.com/go-pg/pg/v10/hook.go new file mode 100644 index 000000000..a95dc20bc --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/hook.go @@ -0,0 +1,139 @@ +package pg + +import ( + "context" + "fmt" + "time" + + "github.com/go-pg/pg/v10/orm" +) + +type ( + BeforeScanHook = orm.BeforeScanHook + AfterScanHook = orm.AfterScanHook + AfterSelectHook = orm.AfterSelectHook + BeforeInsertHook = orm.BeforeInsertHook + AfterInsertHook = orm.AfterInsertHook + BeforeUpdateHook = orm.BeforeUpdateHook + AfterUpdateHook = orm.AfterUpdateHook + BeforeDeleteHook = orm.BeforeDeleteHook + AfterDeleteHook = orm.AfterDeleteHook +) + +//------------------------------------------------------------------------------ + +type dummyFormatter struct{} + +func (dummyFormatter) FormatQuery(b []byte, query string, params ...interface{}) []byte { + return append(b, query...) +} + +// QueryEvent ... +type QueryEvent struct { + StartTime time.Time + DB orm.DB + Model interface{} + Query interface{} + Params []interface{} + fmtedQuery []byte + Result Result + Err error + + Stash map[interface{}]interface{} +} + +// QueryHook ... +type QueryHook interface { + BeforeQuery(context.Context, *QueryEvent) (context.Context, error) + AfterQuery(context.Context, *QueryEvent) error +} + +// UnformattedQuery returns the unformatted query of a query event. +// The query is only valid until the query Result is returned to the user. +func (e *QueryEvent) UnformattedQuery() ([]byte, error) { + return queryString(e.Query) +} + +func queryString(query interface{}) ([]byte, error) { + switch query := query.(type) { + case orm.TemplateAppender: + return query.AppendTemplate(nil) + case string: + return dummyFormatter{}.FormatQuery(nil, query), nil + default: + return nil, fmt.Errorf("pg: can't append %T", query) + } +} + +// FormattedQuery returns the formatted query of a query event. +// The query is only valid until the query Result is returned to the user. +func (e *QueryEvent) FormattedQuery() ([]byte, error) { + return e.fmtedQuery, nil +} + +// AddQueryHook adds a hook into query processing. +func (db *baseDB) AddQueryHook(hook QueryHook) { + db.queryHooks = append(db.queryHooks, hook) +} + +func (db *baseDB) beforeQuery( + ctx context.Context, + ormDB orm.DB, + model, query interface{}, + params []interface{}, + fmtedQuery []byte, +) (context.Context, *QueryEvent, error) { + if len(db.queryHooks) == 0 { + return ctx, nil, nil + } + + event := &QueryEvent{ + StartTime: time.Now(), + DB: ormDB, + Model: model, + Query: query, + Params: params, + fmtedQuery: fmtedQuery, + } + + for i, hook := range db.queryHooks { + var err error + ctx, err = hook.BeforeQuery(ctx, event) + if err != nil { + if err := db.afterQueryFromIndex(ctx, event, i); err != nil { + return ctx, nil, err + } + return ctx, nil, err + } + } + + return ctx, event, nil +} + +func (db *baseDB) afterQuery( + ctx context.Context, + event *QueryEvent, + res Result, + err error, +) error { + if event == nil { + return nil + } + + event.Err = err + event.Result = res + return db.afterQueryFromIndex(ctx, event, len(db.queryHooks)-1) +} + +func (db *baseDB) afterQueryFromIndex(ctx context.Context, event *QueryEvent, hookIndex int) error { + for ; hookIndex >= 0; hookIndex-- { + if err := db.queryHooks[hookIndex].AfterQuery(ctx, event); err != nil { + return err + } + } + return nil +} + +func copyQueryHooks(s []QueryHook) []QueryHook { + return s[:len(s):len(s)] +} diff --git a/vendor/github.com/go-pg/pg/v10/internal/context.go b/vendor/github.com/go-pg/pg/v10/internal/context.go new file mode 100644 index 000000000..06d20c152 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/internal/context.go @@ -0,0 +1,26 @@ +package internal + +import ( + "context" + "time" +) + +type UndoneContext struct { + context.Context +} + +func UndoContext(ctx context.Context) UndoneContext { + return UndoneContext{Context: ctx} +} + +func (UndoneContext) Deadline() (deadline time.Time, ok bool) { + return time.Time{}, false +} + +func (UndoneContext) Done() <-chan struct{} { + return nil +} + +func (UndoneContext) Err() error { + return nil +} diff --git a/vendor/github.com/go-pg/pg/v10/internal/error.go b/vendor/github.com/go-pg/pg/v10/internal/error.go new file mode 100644 index 000000000..ae6524aeb --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/internal/error.go @@ -0,0 +1,61 @@ +package internal + +import ( + "fmt" +) + +var ( + ErrNoRows = Errorf("pg: no rows in result set") + ErrMultiRows = Errorf("pg: multiple rows in result set") +) + +type Error struct { + s string +} + +func Errorf(s string, args ...interface{}) Error { + return Error{s: fmt.Sprintf(s, args...)} +} + +func (err Error) Error() string { + return err.s +} + +type PGError struct { + m map[byte]string +} + +func NewPGError(m map[byte]string) PGError { + return PGError{ + m: m, + } +} + +func (err PGError) Field(k byte) string { + return err.m[k] +} + +func (err PGError) IntegrityViolation() bool { + switch err.Field('C') { + case "23000", "23001", "23502", "23503", "23505", "23514", "23P01": + return true + default: + return false + } +} + +func (err PGError) Error() string { + return fmt.Sprintf("%s #%s %s", + err.Field('S'), err.Field('C'), err.Field('M')) +} + +func AssertOneRow(l int) error { + switch { + case l == 0: + return ErrNoRows + case l > 1: + return ErrMultiRows + default: + return nil + } +} diff --git a/vendor/github.com/go-pg/pg/v10/internal/internal.go b/vendor/github.com/go-pg/pg/v10/internal/internal.go new file mode 100644 index 000000000..bda5028c6 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/internal/internal.go @@ -0,0 +1,27 @@ +/* +internal is a private internal package. +*/ +package internal + +import ( + "math/rand" + "time" +) + +func RetryBackoff(retry int, minBackoff, maxBackoff time.Duration) time.Duration { + if retry < 0 { + panic("not reached") + } + if minBackoff == 0 { + return 0 + } + + d := minBackoff << uint(retry) + d = minBackoff + time.Duration(rand.Int63n(int64(d))) + + if d > maxBackoff || d < minBackoff { + d = maxBackoff + } + + return d +} diff --git a/vendor/github.com/go-pg/pg/v10/internal/log.go b/vendor/github.com/go-pg/pg/v10/internal/log.go new file mode 100644 index 000000000..7ea547b10 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/internal/log.go @@ -0,0 +1,28 @@ +package internal + +import ( + "context" + "fmt" + "log" + "os" +) + +var Warn = log.New(os.Stderr, "WARN: pg: ", log.LstdFlags) + +var Deprecated = log.New(os.Stderr, "DEPRECATED: pg: ", log.LstdFlags) + +type Logging interface { + Printf(ctx context.Context, format string, v ...interface{}) +} + +type logger struct { + log *log.Logger +} + +func (l *logger) Printf(ctx context.Context, format string, v ...interface{}) { + _ = l.log.Output(2, fmt.Sprintf(format, v...)) +} + +var Logger Logging = &logger{ + log: log.New(os.Stderr, "pg: ", log.LstdFlags|log.Lshortfile), +} diff --git a/vendor/github.com/go-pg/pg/v10/internal/parser/parser.go b/vendor/github.com/go-pg/pg/v10/internal/parser/parser.go new file mode 100644 index 000000000..f2db676c9 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/internal/parser/parser.go @@ -0,0 +1,141 @@ +package parser + +import ( + "bytes" + "strconv" + + "github.com/go-pg/pg/v10/internal" +) + +type Parser struct { + b []byte + i int +} + +func New(b []byte) *Parser { + return &Parser{ + b: b, + } +} + +func NewString(s string) *Parser { + return New(internal.StringToBytes(s)) +} + +func (p *Parser) Valid() bool { + return p.i < len(p.b) +} + +func (p *Parser) Bytes() []byte { + return p.b[p.i:] +} + +func (p *Parser) Read() byte { + if p.Valid() { + c := p.b[p.i] + p.Advance() + return c + } + return 0 +} + +func (p *Parser) Peek() byte { + if p.Valid() { + return p.b[p.i] + } + return 0 +} + +func (p *Parser) Advance() { + p.i++ +} + +func (p *Parser) Skip(skip byte) bool { + if p.Peek() == skip { + p.Advance() + return true + } + return false +} + +func (p *Parser) SkipBytes(skip []byte) bool { + if len(skip) > len(p.b[p.i:]) { + return false + } + if !bytes.Equal(p.b[p.i:p.i+len(skip)], skip) { + return false + } + p.i += len(skip) + return true +} + +func (p *Parser) ReadSep(sep byte) ([]byte, bool) { + ind := bytes.IndexByte(p.b[p.i:], sep) + if ind == -1 { + b := p.b[p.i:] + p.i = len(p.b) + return b, false + } + + b := p.b[p.i : p.i+ind] + p.i += ind + 1 + return b, true +} + +func (p *Parser) ReadIdentifier() (string, bool) { + if p.i < len(p.b) && p.b[p.i] == '(' { + s := p.i + 1 + if ind := bytes.IndexByte(p.b[s:], ')'); ind != -1 { + b := p.b[s : s+ind] + p.i = s + ind + 1 + return internal.BytesToString(b), false + } + } + + ind := len(p.b) - p.i + var alpha bool + for i, c := range p.b[p.i:] { + if isNum(c) { + continue + } + if isAlpha(c) || (i > 0 && alpha && c == '_') { + alpha = true + continue + } + ind = i + break + } + if ind == 0 { + return "", false + } + b := p.b[p.i : p.i+ind] + p.i += ind + return internal.BytesToString(b), !alpha +} + +func (p *Parser) ReadNumber() int { + ind := len(p.b) - p.i + for i, c := range p.b[p.i:] { + if !isNum(c) { + ind = i + break + } + } + if ind == 0 { + return 0 + } + n, err := strconv.Atoi(string(p.b[p.i : p.i+ind])) + if err != nil { + panic(err) + } + p.i += ind + return n +} + +func isNum(c byte) bool { + return c >= '0' && c <= '9' +} + +func isAlpha(c byte) bool { + return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') +} diff --git a/vendor/github.com/go-pg/pg/v10/internal/parser/streaming_parser.go b/vendor/github.com/go-pg/pg/v10/internal/parser/streaming_parser.go new file mode 100644 index 000000000..723c12b16 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/internal/parser/streaming_parser.go @@ -0,0 +1,65 @@ +package parser + +import ( + "fmt" + + "github.com/go-pg/pg/v10/internal/pool" +) + +type StreamingParser struct { + pool.Reader +} + +func NewStreamingParser(rd pool.Reader) StreamingParser { + return StreamingParser{ + Reader: rd, + } +} + +func (p StreamingParser) SkipByte(skip byte) error { + c, err := p.ReadByte() + if err != nil { + return err + } + if c == skip { + return nil + } + _ = p.UnreadByte() + return fmt.Errorf("got %q, wanted %q", c, skip) +} + +func (p StreamingParser) ReadSubstring(b []byte) ([]byte, error) { + c, err := p.ReadByte() + if err != nil { + return b, err + } + + for { + if c == '"' { + return b, nil + } + + next, err := p.ReadByte() + if err != nil { + return b, err + } + + if c == '\\' { + switch next { + case '\\', '"': + b = append(b, next) + c, err = p.ReadByte() + if err != nil { + return nil, err + } + default: + b = append(b, '\\') + c = next + } + continue + } + + b = append(b, c) + c = next + } +} diff --git a/vendor/github.com/go-pg/pg/v10/internal/pool/conn.go b/vendor/github.com/go-pg/pg/v10/internal/pool/conn.go new file mode 100644 index 000000000..91045245b --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/internal/pool/conn.go @@ -0,0 +1,158 @@ +package pool + +import ( + "context" + "net" + "strconv" + "sync/atomic" + "time" +) + +var noDeadline = time.Time{} + +type Conn struct { + netConn net.Conn + rd *ReaderContext + + ProcessID int32 + SecretKey int32 + lastID int64 + + createdAt time.Time + usedAt uint32 // atomic + pooled bool + Inited bool +} + +func NewConn(netConn net.Conn) *Conn { + cn := &Conn{ + createdAt: time.Now(), + } + cn.SetNetConn(netConn) + cn.SetUsedAt(time.Now()) + return cn +} + +func (cn *Conn) UsedAt() time.Time { + unix := atomic.LoadUint32(&cn.usedAt) + return time.Unix(int64(unix), 0) +} + +func (cn *Conn) SetUsedAt(tm time.Time) { + atomic.StoreUint32(&cn.usedAt, uint32(tm.Unix())) +} + +func (cn *Conn) RemoteAddr() net.Addr { + return cn.netConn.RemoteAddr() +} + +func (cn *Conn) SetNetConn(netConn net.Conn) { + cn.netConn = netConn + if cn.rd != nil { + cn.rd.Reset(netConn) + } +} + +func (cn *Conn) LockReader() { + if cn.rd != nil { + panic("not reached") + } + cn.rd = NewReaderContext() + cn.rd.Reset(cn.netConn) +} + +func (cn *Conn) NetConn() net.Conn { + return cn.netConn +} + +func (cn *Conn) NextID() string { + cn.lastID++ + return strconv.FormatInt(cn.lastID, 10) +} + +func (cn *Conn) WithReader( + ctx context.Context, timeout time.Duration, fn func(rd *ReaderContext) error, +) error { + if err := cn.netConn.SetReadDeadline(cn.deadline(ctx, timeout)); err != nil { + return err + } + + rd := cn.rd + if rd == nil { + rd = GetReaderContext() + defer PutReaderContext(rd) + + rd.Reset(cn.netConn) + } + + rd.bytesRead = 0 + + if err := fn(rd); err != nil { + return err + } + + return nil +} + +func (cn *Conn) WithWriter( + ctx context.Context, timeout time.Duration, fn func(wb *WriteBuffer) error, +) error { + wb := GetWriteBuffer() + defer PutWriteBuffer(wb) + + if err := fn(wb); err != nil { + return err + } + + return cn.writeBuffer(ctx, timeout, wb) +} + +func (cn *Conn) WriteBuffer(ctx context.Context, timeout time.Duration, wb *WriteBuffer) error { + return cn.writeBuffer(ctx, timeout, wb) +} + +func (cn *Conn) writeBuffer( + ctx context.Context, + timeout time.Duration, + wb *WriteBuffer, +) error { + if err := cn.netConn.SetWriteDeadline(cn.deadline(ctx, timeout)); err != nil { + return err + } + if _, err := cn.netConn.Write(wb.Bytes); err != nil { + return err + } + return nil +} + +func (cn *Conn) Close() error { + return cn.netConn.Close() +} + +func (cn *Conn) deadline(ctx context.Context, timeout time.Duration) time.Time { + tm := time.Now() + cn.SetUsedAt(tm) + + if timeout > 0 { + tm = tm.Add(timeout) + } + + if ctx != nil { + deadline, ok := ctx.Deadline() + if ok { + if timeout == 0 { + return deadline + } + if deadline.Before(tm) { + return deadline + } + return tm + } + } + + if timeout > 0 { + return tm + } + + return noDeadline +} diff --git a/vendor/github.com/go-pg/pg/v10/internal/pool/pool.go b/vendor/github.com/go-pg/pg/v10/internal/pool/pool.go new file mode 100644 index 000000000..59f2c72d0 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/internal/pool/pool.go @@ -0,0 +1,506 @@ +package pool + +import ( + "context" + "errors" + "net" + "sync" + "sync/atomic" + "time" + + "github.com/go-pg/pg/v10/internal" +) + +var ( + ErrClosed = errors.New("pg: database is closed") + ErrPoolTimeout = errors.New("pg: connection pool timeout") +) + +var timers = sync.Pool{ + New: func() interface{} { + t := time.NewTimer(time.Hour) + t.Stop() + return t + }, +} + +// Stats contains pool state information and accumulated stats. +type Stats struct { + Hits uint32 // number of times free connection was found in the pool + Misses uint32 // number of times free connection was NOT found in the pool + Timeouts uint32 // number of times a wait timeout occurred + + TotalConns uint32 // number of total connections in the pool + IdleConns uint32 // number of idle connections in the pool + StaleConns uint32 // number of stale connections removed from the pool +} + +type Pooler interface { + NewConn(context.Context) (*Conn, error) + CloseConn(*Conn) error + + Get(context.Context) (*Conn, error) + Put(context.Context, *Conn) + Remove(context.Context, *Conn, error) + + Len() int + IdleLen() int + Stats() *Stats + + Close() error +} + +type Options struct { + Dialer func(context.Context) (net.Conn, error) + OnClose func(*Conn) error + + PoolSize int + MinIdleConns int + MaxConnAge time.Duration + PoolTimeout time.Duration + IdleTimeout time.Duration + IdleCheckFrequency time.Duration +} + +type ConnPool struct { + opt *Options + + dialErrorsNum uint32 // atomic + + _closed uint32 // atomic + + lastDialErrorMu sync.RWMutex + lastDialError error + + queue chan struct{} + + stats Stats + + connsMu sync.Mutex + conns []*Conn + idleConns []*Conn + + poolSize int + idleConnsLen int +} + +var _ Pooler = (*ConnPool)(nil) + +func NewConnPool(opt *Options) *ConnPool { + p := &ConnPool{ + opt: opt, + + queue: make(chan struct{}, opt.PoolSize), + conns: make([]*Conn, 0, opt.PoolSize), + idleConns: make([]*Conn, 0, opt.PoolSize), + } + + p.connsMu.Lock() + p.checkMinIdleConns() + p.connsMu.Unlock() + + if opt.IdleTimeout > 0 && opt.IdleCheckFrequency > 0 { + go p.reaper(opt.IdleCheckFrequency) + } + + return p +} + +func (p *ConnPool) checkMinIdleConns() { + if p.opt.MinIdleConns == 0 { + return + } + for p.poolSize < p.opt.PoolSize && p.idleConnsLen < p.opt.MinIdleConns { + p.poolSize++ + p.idleConnsLen++ + go func() { + err := p.addIdleConn() + if err != nil { + p.connsMu.Lock() + p.poolSize-- + p.idleConnsLen-- + p.connsMu.Unlock() + } + }() + } +} + +func (p *ConnPool) addIdleConn() error { + cn, err := p.dialConn(context.TODO(), true) + if err != nil { + return err + } + + p.connsMu.Lock() + p.conns = append(p.conns, cn) + p.idleConns = append(p.idleConns, cn) + p.connsMu.Unlock() + return nil +} + +func (p *ConnPool) NewConn(c context.Context) (*Conn, error) { + return p.newConn(c, false) +} + +func (p *ConnPool) newConn(c context.Context, pooled bool) (*Conn, error) { + cn, err := p.dialConn(c, pooled) + if err != nil { + return nil, err + } + + p.connsMu.Lock() + + p.conns = append(p.conns, cn) + if pooled { + // If pool is full remove the cn on next Put. + if p.poolSize >= p.opt.PoolSize { + cn.pooled = false + } else { + p.poolSize++ + } + } + + p.connsMu.Unlock() + return cn, nil +} + +func (p *ConnPool) dialConn(c context.Context, pooled bool) (*Conn, error) { + if p.closed() { + return nil, ErrClosed + } + + if atomic.LoadUint32(&p.dialErrorsNum) >= uint32(p.opt.PoolSize) { + return nil, p.getLastDialError() + } + + netConn, err := p.opt.Dialer(c) + if err != nil { + p.setLastDialError(err) + if atomic.AddUint32(&p.dialErrorsNum, 1) == uint32(p.opt.PoolSize) { + go p.tryDial() + } + return nil, err + } + + cn := NewConn(netConn) + cn.pooled = pooled + return cn, nil +} + +func (p *ConnPool) tryDial() { + for { + if p.closed() { + return + } + + conn, err := p.opt.Dialer(context.TODO()) + if err != nil { + p.setLastDialError(err) + time.Sleep(time.Second) + continue + } + + atomic.StoreUint32(&p.dialErrorsNum, 0) + _ = conn.Close() + return + } +} + +func (p *ConnPool) setLastDialError(err error) { + p.lastDialErrorMu.Lock() + p.lastDialError = err + p.lastDialErrorMu.Unlock() +} + +func (p *ConnPool) getLastDialError() error { + p.lastDialErrorMu.RLock() + err := p.lastDialError + p.lastDialErrorMu.RUnlock() + return err +} + +// Get returns existed connection from the pool or creates a new one. +func (p *ConnPool) Get(ctx context.Context) (*Conn, error) { + if p.closed() { + return nil, ErrClosed + } + + err := p.waitTurn(ctx) + if err != nil { + return nil, err + } + + for { + p.connsMu.Lock() + cn := p.popIdle() + p.connsMu.Unlock() + + if cn == nil { + break + } + + if p.isStaleConn(cn) { + _ = p.CloseConn(cn) + continue + } + + atomic.AddUint32(&p.stats.Hits, 1) + return cn, nil + } + + atomic.AddUint32(&p.stats.Misses, 1) + + newcn, err := p.newConn(ctx, true) + if err != nil { + p.freeTurn() + return nil, err + } + + return newcn, nil +} + +func (p *ConnPool) getTurn() { + p.queue <- struct{}{} +} + +func (p *ConnPool) waitTurn(c context.Context) error { + select { + case <-c.Done(): + return c.Err() + default: + } + + select { + case p.queue <- struct{}{}: + return nil + default: + } + + timer := timers.Get().(*time.Timer) + timer.Reset(p.opt.PoolTimeout) + + select { + case <-c.Done(): + if !timer.Stop() { + <-timer.C + } + timers.Put(timer) + return c.Err() + case p.queue <- struct{}{}: + if !timer.Stop() { + <-timer.C + } + timers.Put(timer) + return nil + case <-timer.C: + timers.Put(timer) + atomic.AddUint32(&p.stats.Timeouts, 1) + return ErrPoolTimeout + } +} + +func (p *ConnPool) freeTurn() { + <-p.queue +} + +func (p *ConnPool) popIdle() *Conn { + if len(p.idleConns) == 0 { + return nil + } + + idx := len(p.idleConns) - 1 + cn := p.idleConns[idx] + p.idleConns = p.idleConns[:idx] + p.idleConnsLen-- + p.checkMinIdleConns() + return cn +} + +func (p *ConnPool) Put(ctx context.Context, cn *Conn) { + if !cn.pooled { + p.Remove(ctx, cn, nil) + return + } + + p.connsMu.Lock() + p.idleConns = append(p.idleConns, cn) + p.idleConnsLen++ + p.connsMu.Unlock() + p.freeTurn() +} + +func (p *ConnPool) Remove(ctx context.Context, cn *Conn, reason error) { + p.removeConnWithLock(cn) + p.freeTurn() + _ = p.closeConn(cn) +} + +func (p *ConnPool) CloseConn(cn *Conn) error { + p.removeConnWithLock(cn) + return p.closeConn(cn) +} + +func (p *ConnPool) removeConnWithLock(cn *Conn) { + p.connsMu.Lock() + p.removeConn(cn) + p.connsMu.Unlock() +} + +func (p *ConnPool) removeConn(cn *Conn) { + for i, c := range p.conns { + if c == cn { + p.conns = append(p.conns[:i], p.conns[i+1:]...) + if cn.pooled { + p.poolSize-- + p.checkMinIdleConns() + } + return + } + } +} + +func (p *ConnPool) closeConn(cn *Conn) error { + if p.opt.OnClose != nil { + _ = p.opt.OnClose(cn) + } + return cn.Close() +} + +// Len returns total number of connections. +func (p *ConnPool) Len() int { + p.connsMu.Lock() + n := len(p.conns) + p.connsMu.Unlock() + return n +} + +// IdleLen returns number of idle connections. +func (p *ConnPool) IdleLen() int { + p.connsMu.Lock() + n := p.idleConnsLen + p.connsMu.Unlock() + return n +} + +func (p *ConnPool) Stats() *Stats { + idleLen := p.IdleLen() + return &Stats{ + Hits: atomic.LoadUint32(&p.stats.Hits), + Misses: atomic.LoadUint32(&p.stats.Misses), + Timeouts: atomic.LoadUint32(&p.stats.Timeouts), + + TotalConns: uint32(p.Len()), + IdleConns: uint32(idleLen), + StaleConns: atomic.LoadUint32(&p.stats.StaleConns), + } +} + +func (p *ConnPool) closed() bool { + return atomic.LoadUint32(&p._closed) == 1 +} + +func (p *ConnPool) Filter(fn func(*Conn) bool) error { + var firstErr error + p.connsMu.Lock() + for _, cn := range p.conns { + if fn(cn) { + if err := p.closeConn(cn); err != nil && firstErr == nil { + firstErr = err + } + } + } + p.connsMu.Unlock() + return firstErr +} + +func (p *ConnPool) Close() error { + if !atomic.CompareAndSwapUint32(&p._closed, 0, 1) { + return ErrClosed + } + + var firstErr error + p.connsMu.Lock() + for _, cn := range p.conns { + if err := p.closeConn(cn); err != nil && firstErr == nil { + firstErr = err + } + } + p.conns = nil + p.poolSize = 0 + p.idleConns = nil + p.idleConnsLen = 0 + p.connsMu.Unlock() + + return firstErr +} + +func (p *ConnPool) reaper(frequency time.Duration) { + ticker := time.NewTicker(frequency) + defer ticker.Stop() + + for range ticker.C { + if p.closed() { + break + } + n, err := p.ReapStaleConns() + if err != nil { + internal.Logger.Printf(context.TODO(), "ReapStaleConns failed: %s", err) + continue + } + atomic.AddUint32(&p.stats.StaleConns, uint32(n)) + } +} + +func (p *ConnPool) ReapStaleConns() (int, error) { + var n int + for { + p.getTurn() + + p.connsMu.Lock() + cn := p.reapStaleConn() + p.connsMu.Unlock() + + p.freeTurn() + + if cn != nil { + _ = p.closeConn(cn) + n++ + } else { + break + } + } + return n, nil +} + +func (p *ConnPool) reapStaleConn() *Conn { + if len(p.idleConns) == 0 { + return nil + } + + cn := p.idleConns[0] + if !p.isStaleConn(cn) { + return nil + } + + p.idleConns = append(p.idleConns[:0], p.idleConns[1:]...) + p.idleConnsLen-- + p.removeConn(cn) + + return cn +} + +func (p *ConnPool) isStaleConn(cn *Conn) bool { + if p.opt.IdleTimeout == 0 && p.opt.MaxConnAge == 0 { + return false + } + + now := time.Now() + if p.opt.IdleTimeout > 0 && now.Sub(cn.UsedAt()) >= p.opt.IdleTimeout { + return true + } + if p.opt.MaxConnAge > 0 && now.Sub(cn.createdAt) >= p.opt.MaxConnAge { + return true + } + + return false +} diff --git a/vendor/github.com/go-pg/pg/v10/internal/pool/pool_single.go b/vendor/github.com/go-pg/pg/v10/internal/pool/pool_single.go new file mode 100644 index 000000000..5a3fde191 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/internal/pool/pool_single.go @@ -0,0 +1,58 @@ +package pool + +import "context" + +type SingleConnPool struct { + pool Pooler + cn *Conn + stickyErr error +} + +var _ Pooler = (*SingleConnPool)(nil) + +func NewSingleConnPool(pool Pooler, cn *Conn) *SingleConnPool { + return &SingleConnPool{ + pool: pool, + cn: cn, + } +} + +func (p *SingleConnPool) NewConn(ctx context.Context) (*Conn, error) { + return p.pool.NewConn(ctx) +} + +func (p *SingleConnPool) CloseConn(cn *Conn) error { + return p.pool.CloseConn(cn) +} + +func (p *SingleConnPool) Get(ctx context.Context) (*Conn, error) { + if p.stickyErr != nil { + return nil, p.stickyErr + } + return p.cn, nil +} + +func (p *SingleConnPool) Put(ctx context.Context, cn *Conn) {} + +func (p *SingleConnPool) Remove(ctx context.Context, cn *Conn, reason error) { + p.cn = nil + p.stickyErr = reason +} + +func (p *SingleConnPool) Close() error { + p.cn = nil + p.stickyErr = ErrClosed + return nil +} + +func (p *SingleConnPool) Len() int { + return 0 +} + +func (p *SingleConnPool) IdleLen() int { + return 0 +} + +func (p *SingleConnPool) Stats() *Stats { + return &Stats{} +} diff --git a/vendor/github.com/go-pg/pg/v10/internal/pool/pool_sticky.go b/vendor/github.com/go-pg/pg/v10/internal/pool/pool_sticky.go new file mode 100644 index 000000000..0415b5e87 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/internal/pool/pool_sticky.go @@ -0,0 +1,202 @@ +package pool + +import ( + "context" + "errors" + "fmt" + "sync/atomic" +) + +const ( + stateDefault = 0 + stateInited = 1 + stateClosed = 2 +) + +type BadConnError struct { + wrapped error +} + +var _ error = (*BadConnError)(nil) + +func (e BadConnError) Error() string { + s := "pg: Conn is in a bad state" + if e.wrapped != nil { + s += ": " + e.wrapped.Error() + } + return s +} + +func (e BadConnError) Unwrap() error { + return e.wrapped +} + +//------------------------------------------------------------------------------ + +type StickyConnPool struct { + pool Pooler + shared int32 // atomic + + state uint32 // atomic + ch chan *Conn + + _badConnError atomic.Value +} + +var _ Pooler = (*StickyConnPool)(nil) + +func NewStickyConnPool(pool Pooler) *StickyConnPool { + p, ok := pool.(*StickyConnPool) + if !ok { + p = &StickyConnPool{ + pool: pool, + ch: make(chan *Conn, 1), + } + } + atomic.AddInt32(&p.shared, 1) + return p +} + +func (p *StickyConnPool) NewConn(ctx context.Context) (*Conn, error) { + return p.pool.NewConn(ctx) +} + +func (p *StickyConnPool) CloseConn(cn *Conn) error { + return p.pool.CloseConn(cn) +} + +func (p *StickyConnPool) Get(ctx context.Context) (*Conn, error) { + // In worst case this races with Close which is not a very common operation. + for i := 0; i < 1000; i++ { + switch atomic.LoadUint32(&p.state) { + case stateDefault: + cn, err := p.pool.Get(ctx) + if err != nil { + return nil, err + } + if atomic.CompareAndSwapUint32(&p.state, stateDefault, stateInited) { + return cn, nil + } + p.pool.Remove(ctx, cn, ErrClosed) + case stateInited: + if err := p.badConnError(); err != nil { + return nil, err + } + cn, ok := <-p.ch + if !ok { + return nil, ErrClosed + } + return cn, nil + case stateClosed: + return nil, ErrClosed + default: + panic("not reached") + } + } + return nil, fmt.Errorf("pg: StickyConnPool.Get: infinite loop") +} + +func (p *StickyConnPool) Put(ctx context.Context, cn *Conn) { + defer func() { + if recover() != nil { + p.freeConn(ctx, cn) + } + }() + p.ch <- cn +} + +func (p *StickyConnPool) freeConn(ctx context.Context, cn *Conn) { + if err := p.badConnError(); err != nil { + p.pool.Remove(ctx, cn, err) + } else { + p.pool.Put(ctx, cn) + } +} + +func (p *StickyConnPool) Remove(ctx context.Context, cn *Conn, reason error) { + defer func() { + if recover() != nil { + p.pool.Remove(ctx, cn, ErrClosed) + } + }() + p._badConnError.Store(BadConnError{wrapped: reason}) + p.ch <- cn +} + +func (p *StickyConnPool) Close() error { + if shared := atomic.AddInt32(&p.shared, -1); shared > 0 { + return nil + } + + for i := 0; i < 1000; i++ { + state := atomic.LoadUint32(&p.state) + if state == stateClosed { + return ErrClosed + } + if atomic.CompareAndSwapUint32(&p.state, state, stateClosed) { + close(p.ch) + cn, ok := <-p.ch + if ok { + p.freeConn(context.TODO(), cn) + } + return nil + } + } + + return errors.New("pg: StickyConnPool.Close: infinite loop") +} + +func (p *StickyConnPool) Reset(ctx context.Context) error { + if p.badConnError() == nil { + return nil + } + + select { + case cn, ok := <-p.ch: + if !ok { + return ErrClosed + } + p.pool.Remove(ctx, cn, ErrClosed) + p._badConnError.Store(BadConnError{wrapped: nil}) + default: + return errors.New("pg: StickyConnPool does not have a Conn") + } + + if !atomic.CompareAndSwapUint32(&p.state, stateInited, stateDefault) { + state := atomic.LoadUint32(&p.state) + return fmt.Errorf("pg: invalid StickyConnPool state: %d", state) + } + + return nil +} + +func (p *StickyConnPool) badConnError() error { + if v := p._badConnError.Load(); v != nil { + err := v.(BadConnError) + if err.wrapped != nil { + return err + } + } + return nil +} + +func (p *StickyConnPool) Len() int { + switch atomic.LoadUint32(&p.state) { + case stateDefault: + return 0 + case stateInited: + return 1 + case stateClosed: + return 0 + default: + panic("not reached") + } +} + +func (p *StickyConnPool) IdleLen() int { + return len(p.ch) +} + +func (p *StickyConnPool) Stats() *Stats { + return &Stats{} +} diff --git a/vendor/github.com/go-pg/pg/v10/internal/pool/reader.go b/vendor/github.com/go-pg/pg/v10/internal/pool/reader.go new file mode 100644 index 000000000..b5d00807d --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/internal/pool/reader.go @@ -0,0 +1,80 @@ +package pool + +import ( + "sync" +) + +type Reader interface { + Buffered() int + + Bytes() []byte + Read([]byte) (int, error) + ReadByte() (byte, error) + UnreadByte() error + ReadSlice(byte) ([]byte, error) + Discard(int) (int, error) + + // ReadBytes(fn func(byte) bool) ([]byte, error) + // ReadN(int) ([]byte, error) + ReadFull() ([]byte, error) + ReadFullTemp() ([]byte, error) +} + +type ColumnInfo struct { + Index int16 + DataType int32 + Name string +} + +type ColumnAlloc struct { + columns []ColumnInfo +} + +func NewColumnAlloc() *ColumnAlloc { + return new(ColumnAlloc) +} + +func (c *ColumnAlloc) Reset() { + c.columns = c.columns[:0] +} + +func (c *ColumnAlloc) New(index int16, name []byte) *ColumnInfo { + c.columns = append(c.columns, ColumnInfo{ + Index: index, + Name: string(name), + }) + return &c.columns[len(c.columns)-1] +} + +func (c *ColumnAlloc) Columns() []ColumnInfo { + return c.columns +} + +type ReaderContext struct { + *BufReader + ColumnAlloc *ColumnAlloc +} + +func NewReaderContext() *ReaderContext { + const bufSize = 1 << 20 // 1mb + return &ReaderContext{ + BufReader: NewBufReader(bufSize), + ColumnAlloc: NewColumnAlloc(), + } +} + +var readerPool = sync.Pool{ + New: func() interface{} { + return NewReaderContext() + }, +} + +func GetReaderContext() *ReaderContext { + rd := readerPool.Get().(*ReaderContext) + return rd +} + +func PutReaderContext(rd *ReaderContext) { + rd.ColumnAlloc.Reset() + readerPool.Put(rd) +} diff --git a/vendor/github.com/go-pg/pg/v10/internal/pool/reader_buf.go b/vendor/github.com/go-pg/pg/v10/internal/pool/reader_buf.go new file mode 100644 index 000000000..3172e8b05 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/internal/pool/reader_buf.go @@ -0,0 +1,431 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package pool + +import ( + "bufio" + "bytes" + "io" +) + +type BufReader struct { + rd io.Reader // reader provided by the client + + buf []byte + r, w int // buf read and write positions + lastByte int + bytesRead int64 + err error + + available int // bytes available for reading + brd BytesReader // reusable bytes reader +} + +func NewBufReader(bufSize int) *BufReader { + return &BufReader{ + buf: make([]byte, bufSize), + available: -1, + } +} + +func (b *BufReader) BytesReader(n int) *BytesReader { + if n == -1 { + n = 0 + } + buf := b.buf[b.r : b.r+n] + b.r += n + b.brd.Reset(buf) + return &b.brd +} + +func (b *BufReader) SetAvailable(n int) { + b.available = n +} + +func (b *BufReader) Available() int { + return b.available +} + +func (b *BufReader) changeAvailable(n int) { + if b.available != -1 { + b.available += n + } +} + +func (b *BufReader) Reset(rd io.Reader) { + b.rd = rd + b.r, b.w = 0, 0 + b.err = nil +} + +// Buffered returns the number of bytes that can be read from the current buffer. +func (b *BufReader) Buffered() int { + buffered := b.w - b.r + if b.available == -1 || buffered <= b.available { + return buffered + } + return b.available +} + +func (b *BufReader) Bytes() []byte { + if b.available == -1 { + return b.buf[b.r:b.w] + } + w := b.r + b.available + if w > b.w { + w = b.w + } + return b.buf[b.r:w] +} + +func (b *BufReader) flush() []byte { + if b.available == -1 { + buf := b.buf[b.r:b.w] + b.r = b.w + return buf + } + + w := b.r + b.available + if w > b.w { + w = b.w + } + buf := b.buf[b.r:w] + b.r = w + b.changeAvailable(-len(buf)) + return buf +} + +// fill reads a new chunk into the buffer. +func (b *BufReader) fill() { + // Slide existing data to beginning. + if b.r > 0 { + copy(b.buf, b.buf[b.r:b.w]) + b.w -= b.r + b.r = 0 + } + + if b.w >= len(b.buf) { + panic("bufio: tried to fill full buffer") + } + if b.available == 0 { + b.err = io.EOF + return + } + + // Read new data: try a limited number of times. + const maxConsecutiveEmptyReads = 100 + for i := maxConsecutiveEmptyReads; i > 0; i-- { + n, err := b.read(b.buf[b.w:]) + b.w += n + if err != nil { + b.err = err + return + } + if n > 0 { + return + } + } + b.err = io.ErrNoProgress +} + +func (b *BufReader) readErr() error { + err := b.err + b.err = nil + return err +} + +func (b *BufReader) Read(p []byte) (n int, err error) { + if len(p) == 0 { + return 0, b.readErr() + } + + if b.available != -1 { + if b.available == 0 { + return 0, io.EOF + } + if len(p) > b.available { + p = p[:b.available] + } + } + + if b.r == b.w { + if b.err != nil { + return 0, b.readErr() + } + + if len(p) >= len(b.buf) { + // Large read, empty buffer. + // Read directly into p to avoid copy. + n, err = b.read(p) + if n > 0 { + b.changeAvailable(-n) + b.lastByte = int(p[n-1]) + } + return n, err + } + + // One read. + // Do not use b.fill, which will loop. + b.r = 0 + b.w = 0 + n, b.err = b.read(b.buf) + if n == 0 { + return 0, b.readErr() + } + b.w += n + } + + // copy as much as we can + n = copy(p, b.Bytes()) + b.r += n + b.changeAvailable(-n) + b.lastByte = int(b.buf[b.r-1]) + return n, nil +} + +// ReadSlice reads until the first occurrence of delim in the input, +// returning a slice pointing at the bytes in the buffer. +// The bytes stop being valid at the next read. +// If ReadSlice encounters an error before finding a delimiter, +// it returns all the data in the buffer and the error itself (often io.EOF). +// ReadSlice fails with error ErrBufferFull if the buffer fills without a delim. +// Because the data returned from ReadSlice will be overwritten +// by the next I/O operation, most clients should use +// ReadBytes or ReadString instead. +// ReadSlice returns err != nil if and only if line does not end in delim. +func (b *BufReader) ReadSlice(delim byte) (line []byte, err error) { + for { + // Search buffer. + if i := bytes.IndexByte(b.Bytes(), delim); i >= 0 { + i++ + line = b.buf[b.r : b.r+i] + b.r += i + b.changeAvailable(-i) + break + } + + // Pending error? + if b.err != nil { + line = b.flush() + err = b.readErr() + break + } + + buffered := b.Buffered() + + // Out of available. + if b.available != -1 && buffered >= b.available { + line = b.flush() + err = io.EOF + break + } + + // Buffer full? + if buffered >= len(b.buf) { + line = b.flush() + err = bufio.ErrBufferFull + break + } + + b.fill() // buffer is not full + } + + // Handle last byte, if any. + if i := len(line) - 1; i >= 0 { + b.lastByte = int(line[i]) + } + + return line, err +} + +func (b *BufReader) ReadBytes(fn func(byte) bool) (line []byte, err error) { + for { + for i, c := range b.Bytes() { + if !fn(c) { + i-- + line = b.buf[b.r : b.r+i] //nolint + b.r += i + b.changeAvailable(-i) + break + } + } + + // Pending error? + if b.err != nil { + line = b.flush() + err = b.readErr() + break + } + + buffered := b.Buffered() + + // Out of available. + if b.available != -1 && buffered >= b.available { + line = b.flush() + err = io.EOF + break + } + + // Buffer full? + if buffered >= len(b.buf) { + line = b.flush() + err = bufio.ErrBufferFull + break + } + + b.fill() // buffer is not full + } + + // Handle last byte, if any. + if i := len(line) - 1; i >= 0 { + b.lastByte = int(line[i]) + } + + return line, err +} + +func (b *BufReader) ReadByte() (byte, error) { + if b.available == 0 { + return 0, io.EOF + } + for b.r == b.w { + if b.err != nil { + return 0, b.readErr() + } + b.fill() // buffer is empty + } + c := b.buf[b.r] + b.r++ + b.lastByte = int(c) + b.changeAvailable(-1) + return c, nil +} + +func (b *BufReader) UnreadByte() error { + if b.lastByte < 0 || b.r == 0 && b.w > 0 { + return bufio.ErrInvalidUnreadByte + } + // b.r > 0 || b.w == 0 + if b.r > 0 { + b.r-- + } else { + // b.r == 0 && b.w == 0 + b.w = 1 + } + b.buf[b.r] = byte(b.lastByte) + b.lastByte = -1 + b.changeAvailable(+1) + return nil +} + +// Discard skips the next n bytes, returning the number of bytes discarded. +// +// If Discard skips fewer than n bytes, it also returns an error. +// If 0 <= n <= b.Buffered(), Discard is guaranteed to succeed without +// reading from the underlying io.BufReader. +func (b *BufReader) Discard(n int) (discarded int, err error) { + if n < 0 { + return 0, bufio.ErrNegativeCount + } + if n == 0 { + return + } + remain := n + for { + skip := b.Buffered() + if skip == 0 { + b.fill() + skip = b.Buffered() + } + if skip > remain { + skip = remain + } + b.r += skip + b.changeAvailable(-skip) + remain -= skip + if remain == 0 { + return n, nil + } + if b.err != nil { + return n - remain, b.readErr() + } + } +} + +func (b *BufReader) ReadN(n int) (line []byte, err error) { + if n < 0 { + return nil, bufio.ErrNegativeCount + } + if n == 0 { + return + } + + nn := n + if b.available != -1 && nn > b.available { + nn = b.available + } + + for { + buffered := b.Buffered() + + if buffered >= nn { + line = b.buf[b.r : b.r+nn] + b.r += nn + b.changeAvailable(-nn) + if n > nn { + err = io.EOF + } + break + } + + // Pending error? + if b.err != nil { + line = b.flush() + err = b.readErr() + break + } + + // Buffer full? + if buffered >= len(b.buf) { + line = b.flush() + err = bufio.ErrBufferFull + break + } + + b.fill() // buffer is not full + } + + // Handle last byte, if any. + if i := len(line) - 1; i >= 0 { + b.lastByte = int(line[i]) + } + + return line, err +} + +func (b *BufReader) ReadFull() ([]byte, error) { + if b.available == -1 { + panic("not reached") + } + buf := make([]byte, b.available) + _, err := io.ReadFull(b, buf) + return buf, err +} + +func (b *BufReader) ReadFullTemp() ([]byte, error) { + if b.available == -1 { + panic("not reached") + } + if b.available <= len(b.buf) { + return b.ReadN(b.available) + } + return b.ReadFull() +} + +func (b *BufReader) read(buf []byte) (int, error) { + n, err := b.rd.Read(buf) + b.bytesRead += int64(n) + return n, err +} diff --git a/vendor/github.com/go-pg/pg/v10/internal/pool/reader_bytes.go b/vendor/github.com/go-pg/pg/v10/internal/pool/reader_bytes.go new file mode 100644 index 000000000..93646b1da --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/internal/pool/reader_bytes.go @@ -0,0 +1,121 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package pool + +import ( + "bytes" + "errors" + "io" +) + +type BytesReader struct { + s []byte + i int +} + +func NewBytesReader(b []byte) *BytesReader { + return &BytesReader{ + s: b, + } +} + +func (r *BytesReader) Reset(b []byte) { + r.s = b + r.i = 0 +} + +func (r *BytesReader) Buffered() int { + return len(r.s) - r.i +} + +func (r *BytesReader) Bytes() []byte { + return r.s[r.i:] +} + +func (r *BytesReader) Read(b []byte) (n int, err error) { + if r.i >= len(r.s) { + return 0, io.EOF + } + n = copy(b, r.s[r.i:]) + r.i += n + return +} + +func (r *BytesReader) ReadByte() (byte, error) { + if r.i >= len(r.s) { + return 0, io.EOF + } + b := r.s[r.i] + r.i++ + return b, nil +} + +func (r *BytesReader) UnreadByte() error { + if r.i <= 0 { + return errors.New("UnreadByte: at beginning of slice") + } + r.i-- + return nil +} + +func (r *BytesReader) ReadSlice(delim byte) ([]byte, error) { + if i := bytes.IndexByte(r.s[r.i:], delim); i >= 0 { + i++ + line := r.s[r.i : r.i+i] + r.i += i + return line, nil + } + + line := r.s[r.i:] + r.i = len(r.s) + return line, io.EOF +} + +func (r *BytesReader) ReadBytes(fn func(byte) bool) ([]byte, error) { + for i, c := range r.s[r.i:] { + if !fn(c) { + i++ + line := r.s[r.i : r.i+i] + r.i += i + return line, nil + } + } + + line := r.s[r.i:] + r.i = len(r.s) + return line, io.EOF +} + +func (r *BytesReader) Discard(n int) (int, error) { + b, err := r.ReadN(n) + return len(b), err +} + +func (r *BytesReader) ReadN(n int) ([]byte, error) { + nn := n + if nn > len(r.s) { + nn = len(r.s) + } + + b := r.s[r.i : r.i+nn] + r.i += nn + if n > nn { + return b, io.EOF + } + return b, nil +} + +func (r *BytesReader) ReadFull() ([]byte, error) { + b := make([]byte, len(r.s)-r.i) + copy(b, r.s[r.i:]) + r.i = len(r.s) + return b, nil +} + +func (r *BytesReader) ReadFullTemp() ([]byte, error) { + b := r.s[r.i:] + r.i = len(r.s) + return b, nil +} diff --git a/vendor/github.com/go-pg/pg/v10/internal/pool/write_buffer.go b/vendor/github.com/go-pg/pg/v10/internal/pool/write_buffer.go new file mode 100644 index 000000000..6981d3f4c --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/internal/pool/write_buffer.go @@ -0,0 +1,114 @@ +package pool + +import ( + "encoding/binary" + "io" + "sync" +) + +const defaultBufSize = 65 << 10 // 65kb + +var wbPool = sync.Pool{ + New: func() interface{} { + return NewWriteBuffer() + }, +} + +func GetWriteBuffer() *WriteBuffer { + wb := wbPool.Get().(*WriteBuffer) + return wb +} + +func PutWriteBuffer(wb *WriteBuffer) { + wb.Reset() + wbPool.Put(wb) +} + +type WriteBuffer struct { + Bytes []byte + + msgStart int + paramStart int +} + +func NewWriteBuffer() *WriteBuffer { + return &WriteBuffer{ + Bytes: make([]byte, 0, defaultBufSize), + } +} + +func (buf *WriteBuffer) Reset() { + buf.Bytes = buf.Bytes[:0] +} + +func (buf *WriteBuffer) StartMessage(c byte) { + if c == 0 { + buf.msgStart = len(buf.Bytes) + buf.Bytes = append(buf.Bytes, 0, 0, 0, 0) + } else { + buf.msgStart = len(buf.Bytes) + 1 + buf.Bytes = append(buf.Bytes, c, 0, 0, 0, 0) + } +} + +func (buf *WriteBuffer) FinishMessage() { + binary.BigEndian.PutUint32( + buf.Bytes[buf.msgStart:], uint32(len(buf.Bytes)-buf.msgStart)) +} + +func (buf *WriteBuffer) Query() []byte { + return buf.Bytes[buf.msgStart+4 : len(buf.Bytes)-1] +} + +func (buf *WriteBuffer) StartParam() { + buf.paramStart = len(buf.Bytes) + buf.Bytes = append(buf.Bytes, 0, 0, 0, 0) +} + +func (buf *WriteBuffer) FinishParam() { + binary.BigEndian.PutUint32( + buf.Bytes[buf.paramStart:], uint32(len(buf.Bytes)-buf.paramStart-4)) +} + +var nullParamLength = int32(-1) + +func (buf *WriteBuffer) FinishNullParam() { + binary.BigEndian.PutUint32( + buf.Bytes[buf.paramStart:], uint32(nullParamLength)) +} + +func (buf *WriteBuffer) Write(b []byte) (int, error) { + buf.Bytes = append(buf.Bytes, b...) + return len(b), nil +} + +func (buf *WriteBuffer) WriteInt16(num int16) { + buf.Bytes = append(buf.Bytes, 0, 0) + binary.BigEndian.PutUint16(buf.Bytes[len(buf.Bytes)-2:], uint16(num)) +} + +func (buf *WriteBuffer) WriteInt32(num int32) { + buf.Bytes = append(buf.Bytes, 0, 0, 0, 0) + binary.BigEndian.PutUint32(buf.Bytes[len(buf.Bytes)-4:], uint32(num)) +} + +func (buf *WriteBuffer) WriteString(s string) { + buf.Bytes = append(buf.Bytes, s...) + buf.Bytes = append(buf.Bytes, 0) +} + +func (buf *WriteBuffer) WriteBytes(b []byte) { + buf.Bytes = append(buf.Bytes, b...) + buf.Bytes = append(buf.Bytes, 0) +} + +func (buf *WriteBuffer) WriteByte(c byte) error { + buf.Bytes = append(buf.Bytes, c) + return nil +} + +func (buf *WriteBuffer) ReadFrom(r io.Reader) (int64, error) { + n, err := r.Read(buf.Bytes[len(buf.Bytes):cap(buf.Bytes)]) + buf.Bytes = buf.Bytes[:len(buf.Bytes)+n] + return int64(n), err +} diff --git a/vendor/github.com/go-pg/pg/v10/internal/safe.go b/vendor/github.com/go-pg/pg/v10/internal/safe.go new file mode 100644 index 000000000..870fe541f --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/internal/safe.go @@ -0,0 +1,11 @@ +// +build appengine + +package internal + +func BytesToString(b []byte) string { + return string(b) +} + +func StringToBytes(s string) []byte { + return []byte(s) +} diff --git a/vendor/github.com/go-pg/pg/v10/internal/strconv.go b/vendor/github.com/go-pg/pg/v10/internal/strconv.go new file mode 100644 index 000000000..9e42ffb03 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/internal/strconv.go @@ -0,0 +1,19 @@ +package internal + +import "strconv" + +func Atoi(b []byte) (int, error) { + return strconv.Atoi(BytesToString(b)) +} + +func ParseInt(b []byte, base int, bitSize int) (int64, error) { + return strconv.ParseInt(BytesToString(b), base, bitSize) +} + +func ParseUint(b []byte, base int, bitSize int) (uint64, error) { + return strconv.ParseUint(BytesToString(b), base, bitSize) +} + +func ParseFloat(b []byte, bitSize int) (float64, error) { + return strconv.ParseFloat(BytesToString(b), bitSize) +} diff --git a/vendor/github.com/go-pg/pg/v10/internal/underscore.go b/vendor/github.com/go-pg/pg/v10/internal/underscore.go new file mode 100644 index 000000000..e71c11705 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/internal/underscore.go @@ -0,0 +1,93 @@ +package internal + +func IsUpper(c byte) bool { + return c >= 'A' && c <= 'Z' +} + +func IsLower(c byte) bool { + return c >= 'a' && c <= 'z' +} + +func ToUpper(c byte) byte { + return c - 32 +} + +func ToLower(c byte) byte { + return c + 32 +} + +// Underscore converts "CamelCasedString" to "camel_cased_string". +func Underscore(s string) string { + r := make([]byte, 0, len(s)+5) + for i := 0; i < len(s); i++ { + c := s[i] + if IsUpper(c) { + if i > 0 && i+1 < len(s) && (IsLower(s[i-1]) || IsLower(s[i+1])) { + r = append(r, '_', ToLower(c)) + } else { + r = append(r, ToLower(c)) + } + } else { + r = append(r, c) + } + } + return string(r) +} + +func CamelCased(s string) string { + r := make([]byte, 0, len(s)) + upperNext := true + for i := 0; i < len(s); i++ { + c := s[i] + if c == '_' { + upperNext = true + continue + } + if upperNext { + if IsLower(c) { + c = ToUpper(c) + } + upperNext = false + } + r = append(r, c) + } + return string(r) +} + +func ToExported(s string) string { + if len(s) == 0 { + return s + } + if c := s[0]; IsLower(c) { + b := []byte(s) + b[0] = ToUpper(c) + return string(b) + } + return s +} + +func UpperString(s string) string { + if isUpperString(s) { + return s + } + + b := make([]byte, len(s)) + for i := range b { + c := s[i] + if IsLower(c) { + c = ToUpper(c) + } + b[i] = c + } + return string(b) +} + +func isUpperString(s string) bool { + for i := 0; i < len(s); i++ { + c := s[i] + if IsLower(c) { + return false + } + } + return true +} diff --git a/vendor/github.com/go-pg/pg/v10/internal/unsafe.go b/vendor/github.com/go-pg/pg/v10/internal/unsafe.go new file mode 100644 index 000000000..f8bc18d91 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/internal/unsafe.go @@ -0,0 +1,22 @@ +// +build !appengine + +package internal + +import ( + "unsafe" +) + +// BytesToString converts byte slice to string. +func BytesToString(b []byte) string { + return *(*string)(unsafe.Pointer(&b)) +} + +// StringToBytes converts string to byte slice. +func StringToBytes(s string) []byte { + return *(*[]byte)(unsafe.Pointer( + &struct { + string + Cap int + }{s, len(s)}, + )) +} diff --git a/vendor/github.com/go-pg/pg/v10/internal/util.go b/vendor/github.com/go-pg/pg/v10/internal/util.go new file mode 100644 index 000000000..80ad1dd9a --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/internal/util.go @@ -0,0 +1,71 @@ +package internal + +import ( + "context" + "reflect" + "time" +) + +func Sleep(ctx context.Context, dur time.Duration) error { + t := time.NewTimer(dur) + defer t.Stop() + + select { + case <-t.C: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +func MakeSliceNextElemFunc(v reflect.Value) func() reflect.Value { + if v.Kind() == reflect.Array { + var pos int + return func() reflect.Value { + v := v.Index(pos) + pos++ + return v + } + } + + elemType := v.Type().Elem() + + if elemType.Kind() == reflect.Ptr { + elemType = elemType.Elem() + return func() reflect.Value { + if v.Len() < v.Cap() { + v.Set(v.Slice(0, v.Len()+1)) + elem := v.Index(v.Len() - 1) + if elem.IsNil() { + elem.Set(reflect.New(elemType)) + } + return elem.Elem() + } + + elem := reflect.New(elemType) + v.Set(reflect.Append(v, elem)) + return elem.Elem() + } + } + + zero := reflect.Zero(elemType) + return func() reflect.Value { + if v.Len() < v.Cap() { + v.Set(v.Slice(0, v.Len()+1)) + return v.Index(v.Len() - 1) + } + + v.Set(reflect.Append(v, zero)) + return v.Index(v.Len() - 1) + } +} + +func Unwrap(err error) error { + u, ok := err.(interface { + Unwrap() error + }) + if !ok { + return nil + } + return u.Unwrap() +} diff --git a/vendor/github.com/go-pg/pg/v10/listener.go b/vendor/github.com/go-pg/pg/v10/listener.go new file mode 100644 index 000000000..d37be08d4 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/listener.go @@ -0,0 +1,414 @@ +package pg + +import ( + "context" + "errors" + "fmt" + "strings" + "sync" + "time" + + "github.com/go-pg/pg/v10/internal" + "github.com/go-pg/pg/v10/internal/pool" + "github.com/go-pg/pg/v10/types" +) + +const gopgChannel = "gopg:ping" + +var ( + errListenerClosed = errors.New("pg: listener is closed") + errPingTimeout = errors.New("pg: ping timeout") +) + +// Notification which is received with LISTEN command. +type Notification struct { + Channel string + Payload string +} + +// Listener listens for notifications sent with NOTIFY command. +// It's NOT safe for concurrent use by multiple goroutines +// except the Channel API. +type Listener struct { + db *DB + + channels []string + + mu sync.Mutex + cn *pool.Conn + exit chan struct{} + closed bool + + chOnce sync.Once + ch chan Notification + pingCh chan struct{} +} + +func (ln *Listener) String() string { + ln.mu.Lock() + defer ln.mu.Unlock() + + return fmt.Sprintf("Listener(%s)", strings.Join(ln.channels, ", ")) +} + +func (ln *Listener) init() { + ln.exit = make(chan struct{}) +} + +func (ln *Listener) connWithLock(ctx context.Context) (*pool.Conn, error) { + ln.mu.Lock() + cn, err := ln.conn(ctx) + ln.mu.Unlock() + + switch err { + case nil: + return cn, nil + case errListenerClosed: + return nil, err + case pool.ErrClosed: + _ = ln.Close() + return nil, errListenerClosed + default: + internal.Logger.Printf(ctx, "pg: Listen failed: %s", err) + return nil, err + } +} + +func (ln *Listener) conn(ctx context.Context) (*pool.Conn, error) { + if ln.closed { + return nil, errListenerClosed + } + + if ln.cn != nil { + return ln.cn, nil + } + + cn, err := ln.db.pool.NewConn(ctx) + if err != nil { + return nil, err + } + + if err := ln.db.initConn(ctx, cn); err != nil { + _ = ln.db.pool.CloseConn(cn) + return nil, err + } + + cn.LockReader() + + if len(ln.channels) > 0 { + err := ln.listen(ctx, cn, ln.channels...) + if err != nil { + _ = ln.db.pool.CloseConn(cn) + return nil, err + } + } + + ln.cn = cn + return cn, nil +} + +func (ln *Listener) releaseConn(ctx context.Context, cn *pool.Conn, err error, allowTimeout bool) { + ln.mu.Lock() + if ln.cn == cn { + if isBadConn(err, allowTimeout) { + ln.reconnect(ctx, err) + } + } + ln.mu.Unlock() +} + +func (ln *Listener) reconnect(ctx context.Context, reason error) { + _ = ln.closeTheCn(reason) + _, _ = ln.conn(ctx) +} + +func (ln *Listener) closeTheCn(reason error) error { + if ln.cn == nil { + return nil + } + if !ln.closed { + internal.Logger.Printf(ln.db.ctx, "pg: discarding bad listener connection: %s", reason) + } + + err := ln.db.pool.CloseConn(ln.cn) + ln.cn = nil + return err +} + +// Close closes the listener, releasing any open resources. +func (ln *Listener) Close() error { + ln.mu.Lock() + defer ln.mu.Unlock() + + if ln.closed { + return errListenerClosed + } + ln.closed = true + close(ln.exit) + + return ln.closeTheCn(errListenerClosed) +} + +// Listen starts listening for notifications on channels. +func (ln *Listener) Listen(ctx context.Context, channels ...string) error { + // Always append channels so DB.Listen works correctly. + ln.mu.Lock() + ln.channels = appendIfNotExists(ln.channels, channels...) + ln.mu.Unlock() + + cn, err := ln.connWithLock(ctx) + if err != nil { + return err + } + + if err := ln.listen(ctx, cn, channels...); err != nil { + ln.releaseConn(ctx, cn, err, false) + return err + } + + return nil +} + +func (ln *Listener) listen(ctx context.Context, cn *pool.Conn, channels ...string) error { + err := cn.WithWriter(ctx, ln.db.opt.WriteTimeout, func(wb *pool.WriteBuffer) error { + for _, channel := range channels { + if err := writeQueryMsg(wb, ln.db.fmter, "LISTEN ?", pgChan(channel)); err != nil { + return err + } + } + return nil + }) + return err +} + +// Unlisten stops listening for notifications on channels. +func (ln *Listener) Unlisten(ctx context.Context, channels ...string) error { + ln.mu.Lock() + ln.channels = removeIfExists(ln.channels, channels...) + ln.mu.Unlock() + + cn, err := ln.conn(ctx) + if err != nil { + return err + } + + if err := ln.unlisten(ctx, cn, channels...); err != nil { + ln.releaseConn(ctx, cn, err, false) + return err + } + + return nil +} + +func (ln *Listener) unlisten(ctx context.Context, cn *pool.Conn, channels ...string) error { + err := cn.WithWriter(ctx, ln.db.opt.WriteTimeout, func(wb *pool.WriteBuffer) error { + for _, channel := range channels { + if err := writeQueryMsg(wb, ln.db.fmter, "UNLISTEN ?", pgChan(channel)); err != nil { + return err + } + } + return nil + }) + return err +} + +// Receive indefinitely waits for a notification. This is low-level API +// and in most cases Channel should be used instead. +func (ln *Listener) Receive(ctx context.Context) (channel string, payload string, err error) { + return ln.ReceiveTimeout(ctx, 0) +} + +// ReceiveTimeout waits for a notification until timeout is reached. +// This is low-level API and in most cases Channel should be used instead. +func (ln *Listener) ReceiveTimeout( + ctx context.Context, timeout time.Duration, +) (channel, payload string, err error) { + cn, err := ln.connWithLock(ctx) + if err != nil { + return "", "", err + } + + err = cn.WithReader(ctx, timeout, func(rd *pool.ReaderContext) error { + channel, payload, err = readNotification(rd) + return err + }) + if err != nil { + ln.releaseConn(ctx, cn, err, timeout > 0) + return "", "", err + } + + return channel, payload, nil +} + +// Channel returns a channel for concurrently receiving notifications. +// It periodically sends Ping notification to test connection health. +// +// The channel is closed with Listener. Receive* APIs can not be used +// after channel is created. +func (ln *Listener) Channel() <-chan Notification { + return ln.channel(100) +} + +// ChannelSize is like Channel, but creates a Go channel +// with specified buffer size. +func (ln *Listener) ChannelSize(size int) <-chan Notification { + return ln.channel(size) +} + +func (ln *Listener) channel(size int) <-chan Notification { + ln.chOnce.Do(func() { + ln.initChannel(size) + }) + if cap(ln.ch) != size { + err := fmt.Errorf("pg: Listener.Channel is called with different buffer size") + panic(err) + } + return ln.ch +} + +func (ln *Listener) initChannel(size int) { + const pingTimeout = time.Second + const chanSendTimeout = time.Minute + + ctx := ln.db.ctx + _ = ln.Listen(ctx, gopgChannel) + + ln.ch = make(chan Notification, size) + ln.pingCh = make(chan struct{}, 1) + + go func() { + timer := time.NewTimer(time.Minute) + timer.Stop() + + var errCount int + for { + channel, payload, err := ln.Receive(ctx) + if err != nil { + if err == errListenerClosed { + close(ln.ch) + return + } + + if errCount > 0 { + time.Sleep(500 * time.Millisecond) + } + errCount++ + + continue + } + + errCount = 0 + + // Any notification is as good as a ping. + select { + case ln.pingCh <- struct{}{}: + default: + } + + switch channel { + case gopgChannel: + // ignore + default: + timer.Reset(chanSendTimeout) + select { + case ln.ch <- Notification{channel, payload}: + if !timer.Stop() { + <-timer.C + } + case <-timer.C: + internal.Logger.Printf( + ctx, + "pg: %s channel is full for %s (notification is dropped)", + ln, + chanSendTimeout, + ) + } + } + } + }() + + go func() { + timer := time.NewTimer(time.Minute) + timer.Stop() + + healthy := true + for { + timer.Reset(pingTimeout) + select { + case <-ln.pingCh: + healthy = true + if !timer.Stop() { + <-timer.C + } + case <-timer.C: + pingErr := ln.ping() + if healthy { + healthy = false + } else { + if pingErr == nil { + pingErr = errPingTimeout + } + ln.mu.Lock() + ln.reconnect(ctx, pingErr) + ln.mu.Unlock() + } + case <-ln.exit: + return + } + } + }() +} + +func (ln *Listener) ping() error { + _, err := ln.db.Exec("NOTIFY ?", pgChan(gopgChannel)) + return err +} + +func appendIfNotExists(ss []string, es ...string) []string { +loop: + for _, e := range es { + for _, s := range ss { + if s == e { + continue loop + } + } + ss = append(ss, e) + } + return ss +} + +func removeIfExists(ss []string, es ...string) []string { + for _, e := range es { + for i, s := range ss { + if s == e { + last := len(ss) - 1 + ss[i] = ss[last] + ss = ss[:last] + break + } + } + } + return ss +} + +type pgChan string + +var _ types.ValueAppender = pgChan("") + +func (ch pgChan) AppendValue(b []byte, quote int) ([]byte, error) { + if quote == 0 { + return append(b, ch...), nil + } + + b = append(b, '"') + for _, c := range []byte(ch) { + if c == '"' { + b = append(b, '"', '"') + } else { + b = append(b, c) + } + } + b = append(b, '"') + + return b, nil +} diff --git a/vendor/github.com/go-pg/pg/v10/messages.go b/vendor/github.com/go-pg/pg/v10/messages.go new file mode 100644 index 000000000..7fb84ba0d --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/messages.go @@ -0,0 +1,1390 @@ +package pg + +import ( + "bufio" + "context" + "crypto/md5" //nolint + "crypto/tls" + "encoding/binary" + "encoding/hex" + "errors" + "fmt" + "io" + "strings" + + "mellium.im/sasl" + + "github.com/go-pg/pg/v10/internal" + "github.com/go-pg/pg/v10/internal/pool" + "github.com/go-pg/pg/v10/orm" + "github.com/go-pg/pg/v10/types" +) + +// https://www.postgresql.org/docs/current/protocol-message-formats.html +const ( + commandCompleteMsg = 'C' + errorResponseMsg = 'E' + noticeResponseMsg = 'N' + parameterStatusMsg = 'S' + authenticationOKMsg = 'R' + backendKeyDataMsg = 'K' + noDataMsg = 'n' + passwordMessageMsg = 'p' + terminateMsg = 'X' + + saslInitialResponseMsg = 'p' + authenticationSASLContinueMsg = 'R' + saslResponseMsg = 'p' + authenticationSASLFinalMsg = 'R' + + authenticationOK = 0 + authenticationCleartextPassword = 3 + authenticationMD5Password = 5 + authenticationSASL = 10 + + notificationResponseMsg = 'A' + + describeMsg = 'D' + parameterDescriptionMsg = 't' + + queryMsg = 'Q' + readyForQueryMsg = 'Z' + emptyQueryResponseMsg = 'I' + rowDescriptionMsg = 'T' + dataRowMsg = 'D' + + parseMsg = 'P' + parseCompleteMsg = '1' + + bindMsg = 'B' + bindCompleteMsg = '2' + + executeMsg = 'E' + + syncMsg = 'S' + flushMsg = 'H' + + closeMsg = 'C' + closeCompleteMsg = '3' + + copyInResponseMsg = 'G' + copyOutResponseMsg = 'H' + copyDataMsg = 'd' + copyDoneMsg = 'c' +) + +var errEmptyQuery = internal.Errorf("pg: query is empty") + +func (db *baseDB) startup( + c context.Context, cn *pool.Conn, user, password, database, appName string, +) error { + err := cn.WithWriter(c, db.opt.WriteTimeout, func(wb *pool.WriteBuffer) error { + writeStartupMsg(wb, user, database, appName) + return nil + }) + if err != nil { + return err + } + + return cn.WithReader(c, db.opt.ReadTimeout, func(rd *pool.ReaderContext) error { + for { + typ, msgLen, err := readMessageType(rd) + if err != nil { + return err + } + + switch typ { + case backendKeyDataMsg: + processID, err := readInt32(rd) + if err != nil { + return err + } + secretKey, err := readInt32(rd) + if err != nil { + return err + } + cn.ProcessID = processID + cn.SecretKey = secretKey + case parameterStatusMsg: + if err := logParameterStatus(rd, msgLen); err != nil { + return err + } + case authenticationOKMsg: + err := db.auth(c, cn, rd, user, password) + if err != nil { + return err + } + case readyForQueryMsg: + _, err := rd.ReadN(msgLen) + return err + case noticeResponseMsg: + // If we encounter a notice message from the server then we want to try to log it as it might be + // important for the client. If something goes wrong with this we want to fail. At the time of writing + // this the client will fail just encountering a notice during startup. So failing if a bad notice is + // sent is probably better than not failing, especially if we can try to log at least some data from the + // notice. + if err := db.logStartupNotice(rd); err != nil { + return err + } + case errorResponseMsg: + e, err := readError(rd) + if err != nil { + return err + } + return e + default: + return fmt.Errorf("pg: unknown startup message response: %q", typ) + } + } + }) +} + +func (db *baseDB) enableSSL(c context.Context, cn *pool.Conn, tlsConf *tls.Config) error { + err := cn.WithWriter(c, db.opt.WriteTimeout, func(wb *pool.WriteBuffer) error { + writeSSLMsg(wb) + return nil + }) + if err != nil { + return err + } + + err = cn.WithReader(c, db.opt.ReadTimeout, func(rd *pool.ReaderContext) error { + c, err := rd.ReadByte() + if err != nil { + return err + } + if c != 'S' { + return errors.New("pg: SSL is not enabled on the server") + } + return nil + }) + if err != nil { + return err + } + + cn.SetNetConn(tls.Client(cn.NetConn(), tlsConf)) + return nil +} + +func (db *baseDB) auth( + c context.Context, cn *pool.Conn, rd *pool.ReaderContext, user, password string, +) error { + num, err := readInt32(rd) + if err != nil { + return err + } + + switch num { + case authenticationOK: + return nil + case authenticationCleartextPassword: + return db.authCleartext(c, cn, rd, password) + case authenticationMD5Password: + return db.authMD5(c, cn, rd, user, password) + case authenticationSASL: + return db.authSASL(c, cn, rd, user, password) + default: + return fmt.Errorf("pg: unknown authentication message response: %q", num) + } +} + +// logStartupNotice will handle notice messages during the startup process. It will parse them and log them for the +// client. Notices are not common and only happen if there is something the client should be aware of. So logging should +// not be a problem. +// Notice messages can be seen in startup: https://www.postgresql.org/docs/13/protocol-flow.html +// Information on the notice message format: https://www.postgresql.org/docs/13/protocol-message-formats.html +// Note: This is true for earlier versions of PostgreSQL as well, I've just included the latest versions of the docs. +func (db *baseDB) logStartupNotice( + rd *pool.ReaderContext, +) error { + message := make([]string, 0) + // Notice messages are null byte delimited key-value pairs. Where the keys are one byte. + for { + // Read the key byte. + fieldType, err := rd.ReadByte() + if err != nil { + return err + } + + // If they key byte (the type of field this data is) is 0 then that means we have reached the end of the notice. + // We can break our loop here and throw our message data into the logger. + if fieldType == 0 { + break + } + + // Read until the next null byte to get the data for this field. This does include the null byte at the end of + // fieldValue so we will trim it off down below. + fieldValue, err := readString(rd) + if err != nil { + return err + } + + // Just throw the field type as a string and its value into an array. + // Field types can be seen here: https://www.postgresql.org/docs/13/protocol-error-fields.html + // TODO This is a rare occurrence as is, would it be worth adding something to indicate what the field names + // are? Or is PostgreSQL documentation enough for a user at this point? + message = append(message, fmt.Sprintf("%s: %s", string(fieldType), fieldValue)) + } + + // Tell the client what PostgreSQL told us. Warning because its probably something the client should at the very + // least adjust. + internal.Warn.Printf("notice during startup: %s", strings.Join(message, ", ")) + + return nil +} + +func (db *baseDB) authCleartext( + c context.Context, cn *pool.Conn, rd *pool.ReaderContext, password string, +) error { + err := cn.WithWriter(c, db.opt.WriteTimeout, func(wb *pool.WriteBuffer) error { + writePasswordMsg(wb, password) + return nil + }) + if err != nil { + return err + } + return readAuthOK(rd) +} + +func (db *baseDB) authMD5( + c context.Context, cn *pool.Conn, rd *pool.ReaderContext, user, password string, +) error { + b, err := rd.ReadN(4) + if err != nil { + return err + } + + secret := "md5" + md5s(md5s(password+user)+string(b)) + err = cn.WithWriter(c, db.opt.WriteTimeout, func(wb *pool.WriteBuffer) error { + writePasswordMsg(wb, secret) + return nil + }) + if err != nil { + return err + } + + return readAuthOK(rd) +} + +func readAuthOK(rd *pool.ReaderContext) error { + c, _, err := readMessageType(rd) + if err != nil { + return err + } + + switch c { + case authenticationOKMsg: + c0, err := readInt32(rd) + if err != nil { + return err + } + if c0 != 0 { + return fmt.Errorf("pg: unexpected authentication code: %q", c0) + } + return nil + case errorResponseMsg: + e, err := readError(rd) + if err != nil { + return err + } + return e + default: + return fmt.Errorf("pg: unknown password message response: %q", c) + } +} + +func (db *baseDB) authSASL( + c context.Context, cn *pool.Conn, rd *pool.ReaderContext, user, password string, +) error { + s, err := readString(rd) + if err != nil { + return err + } + if s != "SCRAM-SHA-256" { + return fmt.Errorf("pg: SASL: got %q, wanted %q", s, "SCRAM-SHA-256") + } + + c0, err := rd.ReadByte() + if err != nil { + return err + } + if c0 != 0 { + return fmt.Errorf("pg: SASL: got %q, wanted %q", c0, 0) + } + + creds := sasl.Credentials(func() (Username, Password, Identity []byte) { + return []byte(user), []byte(password), nil + }) + client := sasl.NewClient(sasl.ScramSha256, creds) + + _, resp, err := client.Step(nil) + if err != nil { + return err + } + + err = cn.WithWriter(c, db.opt.WriteTimeout, func(wb *pool.WriteBuffer) error { + wb.StartMessage(saslInitialResponseMsg) + wb.WriteString("SCRAM-SHA-256") + wb.WriteInt32(int32(len(resp))) + _, err := wb.Write(resp) + if err != nil { + return err + } + wb.FinishMessage() + return nil + }) + if err != nil { + return err + } + + typ, n, err := readMessageType(rd) + if err != nil { + return err + } + + switch typ { + case authenticationSASLContinueMsg: + c11, err := readInt32(rd) + if err != nil { + return err + } + if c11 != 11 { + return fmt.Errorf("pg: SASL: got %q, wanted %q", typ, 11) + } + + b, err := rd.ReadN(n - 4) + if err != nil { + return err + } + + _, resp, err = client.Step(b) + if err != nil { + return err + } + + err = cn.WithWriter(c, db.opt.WriteTimeout, func(wb *pool.WriteBuffer) error { + wb.StartMessage(saslResponseMsg) + _, err := wb.Write(resp) + if err != nil { + return err + } + wb.FinishMessage() + return nil + }) + if err != nil { + return err + } + + return readAuthSASLFinal(rd, client) + case errorResponseMsg: + e, err := readError(rd) + if err != nil { + return err + } + return e + default: + return fmt.Errorf( + "pg: SASL: got %q, wanted %q", typ, authenticationSASLContinueMsg) + } +} + +func readAuthSASLFinal(rd *pool.ReaderContext, client *sasl.Negotiator) error { + c, n, err := readMessageType(rd) + if err != nil { + return err + } + + switch c { + case authenticationSASLFinalMsg: + c12, err := readInt32(rd) + if err != nil { + return err + } + if c12 != 12 { + return fmt.Errorf("pg: SASL: got %q, wanted %q", c, 12) + } + + b, err := rd.ReadN(n - 4) + if err != nil { + return err + } + + _, _, err = client.Step(b) + if err != nil { + return err + } + + if client.State() != sasl.ValidServerResponse { + return fmt.Errorf("pg: SASL: state=%q, wanted %q", + client.State(), sasl.ValidServerResponse) + } + case errorResponseMsg: + e, err := readError(rd) + if err != nil { + return err + } + return e + default: + return fmt.Errorf( + "pg: SASL: got %q, wanted %q", c, authenticationSASLFinalMsg) + } + + return readAuthOK(rd) +} + +func md5s(s string) string { + //nolint + h := md5.Sum([]byte(s)) + return hex.EncodeToString(h[:]) +} + +func writeStartupMsg(buf *pool.WriteBuffer, user, database, appName string) { + buf.StartMessage(0) + buf.WriteInt32(196608) + buf.WriteString("user") + buf.WriteString(user) + buf.WriteString("database") + buf.WriteString(database) + if appName != "" { + buf.WriteString("application_name") + buf.WriteString(appName) + } + buf.WriteString("") + buf.FinishMessage() +} + +func writeSSLMsg(buf *pool.WriteBuffer) { + buf.StartMessage(0) + buf.WriteInt32(80877103) + buf.FinishMessage() +} + +func writePasswordMsg(buf *pool.WriteBuffer, password string) { + buf.StartMessage(passwordMessageMsg) + buf.WriteString(password) + buf.FinishMessage() +} + +func writeFlushMsg(buf *pool.WriteBuffer) { + buf.StartMessage(flushMsg) + buf.FinishMessage() +} + +func writeCancelRequestMsg(buf *pool.WriteBuffer, processID, secretKey int32) { + buf.StartMessage(0) + buf.WriteInt32(80877102) + buf.WriteInt32(processID) + buf.WriteInt32(secretKey) + buf.FinishMessage() +} + +func writeQueryMsg( + buf *pool.WriteBuffer, + fmter orm.QueryFormatter, + query interface{}, + params ...interface{}, +) error { + buf.StartMessage(queryMsg) + bytes, err := appendQuery(fmter, buf.Bytes, query, params...) + if err != nil { + return err + } + buf.Bytes = bytes + err = buf.WriteByte(0x0) + if err != nil { + return err + } + buf.FinishMessage() + return nil +} + +func appendQuery(fmter orm.QueryFormatter, dst []byte, query interface{}, params ...interface{}) ([]byte, error) { + switch query := query.(type) { + case orm.QueryAppender: + if v, ok := fmter.(*orm.Formatter); ok { + fmter = v.WithModel(query) + } + return query.AppendQuery(fmter, dst) + case string: + if len(params) > 0 { + model, ok := params[len(params)-1].(orm.TableModel) + if ok { + if v, ok := fmter.(*orm.Formatter); ok { + fmter = v.WithTableModel(model) + params = params[:len(params)-1] + } + } + } + return fmter.FormatQuery(dst, query, params...), nil + default: + return nil, fmt.Errorf("pg: can't append %T", query) + } +} + +func writeSyncMsg(buf *pool.WriteBuffer) { + buf.StartMessage(syncMsg) + buf.FinishMessage() +} + +func writeParseDescribeSyncMsg(buf *pool.WriteBuffer, name, q string) { + buf.StartMessage(parseMsg) + buf.WriteString(name) + buf.WriteString(q) + buf.WriteInt16(0) + buf.FinishMessage() + + buf.StartMessage(describeMsg) + buf.WriteByte('S') //nolint + buf.WriteString(name) + buf.FinishMessage() + + writeSyncMsg(buf) +} + +func readParseDescribeSync(rd *pool.ReaderContext) ([]types.ColumnInfo, error) { + var columns []types.ColumnInfo + var firstErr error + for { + c, msgLen, err := readMessageType(rd) + if err != nil { + return nil, err + } + switch c { + case parseCompleteMsg: + _, err = rd.ReadN(msgLen) + if err != nil { + return nil, err + } + case rowDescriptionMsg: // Response to the DESCRIBE message. + columns, err = readRowDescription(rd, pool.NewColumnAlloc()) + if err != nil { + return nil, err + } + case parameterDescriptionMsg: // Response to the DESCRIBE message. + _, err := rd.ReadN(msgLen) + if err != nil { + return nil, err + } + case noDataMsg: // Response to the DESCRIBE message. + _, err := rd.ReadN(msgLen) + if err != nil { + return nil, err + } + case readyForQueryMsg: + _, err := rd.ReadN(msgLen) + if err != nil { + return nil, err + } + if firstErr != nil { + return nil, firstErr + } + return columns, err + case errorResponseMsg: + e, err := readError(rd) + if err != nil { + return nil, err + } + if firstErr == nil { + firstErr = e + } + case noticeResponseMsg: + if err := logNotice(rd, msgLen); err != nil { + return nil, err + } + case parameterStatusMsg: + if err := logParameterStatus(rd, msgLen); err != nil { + return nil, err + } + default: + return nil, fmt.Errorf("pg: readParseDescribeSync: unexpected message %q", c) + } + } +} + +// Writes BIND, EXECUTE and SYNC messages. +func writeBindExecuteMsg(buf *pool.WriteBuffer, name string, params ...interface{}) error { + buf.StartMessage(bindMsg) + buf.WriteString("") + buf.WriteString(name) + buf.WriteInt16(0) + buf.WriteInt16(int16(len(params))) + for _, param := range params { + buf.StartParam() + bytes := types.Append(buf.Bytes, param, 0) + if bytes != nil { + buf.Bytes = bytes + buf.FinishParam() + } else { + buf.FinishNullParam() + } + } + buf.WriteInt16(0) + buf.FinishMessage() + + buf.StartMessage(executeMsg) + buf.WriteString("") + buf.WriteInt32(0) + buf.FinishMessage() + + writeSyncMsg(buf) + + return nil +} + +func writeCloseMsg(buf *pool.WriteBuffer, name string) { + buf.StartMessage(closeMsg) + buf.WriteByte('S') //nolint + buf.WriteString(name) + buf.FinishMessage() +} + +func readCloseCompleteMsg(rd *pool.ReaderContext) error { + for { + c, msgLen, err := readMessageType(rd) + if err != nil { + return err + } + switch c { + case closeCompleteMsg: + _, err := rd.ReadN(msgLen) + return err + case errorResponseMsg: + e, err := readError(rd) + if err != nil { + return err + } + return e + case noticeResponseMsg: + if err := logNotice(rd, msgLen); err != nil { + return err + } + case parameterStatusMsg: + if err := logParameterStatus(rd, msgLen); err != nil { + return err + } + default: + return fmt.Errorf("pg: readCloseCompleteMsg: unexpected message %q", c) + } + } +} + +func readSimpleQuery(rd *pool.ReaderContext) (*result, error) { + var res result + var firstErr error + for { + c, msgLen, err := readMessageType(rd) + if err != nil { + return nil, err + } + + switch c { + case commandCompleteMsg: + b, err := rd.ReadN(msgLen) + if err != nil { + return nil, err + } + if err := res.parse(b); err != nil && firstErr == nil { + firstErr = err + } + case readyForQueryMsg: + _, err := rd.ReadN(msgLen) + if err != nil { + return nil, err + } + if firstErr != nil { + return nil, firstErr + } + return &res, nil + case rowDescriptionMsg: + _, err := rd.ReadN(msgLen) + if err != nil { + return nil, err + } + case dataRowMsg: + if _, err := rd.Discard(msgLen); err != nil { + return nil, err + } + res.returned++ + case errorResponseMsg: + e, err := readError(rd) + if err != nil { + return nil, err + } + if firstErr == nil { + firstErr = e + } + case emptyQueryResponseMsg: + if firstErr == nil { + firstErr = errEmptyQuery + } + case noticeResponseMsg: + if err := logNotice(rd, msgLen); err != nil { + return nil, err + } + case parameterStatusMsg: + if err := logParameterStatus(rd, msgLen); err != nil { + return nil, err + } + default: + return nil, fmt.Errorf("pg: readSimpleQuery: unexpected message %q", c) + } + } +} + +func readExtQuery(rd *pool.ReaderContext) (*result, error) { + var res result + var firstErr error + for { + c, msgLen, err := readMessageType(rd) + if err != nil { + return nil, err + } + + switch c { + case bindCompleteMsg: + _, err := rd.ReadN(msgLen) + if err != nil { + return nil, err + } + case dataRowMsg: + _, err := rd.ReadN(msgLen) + if err != nil { + return nil, err + } + res.returned++ + case commandCompleteMsg: // Response to the EXECUTE message. + b, err := rd.ReadN(msgLen) + if err != nil { + return nil, err + } + if err := res.parse(b); err != nil && firstErr == nil { + firstErr = err + } + case readyForQueryMsg: // Response to the SYNC message. + _, err := rd.ReadN(msgLen) + if err != nil { + return nil, err + } + if firstErr != nil { + return nil, firstErr + } + return &res, nil + case errorResponseMsg: + e, err := readError(rd) + if err != nil { + return nil, err + } + if firstErr == nil { + firstErr = e + } + case emptyQueryResponseMsg: + if firstErr == nil { + firstErr = errEmptyQuery + } + case noticeResponseMsg: + if err := logNotice(rd, msgLen); err != nil { + return nil, err + } + case parameterStatusMsg: + if err := logParameterStatus(rd, msgLen); err != nil { + return nil, err + } + default: + return nil, fmt.Errorf("pg: readExtQuery: unexpected message %q", c) + } + } +} + +func readRowDescription( + rd *pool.ReaderContext, columnAlloc *pool.ColumnAlloc, +) ([]types.ColumnInfo, error) { + numCol, err := readInt16(rd) + if err != nil { + return nil, err + } + + for i := 0; i < int(numCol); i++ { + b, err := rd.ReadSlice(0) + if err != nil { + return nil, err + } + + col := columnAlloc.New(int16(i), b[:len(b)-1]) + + if _, err := rd.ReadN(6); err != nil { + return nil, err + } + + dataType, err := readInt32(rd) + if err != nil { + return nil, err + } + col.DataType = dataType + + if _, err := rd.ReadN(8); err != nil { + return nil, err + } + } + + return columnAlloc.Columns(), nil +} + +func readDataRow( + ctx context.Context, + rd *pool.ReaderContext, + columns []types.ColumnInfo, + scanner orm.ColumnScanner, +) error { + numCol, err := readInt16(rd) + if err != nil { + return err + } + + if h, ok := scanner.(orm.BeforeScanHook); ok { + if err := h.BeforeScan(ctx); err != nil { + return err + } + } + + var firstErr error + + for colIdx := int16(0); colIdx < numCol; colIdx++ { + n, err := readInt32(rd) + if err != nil { + return err + } + + var colRd types.Reader + if int(n) <= rd.Buffered() { + colRd = rd.BytesReader(int(n)) + } else { + rd.SetAvailable(int(n)) + colRd = rd + } + + column := columns[colIdx] + if err := scanner.ScanColumn(column, colRd, int(n)); err != nil && firstErr == nil { + firstErr = internal.Errorf(err.Error()) + } + + if rd == colRd { + if rd.Available() > 0 { + if _, err := rd.Discard(rd.Available()); err != nil && firstErr == nil { + firstErr = err + } + } + rd.SetAvailable(-1) + } + } + + if h, ok := scanner.(orm.AfterScanHook); ok { + if err := h.AfterScan(ctx); err != nil { + return err + } + } + + return firstErr +} + +func newModel(mod interface{}) (orm.Model, error) { + m, err := orm.NewModel(mod) + if err != nil { + return nil, err + } + return m, m.Init() +} + +func readSimpleQueryData( + ctx context.Context, rd *pool.ReaderContext, mod interface{}, +) (*result, error) { + var columns []types.ColumnInfo + var res result + var firstErr error + for { + c, msgLen, err := readMessageType(rd) + if err != nil { + return nil, err + } + + switch c { + case rowDescriptionMsg: + columns, err = readRowDescription(rd, rd.ColumnAlloc) + if err != nil { + return nil, err + } + + if res.model == nil { + var err error + res.model, err = newModel(mod) + if err != nil { + if firstErr == nil { + firstErr = err + } + res.model = Discard + } + } + case dataRowMsg: + scanner := res.model.NextColumnScanner() + if err := readDataRow(ctx, rd, columns, scanner); err != nil { + if firstErr == nil { + firstErr = err + } + } else if err := res.model.AddColumnScanner(scanner); err != nil { + if firstErr == nil { + firstErr = err + } + } + + res.returned++ + case commandCompleteMsg: + b, err := rd.ReadN(msgLen) + if err != nil { + return nil, err + } + if err := res.parse(b); err != nil && firstErr == nil { + firstErr = err + } + case readyForQueryMsg: + _, err := rd.ReadN(msgLen) + if err != nil { + return nil, err + } + if firstErr != nil { + return nil, firstErr + } + return &res, nil + case errorResponseMsg: + e, err := readError(rd) + if err != nil { + return nil, err + } + if firstErr == nil { + firstErr = e + } + case emptyQueryResponseMsg: + if firstErr == nil { + firstErr = errEmptyQuery + } + case noticeResponseMsg: + if err := logNotice(rd, msgLen); err != nil { + return nil, err + } + case parameterStatusMsg: + if err := logParameterStatus(rd, msgLen); err != nil { + return nil, err + } + default: + return nil, fmt.Errorf("pg: readSimpleQueryData: unexpected message %q", c) + } + } +} + +func readExtQueryData( + ctx context.Context, rd *pool.ReaderContext, mod interface{}, columns []types.ColumnInfo, +) (*result, error) { + var res result + var firstErr error + for { + c, msgLen, err := readMessageType(rd) + if err != nil { + return nil, err + } + + switch c { + case bindCompleteMsg: + _, err := rd.ReadN(msgLen) + if err != nil { + return nil, err + } + case dataRowMsg: + if res.model == nil { + var err error + res.model, err = newModel(mod) + if err != nil { + if firstErr == nil { + firstErr = err + } + res.model = Discard + } + } + + scanner := res.model.NextColumnScanner() + if err := readDataRow(ctx, rd, columns, scanner); err != nil { + if firstErr == nil { + firstErr = err + } + } else if err := res.model.AddColumnScanner(scanner); err != nil { + if firstErr == nil { + firstErr = err + } + } + + res.returned++ + case commandCompleteMsg: // Response to the EXECUTE message. + b, err := rd.ReadN(msgLen) + if err != nil { + return nil, err + } + if err := res.parse(b); err != nil && firstErr == nil { + firstErr = err + } + case readyForQueryMsg: // Response to the SYNC message. + _, err := rd.ReadN(msgLen) + if err != nil { + return nil, err + } + if firstErr != nil { + return nil, firstErr + } + return &res, nil + case errorResponseMsg: + e, err := readError(rd) + if err != nil { + return nil, err + } + if firstErr == nil { + firstErr = e + } + case noticeResponseMsg: + if err := logNotice(rd, msgLen); err != nil { + return nil, err + } + case parameterStatusMsg: + if err := logParameterStatus(rd, msgLen); err != nil { + return nil, err + } + default: + return nil, fmt.Errorf("pg: readExtQueryData: unexpected message %q", c) + } + } +} + +func readCopyInResponse(rd *pool.ReaderContext) error { + var firstErr error + for { + c, msgLen, err := readMessageType(rd) + if err != nil { + return err + } + + switch c { + case copyInResponseMsg: + _, err := rd.ReadN(msgLen) + return err + case errorResponseMsg: + e, err := readError(rd) + if err != nil { + return err + } + if firstErr == nil { + firstErr = e + } + case readyForQueryMsg: + _, err := rd.ReadN(msgLen) + if err != nil { + return err + } + return firstErr + case noticeResponseMsg: + if err := logNotice(rd, msgLen); err != nil { + return err + } + case parameterStatusMsg: + if err := logParameterStatus(rd, msgLen); err != nil { + return err + } + default: + return fmt.Errorf("pg: readCopyInResponse: unexpected message %q", c) + } + } +} + +func readCopyOutResponse(rd *pool.ReaderContext) error { + var firstErr error + for { + c, msgLen, err := readMessageType(rd) + if err != nil { + return err + } + + switch c { + case copyOutResponseMsg: + _, err := rd.ReadN(msgLen) + return err + case errorResponseMsg: + e, err := readError(rd) + if err != nil { + return err + } + if firstErr == nil { + firstErr = e + } + case readyForQueryMsg: + _, err := rd.ReadN(msgLen) + if err != nil { + return err + } + return firstErr + case noticeResponseMsg: + if err := logNotice(rd, msgLen); err != nil { + return err + } + case parameterStatusMsg: + if err := logParameterStatus(rd, msgLen); err != nil { + return err + } + default: + return fmt.Errorf("pg: readCopyOutResponse: unexpected message %q", c) + } + } +} + +func readCopyData(rd *pool.ReaderContext, w io.Writer) (*result, error) { + var res result + var firstErr error + for { + c, msgLen, err := readMessageType(rd) + if err != nil { + return nil, err + } + + switch c { + case copyDataMsg: + for msgLen > 0 { + b, err := rd.ReadN(msgLen) + if err != nil && err != bufio.ErrBufferFull { + return nil, err + } + + _, err = w.Write(b) + if err != nil { + return nil, err + } + + msgLen -= len(b) + } + case copyDoneMsg: + _, err := rd.ReadN(msgLen) + if err != nil { + return nil, err + } + case commandCompleteMsg: + b, err := rd.ReadN(msgLen) + if err != nil { + return nil, err + } + if err := res.parse(b); err != nil && firstErr == nil { + firstErr = err + } + case readyForQueryMsg: + _, err := rd.ReadN(msgLen) + if err != nil { + return nil, err + } + if firstErr != nil { + return nil, firstErr + } + return &res, nil + case errorResponseMsg: + e, err := readError(rd) + if err != nil { + return nil, err + } + return nil, e + case noticeResponseMsg: + if err := logNotice(rd, msgLen); err != nil { + return nil, err + } + case parameterStatusMsg: + if err := logParameterStatus(rd, msgLen); err != nil { + return nil, err + } + default: + return nil, fmt.Errorf("pg: readCopyData: unexpected message %q", c) + } + } +} + +func writeCopyData(buf *pool.WriteBuffer, r io.Reader) error { + buf.StartMessage(copyDataMsg) + _, err := buf.ReadFrom(r) + buf.FinishMessage() + return err +} + +func writeCopyDone(buf *pool.WriteBuffer) { + buf.StartMessage(copyDoneMsg) + buf.FinishMessage() +} + +func readReadyForQuery(rd *pool.ReaderContext) (*result, error) { + var res result + var firstErr error + for { + c, msgLen, err := readMessageType(rd) + if err != nil { + return nil, err + } + + switch c { + case commandCompleteMsg: + b, err := rd.ReadN(msgLen) + if err != nil { + return nil, err + } + if err := res.parse(b); err != nil && firstErr == nil { + firstErr = err + } + case readyForQueryMsg: + _, err := rd.ReadN(msgLen) + if err != nil { + return nil, err + } + if firstErr != nil { + return nil, firstErr + } + return &res, nil + case errorResponseMsg: + e, err := readError(rd) + if err != nil { + return nil, err + } + if firstErr == nil { + firstErr = e + } + case noticeResponseMsg: + if err := logNotice(rd, msgLen); err != nil { + return nil, err + } + case parameterStatusMsg: + if err := logParameterStatus(rd, msgLen); err != nil { + return nil, err + } + default: + return nil, fmt.Errorf("pg: readReadyForQueryOrError: unexpected message %q", c) + } + } +} + +func readNotification(rd *pool.ReaderContext) (channel, payload string, err error) { + for { + c, msgLen, err := readMessageType(rd) + if err != nil { + return "", "", err + } + + switch c { + case commandCompleteMsg: + _, err := rd.ReadN(msgLen) + if err != nil { + return "", "", err + } + case readyForQueryMsg: + _, err := rd.ReadN(msgLen) + if err != nil { + return "", "", err + } + case errorResponseMsg: + e, err := readError(rd) + if err != nil { + return "", "", err + } + return "", "", e + case noticeResponseMsg: + if err := logNotice(rd, msgLen); err != nil { + return "", "", err + } + case notificationResponseMsg: + _, err := readInt32(rd) + if err != nil { + return "", "", err + } + channel, err = readString(rd) + if err != nil { + return "", "", err + } + payload, err = readString(rd) + if err != nil { + return "", "", err + } + return channel, payload, nil + default: + return "", "", fmt.Errorf("pg: readNotification: unexpected message %q", c) + } + } +} + +var terminateMessage = []byte{terminateMsg, 0, 0, 0, 4} + +func terminateConn(cn *pool.Conn) error { + // Don't use cn.Buf because it is racy with user code. + _, err := cn.NetConn().Write(terminateMessage) + return err +} + +//------------------------------------------------------------------------------ + +func logNotice(rd *pool.ReaderContext, msgLen int) error { + _, err := rd.ReadN(msgLen) + return err +} + +func logParameterStatus(rd *pool.ReaderContext, msgLen int) error { + _, err := rd.ReadN(msgLen) + return err +} + +func readInt16(rd *pool.ReaderContext) (int16, error) { + b, err := rd.ReadN(2) + if err != nil { + return 0, err + } + return int16(binary.BigEndian.Uint16(b)), nil +} + +func readInt32(rd *pool.ReaderContext) (int32, error) { + b, err := rd.ReadN(4) + if err != nil { + return 0, err + } + return int32(binary.BigEndian.Uint32(b)), nil +} + +func readString(rd *pool.ReaderContext) (string, error) { + b, err := rd.ReadSlice(0) + if err != nil { + return "", err + } + return string(b[:len(b)-1]), nil +} + +func readError(rd *pool.ReaderContext) (error, error) { + m := make(map[byte]string) + for { + c, err := rd.ReadByte() + if err != nil { + return nil, err + } + if c == 0 { + break + } + s, err := readString(rd) + if err != nil { + return nil, err + } + m[c] = s + } + return internal.NewPGError(m), nil +} + +func readMessageType(rd *pool.ReaderContext) (byte, int, error) { + c, err := rd.ReadByte() + if err != nil { + return 0, 0, err + } + l, err := readInt32(rd) + if err != nil { + return 0, 0, err + } + return c, int(l) - 4, nil +} diff --git a/vendor/github.com/go-pg/pg/v10/options.go b/vendor/github.com/go-pg/pg/v10/options.go new file mode 100644 index 000000000..efd634fd2 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/options.go @@ -0,0 +1,277 @@ +package pg + +import ( + "context" + "crypto/tls" + "errors" + "fmt" + "net" + "net/url" + "os" + "runtime" + "strconv" + "strings" + "time" + + "github.com/go-pg/pg/v10/internal/pool" +) + +// Options contains database connection options. +type Options struct { + // Network type, either tcp or unix. + // Default is tcp. + Network string + // TCP host:port or Unix socket depending on Network. + Addr string + + // Dialer creates new network connection and has priority over + // Network and Addr options. + Dialer func(ctx context.Context, network, addr string) (net.Conn, error) + + // Hook that is called after new connection is established + // and user is authenticated. + OnConnect func(ctx context.Context, cn *Conn) error + + User string + Password string + Database string + + // ApplicationName is the application name. Used in logs on Pg side. + // Only available from pg-9.0. + ApplicationName string + + // TLS config for secure connections. + TLSConfig *tls.Config + + // Dial timeout for establishing new connections. + // Default is 5 seconds. + DialTimeout time.Duration + + // Timeout for socket reads. If reached, commands will fail + // with a timeout instead of blocking. + ReadTimeout time.Duration + // Timeout for socket writes. If reached, commands will fail + // with a timeout instead of blocking. + WriteTimeout time.Duration + + // Maximum number of retries before giving up. + // Default is to not retry failed queries. + MaxRetries int + // Whether to retry queries cancelled because of statement_timeout. + RetryStatementTimeout bool + // Minimum backoff between each retry. + // Default is 250 milliseconds; -1 disables backoff. + MinRetryBackoff time.Duration + // Maximum backoff between each retry. + // Default is 4 seconds; -1 disables backoff. + MaxRetryBackoff time.Duration + + // Maximum number of socket connections. + // Default is 10 connections per every CPU as reported by runtime.NumCPU. + PoolSize int + // Minimum number of idle connections which is useful when establishing + // new connection is slow. + MinIdleConns int + // Connection age at which client retires (closes) the connection. + // It is useful with proxies like PgBouncer and HAProxy. + // Default is to not close aged connections. + MaxConnAge time.Duration + // Time for which client waits for free connection if all + // connections are busy before returning an error. + // Default is 30 seconds if ReadTimeOut is not defined, otherwise, + // ReadTimeout + 1 second. + PoolTimeout time.Duration + // Amount of time after which client closes idle connections. + // Should be less than server's timeout. + // Default is 5 minutes. -1 disables idle timeout check. + IdleTimeout time.Duration + // Frequency of idle checks made by idle connections reaper. + // Default is 1 minute. -1 disables idle connections reaper, + // but idle connections are still discarded by the client + // if IdleTimeout is set. + IdleCheckFrequency time.Duration +} + +func (opt *Options) init() { + if opt.Network == "" { + opt.Network = "tcp" + } + + if opt.Addr == "" { + switch opt.Network { + case "tcp": + host := env("PGHOST", "localhost") + port := env("PGPORT", "5432") + opt.Addr = fmt.Sprintf("%s:%s", host, port) + case "unix": + opt.Addr = "/var/run/postgresql/.s.PGSQL.5432" + } + } + + if opt.DialTimeout == 0 { + opt.DialTimeout = 5 * time.Second + } + if opt.Dialer == nil { + opt.Dialer = func(ctx context.Context, network, addr string) (net.Conn, error) { + netDialer := &net.Dialer{ + Timeout: opt.DialTimeout, + KeepAlive: 5 * time.Minute, + } + return netDialer.DialContext(ctx, network, addr) + } + } + + if opt.User == "" { + opt.User = env("PGUSER", "postgres") + } + + if opt.Database == "" { + opt.Database = env("PGDATABASE", "postgres") + } + + if opt.PoolSize == 0 { + opt.PoolSize = 10 * runtime.NumCPU() + } + + if opt.PoolTimeout == 0 { + if opt.ReadTimeout != 0 { + opt.PoolTimeout = opt.ReadTimeout + time.Second + } else { + opt.PoolTimeout = 30 * time.Second + } + } + + if opt.IdleTimeout == 0 { + opt.IdleTimeout = 5 * time.Minute + } + if opt.IdleCheckFrequency == 0 { + opt.IdleCheckFrequency = time.Minute + } + + switch opt.MinRetryBackoff { + case -1: + opt.MinRetryBackoff = 0 + case 0: + opt.MinRetryBackoff = 250 * time.Millisecond + } + switch opt.MaxRetryBackoff { + case -1: + opt.MaxRetryBackoff = 0 + case 0: + opt.MaxRetryBackoff = 4 * time.Second + } +} + +func env(key, defValue string) string { + envValue := os.Getenv(key) + if envValue != "" { + return envValue + } + return defValue +} + +// ParseURL parses an URL into options that can be used to connect to PostgreSQL. +func ParseURL(sURL string) (*Options, error) { + parsedURL, err := url.Parse(sURL) + if err != nil { + return nil, err + } + + // scheme + if parsedURL.Scheme != "postgres" && parsedURL.Scheme != "postgresql" { + return nil, errors.New("pg: invalid scheme: " + parsedURL.Scheme) + } + + // host and port + options := &Options{ + Addr: parsedURL.Host, + } + if !strings.Contains(options.Addr, ":") { + options.Addr += ":5432" + } + + // username and password + if parsedURL.User != nil { + options.User = parsedURL.User.Username() + + if password, ok := parsedURL.User.Password(); ok { + options.Password = password + } + } + + if options.User == "" { + options.User = "postgres" + } + + // database + if len(strings.Trim(parsedURL.Path, "/")) > 0 { + options.Database = parsedURL.Path[1:] + } else { + return nil, errors.New("pg: database name not provided") + } + + // ssl mode + query, err := url.ParseQuery(parsedURL.RawQuery) + if err != nil { + return nil, err + } + + if sslMode, ok := query["sslmode"]; ok && len(sslMode) > 0 { + switch sslMode[0] { + case "verify-ca", "verify-full": + options.TLSConfig = &tls.Config{} + case "allow", "prefer", "require": + options.TLSConfig = &tls.Config{InsecureSkipVerify: true} //nolint + case "disable": + options.TLSConfig = nil + default: + return nil, fmt.Errorf("pg: sslmode '%v' is not supported", sslMode[0]) + } + } else { + options.TLSConfig = &tls.Config{InsecureSkipVerify: true} //nolint + } + + delete(query, "sslmode") + + if appName, ok := query["application_name"]; ok && len(appName) > 0 { + options.ApplicationName = appName[0] + } + + delete(query, "application_name") + + if connTimeout, ok := query["connect_timeout"]; ok && len(connTimeout) > 0 { + ct, err := strconv.Atoi(connTimeout[0]) + if err != nil { + return nil, fmt.Errorf("pg: cannot parse connect_timeout option as int") + } + options.DialTimeout = time.Second * time.Duration(ct) + } + + delete(query, "connect_timeout") + + if len(query) > 0 { + return nil, errors.New("pg: options other than 'sslmode', 'application_name' and 'connect_timeout' are not supported") + } + + return options, nil +} + +func (opt *Options) getDialer() func(context.Context) (net.Conn, error) { + return func(ctx context.Context) (net.Conn, error) { + return opt.Dialer(ctx, opt.Network, opt.Addr) + } +} + +func newConnPool(opt *Options) *pool.ConnPool { + return pool.NewConnPool(&pool.Options{ + Dialer: opt.getDialer(), + OnClose: terminateConn, + + PoolSize: opt.PoolSize, + MinIdleConns: opt.MinIdleConns, + MaxConnAge: opt.MaxConnAge, + PoolTimeout: opt.PoolTimeout, + IdleTimeout: opt.IdleTimeout, + IdleCheckFrequency: opt.IdleCheckFrequency, + }) +} diff --git a/vendor/github.com/go-pg/pg/v10/orm/composite.go b/vendor/github.com/go-pg/pg/v10/orm/composite.go new file mode 100644 index 000000000..d2e48a8b3 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/orm/composite.go @@ -0,0 +1,100 @@ +package orm + +import ( + "fmt" + "reflect" + + "github.com/go-pg/pg/v10/internal/pool" + "github.com/go-pg/pg/v10/types" +) + +func compositeScanner(typ reflect.Type) types.ScannerFunc { + if typ.Kind() == reflect.Ptr { + typ = typ.Elem() + } + + var table *Table + return func(v reflect.Value, rd types.Reader, n int) error { + if n == -1 { + v.Set(reflect.Zero(v.Type())) + return nil + } + + if table == nil { + table = GetTable(typ) + } + if v.Kind() == reflect.Ptr { + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } + v = v.Elem() + } + + p := newCompositeParser(rd) + var elemReader *pool.BytesReader + + var firstErr error + for i := 0; ; i++ { + elem, err := p.NextElem() + if err != nil { + if err == errEndOfComposite { + break + } + return err + } + + if i >= len(table.Fields) { + if firstErr == nil { + firstErr = fmt.Errorf( + "pg: %s has %d fields, but composite requires at least %d values", + table, len(table.Fields), i) + } + continue + } + + if elemReader == nil { + elemReader = pool.NewBytesReader(elem) + } else { + elemReader.Reset(elem) + } + + field := table.Fields[i] + if elem == nil { + err = field.ScanValue(v, elemReader, -1) + } else { + err = field.ScanValue(v, elemReader, len(elem)) + } + if err != nil && firstErr == nil { + firstErr = err + } + } + + return firstErr + } +} + +func compositeAppender(typ reflect.Type) types.AppenderFunc { + if typ.Kind() == reflect.Ptr { + typ = typ.Elem() + } + + var table *Table + return func(b []byte, v reflect.Value, quote int) []byte { + if table == nil { + table = GetTable(typ) + } + if v.Kind() == reflect.Ptr { + v = v.Elem() + } + + b = append(b, "ROW("...) + for i, f := range table.Fields { + if i > 0 { + b = append(b, ',') + } + b = f.AppendValue(b, v, quote) + } + b = append(b, ')') + return b + } +} diff --git a/vendor/github.com/go-pg/pg/v10/orm/composite_create.go b/vendor/github.com/go-pg/pg/v10/orm/composite_create.go new file mode 100644 index 000000000..fd60a94e4 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/orm/composite_create.go @@ -0,0 +1,89 @@ +package orm + +import ( + "strconv" +) + +type CreateCompositeOptions struct { + Varchar int // replaces PostgreSQL data type `text` with `varchar(n)` +} + +type CreateCompositeQuery struct { + q *Query + opt *CreateCompositeOptions +} + +var ( + _ QueryAppender = (*CreateCompositeQuery)(nil) + _ QueryCommand = (*CreateCompositeQuery)(nil) +) + +func NewCreateCompositeQuery(q *Query, opt *CreateCompositeOptions) *CreateCompositeQuery { + return &CreateCompositeQuery{ + q: q, + opt: opt, + } +} + +func (q *CreateCompositeQuery) String() string { + b, err := q.AppendQuery(defaultFmter, nil) + if err != nil { + panic(err) + } + return string(b) +} + +func (q *CreateCompositeQuery) Operation() QueryOp { + return CreateCompositeOp +} + +func (q *CreateCompositeQuery) Clone() QueryCommand { + return &CreateCompositeQuery{ + q: q.q.Clone(), + opt: q.opt, + } +} + +func (q *CreateCompositeQuery) Query() *Query { + return q.q +} + +func (q *CreateCompositeQuery) AppendTemplate(b []byte) ([]byte, error) { + return q.AppendQuery(dummyFormatter{}, b) +} + +func (q *CreateCompositeQuery) AppendQuery(fmter QueryFormatter, b []byte) ([]byte, error) { + if q.q.stickyErr != nil { + return nil, q.q.stickyErr + } + if q.q.tableModel == nil { + return nil, errModelNil + } + + table := q.q.tableModel.Table() + + b = append(b, "CREATE TYPE "...) + b = append(b, table.Alias...) + b = append(b, " AS ("...) + + for i, field := range table.Fields { + if i > 0 { + b = append(b, ", "...) + } + + b = append(b, field.Column...) + b = append(b, " "...) + if field.UserSQLType == "" && q.opt != nil && q.opt.Varchar > 0 && + field.SQLType == "text" { + b = append(b, "varchar("...) + b = strconv.AppendInt(b, int64(q.opt.Varchar), 10) + b = append(b, ")"...) + } else { + b = append(b, field.SQLType...) + } + } + + b = append(b, ")"...) + + return b, q.q.stickyErr +} diff --git a/vendor/github.com/go-pg/pg/v10/orm/composite_drop.go b/vendor/github.com/go-pg/pg/v10/orm/composite_drop.go new file mode 100644 index 000000000..2a169b07a --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/orm/composite_drop.go @@ -0,0 +1,70 @@ +package orm + +type DropCompositeOptions struct { + IfExists bool + Cascade bool +} + +type DropCompositeQuery struct { + q *Query + opt *DropCompositeOptions +} + +var ( + _ QueryAppender = (*DropCompositeQuery)(nil) + _ QueryCommand = (*DropCompositeQuery)(nil) +) + +func NewDropCompositeQuery(q *Query, opt *DropCompositeOptions) *DropCompositeQuery { + return &DropCompositeQuery{ + q: q, + opt: opt, + } +} + +func (q *DropCompositeQuery) String() string { + b, err := q.AppendQuery(defaultFmter, nil) + if err != nil { + panic(err) + } + return string(b) +} + +func (q *DropCompositeQuery) Operation() QueryOp { + return DropCompositeOp +} + +func (q *DropCompositeQuery) Clone() QueryCommand { + return &DropCompositeQuery{ + q: q.q.Clone(), + opt: q.opt, + } +} + +func (q *DropCompositeQuery) Query() *Query { + return q.q +} + +func (q *DropCompositeQuery) AppendTemplate(b []byte) ([]byte, error) { + return q.AppendQuery(dummyFormatter{}, b) +} + +func (q *DropCompositeQuery) AppendQuery(fmter QueryFormatter, b []byte) ([]byte, error) { + if q.q.stickyErr != nil { + return nil, q.q.stickyErr + } + if q.q.tableModel == nil { + return nil, errModelNil + } + + b = append(b, "DROP TYPE "...) + if q.opt != nil && q.opt.IfExists { + b = append(b, "IF EXISTS "...) + } + b = append(b, q.q.tableModel.Table().Alias...) + if q.opt != nil && q.opt.Cascade { + b = append(b, " CASCADE"...) + } + + return b, q.q.stickyErr +} diff --git a/vendor/github.com/go-pg/pg/v10/orm/composite_parser.go b/vendor/github.com/go-pg/pg/v10/orm/composite_parser.go new file mode 100644 index 000000000..29e500444 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/orm/composite_parser.go @@ -0,0 +1,140 @@ +package orm + +import ( + "bufio" + "errors" + "fmt" + "io" + + "github.com/go-pg/pg/v10/internal/parser" + "github.com/go-pg/pg/v10/types" +) + +var errEndOfComposite = errors.New("pg: end of composite") + +type compositeParser struct { + p parser.StreamingParser + + stickyErr error +} + +func newCompositeParserErr(err error) *compositeParser { + return &compositeParser{ + stickyErr: err, + } +} + +func newCompositeParser(rd types.Reader) *compositeParser { + p := parser.NewStreamingParser(rd) + err := p.SkipByte('(') + if err != nil { + return newCompositeParserErr(err) + } + return &compositeParser{ + p: p, + } +} + +func (p *compositeParser) NextElem() ([]byte, error) { + if p.stickyErr != nil { + return nil, p.stickyErr + } + + c, err := p.p.ReadByte() + if err != nil { + if err == io.EOF { + return nil, errEndOfComposite + } + return nil, err + } + + switch c { + case '"': + return p.readQuoted() + case ',': + return nil, nil + case ')': + return nil, errEndOfComposite + default: + _ = p.p.UnreadByte() + } + + var b []byte + for { + tmp, err := p.p.ReadSlice(',') + if err == nil { + if b == nil { + b = tmp + } else { + b = append(b, tmp...) + } + b = b[:len(b)-1] + break + } + b = append(b, tmp...) + if err == bufio.ErrBufferFull { + continue + } + if err == io.EOF { + if b[len(b)-1] == ')' { + b = b[:len(b)-1] + break + } + } + return nil, err + } + + if len(b) == 0 { // NULL + return nil, nil + } + return b, nil +} + +func (p *compositeParser) readQuoted() ([]byte, error) { + var b []byte + + c, err := p.p.ReadByte() + if err != nil { + return nil, err + } + + for { + next, err := p.p.ReadByte() + if err != nil { + return nil, err + } + + if c == '\\' || c == '\'' { + if next == c { + b = append(b, c) + c, err = p.p.ReadByte() + if err != nil { + return nil, err + } + } else { + b = append(b, c) + c = next + } + continue + } + + if c == '"' { + switch next { + case '"': + b = append(b, '"') + c, err = p.p.ReadByte() + if err != nil { + return nil, err + } + case ',', ')': + return b, nil + default: + return nil, fmt.Errorf("pg: got %q, wanted ',' or ')'", c) + } + continue + } + + b = append(b, c) + c = next + } +} diff --git a/vendor/github.com/go-pg/pg/v10/orm/count_estimate.go b/vendor/github.com/go-pg/pg/v10/orm/count_estimate.go new file mode 100644 index 000000000..bfa664a72 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/orm/count_estimate.go @@ -0,0 +1,90 @@ +package orm + +import ( + "fmt" + + "github.com/go-pg/pg/v10/internal" +) + +// Placeholder that is replaced with count(*). +const placeholder = `'_go_pg_placeholder'` + +// https://wiki.postgresql.org/wiki/Count_estimate +//nolint +var pgCountEstimateFunc = fmt.Sprintf(` +CREATE OR REPLACE FUNCTION _go_pg_count_estimate_v2(query text, threshold int) +RETURNS int AS $$ +DECLARE + rec record; + nrows int; +BEGIN + FOR rec IN EXECUTE 'EXPLAIN ' || query LOOP + nrows := substring(rec."QUERY PLAN" FROM ' rows=(\d+)'); + EXIT WHEN nrows IS NOT NULL; + END LOOP; + + -- Return the estimation if there are too many rows. + IF nrows > threshold THEN + RETURN nrows; + END IF; + + -- Otherwise execute real count query. + query := replace(query, 'SELECT '%s'', 'SELECT count(*)'); + EXECUTE query INTO nrows; + + IF nrows IS NULL THEN + nrows := 0; + END IF; + + RETURN nrows; +END; +$$ LANGUAGE plpgsql; +`, placeholder) + +// CountEstimate uses EXPLAIN to get estimated number of rows returned the query. +// If that number is bigger than the threshold it returns the estimation. +// Otherwise it executes another query using count aggregate function and +// returns the result. +// +// Based on https://wiki.postgresql.org/wiki/Count_estimate +func (q *Query) CountEstimate(threshold int) (int, error) { + if q.stickyErr != nil { + return 0, q.stickyErr + } + + query, err := q.countSelectQuery(placeholder).AppendQuery(q.db.Formatter(), nil) + if err != nil { + return 0, err + } + + for i := 0; i < 3; i++ { + var count int + _, err = q.db.QueryOneContext( + q.ctx, + Scan(&count), + "SELECT _go_pg_count_estimate_v2(?, ?)", + string(query), threshold, + ) + if err != nil { + if pgerr, ok := err.(internal.PGError); ok && pgerr.Field('C') == "42883" { + // undefined_function + err = q.createCountEstimateFunc() + if err != nil { + pgerr, ok := err.(internal.PGError) + if !ok || !pgerr.IntegrityViolation() { + return 0, err + } + } + continue + } + } + return count, err + } + + return 0, err +} + +func (q *Query) createCountEstimateFunc() error { + _, err := q.db.ExecContext(q.ctx, pgCountEstimateFunc) + return err +} diff --git a/vendor/github.com/go-pg/pg/v10/orm/delete.go b/vendor/github.com/go-pg/pg/v10/orm/delete.go new file mode 100644 index 000000000..c54cd10f8 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/orm/delete.go @@ -0,0 +1,158 @@ +package orm + +import ( + "reflect" + + "github.com/go-pg/pg/v10/types" +) + +type DeleteQuery struct { + q *Query + placeholder bool +} + +var ( + _ QueryAppender = (*DeleteQuery)(nil) + _ QueryCommand = (*DeleteQuery)(nil) +) + +func NewDeleteQuery(q *Query) *DeleteQuery { + return &DeleteQuery{ + q: q, + } +} + +func (q *DeleteQuery) String() string { + b, err := q.AppendQuery(defaultFmter, nil) + if err != nil { + panic(err) + } + return string(b) +} + +func (q *DeleteQuery) Operation() QueryOp { + return DeleteOp +} + +func (q *DeleteQuery) Clone() QueryCommand { + return &DeleteQuery{ + q: q.q.Clone(), + placeholder: q.placeholder, + } +} + +func (q *DeleteQuery) Query() *Query { + return q.q +} + +func (q *DeleteQuery) AppendTemplate(b []byte) ([]byte, error) { + cp := q.Clone().(*DeleteQuery) + cp.placeholder = true + return cp.AppendQuery(dummyFormatter{}, b) +} + +func (q *DeleteQuery) AppendQuery(fmter QueryFormatter, b []byte) (_ []byte, err error) { + if q.q.stickyErr != nil { + return nil, q.q.stickyErr + } + + if len(q.q.with) > 0 { + b, err = q.q.appendWith(fmter, b) + if err != nil { + return nil, err + } + } + + b = append(b, "DELETE FROM "...) + b, err = q.q.appendFirstTableWithAlias(fmter, b) + if err != nil { + return nil, err + } + + if q.q.hasMultiTables() { + b = append(b, " USING "...) + b, err = q.q.appendOtherTables(fmter, b) + if err != nil { + return nil, err + } + } + + b = append(b, " WHERE "...) + value := q.q.tableModel.Value() + + if q.q.isSliceModelWithData() { + if len(q.q.where) > 0 { + b, err = q.q.appendWhere(fmter, b) + if err != nil { + return nil, err + } + } else { + table := q.q.tableModel.Table() + err = table.checkPKs() + if err != nil { + return nil, err + } + + b = appendColumnAndSliceValue(fmter, b, value, table.Alias, table.PKs) + } + } else { + b, err = q.q.mustAppendWhere(fmter, b) + if err != nil { + return nil, err + } + } + + if len(q.q.returning) > 0 { + b, err = q.q.appendReturning(fmter, b) + if err != nil { + return nil, err + } + } + + return b, q.q.stickyErr +} + +func appendColumnAndSliceValue( + fmter QueryFormatter, b []byte, slice reflect.Value, alias types.Safe, fields []*Field, +) []byte { + if len(fields) > 1 { + b = append(b, '(') + } + b = appendColumns(b, alias, fields) + if len(fields) > 1 { + b = append(b, ')') + } + + b = append(b, " IN ("...) + + isPlaceholder := isTemplateFormatter(fmter) + sliceLen := slice.Len() + for i := 0; i < sliceLen; i++ { + if i > 0 { + b = append(b, ", "...) + } + + el := indirect(slice.Index(i)) + + if len(fields) > 1 { + b = append(b, '(') + } + for i, f := range fields { + if i > 0 { + b = append(b, ", "...) + } + if isPlaceholder { + b = append(b, '?') + } else { + b = f.AppendValue(b, el, 1) + } + } + if len(fields) > 1 { + b = append(b, ')') + } + } + + b = append(b, ')') + + return b +} diff --git a/vendor/github.com/go-pg/pg/v10/orm/field.go b/vendor/github.com/go-pg/pg/v10/orm/field.go new file mode 100644 index 000000000..fe9b4abea --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/orm/field.go @@ -0,0 +1,146 @@ +package orm + +import ( + "fmt" + "reflect" + + "github.com/go-pg/pg/v10/types" + "github.com/go-pg/zerochecker" +) + +const ( + PrimaryKeyFlag = uint8(1) << iota + ForeignKeyFlag + NotNullFlag + UseZeroFlag + UniqueFlag + ArrayFlag +) + +type Field struct { + Field reflect.StructField + Type reflect.Type + Index []int + + GoName string // struct field name, e.g. Id + SQLName string // SQL name, .e.g. id + Column types.Safe // escaped SQL name, e.g. "id" + SQLType string + UserSQLType string + Default types.Safe + OnDelete string + OnUpdate string + + flags uint8 + + append types.AppenderFunc + scan types.ScannerFunc + + isZero zerochecker.Func +} + +func indexEqual(ind1, ind2 []int) bool { + if len(ind1) != len(ind2) { + return false + } + for i, ind := range ind1 { + if ind != ind2[i] { + return false + } + } + return true +} + +func (f *Field) Clone() *Field { + cp := *f + cp.Index = cp.Index[:len(f.Index):len(f.Index)] + return &cp +} + +func (f *Field) setFlag(flag uint8) { + f.flags |= flag +} + +func (f *Field) hasFlag(flag uint8) bool { + return f.flags&flag != 0 +} + +func (f *Field) Value(strct reflect.Value) reflect.Value { + return fieldByIndexAlloc(strct, f.Index) +} + +func (f *Field) HasZeroValue(strct reflect.Value) bool { + return f.hasZeroValue(strct, f.Index) +} + +func (f *Field) hasZeroValue(v reflect.Value, index []int) bool { + for _, idx := range index { + if v.Kind() == reflect.Ptr { + if v.IsNil() { + return true + } + v = v.Elem() + } + v = v.Field(idx) + } + return f.isZero(v) +} + +func (f *Field) NullZero() bool { + return !f.hasFlag(UseZeroFlag) +} + +func (f *Field) AppendValue(b []byte, strct reflect.Value, quote int) []byte { + fv, ok := fieldByIndex(strct, f.Index) + if !ok { + return types.AppendNull(b, quote) + } + + if f.NullZero() && f.isZero(fv) { + return types.AppendNull(b, quote) + } + if f.append == nil { + panic(fmt.Errorf("pg: AppendValue(unsupported %s)", fv.Type())) + } + return f.append(b, fv, quote) +} + +func (f *Field) ScanValue(strct reflect.Value, rd types.Reader, n int) error { + if f.scan == nil { + return fmt.Errorf("pg: ScanValue(unsupported %s)", f.Type) + } + + var fv reflect.Value + if n == -1 { + var ok bool + fv, ok = fieldByIndex(strct, f.Index) + if !ok { + return nil + } + } else { + fv = fieldByIndexAlloc(strct, f.Index) + } + + return f.scan(fv, rd, n) +} + +type Method struct { + Index int + + flags int8 + + appender func([]byte, reflect.Value, int) []byte +} + +func (m *Method) Has(flag int8) bool { + return m.flags&flag != 0 +} + +func (m *Method) Value(strct reflect.Value) reflect.Value { + return strct.Method(m.Index).Call(nil)[0] +} + +func (m *Method) AppendValue(dst []byte, strct reflect.Value, quote int) []byte { + mv := m.Value(strct) + return m.appender(dst, mv, quote) +} diff --git a/vendor/github.com/go-pg/pg/v10/orm/format.go b/vendor/github.com/go-pg/pg/v10/orm/format.go new file mode 100644 index 000000000..9945f6e1d --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/orm/format.go @@ -0,0 +1,333 @@ +package orm + +import ( + "bytes" + "fmt" + "sort" + "strconv" + "strings" + + "github.com/go-pg/pg/v10/internal" + "github.com/go-pg/pg/v10/internal/parser" + "github.com/go-pg/pg/v10/types" +) + +var defaultFmter = NewFormatter() + +type queryWithSepAppender interface { + QueryAppender + AppendSep([]byte) []byte +} + +//------------------------------------------------------------------------------ + +type SafeQueryAppender struct { + query string + params []interface{} +} + +var ( + _ QueryAppender = (*SafeQueryAppender)(nil) + _ types.ValueAppender = (*SafeQueryAppender)(nil) +) + +//nolint +func SafeQuery(query string, params ...interface{}) *SafeQueryAppender { + return &SafeQueryAppender{query, params} +} + +func (q *SafeQueryAppender) AppendQuery(fmter QueryFormatter, b []byte) ([]byte, error) { + return fmter.FormatQuery(b, q.query, q.params...), nil +} + +func (q *SafeQueryAppender) AppendValue(b []byte, quote int) ([]byte, error) { + return q.AppendQuery(defaultFmter, b) +} + +func (q *SafeQueryAppender) Value() types.Safe { + b, err := q.AppendValue(nil, 1) + if err != nil { + return types.Safe(err.Error()) + } + return types.Safe(internal.BytesToString(b)) +} + +//------------------------------------------------------------------------------ + +type condGroupAppender struct { + sep string + cond []queryWithSepAppender +} + +var ( + _ QueryAppender = (*condAppender)(nil) + _ queryWithSepAppender = (*condAppender)(nil) +) + +func (q *condGroupAppender) AppendSep(b []byte) []byte { + return append(b, q.sep...) +} + +func (q *condGroupAppender) AppendQuery(fmter QueryFormatter, b []byte) (_ []byte, err error) { + b = append(b, '(') + for i, app := range q.cond { + if i > 0 { + b = app.AppendSep(b) + } + b, err = app.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + b = append(b, ')') + return b, nil +} + +//------------------------------------------------------------------------------ + +type condAppender struct { + sep string + cond string + params []interface{} +} + +var ( + _ QueryAppender = (*condAppender)(nil) + _ queryWithSepAppender = (*condAppender)(nil) +) + +func (q *condAppender) AppendSep(b []byte) []byte { + return append(b, q.sep...) +} + +func (q *condAppender) AppendQuery(fmter QueryFormatter, b []byte) ([]byte, error) { + b = append(b, '(') + b = fmter.FormatQuery(b, q.cond, q.params...) + b = append(b, ')') + return b, nil +} + +//------------------------------------------------------------------------------ + +type fieldAppender struct { + field string +} + +var _ QueryAppender = (*fieldAppender)(nil) + +func (a fieldAppender) AppendQuery(fmter QueryFormatter, b []byte) ([]byte, error) { + return types.AppendIdent(b, a.field, 1), nil +} + +//------------------------------------------------------------------------------ + +type dummyFormatter struct{} + +func (f dummyFormatter) FormatQuery(b []byte, query string, params ...interface{}) []byte { + return append(b, query...) +} + +func isTemplateFormatter(fmter QueryFormatter) bool { + _, ok := fmter.(dummyFormatter) + return ok +} + +//------------------------------------------------------------------------------ + +type QueryFormatter interface { + FormatQuery(b []byte, query string, params ...interface{}) []byte +} + +type Formatter struct { + namedParams map[string]interface{} + model TableModel +} + +var _ QueryFormatter = (*Formatter)(nil) + +func NewFormatter() *Formatter { + return new(Formatter) +} + +func (f *Formatter) String() string { + if len(f.namedParams) == 0 { + return "" + } + + keys := make([]string, len(f.namedParams)) + index := 0 + for k := range f.namedParams { + keys[index] = k + index++ + } + + sort.Strings(keys) + + ss := make([]string, len(keys)) + for i, k := range keys { + ss[i] = fmt.Sprintf("%s=%v", k, f.namedParams[k]) + } + return " " + strings.Join(ss, " ") +} + +func (f *Formatter) clone() *Formatter { + cp := NewFormatter() + + cp.model = f.model + if len(f.namedParams) > 0 { + cp.namedParams = make(map[string]interface{}, len(f.namedParams)) + } + for param, value := range f.namedParams { + cp.setParam(param, value) + } + + return cp +} + +func (f *Formatter) WithTableModel(model TableModel) *Formatter { + cp := f.clone() + cp.model = model + return cp +} + +func (f *Formatter) WithModel(model interface{}) *Formatter { + switch model := model.(type) { + case TableModel: + return f.WithTableModel(model) + case *Query: + return f.WithTableModel(model.tableModel) + case QueryCommand: + return f.WithTableModel(model.Query().tableModel) + default: + panic(fmt.Errorf("pg: unsupported model %T", model)) + } +} + +func (f *Formatter) setParam(param string, value interface{}) { + if f.namedParams == nil { + f.namedParams = make(map[string]interface{}) + } + f.namedParams[param] = value +} + +func (f *Formatter) WithParam(param string, value interface{}) *Formatter { + cp := f.clone() + cp.setParam(param, value) + return cp +} + +func (f *Formatter) Param(param string) interface{} { + return f.namedParams[param] +} + +func (f *Formatter) hasParams() bool { + return len(f.namedParams) > 0 || f.model != nil +} + +func (f *Formatter) FormatQueryBytes(dst, query []byte, params ...interface{}) []byte { + if (params == nil && !f.hasParams()) || bytes.IndexByte(query, '?') == -1 { + return append(dst, query...) + } + return f.append(dst, parser.New(query), params) +} + +func (f *Formatter) FormatQuery(dst []byte, query string, params ...interface{}) []byte { + if (params == nil && !f.hasParams()) || strings.IndexByte(query, '?') == -1 { + return append(dst, query...) + } + return f.append(dst, parser.NewString(query), params) +} + +func (f *Formatter) append(dst []byte, p *parser.Parser, params []interface{}) []byte { + var paramsIndex int + var namedParamsOnce bool + var tableParams *tableParams + + for p.Valid() { + b, ok := p.ReadSep('?') + if !ok { + dst = append(dst, b...) + continue + } + if len(b) > 0 && b[len(b)-1] == '\\' { + dst = append(dst, b[:len(b)-1]...) + dst = append(dst, '?') + continue + } + dst = append(dst, b...) + + id, numeric := p.ReadIdentifier() + if id != "" { + if numeric { + idx, err := strconv.Atoi(id) + if err != nil { + goto restore_param + } + + if idx >= len(params) { + goto restore_param + } + + dst = f.appendParam(dst, params[idx]) + continue + } + + if f.namedParams != nil { + param, paramOK := f.namedParams[id] + if paramOK { + dst = f.appendParam(dst, param) + continue + } + } + + if !namedParamsOnce && len(params) > 0 { + namedParamsOnce = true + tableParams, _ = newTableParams(params[len(params)-1]) + } + + if tableParams != nil { + dst, ok = tableParams.AppendParam(f, dst, id) + if ok { + continue + } + } + + if f.model != nil { + dst, ok = f.model.AppendParam(f, dst, id) + if ok { + continue + } + } + + restore_param: + dst = append(dst, '?') + dst = append(dst, id...) + continue + } + + if paramsIndex >= len(params) { + dst = append(dst, '?') + continue + } + + param := params[paramsIndex] + paramsIndex++ + + dst = f.appendParam(dst, param) + } + + return dst +} + +func (f *Formatter) appendParam(b []byte, param interface{}) []byte { + switch param := param.(type) { + case QueryAppender: + bb, err := param.AppendQuery(f, b) + if err != nil { + return types.AppendError(b, err) + } + return bb + default: + return types.Append(b, param, 1) + } +} diff --git a/vendor/github.com/go-pg/pg/v10/orm/hook.go b/vendor/github.com/go-pg/pg/v10/orm/hook.go new file mode 100644 index 000000000..78bd10310 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/orm/hook.go @@ -0,0 +1,248 @@ +package orm + +import ( + "context" + "reflect" +) + +type hookStubs struct{} + +var ( + _ AfterScanHook = (*hookStubs)(nil) + _ AfterSelectHook = (*hookStubs)(nil) + _ BeforeInsertHook = (*hookStubs)(nil) + _ AfterInsertHook = (*hookStubs)(nil) + _ BeforeUpdateHook = (*hookStubs)(nil) + _ AfterUpdateHook = (*hookStubs)(nil) + _ BeforeDeleteHook = (*hookStubs)(nil) + _ AfterDeleteHook = (*hookStubs)(nil) +) + +func (hookStubs) AfterScan(ctx context.Context) error { + return nil +} + +func (hookStubs) AfterSelect(ctx context.Context) error { + return nil +} + +func (hookStubs) BeforeInsert(ctx context.Context) (context.Context, error) { + return ctx, nil +} + +func (hookStubs) AfterInsert(ctx context.Context) error { + return nil +} + +func (hookStubs) BeforeUpdate(ctx context.Context) (context.Context, error) { + return ctx, nil +} + +func (hookStubs) AfterUpdate(ctx context.Context) error { + return nil +} + +func (hookStubs) BeforeDelete(ctx context.Context) (context.Context, error) { + return ctx, nil +} + +func (hookStubs) AfterDelete(ctx context.Context) error { + return nil +} + +func callHookSlice( + ctx context.Context, + slice reflect.Value, + ptr bool, + hook func(context.Context, reflect.Value) (context.Context, error), +) (context.Context, error) { + var firstErr error + sliceLen := slice.Len() + for i := 0; i < sliceLen; i++ { + v := slice.Index(i) + if !ptr { + v = v.Addr() + } + + var err error + ctx, err = hook(ctx, v) + if err != nil && firstErr == nil { + firstErr = err + } + } + return ctx, firstErr +} + +func callHookSlice2( + ctx context.Context, + slice reflect.Value, + ptr bool, + hook func(context.Context, reflect.Value) error, +) error { + var firstErr error + if slice.IsValid() { + sliceLen := slice.Len() + for i := 0; i < sliceLen; i++ { + v := slice.Index(i) + if !ptr { + v = v.Addr() + } + + err := hook(ctx, v) + if err != nil && firstErr == nil { + firstErr = err + } + } + } + return firstErr +} + +//------------------------------------------------------------------------------ + +type BeforeScanHook interface { + BeforeScan(context.Context) error +} + +var beforeScanHookType = reflect.TypeOf((*BeforeScanHook)(nil)).Elem() + +func callBeforeScanHook(ctx context.Context, v reflect.Value) error { + return v.Interface().(BeforeScanHook).BeforeScan(ctx) +} + +//------------------------------------------------------------------------------ + +type AfterScanHook interface { + AfterScan(context.Context) error +} + +var afterScanHookType = reflect.TypeOf((*AfterScanHook)(nil)).Elem() + +func callAfterScanHook(ctx context.Context, v reflect.Value) error { + return v.Interface().(AfterScanHook).AfterScan(ctx) +} + +//------------------------------------------------------------------------------ + +type AfterSelectHook interface { + AfterSelect(context.Context) error +} + +var afterSelectHookType = reflect.TypeOf((*AfterSelectHook)(nil)).Elem() + +func callAfterSelectHook(ctx context.Context, v reflect.Value) error { + return v.Interface().(AfterSelectHook).AfterSelect(ctx) +} + +func callAfterSelectHookSlice( + ctx context.Context, slice reflect.Value, ptr bool, +) error { + return callHookSlice2(ctx, slice, ptr, callAfterSelectHook) +} + +//------------------------------------------------------------------------------ + +type BeforeInsertHook interface { + BeforeInsert(context.Context) (context.Context, error) +} + +var beforeInsertHookType = reflect.TypeOf((*BeforeInsertHook)(nil)).Elem() + +func callBeforeInsertHook(ctx context.Context, v reflect.Value) (context.Context, error) { + return v.Interface().(BeforeInsertHook).BeforeInsert(ctx) +} + +func callBeforeInsertHookSlice( + ctx context.Context, slice reflect.Value, ptr bool, +) (context.Context, error) { + return callHookSlice(ctx, slice, ptr, callBeforeInsertHook) +} + +//------------------------------------------------------------------------------ + +type AfterInsertHook interface { + AfterInsert(context.Context) error +} + +var afterInsertHookType = reflect.TypeOf((*AfterInsertHook)(nil)).Elem() + +func callAfterInsertHook(ctx context.Context, v reflect.Value) error { + return v.Interface().(AfterInsertHook).AfterInsert(ctx) +} + +func callAfterInsertHookSlice( + ctx context.Context, slice reflect.Value, ptr bool, +) error { + return callHookSlice2(ctx, slice, ptr, callAfterInsertHook) +} + +//------------------------------------------------------------------------------ + +type BeforeUpdateHook interface { + BeforeUpdate(context.Context) (context.Context, error) +} + +var beforeUpdateHookType = reflect.TypeOf((*BeforeUpdateHook)(nil)).Elem() + +func callBeforeUpdateHook(ctx context.Context, v reflect.Value) (context.Context, error) { + return v.Interface().(BeforeUpdateHook).BeforeUpdate(ctx) +} + +func callBeforeUpdateHookSlice( + ctx context.Context, slice reflect.Value, ptr bool, +) (context.Context, error) { + return callHookSlice(ctx, slice, ptr, callBeforeUpdateHook) +} + +//------------------------------------------------------------------------------ + +type AfterUpdateHook interface { + AfterUpdate(context.Context) error +} + +var afterUpdateHookType = reflect.TypeOf((*AfterUpdateHook)(nil)).Elem() + +func callAfterUpdateHook(ctx context.Context, v reflect.Value) error { + return v.Interface().(AfterUpdateHook).AfterUpdate(ctx) +} + +func callAfterUpdateHookSlice( + ctx context.Context, slice reflect.Value, ptr bool, +) error { + return callHookSlice2(ctx, slice, ptr, callAfterUpdateHook) +} + +//------------------------------------------------------------------------------ + +type BeforeDeleteHook interface { + BeforeDelete(context.Context) (context.Context, error) +} + +var beforeDeleteHookType = reflect.TypeOf((*BeforeDeleteHook)(nil)).Elem() + +func callBeforeDeleteHook(ctx context.Context, v reflect.Value) (context.Context, error) { + return v.Interface().(BeforeDeleteHook).BeforeDelete(ctx) +} + +func callBeforeDeleteHookSlice( + ctx context.Context, slice reflect.Value, ptr bool, +) (context.Context, error) { + return callHookSlice(ctx, slice, ptr, callBeforeDeleteHook) +} + +//------------------------------------------------------------------------------ + +type AfterDeleteHook interface { + AfterDelete(context.Context) error +} + +var afterDeleteHookType = reflect.TypeOf((*AfterDeleteHook)(nil)).Elem() + +func callAfterDeleteHook(ctx context.Context, v reflect.Value) error { + return v.Interface().(AfterDeleteHook).AfterDelete(ctx) +} + +func callAfterDeleteHookSlice( + ctx context.Context, slice reflect.Value, ptr bool, +) error { + return callHookSlice2(ctx, slice, ptr, callAfterDeleteHook) +} diff --git a/vendor/github.com/go-pg/pg/v10/orm/insert.go b/vendor/github.com/go-pg/pg/v10/orm/insert.go new file mode 100644 index 000000000..a7a543576 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/orm/insert.go @@ -0,0 +1,345 @@ +package orm + +import ( + "fmt" + "reflect" + "sort" + + "github.com/go-pg/pg/v10/types" +) + +type InsertQuery struct { + q *Query + returningFields []*Field + placeholder bool +} + +var _ QueryCommand = (*InsertQuery)(nil) + +func NewInsertQuery(q *Query) *InsertQuery { + return &InsertQuery{ + q: q, + } +} + +func (q *InsertQuery) String() string { + b, err := q.AppendQuery(defaultFmter, nil) + if err != nil { + panic(err) + } + return string(b) +} + +func (q *InsertQuery) Operation() QueryOp { + return InsertOp +} + +func (q *InsertQuery) Clone() QueryCommand { + return &InsertQuery{ + q: q.q.Clone(), + placeholder: q.placeholder, + } +} + +func (q *InsertQuery) Query() *Query { + return q.q +} + +var _ TemplateAppender = (*InsertQuery)(nil) + +func (q *InsertQuery) AppendTemplate(b []byte) ([]byte, error) { + cp := q.Clone().(*InsertQuery) + cp.placeholder = true + return cp.AppendQuery(dummyFormatter{}, b) +} + +var _ QueryAppender = (*InsertQuery)(nil) + +func (q *InsertQuery) AppendQuery(fmter QueryFormatter, b []byte) (_ []byte, err error) { + if q.q.stickyErr != nil { + return nil, q.q.stickyErr + } + + if len(q.q.with) > 0 { + b, err = q.q.appendWith(fmter, b) + if err != nil { + return nil, err + } + } + + b = append(b, "INSERT INTO "...) + if q.q.onConflict != nil { + b, err = q.q.appendFirstTableWithAlias(fmter, b) + } else { + b, err = q.q.appendFirstTable(fmter, b) + } + if err != nil { + return nil, err + } + + b, err = q.appendColumnsValues(fmter, b) + if err != nil { + return nil, err + } + + if q.q.onConflict != nil { + b = append(b, " ON CONFLICT "...) + b, err = q.q.onConflict.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + + if q.q.onConflictDoUpdate() { + if len(q.q.set) > 0 { + b, err = q.q.appendSet(fmter, b) + if err != nil { + return nil, err + } + } else { + fields, err := q.q.getDataFields() + if err != nil { + return nil, err + } + + if len(fields) == 0 { + fields = q.q.tableModel.Table().DataFields + } + + b = q.appendSetExcluded(b, fields) + } + + if len(q.q.updWhere) > 0 { + b = append(b, " WHERE "...) + b, err = q.q.appendUpdWhere(fmter, b) + if err != nil { + return nil, err + } + } + } + } + + if len(q.q.returning) > 0 { + b, err = q.q.appendReturning(fmter, b) + if err != nil { + return nil, err + } + } else if len(q.returningFields) > 0 { + b = appendReturningFields(b, q.returningFields) + } + + return b, q.q.stickyErr +} + +func (q *InsertQuery) appendColumnsValues(fmter QueryFormatter, b []byte) (_ []byte, err error) { + if q.q.hasMultiTables() { + if q.q.columns != nil { + b = append(b, " ("...) + b, err = q.q.appendColumns(fmter, b) + if err != nil { + return nil, err + } + b = append(b, ")"...) + } + + b = append(b, " SELECT * FROM "...) + b, err = q.q.appendOtherTables(fmter, b) + if err != nil { + return nil, err + } + + return b, nil + } + + if m, ok := q.q.model.(*mapModel); ok { + return q.appendMapColumnsValues(b, m.m), nil + } + + if !q.q.hasTableModel() { + return nil, errModelNil + } + + fields, err := q.q.getFields() + if err != nil { + return nil, err + } + + if len(fields) == 0 { + fields = q.q.tableModel.Table().Fields + } + value := q.q.tableModel.Value() + + b = append(b, " ("...) + b = q.appendColumns(b, fields) + b = append(b, ") VALUES ("...) + if m, ok := q.q.tableModel.(*sliceTableModel); ok { + if m.sliceLen == 0 { + err = fmt.Errorf("pg: can't bulk-insert empty slice %s", value.Type()) + return nil, err + } + b, err = q.appendSliceValues(fmter, b, fields, value) + if err != nil { + return nil, err + } + } else { + b, err = q.appendValues(fmter, b, fields, value) + if err != nil { + return nil, err + } + } + b = append(b, ")"...) + + return b, nil +} + +func (q *InsertQuery) appendMapColumnsValues(b []byte, m map[string]interface{}) []byte { + keys := make([]string, 0, len(m)) + + for k := range m { + keys = append(keys, k) + } + sort.Strings(keys) + + b = append(b, " ("...) + + for i, k := range keys { + if i > 0 { + b = append(b, ", "...) + } + b = types.AppendIdent(b, k, 1) + } + + b = append(b, ") VALUES ("...) + + for i, k := range keys { + if i > 0 { + b = append(b, ", "...) + } + if q.placeholder { + b = append(b, '?') + } else { + b = types.Append(b, m[k], 1) + } + } + + b = append(b, ")"...) + + return b +} + +func (q *InsertQuery) appendValues( + fmter QueryFormatter, b []byte, fields []*Field, strct reflect.Value, +) (_ []byte, err error) { + for i, f := range fields { + if i > 0 { + b = append(b, ", "...) + } + + app, ok := q.q.modelValues[f.SQLName] + if ok { + b, err = app.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + q.addReturningField(f) + continue + } + + switch { + case q.placeholder: + b = append(b, '?') + case (f.Default != "" || f.NullZero()) && f.HasZeroValue(strct): + b = append(b, "DEFAULT"...) + q.addReturningField(f) + default: + b = f.AppendValue(b, strct, 1) + } + } + + for i, v := range q.q.extraValues { + if i > 0 || len(fields) > 0 { + b = append(b, ", "...) + } + + b, err = v.value.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + + return b, nil +} + +func (q *InsertQuery) appendSliceValues( + fmter QueryFormatter, b []byte, fields []*Field, slice reflect.Value, +) (_ []byte, err error) { + if q.placeholder { + return q.appendValues(fmter, b, fields, reflect.Value{}) + } + + sliceLen := slice.Len() + for i := 0; i < sliceLen; i++ { + if i > 0 { + b = append(b, "), ("...) + } + el := indirect(slice.Index(i)) + b, err = q.appendValues(fmter, b, fields, el) + if err != nil { + return nil, err + } + } + + for i, v := range q.q.extraValues { + if i > 0 || len(fields) > 0 { + b = append(b, ", "...) + } + + b, err = v.value.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + + return b, nil +} + +func (q *InsertQuery) addReturningField(field *Field) { + if len(q.q.returning) > 0 { + return + } + for _, f := range q.returningFields { + if f == field { + return + } + } + q.returningFields = append(q.returningFields, field) +} + +func (q *InsertQuery) appendSetExcluded(b []byte, fields []*Field) []byte { + b = append(b, " SET "...) + for i, f := range fields { + if i > 0 { + b = append(b, ", "...) + } + b = append(b, f.Column...) + b = append(b, " = EXCLUDED."...) + b = append(b, f.Column...) + } + return b +} + +func (q *InsertQuery) appendColumns(b []byte, fields []*Field) []byte { + b = appendColumns(b, "", fields) + for i, v := range q.q.extraValues { + if i > 0 || len(fields) > 0 { + b = append(b, ", "...) + } + b = types.AppendIdent(b, v.column, 1) + } + return b +} + +func appendReturningFields(b []byte, fields []*Field) []byte { + b = append(b, " RETURNING "...) + b = appendColumns(b, "", fields) + return b +} diff --git a/vendor/github.com/go-pg/pg/v10/orm/join.go b/vendor/github.com/go-pg/pg/v10/orm/join.go new file mode 100644 index 000000000..2b64ba1b8 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/orm/join.go @@ -0,0 +1,351 @@ +package orm + +import ( + "reflect" + + "github.com/go-pg/pg/v10/internal" + "github.com/go-pg/pg/v10/types" +) + +type join struct { + Parent *join + BaseModel TableModel + JoinModel TableModel + Rel *Relation + + ApplyQuery func(*Query) (*Query, error) + Columns []string + on []*condAppender +} + +func (j *join) AppendOn(app *condAppender) { + j.on = append(j.on, app) +} + +func (j *join) Select(fmter QueryFormatter, q *Query) error { + switch j.Rel.Type { + case HasManyRelation: + return j.selectMany(fmter, q) + case Many2ManyRelation: + return j.selectM2M(fmter, q) + } + panic("not reached") +} + +func (j *join) selectMany(_ QueryFormatter, q *Query) error { + q, err := j.manyQuery(q) + if err != nil { + return err + } + if q == nil { + return nil + } + return q.Select() +} + +func (j *join) manyQuery(q *Query) (*Query, error) { + manyModel := newManyModel(j) + if manyModel == nil { + return nil, nil + } + + q = q.Model(manyModel) + if j.ApplyQuery != nil { + var err error + q, err = j.ApplyQuery(q) + if err != nil { + return nil, err + } + } + + if len(q.columns) == 0 { + q.columns = append(q.columns, &hasManyColumnsAppender{j}) + } + + baseTable := j.BaseModel.Table() + var where []byte + if len(j.Rel.JoinFKs) > 1 { + where = append(where, '(') + } + where = appendColumns(where, j.JoinModel.Table().Alias, j.Rel.JoinFKs) + if len(j.Rel.JoinFKs) > 1 { + where = append(where, ')') + } + where = append(where, " IN ("...) + where = appendChildValues( + where, j.JoinModel.Root(), j.JoinModel.ParentIndex(), j.Rel.BaseFKs) + where = append(where, ")"...) + q = q.Where(internal.BytesToString(where)) + + if j.Rel.Polymorphic != nil { + q = q.Where(`? IN (?, ?)`, + j.Rel.Polymorphic.Column, + baseTable.ModelName, baseTable.TypeName) + } + + return q, nil +} + +func (j *join) selectM2M(fmter QueryFormatter, q *Query) error { + q, err := j.m2mQuery(fmter, q) + if err != nil { + return err + } + if q == nil { + return nil + } + return q.Select() +} + +func (j *join) m2mQuery(fmter QueryFormatter, q *Query) (*Query, error) { + m2mModel := newM2MModel(j) + if m2mModel == nil { + return nil, nil + } + + q = q.Model(m2mModel) + if j.ApplyQuery != nil { + var err error + q, err = j.ApplyQuery(q) + if err != nil { + return nil, err + } + } + + if len(q.columns) == 0 { + q.columns = append(q.columns, &hasManyColumnsAppender{j}) + } + + index := j.JoinModel.ParentIndex() + baseTable := j.BaseModel.Table() + + //nolint + var join []byte + join = append(join, "JOIN "...) + join = fmter.FormatQuery(join, string(j.Rel.M2MTableName)) + join = append(join, " AS "...) + join = append(join, j.Rel.M2MTableAlias...) + join = append(join, " ON ("...) + for i, col := range j.Rel.M2MBaseFKs { + if i > 0 { + join = append(join, ", "...) + } + join = append(join, j.Rel.M2MTableAlias...) + join = append(join, '.') + join = types.AppendIdent(join, col, 1) + } + join = append(join, ") IN ("...) + join = appendChildValues(join, j.BaseModel.Root(), index, baseTable.PKs) + join = append(join, ")"...) + q = q.Join(internal.BytesToString(join)) + + joinTable := j.JoinModel.Table() + for i, col := range j.Rel.M2MJoinFKs { + pk := joinTable.PKs[i] + q = q.Where("?.? = ?.?", + joinTable.Alias, pk.Column, + j.Rel.M2MTableAlias, types.Ident(col)) + } + + return q, nil +} + +func (j *join) hasParent() bool { + if j.Parent != nil { + switch j.Parent.Rel.Type { + case HasOneRelation, BelongsToRelation: + return true + } + } + return false +} + +func (j *join) appendAlias(b []byte) []byte { + b = append(b, '"') + b = appendAlias(b, j) + b = append(b, '"') + return b +} + +func (j *join) appendAliasColumn(b []byte, column string) []byte { + b = append(b, '"') + b = appendAlias(b, j) + b = append(b, "__"...) + b = append(b, column...) + b = append(b, '"') + return b +} + +func (j *join) appendBaseAlias(b []byte) []byte { + if j.hasParent() { + b = append(b, '"') + b = appendAlias(b, j.Parent) + b = append(b, '"') + return b + } + return append(b, j.BaseModel.Table().Alias...) +} + +func (j *join) appendSoftDelete(b []byte, flags queryFlag) []byte { + b = append(b, '.') + b = append(b, j.JoinModel.Table().SoftDeleteField.Column...) + if hasFlag(flags, deletedFlag) { + b = append(b, " IS NOT NULL"...) + } else { + b = append(b, " IS NULL"...) + } + return b +} + +func appendAlias(b []byte, j *join) []byte { + if j.hasParent() { + b = appendAlias(b, j.Parent) + b = append(b, "__"...) + } + b = append(b, j.Rel.Field.SQLName...) + return b +} + +func (j *join) appendHasOneColumns(b []byte) []byte { + if j.Columns == nil { + for i, f := range j.JoinModel.Table().Fields { + if i > 0 { + b = append(b, ", "...) + } + b = j.appendAlias(b) + b = append(b, '.') + b = append(b, f.Column...) + b = append(b, " AS "...) + b = j.appendAliasColumn(b, f.SQLName) + } + return b + } + + for i, column := range j.Columns { + if i > 0 { + b = append(b, ", "...) + } + b = j.appendAlias(b) + b = append(b, '.') + b = types.AppendIdent(b, column, 1) + b = append(b, " AS "...) + b = j.appendAliasColumn(b, column) + } + + return b +} + +func (j *join) appendHasOneJoin(fmter QueryFormatter, b []byte, q *Query) (_ []byte, err error) { + isSoftDelete := j.JoinModel.Table().SoftDeleteField != nil && !q.hasFlag(allWithDeletedFlag) + + b = append(b, "LEFT JOIN "...) + b = fmter.FormatQuery(b, string(j.JoinModel.Table().SQLNameForSelects)) + b = append(b, " AS "...) + b = j.appendAlias(b) + + b = append(b, " ON "...) + + if isSoftDelete { + b = append(b, '(') + } + + if len(j.Rel.BaseFKs) > 1 { + b = append(b, '(') + } + for i, baseFK := range j.Rel.BaseFKs { + if i > 0 { + b = append(b, " AND "...) + } + b = j.appendAlias(b) + b = append(b, '.') + b = append(b, j.Rel.JoinFKs[i].Column...) + b = append(b, " = "...) + b = j.appendBaseAlias(b) + b = append(b, '.') + b = append(b, baseFK.Column...) + } + if len(j.Rel.BaseFKs) > 1 { + b = append(b, ')') + } + + for _, on := range j.on { + b = on.AppendSep(b) + b, err = on.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + + if isSoftDelete { + b = append(b, ')') + } + + if isSoftDelete { + b = append(b, " AND "...) + b = j.appendAlias(b) + b = j.appendSoftDelete(b, q.flags) + } + + return b, nil +} + +type hasManyColumnsAppender struct { + *join +} + +var _ QueryAppender = (*hasManyColumnsAppender)(nil) + +func (q *hasManyColumnsAppender) AppendQuery(fmter QueryFormatter, b []byte) ([]byte, error) { + if q.Rel.M2MTableAlias != "" { + b = append(b, q.Rel.M2MTableAlias...) + b = append(b, ".*, "...) + } + + joinTable := q.JoinModel.Table() + + if q.Columns != nil { + for i, column := range q.Columns { + if i > 0 { + b = append(b, ", "...) + } + b = append(b, joinTable.Alias...) + b = append(b, '.') + b = types.AppendIdent(b, column, 1) + } + return b, nil + } + + b = appendColumns(b, joinTable.Alias, joinTable.Fields) + return b, nil +} + +func appendChildValues(b []byte, v reflect.Value, index []int, fields []*Field) []byte { + seen := make(map[string]struct{}) + walk(v, index, func(v reflect.Value) { + start := len(b) + + if len(fields) > 1 { + b = append(b, '(') + } + for i, f := range fields { + if i > 0 { + b = append(b, ", "...) + } + b = f.AppendValue(b, v, 1) + } + if len(fields) > 1 { + b = append(b, ')') + } + b = append(b, ", "...) + + if _, ok := seen[string(b[start:])]; ok { + b = b[:start] + } else { + seen[string(b[start:])] = struct{}{} + } + }) + if len(seen) > 0 { + b = b[:len(b)-2] // trim ", " + } + return b +} diff --git a/vendor/github.com/go-pg/pg/v10/orm/model.go b/vendor/github.com/go-pg/pg/v10/orm/model.go new file mode 100644 index 000000000..333a90dd7 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/orm/model.go @@ -0,0 +1,150 @@ +package orm + +import ( + "database/sql" + "errors" + "fmt" + "reflect" + + "github.com/go-pg/pg/v10/types" +) + +var errModelNil = errors.New("pg: Model(nil)") + +type useQueryOne interface { + useQueryOne() bool +} + +type HooklessModel interface { + // Init is responsible to initialize/reset model state. + // It is called only once no matter how many rows were returned. + Init() error + + // NextColumnScanner returns a ColumnScanner that is used to scan columns + // from the current row. It is called once for every row. + NextColumnScanner() ColumnScanner + + // AddColumnScanner adds the ColumnScanner to the model. + AddColumnScanner(ColumnScanner) error +} + +type Model interface { + HooklessModel + + AfterScanHook + AfterSelectHook + + BeforeInsertHook + AfterInsertHook + + BeforeUpdateHook + AfterUpdateHook + + BeforeDeleteHook + AfterDeleteHook +} + +func NewModel(value interface{}) (Model, error) { + return newModel(value, false) +} + +func newScanModel(values []interface{}) (Model, error) { + if len(values) > 1 { + return Scan(values...), nil + } + return newModel(values[0], true) +} + +func newModel(value interface{}, scan bool) (Model, error) { + switch value := value.(type) { + case Model: + return value, nil + case HooklessModel: + return newModelWithHookStubs(value), nil + case types.ValueScanner, sql.Scanner: + if !scan { + return nil, fmt.Errorf("pg: Model(unsupported %T)", value) + } + return Scan(value), nil + } + + v := reflect.ValueOf(value) + if !v.IsValid() { + return nil, errModelNil + } + if v.Kind() != reflect.Ptr { + return nil, fmt.Errorf("pg: Model(non-pointer %T)", value) + } + + if v.IsNil() { + typ := v.Type().Elem() + if typ.Kind() == reflect.Struct { + return newStructTableModel(GetTable(typ)), nil + } + return nil, errModelNil + } + + v = v.Elem() + + if v.Kind() == reflect.Interface { + if !v.IsNil() { + v = v.Elem() + if v.Kind() != reflect.Ptr { + return nil, fmt.Errorf("pg: Model(non-pointer %s)", v.Type().String()) + } + } + } + + switch v.Kind() { + case reflect.Struct: + if v.Type() != timeType { + return newStructTableModelValue(v), nil + } + case reflect.Slice: + elemType := sliceElemType(v) + switch elemType.Kind() { + case reflect.Struct: + if elemType != timeType { + return newSliceTableModel(v, elemType), nil + } + case reflect.Map: + if err := validMap(elemType); err != nil { + return nil, err + } + slicePtr := v.Addr().Interface().(*[]map[string]interface{}) + return newMapSliceModel(slicePtr), nil + } + return newSliceModel(v, elemType), nil + case reflect.Map: + typ := v.Type() + if err := validMap(typ); err != nil { + return nil, err + } + mapPtr := v.Addr().Interface().(*map[string]interface{}) + return newMapModel(mapPtr), nil + } + + if !scan { + return nil, fmt.Errorf("pg: Model(unsupported %T)", value) + } + return Scan(value), nil +} + +type modelWithHookStubs struct { + hookStubs + HooklessModel +} + +func newModelWithHookStubs(m HooklessModel) Model { + return modelWithHookStubs{ + HooklessModel: m, + } +} + +func validMap(typ reflect.Type) error { + if typ.Key().Kind() != reflect.String || typ.Elem().Kind() != reflect.Interface { + return fmt.Errorf("pg: Model(unsupported %s, expected *map[string]interface{})", + typ.String()) + } + return nil +} diff --git a/vendor/github.com/go-pg/pg/v10/orm/model_discard.go b/vendor/github.com/go-pg/pg/v10/orm/model_discard.go new file mode 100644 index 000000000..92e5c566c --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/orm/model_discard.go @@ -0,0 +1,27 @@ +package orm + +import ( + "github.com/go-pg/pg/v10/types" +) + +type Discard struct { + hookStubs +} + +var _ Model = (*Discard)(nil) + +func (Discard) Init() error { + return nil +} + +func (m Discard) NextColumnScanner() ColumnScanner { + return m +} + +func (m Discard) AddColumnScanner(ColumnScanner) error { + return nil +} + +func (m Discard) ScanColumn(col types.ColumnInfo, rd types.Reader, n int) error { + return nil +} diff --git a/vendor/github.com/go-pg/pg/v10/orm/model_func.go b/vendor/github.com/go-pg/pg/v10/orm/model_func.go new file mode 100644 index 000000000..8427bdea2 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/orm/model_func.go @@ -0,0 +1,89 @@ +package orm + +import ( + "fmt" + "reflect" +) + +var errorType = reflect.TypeOf((*error)(nil)).Elem() + +type funcModel struct { + Model + fnv reflect.Value + fnIn []reflect.Value +} + +var _ Model = (*funcModel)(nil) + +func newFuncModel(fn interface{}) *funcModel { + m := &funcModel{ + fnv: reflect.ValueOf(fn), + } + + fnt := m.fnv.Type() + if fnt.Kind() != reflect.Func { + panic(fmt.Errorf("ForEach expects a %s, got a %s", + reflect.Func, fnt.Kind())) + } + + if fnt.NumIn() < 1 { + panic(fmt.Errorf("ForEach expects at least 1 arg, got %d", fnt.NumIn())) + } + + if fnt.NumOut() != 1 { + panic(fmt.Errorf("ForEach must return 1 error value, got %d", fnt.NumOut())) + } + if fnt.Out(0) != errorType { + panic(fmt.Errorf("ForEach must return an error, got %T", fnt.Out(0))) + } + + if fnt.NumIn() > 1 { + initFuncModelScan(m, fnt) + return m + } + + t0 := fnt.In(0) + var v0 reflect.Value + if t0.Kind() == reflect.Ptr { + t0 = t0.Elem() + v0 = reflect.New(t0) + } else { + v0 = reflect.New(t0).Elem() + } + + m.fnIn = []reflect.Value{v0} + + model, ok := v0.Interface().(Model) + if ok { + m.Model = model + return m + } + + if v0.Kind() == reflect.Ptr { + v0 = v0.Elem() + } + if v0.Kind() != reflect.Struct { + panic(fmt.Errorf("ForEach accepts a %s, got %s", + reflect.Struct, v0.Kind())) + } + m.Model = newStructTableModelValue(v0) + + return m +} + +func initFuncModelScan(m *funcModel, fnt reflect.Type) { + m.fnIn = make([]reflect.Value, fnt.NumIn()) + for i := 0; i < fnt.NumIn(); i++ { + m.fnIn[i] = reflect.New(fnt.In(i)).Elem() + } + m.Model = scanReflectValues(m.fnIn) +} + +func (m *funcModel) AddColumnScanner(_ ColumnScanner) error { + out := m.fnv.Call(m.fnIn) + errv := out[0] + if !errv.IsNil() { + return errv.Interface().(error) + } + return nil +} diff --git a/vendor/github.com/go-pg/pg/v10/orm/model_map.go b/vendor/github.com/go-pg/pg/v10/orm/model_map.go new file mode 100644 index 000000000..24533d43c --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/orm/model_map.go @@ -0,0 +1,53 @@ +package orm + +import ( + "github.com/go-pg/pg/v10/types" +) + +type mapModel struct { + hookStubs + ptr *map[string]interface{} + m map[string]interface{} +} + +var _ Model = (*mapModel)(nil) + +func newMapModel(ptr *map[string]interface{}) *mapModel { + model := &mapModel{ + ptr: ptr, + } + if ptr != nil { + model.m = *ptr + } + return model +} + +func (m *mapModel) Init() error { + return nil +} + +func (m *mapModel) NextColumnScanner() ColumnScanner { + if m.m == nil { + m.m = make(map[string]interface{}) + *m.ptr = m.m + } + return m +} + +func (m mapModel) AddColumnScanner(ColumnScanner) error { + return nil +} + +func (m *mapModel) ScanColumn(col types.ColumnInfo, rd types.Reader, n int) error { + val, err := types.ReadColumnValue(col, rd, n) + if err != nil { + return err + } + + m.m[col.Name] = val + return nil +} + +func (mapModel) useQueryOne() bool { + return true +} diff --git a/vendor/github.com/go-pg/pg/v10/orm/model_map_slice.go b/vendor/github.com/go-pg/pg/v10/orm/model_map_slice.go new file mode 100644 index 000000000..ea14c9b6b --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/orm/model_map_slice.go @@ -0,0 +1,45 @@ +package orm + +type mapSliceModel struct { + mapModel + slice *[]map[string]interface{} +} + +var _ Model = (*mapSliceModel)(nil) + +func newMapSliceModel(ptr *[]map[string]interface{}) *mapSliceModel { + return &mapSliceModel{ + slice: ptr, + } +} + +func (m *mapSliceModel) Init() error { + slice := *m.slice + if len(slice) > 0 { + *m.slice = slice[:0] + } + return nil +} + +func (m *mapSliceModel) NextColumnScanner() ColumnScanner { + slice := *m.slice + if len(slice) == cap(slice) { + m.mapModel.m = make(map[string]interface{}) + *m.slice = append(slice, m.mapModel.m) //nolint:gocritic + return m + } + + slice = slice[:len(slice)+1] + el := slice[len(slice)-1] + if el != nil { + m.mapModel.m = el + } else { + el = make(map[string]interface{}) + slice[len(slice)-1] = el + m.mapModel.m = el + } + *m.slice = slice + return m +} + +func (mapSliceModel) useQueryOne() {} //nolint:unused diff --git a/vendor/github.com/go-pg/pg/v10/orm/model_scan.go b/vendor/github.com/go-pg/pg/v10/orm/model_scan.go new file mode 100644 index 000000000..08f66beba --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/orm/model_scan.go @@ -0,0 +1,69 @@ +package orm + +import ( + "fmt" + "reflect" + + "github.com/go-pg/pg/v10/types" +) + +type scanValuesModel struct { + Discard + values []interface{} +} + +var _ Model = scanValuesModel{} + +//nolint +func Scan(values ...interface{}) scanValuesModel { + return scanValuesModel{ + values: values, + } +} + +func (scanValuesModel) useQueryOne() bool { + return true +} + +func (m scanValuesModel) NextColumnScanner() ColumnScanner { + return m +} + +func (m scanValuesModel) ScanColumn(col types.ColumnInfo, rd types.Reader, n int) error { + if int(col.Index) >= len(m.values) { + return fmt.Errorf("pg: no Scan var for column index=%d name=%q", + col.Index, col.Name) + } + return types.Scan(m.values[col.Index], rd, n) +} + +//------------------------------------------------------------------------------ + +type scanReflectValuesModel struct { + Discard + values []reflect.Value +} + +var _ Model = scanReflectValuesModel{} + +func scanReflectValues(values []reflect.Value) scanReflectValuesModel { + return scanReflectValuesModel{ + values: values, + } +} + +func (scanReflectValuesModel) useQueryOne() bool { + return true +} + +func (m scanReflectValuesModel) NextColumnScanner() ColumnScanner { + return m +} + +func (m scanReflectValuesModel) ScanColumn(col types.ColumnInfo, rd types.Reader, n int) error { + if int(col.Index) >= len(m.values) { + return fmt.Errorf("pg: no Scan var for column index=%d name=%q", + col.Index, col.Name) + } + return types.ScanValue(m.values[col.Index], rd, n) +} diff --git a/vendor/github.com/go-pg/pg/v10/orm/model_slice.go b/vendor/github.com/go-pg/pg/v10/orm/model_slice.go new file mode 100644 index 000000000..1e163629e --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/orm/model_slice.go @@ -0,0 +1,43 @@ +package orm + +import ( + "reflect" + + "github.com/go-pg/pg/v10/internal" + "github.com/go-pg/pg/v10/types" +) + +type sliceModel struct { + Discard + slice reflect.Value + nextElem func() reflect.Value + scan func(reflect.Value, types.Reader, int) error +} + +var _ Model = (*sliceModel)(nil) + +func newSliceModel(slice reflect.Value, elemType reflect.Type) *sliceModel { + return &sliceModel{ + slice: slice, + scan: types.Scanner(elemType), + } +} + +func (m *sliceModel) Init() error { + if m.slice.IsValid() && m.slice.Len() > 0 { + m.slice.Set(m.slice.Slice(0, 0)) + } + return nil +} + +func (m *sliceModel) NextColumnScanner() ColumnScanner { + return m +} + +func (m *sliceModel) ScanColumn(col types.ColumnInfo, rd types.Reader, n int) error { + if m.nextElem == nil { + m.nextElem = internal.MakeSliceNextElemFunc(m.slice) + } + v := m.nextElem() + return m.scan(v, rd, n) +} diff --git a/vendor/github.com/go-pg/pg/v10/orm/model_table.go b/vendor/github.com/go-pg/pg/v10/orm/model_table.go new file mode 100644 index 000000000..afdc15ccc --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/orm/model_table.go @@ -0,0 +1,65 @@ +package orm + +import ( + "fmt" + "reflect" + + "github.com/go-pg/pg/v10/types" +) + +type TableModel interface { + Model + + IsNil() bool + Table() *Table + Relation() *Relation + AppendParam(QueryFormatter, []byte, string) ([]byte, bool) + + Join(string, func(*Query) (*Query, error)) *join + GetJoin(string) *join + GetJoins() []join + AddJoin(join) *join + + Root() reflect.Value + Index() []int + ParentIndex() []int + Mount(reflect.Value) + Kind() reflect.Kind + Value() reflect.Value + + setSoftDeleteField() error + scanColumn(types.ColumnInfo, types.Reader, int) (bool, error) +} + +func newTableModelIndex(typ reflect.Type, root reflect.Value, index []int, rel *Relation) (TableModel, error) { + typ = typeByIndex(typ, index) + + if typ.Kind() == reflect.Struct { + return &structTableModel{ + table: GetTable(typ), + rel: rel, + + root: root, + index: index, + }, nil + } + + if typ.Kind() == reflect.Slice { + structType := indirectType(typ.Elem()) + if structType.Kind() == reflect.Struct { + m := sliceTableModel{ + structTableModel: structTableModel{ + table: GetTable(structType), + rel: rel, + + root: root, + index: index, + }, + } + m.init(typ) + return &m, nil + } + } + + return nil, fmt.Errorf("pg: NewModel(%s)", typ) +} diff --git a/vendor/github.com/go-pg/pg/v10/orm/model_table_m2m.go b/vendor/github.com/go-pg/pg/v10/orm/model_table_m2m.go new file mode 100644 index 000000000..83ac73bde --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/orm/model_table_m2m.go @@ -0,0 +1,111 @@ +package orm + +import ( + "fmt" + "reflect" + + "github.com/go-pg/pg/v10/internal/pool" + "github.com/go-pg/pg/v10/types" +) + +type m2mModel struct { + *sliceTableModel + baseTable *Table + rel *Relation + + buf []byte + dstValues map[string][]reflect.Value + columns map[string]string +} + +var _ TableModel = (*m2mModel)(nil) + +func newM2MModel(j *join) *m2mModel { + baseTable := j.BaseModel.Table() + joinModel := j.JoinModel.(*sliceTableModel) + dstValues := dstValues(joinModel, baseTable.PKs) + if len(dstValues) == 0 { + return nil + } + m := &m2mModel{ + sliceTableModel: joinModel, + baseTable: baseTable, + rel: j.Rel, + + dstValues: dstValues, + columns: make(map[string]string), + } + if !m.sliceOfPtr { + m.strct = reflect.New(m.table.Type).Elem() + } + return m +} + +func (m *m2mModel) NextColumnScanner() ColumnScanner { + if m.sliceOfPtr { + m.strct = reflect.New(m.table.Type).Elem() + } else { + m.strct.Set(m.table.zeroStruct) + } + m.structInited = false + return m +} + +func (m *m2mModel) AddColumnScanner(_ ColumnScanner) error { + buf, err := m.modelIDMap(m.buf[:0]) + if err != nil { + return err + } + m.buf = buf + + dstValues, ok := m.dstValues[string(buf)] + if !ok { + return fmt.Errorf( + "pg: relation=%q does not have base %s with id=%q (check join conditions)", + m.rel.Field.GoName, m.baseTable, buf) + } + + for _, v := range dstValues { + if m.sliceOfPtr { + v.Set(reflect.Append(v, m.strct.Addr())) + } else { + v.Set(reflect.Append(v, m.strct)) + } + } + + return nil +} + +func (m *m2mModel) modelIDMap(b []byte) ([]byte, error) { + for i, col := range m.rel.M2MBaseFKs { + if i > 0 { + b = append(b, ',') + } + if s, ok := m.columns[col]; ok { + b = append(b, s...) + } else { + return nil, fmt.Errorf("pg: %s does not have column=%q", + m.sliceTableModel, col) + } + } + return b, nil +} + +func (m *m2mModel) ScanColumn(col types.ColumnInfo, rd types.Reader, n int) error { + if n > 0 { + b, err := rd.ReadFullTemp() + if err != nil { + return err + } + + m.columns[col.Name] = string(b) + rd = pool.NewBytesReader(b) + } else { + m.columns[col.Name] = "" + } + + if ok, err := m.sliceTableModel.scanColumn(col, rd, n); ok { + return err + } + return nil +} diff --git a/vendor/github.com/go-pg/pg/v10/orm/model_table_many.go b/vendor/github.com/go-pg/pg/v10/orm/model_table_many.go new file mode 100644 index 000000000..561384bba --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/orm/model_table_many.go @@ -0,0 +1,75 @@ +package orm + +import ( + "fmt" + "reflect" +) + +type manyModel struct { + *sliceTableModel + baseTable *Table + rel *Relation + + buf []byte + dstValues map[string][]reflect.Value +} + +var _ TableModel = (*manyModel)(nil) + +func newManyModel(j *join) *manyModel { + baseTable := j.BaseModel.Table() + joinModel := j.JoinModel.(*sliceTableModel) + dstValues := dstValues(joinModel, j.Rel.BaseFKs) + if len(dstValues) == 0 { + return nil + } + m := manyModel{ + sliceTableModel: joinModel, + baseTable: baseTable, + rel: j.Rel, + + dstValues: dstValues, + } + if !m.sliceOfPtr { + m.strct = reflect.New(m.table.Type).Elem() + } + return &m +} + +func (m *manyModel) NextColumnScanner() ColumnScanner { + if m.sliceOfPtr { + m.strct = reflect.New(m.table.Type).Elem() + } else { + m.strct.Set(m.table.zeroStruct) + } + m.structInited = false + return m +} + +func (m *manyModel) AddColumnScanner(model ColumnScanner) error { + m.buf = modelID(m.buf[:0], m.strct, m.rel.JoinFKs) + dstValues, ok := m.dstValues[string(m.buf)] + if !ok { + return fmt.Errorf( + "pg: relation=%q does not have base %s with id=%q (check join conditions)", + m.rel.Field.GoName, m.baseTable, m.buf) + } + + for i, v := range dstValues { + if !m.sliceOfPtr { + v.Set(reflect.Append(v, m.strct)) + continue + } + + if i == 0 { + v.Set(reflect.Append(v, m.strct.Addr())) + continue + } + + clone := reflect.New(m.strct.Type()).Elem() + clone.Set(m.strct) + v.Set(reflect.Append(v, clone.Addr())) + } + + return nil +} diff --git a/vendor/github.com/go-pg/pg/v10/orm/model_table_slice.go b/vendor/github.com/go-pg/pg/v10/orm/model_table_slice.go new file mode 100644 index 000000000..c50be8252 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/orm/model_table_slice.go @@ -0,0 +1,156 @@ +package orm + +import ( + "context" + "reflect" + + "github.com/go-pg/pg/v10/internal" +) + +type sliceTableModel struct { + structTableModel + + slice reflect.Value + sliceLen int + sliceOfPtr bool + nextElem func() reflect.Value +} + +var _ TableModel = (*sliceTableModel)(nil) + +func newSliceTableModel(slice reflect.Value, elemType reflect.Type) *sliceTableModel { + m := &sliceTableModel{ + structTableModel: structTableModel{ + table: GetTable(elemType), + root: slice, + }, + slice: slice, + sliceLen: slice.Len(), + nextElem: internal.MakeSliceNextElemFunc(slice), + } + m.init(slice.Type()) + return m +} + +func (m *sliceTableModel) init(sliceType reflect.Type) { + switch sliceType.Elem().Kind() { + case reflect.Ptr, reflect.Interface: + m.sliceOfPtr = true + } +} + +//nolint +func (*sliceTableModel) useQueryOne() {} + +func (m *sliceTableModel) IsNil() bool { + return false +} + +func (m *sliceTableModel) AppendParam(fmter QueryFormatter, b []byte, name string) ([]byte, bool) { + if field, ok := m.table.FieldsMap[name]; ok { + b = append(b, "_data."...) + b = append(b, field.Column...) + return b, true + } + return m.structTableModel.AppendParam(fmter, b, name) +} + +func (m *sliceTableModel) Join(name string, apply func(*Query) (*Query, error)) *join { + return m.join(m.Value(), name, apply) +} + +func (m *sliceTableModel) Bind(bind reflect.Value) { + m.slice = bind.Field(m.index[len(m.index)-1]) +} + +func (m *sliceTableModel) Kind() reflect.Kind { + return reflect.Slice +} + +func (m *sliceTableModel) Value() reflect.Value { + return m.slice +} + +func (m *sliceTableModel) Init() error { + if m.slice.IsValid() && m.slice.Len() > 0 { + m.slice.Set(m.slice.Slice(0, 0)) + } + return nil +} + +func (m *sliceTableModel) NextColumnScanner() ColumnScanner { + m.strct = m.nextElem() + m.structInited = false + return m +} + +func (m *sliceTableModel) AddColumnScanner(_ ColumnScanner) error { + return nil +} + +// Inherit these hooks from structTableModel. +var ( + _ BeforeScanHook = (*sliceTableModel)(nil) + _ AfterScanHook = (*sliceTableModel)(nil) +) + +func (m *sliceTableModel) AfterSelect(ctx context.Context) error { + if m.table.hasFlag(afterSelectHookFlag) { + return callAfterSelectHookSlice(ctx, m.slice, m.sliceOfPtr) + } + return nil +} + +func (m *sliceTableModel) BeforeInsert(ctx context.Context) (context.Context, error) { + if m.table.hasFlag(beforeInsertHookFlag) { + return callBeforeInsertHookSlice(ctx, m.slice, m.sliceOfPtr) + } + return ctx, nil +} + +func (m *sliceTableModel) AfterInsert(ctx context.Context) error { + if m.table.hasFlag(afterInsertHookFlag) { + return callAfterInsertHookSlice(ctx, m.slice, m.sliceOfPtr) + } + return nil +} + +func (m *sliceTableModel) BeforeUpdate(ctx context.Context) (context.Context, error) { + if m.table.hasFlag(beforeUpdateHookFlag) && !m.IsNil() { + return callBeforeUpdateHookSlice(ctx, m.slice, m.sliceOfPtr) + } + return ctx, nil +} + +func (m *sliceTableModel) AfterUpdate(ctx context.Context) error { + if m.table.hasFlag(afterUpdateHookFlag) { + return callAfterUpdateHookSlice(ctx, m.slice, m.sliceOfPtr) + } + return nil +} + +func (m *sliceTableModel) BeforeDelete(ctx context.Context) (context.Context, error) { + if m.table.hasFlag(beforeDeleteHookFlag) && !m.IsNil() { + return callBeforeDeleteHookSlice(ctx, m.slice, m.sliceOfPtr) + } + return ctx, nil +} + +func (m *sliceTableModel) AfterDelete(ctx context.Context) error { + if m.table.hasFlag(afterDeleteHookFlag) && !m.IsNil() { + return callAfterDeleteHookSlice(ctx, m.slice, m.sliceOfPtr) + } + return nil +} + +func (m *sliceTableModel) setSoftDeleteField() error { + sliceLen := m.slice.Len() + for i := 0; i < sliceLen; i++ { + strct := indirect(m.slice.Index(i)) + fv := m.table.SoftDeleteField.Value(strct) + if err := m.table.SetSoftDeleteField(fv); err != nil { + return err + } + } + return nil +} diff --git a/vendor/github.com/go-pg/pg/v10/orm/model_table_struct.go b/vendor/github.com/go-pg/pg/v10/orm/model_table_struct.go new file mode 100644 index 000000000..fce7cc6b7 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/orm/model_table_struct.go @@ -0,0 +1,399 @@ +package orm + +import ( + "context" + "fmt" + "reflect" + "strings" + + "github.com/go-pg/pg/v10/types" +) + +type structTableModel struct { + table *Table + rel *Relation + joins []join + + root reflect.Value + index []int + + strct reflect.Value + structInited bool + structInitErr error +} + +var _ TableModel = (*structTableModel)(nil) + +func newStructTableModel(table *Table) *structTableModel { + return &structTableModel{ + table: table, + } +} + +func newStructTableModelValue(v reflect.Value) *structTableModel { + return &structTableModel{ + table: GetTable(v.Type()), + root: v, + strct: v, + } +} + +func (*structTableModel) useQueryOne() bool { + return true +} + +func (m *structTableModel) String() string { + return m.table.String() +} + +func (m *structTableModel) IsNil() bool { + return !m.strct.IsValid() +} + +func (m *structTableModel) Table() *Table { + return m.table +} + +func (m *structTableModel) Relation() *Relation { + return m.rel +} + +func (m *structTableModel) AppendParam(fmter QueryFormatter, b []byte, name string) ([]byte, bool) { + b, ok := m.table.AppendParam(b, m.strct, name) + if ok { + return b, true + } + + switch name { + case "TableName": + b = fmter.FormatQuery(b, string(m.table.SQLName)) + return b, true + case "TableAlias": + b = append(b, m.table.Alias...) + return b, true + case "TableColumns": + b = appendColumns(b, m.table.Alias, m.table.Fields) + return b, true + case "Columns": + b = appendColumns(b, "", m.table.Fields) + return b, true + case "TablePKs": + b = appendColumns(b, m.table.Alias, m.table.PKs) + return b, true + case "PKs": + b = appendColumns(b, "", m.table.PKs) + return b, true + } + + return b, false +} + +func (m *structTableModel) Root() reflect.Value { + return m.root +} + +func (m *structTableModel) Index() []int { + return m.index +} + +func (m *structTableModel) ParentIndex() []int { + return m.index[:len(m.index)-len(m.rel.Field.Index)] +} + +func (m *structTableModel) Kind() reflect.Kind { + return reflect.Struct +} + +func (m *structTableModel) Value() reflect.Value { + return m.strct +} + +func (m *structTableModel) Mount(host reflect.Value) { + m.strct = host.FieldByIndex(m.rel.Field.Index) + m.structInited = false +} + +func (m *structTableModel) initStruct() error { + if m.structInited { + return m.structInitErr + } + m.structInited = true + + switch m.strct.Kind() { + case reflect.Invalid: + m.structInitErr = errModelNil + return m.structInitErr + case reflect.Interface: + m.strct = m.strct.Elem() + } + + if m.strct.Kind() == reflect.Ptr { + if m.strct.IsNil() { + m.strct.Set(reflect.New(m.strct.Type().Elem())) + m.strct = m.strct.Elem() + } else { + m.strct = m.strct.Elem() + } + } + + m.mountJoins() + + return nil +} + +func (m *structTableModel) mountJoins() { + for i := range m.joins { + j := &m.joins[i] + switch j.Rel.Type { + case HasOneRelation, BelongsToRelation: + j.JoinModel.Mount(m.strct) + } + } +} + +func (structTableModel) Init() error { + return nil +} + +func (m *structTableModel) NextColumnScanner() ColumnScanner { + return m +} + +func (m *structTableModel) AddColumnScanner(_ ColumnScanner) error { + return nil +} + +var _ BeforeScanHook = (*structTableModel)(nil) + +func (m *structTableModel) BeforeScan(ctx context.Context) error { + if !m.table.hasFlag(beforeScanHookFlag) { + return nil + } + return callBeforeScanHook(ctx, m.strct.Addr()) +} + +var _ AfterScanHook = (*structTableModel)(nil) + +func (m *structTableModel) AfterScan(ctx context.Context) error { + if !m.table.hasFlag(afterScanHookFlag) || !m.structInited { + return nil + } + + var firstErr error + + if err := callAfterScanHook(ctx, m.strct.Addr()); err != nil && firstErr == nil { + firstErr = err + } + + for _, j := range m.joins { + switch j.Rel.Type { + case HasOneRelation, BelongsToRelation: + if err := j.JoinModel.AfterScan(ctx); err != nil && firstErr == nil { + firstErr = err + } + } + } + + return firstErr +} + +func (m *structTableModel) AfterSelect(ctx context.Context) error { + if m.table.hasFlag(afterSelectHookFlag) { + return callAfterSelectHook(ctx, m.strct.Addr()) + } + return nil +} + +func (m *structTableModel) BeforeInsert(ctx context.Context) (context.Context, error) { + if m.table.hasFlag(beforeInsertHookFlag) { + return callBeforeInsertHook(ctx, m.strct.Addr()) + } + return ctx, nil +} + +func (m *structTableModel) AfterInsert(ctx context.Context) error { + if m.table.hasFlag(afterInsertHookFlag) { + return callAfterInsertHook(ctx, m.strct.Addr()) + } + return nil +} + +func (m *structTableModel) BeforeUpdate(ctx context.Context) (context.Context, error) { + if m.table.hasFlag(beforeUpdateHookFlag) && !m.IsNil() { + return callBeforeUpdateHook(ctx, m.strct.Addr()) + } + return ctx, nil +} + +func (m *structTableModel) AfterUpdate(ctx context.Context) error { + if m.table.hasFlag(afterUpdateHookFlag) && !m.IsNil() { + return callAfterUpdateHook(ctx, m.strct.Addr()) + } + return nil +} + +func (m *structTableModel) BeforeDelete(ctx context.Context) (context.Context, error) { + if m.table.hasFlag(beforeDeleteHookFlag) && !m.IsNil() { + return callBeforeDeleteHook(ctx, m.strct.Addr()) + } + return ctx, nil +} + +func (m *structTableModel) AfterDelete(ctx context.Context) error { + if m.table.hasFlag(afterDeleteHookFlag) && !m.IsNil() { + return callAfterDeleteHook(ctx, m.strct.Addr()) + } + return nil +} + +func (m *structTableModel) ScanColumn( + col types.ColumnInfo, rd types.Reader, n int, +) error { + ok, err := m.scanColumn(col, rd, n) + if ok { + return err + } + if m.table.hasFlag(discardUnknownColumnsFlag) || col.Name[0] == '_' { + return nil + } + return fmt.Errorf( + "pg: can't find column=%s in %s "+ + "(prefix the column with underscore or use discard_unknown_columns)", + col.Name, m.table, + ) +} + +func (m *structTableModel) scanColumn(col types.ColumnInfo, rd types.Reader, n int) (bool, error) { + // Don't init nil struct if value is NULL. + if n == -1 && + !m.structInited && + m.strct.Kind() == reflect.Ptr && + m.strct.IsNil() { + return true, nil + } + + if err := m.initStruct(); err != nil { + return true, err + } + + joinName, fieldName := splitColumn(col.Name) + if joinName != "" { + if join := m.GetJoin(joinName); join != nil { + joinCol := col + joinCol.Name = fieldName + return join.JoinModel.scanColumn(joinCol, rd, n) + } + if m.table.ModelName == joinName { + joinCol := col + joinCol.Name = fieldName + return m.scanColumn(joinCol, rd, n) + } + } + + field, ok := m.table.FieldsMap[col.Name] + if !ok { + return false, nil + } + + return true, field.ScanValue(m.strct, rd, n) +} + +func (m *structTableModel) GetJoin(name string) *join { + for i := range m.joins { + j := &m.joins[i] + if j.Rel.Field.GoName == name || j.Rel.Field.SQLName == name { + return j + } + } + return nil +} + +func (m *structTableModel) GetJoins() []join { + return m.joins +} + +func (m *structTableModel) AddJoin(j join) *join { + m.joins = append(m.joins, j) + return &m.joins[len(m.joins)-1] +} + +func (m *structTableModel) Join(name string, apply func(*Query) (*Query, error)) *join { + return m.join(m.Value(), name, apply) +} + +func (m *structTableModel) join( + bind reflect.Value, name string, apply func(*Query) (*Query, error), +) *join { + path := strings.Split(name, ".") + index := make([]int, 0, len(path)) + + currJoin := join{ + BaseModel: m, + JoinModel: m, + } + var lastJoin *join + var hasColumnName bool + + for _, name := range path { + rel, ok := currJoin.JoinModel.Table().Relations[name] + if !ok { + hasColumnName = true + break + } + + currJoin.Rel = rel + index = append(index, rel.Field.Index...) + + if j := currJoin.JoinModel.GetJoin(name); j != nil { + currJoin.BaseModel = j.BaseModel + currJoin.JoinModel = j.JoinModel + + lastJoin = j + } else { + model, err := newTableModelIndex(m.table.Type, bind, index, rel) + if err != nil { + return nil + } + + currJoin.Parent = lastJoin + currJoin.BaseModel = currJoin.JoinModel + currJoin.JoinModel = model + + lastJoin = currJoin.BaseModel.AddJoin(currJoin) + } + } + + // No joins with such name. + if lastJoin == nil { + return nil + } + if apply != nil { + lastJoin.ApplyQuery = apply + } + + if hasColumnName { + column := path[len(path)-1] + if column == "_" { + if lastJoin.Columns == nil { + lastJoin.Columns = make([]string, 0) + } + } else { + lastJoin.Columns = append(lastJoin.Columns, column) + } + } + + return lastJoin +} + +func (m *structTableModel) setSoftDeleteField() error { + fv := m.table.SoftDeleteField.Value(m.strct) + return m.table.SetSoftDeleteField(fv) +} + +func splitColumn(s string) (string, string) { + ind := strings.Index(s, "__") + if ind == -1 { + return "", s + } + return s[:ind], s[ind+2:] +} diff --git a/vendor/github.com/go-pg/pg/v10/orm/msgpack.go b/vendor/github.com/go-pg/pg/v10/orm/msgpack.go new file mode 100644 index 000000000..56c88a23e --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/orm/msgpack.go @@ -0,0 +1,52 @@ +package orm + +import ( + "reflect" + + "github.com/vmihailenco/msgpack/v5" + + "github.com/go-pg/pg/v10/types" +) + +func msgpackAppender(_ reflect.Type) types.AppenderFunc { + return func(b []byte, v reflect.Value, flags int) []byte { + hexEnc := types.NewHexEncoder(b, flags) + + enc := msgpack.GetEncoder() + defer msgpack.PutEncoder(enc) + + enc.Reset(hexEnc) + if err := enc.EncodeValue(v); err != nil { + return types.AppendError(b, err) + } + + if err := hexEnc.Close(); err != nil { + return types.AppendError(b, err) + } + + return hexEnc.Bytes() + } +} + +func msgpackScanner(_ reflect.Type) types.ScannerFunc { + return func(v reflect.Value, rd types.Reader, n int) error { + if n <= 0 { + return nil + } + + hexDec, err := types.NewHexDecoder(rd, n) + if err != nil { + return err + } + + dec := msgpack.GetDecoder() + defer msgpack.PutDecoder(dec) + + dec.Reset(hexDec) + if err := dec.DecodeValue(v); err != nil { + return err + } + + return nil + } +} diff --git a/vendor/github.com/go-pg/pg/v10/orm/orm.go b/vendor/github.com/go-pg/pg/v10/orm/orm.go new file mode 100644 index 000000000..d18993d2d --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/orm/orm.go @@ -0,0 +1,58 @@ +/* +The API in this package is not stable and may change without any notice. +*/ +package orm + +import ( + "context" + "io" + + "github.com/go-pg/pg/v10/types" +) + +// ColumnScanner is used to scan column values. +type ColumnScanner interface { + // Scan assigns a column value from a row. + // + // An error should be returned if the value can not be stored + // without loss of information. + ScanColumn(col types.ColumnInfo, rd types.Reader, n int) error +} + +type QueryAppender interface { + AppendQuery(fmter QueryFormatter, b []byte) ([]byte, error) +} + +type TemplateAppender interface { + AppendTemplate(b []byte) ([]byte, error) +} + +type QueryCommand interface { + QueryAppender + TemplateAppender + String() string + Operation() QueryOp + Clone() QueryCommand + Query() *Query +} + +// DB is a common interface for pg.DB and pg.Tx types. +type DB interface { + Model(model ...interface{}) *Query + ModelContext(c context.Context, model ...interface{}) *Query + + Exec(query interface{}, params ...interface{}) (Result, error) + ExecContext(c context.Context, query interface{}, params ...interface{}) (Result, error) + ExecOne(query interface{}, params ...interface{}) (Result, error) + ExecOneContext(c context.Context, query interface{}, params ...interface{}) (Result, error) + Query(model, query interface{}, params ...interface{}) (Result, error) + QueryContext(c context.Context, model, query interface{}, params ...interface{}) (Result, error) + QueryOne(model, query interface{}, params ...interface{}) (Result, error) + QueryOneContext(c context.Context, model, query interface{}, params ...interface{}) (Result, error) + + CopyFrom(r io.Reader, query interface{}, params ...interface{}) (Result, error) + CopyTo(w io.Writer, query interface{}, params ...interface{}) (Result, error) + + Context() context.Context + Formatter() QueryFormatter +} diff --git a/vendor/github.com/go-pg/pg/v10/orm/query.go b/vendor/github.com/go-pg/pg/v10/orm/query.go new file mode 100644 index 000000000..8a9231f65 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/orm/query.go @@ -0,0 +1,1680 @@ +package orm + +import ( + "context" + "errors" + "fmt" + "io" + "reflect" + "strconv" + "strings" + "sync" + "time" + + "github.com/go-pg/pg/v10/internal" + "github.com/go-pg/pg/v10/types" +) + +type QueryOp string + +const ( + SelectOp QueryOp = "SELECT" + InsertOp QueryOp = "INSERT" + UpdateOp QueryOp = "UPDATE" + DeleteOp QueryOp = "DELETE" + CreateTableOp QueryOp = "CREATE TABLE" + DropTableOp QueryOp = "DROP TABLE" + CreateCompositeOp QueryOp = "CREATE COMPOSITE" + DropCompositeOp QueryOp = "DROP COMPOSITE" +) + +type queryFlag uint8 + +const ( + implicitModelFlag queryFlag = 1 << iota + deletedFlag + allWithDeletedFlag +) + +type withQuery struct { + name string + query QueryAppender +} + +type columnValue struct { + column string + value *SafeQueryAppender +} + +type union struct { + expr string + query *Query +} + +type Query struct { + ctx context.Context + db DB + stickyErr error + + model Model + tableModel TableModel + flags queryFlag + + with []withQuery + tables []QueryAppender + distinctOn []*SafeQueryAppender + columns []QueryAppender + set []QueryAppender + modelValues map[string]*SafeQueryAppender + extraValues []*columnValue + where []queryWithSepAppender + updWhere []queryWithSepAppender + group []QueryAppender + having []*SafeQueryAppender + union []*union + joins []QueryAppender + joinAppendOn func(app *condAppender) + order []QueryAppender + limit int + offset int + selFor *SafeQueryAppender + + onConflict *SafeQueryAppender + returning []*SafeQueryAppender +} + +func NewQuery(db DB, model ...interface{}) *Query { + ctx := context.Background() + if db != nil { + ctx = db.Context() + } + q := &Query{ctx: ctx} + return q.DB(db).Model(model...) +} + +func NewQueryContext(ctx context.Context, db DB, model ...interface{}) *Query { + return NewQuery(db, model...).Context(ctx) +} + +// New returns new zero Query bound to the current db. +func (q *Query) New() *Query { + clone := &Query{ + ctx: q.ctx, + db: q.db, + + model: q.model, + tableModel: cloneTableModelJoins(q.tableModel), + flags: q.flags, + } + return clone.withFlag(implicitModelFlag) +} + +// Clone clones the Query. +func (q *Query) Clone() *Query { + var modelValues map[string]*SafeQueryAppender + if len(q.modelValues) > 0 { + modelValues = make(map[string]*SafeQueryAppender, len(q.modelValues)) + for k, v := range q.modelValues { + modelValues[k] = v + } + } + + clone := &Query{ + ctx: q.ctx, + db: q.db, + stickyErr: q.stickyErr, + + model: q.model, + tableModel: cloneTableModelJoins(q.tableModel), + flags: q.flags, + + with: q.with[:len(q.with):len(q.with)], + tables: q.tables[:len(q.tables):len(q.tables)], + distinctOn: q.distinctOn[:len(q.distinctOn):len(q.distinctOn)], + columns: q.columns[:len(q.columns):len(q.columns)], + set: q.set[:len(q.set):len(q.set)], + modelValues: modelValues, + extraValues: q.extraValues[:len(q.extraValues):len(q.extraValues)], + where: q.where[:len(q.where):len(q.where)], + updWhere: q.updWhere[:len(q.updWhere):len(q.updWhere)], + joins: q.joins[:len(q.joins):len(q.joins)], + group: q.group[:len(q.group):len(q.group)], + having: q.having[:len(q.having):len(q.having)], + union: q.union[:len(q.union):len(q.union)], + order: q.order[:len(q.order):len(q.order)], + limit: q.limit, + offset: q.offset, + selFor: q.selFor, + + onConflict: q.onConflict, + returning: q.returning[:len(q.returning):len(q.returning)], + } + + return clone +} + +func cloneTableModelJoins(tm TableModel) TableModel { + switch tm := tm.(type) { + case *structTableModel: + if len(tm.joins) == 0 { + return tm + } + clone := *tm + clone.joins = clone.joins[:len(clone.joins):len(clone.joins)] + return &clone + case *sliceTableModel: + if len(tm.joins) == 0 { + return tm + } + clone := *tm + clone.joins = clone.joins[:len(clone.joins):len(clone.joins)] + return &clone + } + return tm +} + +func (q *Query) err(err error) *Query { + if q.stickyErr == nil { + q.stickyErr = err + } + return q +} + +func (q *Query) hasFlag(flag queryFlag) bool { + return hasFlag(q.flags, flag) +} + +func hasFlag(flags, flag queryFlag) bool { + return flags&flag != 0 +} + +func (q *Query) withFlag(flag queryFlag) *Query { + q.flags |= flag + return q +} + +func (q *Query) withoutFlag(flag queryFlag) *Query { + q.flags &= ^flag + return q +} + +func (q *Query) Context(c context.Context) *Query { + q.ctx = c + return q +} + +func (q *Query) DB(db DB) *Query { + q.db = db + return q +} + +func (q *Query) Model(model ...interface{}) *Query { + var err error + switch l := len(model); { + case l == 0: + q.model = nil + case l == 1: + q.model, err = NewModel(model[0]) + case l > 1: + q.model, err = NewModel(&model) + default: + panic("not reached") + } + if err != nil { + q = q.err(err) + } + + q.tableModel, _ = q.model.(TableModel) + + return q.withoutFlag(implicitModelFlag) +} + +func (q *Query) TableModel() TableModel { + return q.tableModel +} + +func (q *Query) isSoftDelete() bool { + if q.tableModel != nil { + return q.tableModel.Table().SoftDeleteField != nil && !q.hasFlag(allWithDeletedFlag) + } + return false +} + +// Deleted adds `WHERE deleted_at IS NOT NULL` clause for soft deleted models. +func (q *Query) Deleted() *Query { + if q.tableModel != nil { + if err := q.tableModel.Table().mustSoftDelete(); err != nil { + return q.err(err) + } + } + return q.withFlag(deletedFlag).withoutFlag(allWithDeletedFlag) +} + +// AllWithDeleted changes query to return all rows including soft deleted ones. +func (q *Query) AllWithDeleted() *Query { + if q.tableModel != nil { + if err := q.tableModel.Table().mustSoftDelete(); err != nil { + return q.err(err) + } + } + return q.withFlag(allWithDeletedFlag).withoutFlag(deletedFlag) +} + +// With adds subq as common table expression with the given name. +func (q *Query) With(name string, subq *Query) *Query { + return q._with(name, NewSelectQuery(subq)) +} + +func (q *Query) WithInsert(name string, subq *Query) *Query { + return q._with(name, NewInsertQuery(subq)) +} + +func (q *Query) WithUpdate(name string, subq *Query) *Query { + return q._with(name, NewUpdateQuery(subq, false)) +} + +func (q *Query) WithDelete(name string, subq *Query) *Query { + return q._with(name, NewDeleteQuery(subq)) +} + +func (q *Query) _with(name string, subq QueryAppender) *Query { + q.with = append(q.with, withQuery{ + name: name, + query: subq, + }) + return q +} + +// WrapWith creates new Query and adds to it current query as +// common table expression with the given name. +func (q *Query) WrapWith(name string) *Query { + wrapper := q.New() + wrapper.with = q.with + q.with = nil + wrapper = wrapper.With(name, q) + return wrapper +} + +func (q *Query) Table(tables ...string) *Query { + for _, table := range tables { + q.tables = append(q.tables, fieldAppender{table}) + } + return q +} + +func (q *Query) TableExpr(expr string, params ...interface{}) *Query { + q.tables = append(q.tables, SafeQuery(expr, params...)) + return q +} + +func (q *Query) Distinct() *Query { + q.distinctOn = make([]*SafeQueryAppender, 0) + return q +} + +func (q *Query) DistinctOn(expr string, params ...interface{}) *Query { + q.distinctOn = append(q.distinctOn, SafeQuery(expr, params...)) + return q +} + +// Column adds a column to the Query quoting it according to PostgreSQL rules. +// Does not expand params like ?TableAlias etc. +// ColumnExpr can be used to bypass quoting restriction or for params expansion. +// Column name can be: +// - column_name, +// - table_alias.column_name, +// - table_alias.*. +func (q *Query) Column(columns ...string) *Query { + for _, column := range columns { + if column == "_" { + if q.columns == nil { + q.columns = make([]QueryAppender, 0) + } + continue + } + + q.columns = append(q.columns, fieldAppender{column}) + } + return q +} + +// ColumnExpr adds column expression to the Query. +func (q *Query) ColumnExpr(expr string, params ...interface{}) *Query { + q.columns = append(q.columns, SafeQuery(expr, params...)) + return q +} + +// ExcludeColumn excludes a column from the list of to be selected columns. +func (q *Query) ExcludeColumn(columns ...string) *Query { + if q.columns == nil { + for _, f := range q.tableModel.Table().Fields { + q.columns = append(q.columns, fieldAppender{f.SQLName}) + } + } + + for _, col := range columns { + if !q.excludeColumn(col) { + return q.err(fmt.Errorf("pg: can't find column=%q", col)) + } + } + return q +} + +func (q *Query) excludeColumn(column string) bool { + for i := 0; i < len(q.columns); i++ { + app, ok := q.columns[i].(fieldAppender) + if ok && app.field == column { + q.columns = append(q.columns[:i], q.columns[i+1:]...) + return true + } + } + return false +} + +func (q *Query) getFields() ([]*Field, error) { + return q._getFields(false) +} + +func (q *Query) getDataFields() ([]*Field, error) { + return q._getFields(true) +} + +func (q *Query) _getFields(omitPKs bool) ([]*Field, error) { + table := q.tableModel.Table() + columns := make([]*Field, 0, len(q.columns)) + for _, col := range q.columns { + f, ok := col.(fieldAppender) + if !ok { + continue + } + + field, err := table.GetField(f.field) + if err != nil { + return nil, err + } + + if omitPKs && field.hasFlag(PrimaryKeyFlag) { + continue + } + + columns = append(columns, field) + } + return columns, nil +} + +// Relation adds a relation to the query. Relation name can be: +// - RelationName to select all columns, +// - RelationName.column_name, +// - RelationName._ to join relation without selecting relation columns. +func (q *Query) Relation(name string, apply ...func(*Query) (*Query, error)) *Query { + var fn func(*Query) (*Query, error) + if len(apply) == 1 { + fn = apply[0] + } else if len(apply) > 1 { + panic("only one apply function is supported") + } + + join := q.tableModel.Join(name, fn) + if join == nil { + return q.err(fmt.Errorf("%s does not have relation=%q", + q.tableModel.Table(), name)) + } + + if fn == nil { + return q + } + + switch join.Rel.Type { + case HasOneRelation, BelongsToRelation: + q.joinAppendOn = join.AppendOn + return q.Apply(fn) + default: + q.joinAppendOn = nil + return q + } +} + +func (q *Query) Set(set string, params ...interface{}) *Query { + q.set = append(q.set, SafeQuery(set, params...)) + return q +} + +// Value overwrites model value for the column in INSERT and UPDATE queries. +func (q *Query) Value(column string, value string, params ...interface{}) *Query { + if !q.hasTableModel() { + q.err(errModelNil) + return q + } + + table := q.tableModel.Table() + if _, ok := table.FieldsMap[column]; ok { + if q.modelValues == nil { + q.modelValues = make(map[string]*SafeQueryAppender) + } + q.modelValues[column] = SafeQuery(value, params...) + } else { + q.extraValues = append(q.extraValues, &columnValue{ + column: column, + value: SafeQuery(value, params...), + }) + } + + return q +} + +func (q *Query) Where(condition string, params ...interface{}) *Query { + q.addWhere(&condAppender{ + sep: " AND ", + cond: condition, + params: params, + }) + return q +} + +func (q *Query) WhereOr(condition string, params ...interface{}) *Query { + q.addWhere(&condAppender{ + sep: " OR ", + cond: condition, + params: params, + }) + return q +} + +// WhereGroup encloses conditions added in the function in parentheses. +// +// q.Where("TRUE"). +// WhereGroup(func(q *pg.Query) (*pg.Query, error) { +// q = q.WhereOr("FALSE").WhereOr("TRUE"). +// return q, nil +// }) +// +// generates +// +// WHERE TRUE AND (FALSE OR TRUE) +func (q *Query) WhereGroup(fn func(*Query) (*Query, error)) *Query { + return q.whereGroup(" AND ", fn) +} + +// WhereGroup encloses conditions added in the function in parentheses. +// +// q.Where("TRUE"). +// WhereNotGroup(func(q *pg.Query) (*pg.Query, error) { +// q = q.WhereOr("FALSE").WhereOr("TRUE"). +// return q, nil +// }) +// +// generates +// +// WHERE TRUE AND NOT (FALSE OR TRUE) +func (q *Query) WhereNotGroup(fn func(*Query) (*Query, error)) *Query { + return q.whereGroup(" AND NOT ", fn) +} + +// WhereOrGroup encloses conditions added in the function in parentheses. +// +// q.Where("TRUE"). +// WhereOrGroup(func(q *pg.Query) (*pg.Query, error) { +// q = q.Where("FALSE").Where("TRUE"). +// return q, nil +// }) +// +// generates +// +// WHERE TRUE OR (FALSE AND TRUE) +func (q *Query) WhereOrGroup(fn func(*Query) (*Query, error)) *Query { + return q.whereGroup(" OR ", fn) +} + +// WhereOrGroup encloses conditions added in the function in parentheses. +// +// q.Where("TRUE"). +// WhereOrGroup(func(q *pg.Query) (*pg.Query, error) { +// q = q.Where("FALSE").Where("TRUE"). +// return q, nil +// }) +// +// generates +// +// WHERE TRUE OR NOT (FALSE AND TRUE) +func (q *Query) WhereOrNotGroup(fn func(*Query) (*Query, error)) *Query { + return q.whereGroup(" OR NOT ", fn) +} + +func (q *Query) whereGroup(conj string, fn func(*Query) (*Query, error)) *Query { + saved := q.where + q.where = nil + + newq, err := fn(q) + if err != nil { + q.err(err) + return q + } + + if len(newq.where) == 0 { + newq.where = saved + return newq + } + + f := &condGroupAppender{ + sep: conj, + cond: newq.where, + } + newq.where = saved + newq.addWhere(f) + + return newq +} + +// WhereIn is a shortcut for Where and pg.In. +func (q *Query) WhereIn(where string, slice interface{}) *Query { + return q.Where(where, types.In(slice)) +} + +// WhereInMulti is a shortcut for Where and pg.InMulti. +func (q *Query) WhereInMulti(where string, values ...interface{}) *Query { + return q.Where(where, types.InMulti(values...)) +} + +func (q *Query) addWhere(f queryWithSepAppender) { + if q.onConflictDoUpdate() { + q.updWhere = append(q.updWhere, f) + } else { + q.where = append(q.where, f) + } +} + +// WherePK adds condition based on the model primary keys. +// Usually it is the same as: +// +// Where("id = ?id") +func (q *Query) WherePK() *Query { + if !q.hasTableModel() { + q.err(errModelNil) + return q + } + + if err := q.tableModel.Table().checkPKs(); err != nil { + q.err(err) + return q + } + + switch q.tableModel.Kind() { + case reflect.Struct: + q.where = append(q.where, wherePKStructQuery{q}) + return q + case reflect.Slice: + q.joins = append(q.joins, joinPKSliceQuery{q: q}) + q.where = append(q.where, wherePKSliceQuery{q: q}) + q = q.OrderExpr(`"_data"."ordering" ASC`) + return q + } + + panic("not reached") +} + +func (q *Query) Join(join string, params ...interface{}) *Query { + j := &joinQuery{ + join: SafeQuery(join, params...), + } + q.joins = append(q.joins, j) + q.joinAppendOn = j.AppendOn + return q +} + +// JoinOn appends join condition to the last join. +func (q *Query) JoinOn(condition string, params ...interface{}) *Query { + if q.joinAppendOn == nil { + q.err(errors.New("pg: no joins to apply JoinOn")) + return q + } + q.joinAppendOn(&condAppender{ + sep: " AND ", + cond: condition, + params: params, + }) + return q +} + +func (q *Query) JoinOnOr(condition string, params ...interface{}) *Query { + if q.joinAppendOn == nil { + q.err(errors.New("pg: no joins to apply JoinOn")) + return q + } + q.joinAppendOn(&condAppender{ + sep: " OR ", + cond: condition, + params: params, + }) + return q +} + +func (q *Query) Group(columns ...string) *Query { + for _, column := range columns { + q.group = append(q.group, fieldAppender{column}) + } + return q +} + +func (q *Query) GroupExpr(group string, params ...interface{}) *Query { + q.group = append(q.group, SafeQuery(group, params...)) + return q +} + +func (q *Query) Having(having string, params ...interface{}) *Query { + q.having = append(q.having, SafeQuery(having, params...)) + return q +} + +func (q *Query) Union(other *Query) *Query { + return q.addUnion(" UNION ", other) +} + +func (q *Query) UnionAll(other *Query) *Query { + return q.addUnion(" UNION ALL ", other) +} + +func (q *Query) Intersect(other *Query) *Query { + return q.addUnion(" INTERSECT ", other) +} + +func (q *Query) IntersectAll(other *Query) *Query { + return q.addUnion(" INTERSECT ALL ", other) +} + +func (q *Query) Except(other *Query) *Query { + return q.addUnion(" EXCEPT ", other) +} + +func (q *Query) ExceptAll(other *Query) *Query { + return q.addUnion(" EXCEPT ALL ", other) +} + +func (q *Query) addUnion(expr string, other *Query) *Query { + q.union = append(q.union, &union{ + expr: expr, + query: other, + }) + return q +} + +// Order adds sort order to the Query quoting column name. Does not expand params like ?TableAlias etc. +// OrderExpr can be used to bypass quoting restriction or for params expansion. +func (q *Query) Order(orders ...string) *Query { +loop: + for _, order := range orders { + if order == "" { + continue + } + ind := strings.Index(order, " ") + if ind != -1 { + field := order[:ind] + sort := order[ind+1:] + switch internal.UpperString(sort) { + case "ASC", "DESC", "ASC NULLS FIRST", "DESC NULLS FIRST", + "ASC NULLS LAST", "DESC NULLS LAST": + q = q.OrderExpr("? ?", types.Ident(field), types.Safe(sort)) + continue loop + } + } + + q.order = append(q.order, fieldAppender{order}) + } + return q +} + +// Order adds sort order to the Query. +func (q *Query) OrderExpr(order string, params ...interface{}) *Query { + if order != "" { + q.order = append(q.order, SafeQuery(order, params...)) + } + return q +} + +func (q *Query) Limit(n int) *Query { + q.limit = n + return q +} + +func (q *Query) Offset(n int) *Query { + q.offset = n + return q +} + +func (q *Query) OnConflict(s string, params ...interface{}) *Query { + q.onConflict = SafeQuery(s, params...) + return q +} + +func (q *Query) onConflictDoUpdate() bool { + return q.onConflict != nil && + strings.HasSuffix(internal.UpperString(q.onConflict.query), "DO UPDATE") +} + +// Returning adds a RETURNING clause to the query. +// +// `Returning("NULL")` can be used to suppress default returning clause +// generated by go-pg for INSERT queries to get values for null columns. +func (q *Query) Returning(s string, params ...interface{}) *Query { + q.returning = append(q.returning, SafeQuery(s, params...)) + return q +} + +func (q *Query) For(s string, params ...interface{}) *Query { + q.selFor = SafeQuery(s, params...) + return q +} + +// Apply calls the fn passing the Query as an argument. +func (q *Query) Apply(fn func(*Query) (*Query, error)) *Query { + qq, err := fn(q) + if err != nil { + q.err(err) + return q + } + return qq +} + +// Count returns number of rows matching the query using count aggregate function. +func (q *Query) Count() (int, error) { + if q.stickyErr != nil { + return 0, q.stickyErr + } + + var count int + _, err := q.db.QueryOneContext( + q.ctx, Scan(&count), q.countSelectQuery("count(*)"), q.tableModel) + return count, err +} + +func (q *Query) countSelectQuery(column string) *SelectQuery { + return &SelectQuery{ + q: q, + count: column, + } +} + +// First sorts rows by primary key and selects the first row. +// It is a shortcut for: +// +// q.OrderExpr("id ASC").Limit(1) +func (q *Query) First() error { + table := q.tableModel.Table() + + if err := table.checkPKs(); err != nil { + return err + } + + b := appendColumns(nil, table.Alias, table.PKs) + return q.OrderExpr(internal.BytesToString(b)).Limit(1).Select() +} + +// Last sorts rows by primary key and selects the last row. +// It is a shortcut for: +// +// q.OrderExpr("id DESC").Limit(1) +func (q *Query) Last() error { + table := q.tableModel.Table() + + if err := table.checkPKs(); err != nil { + return err + } + + // TODO: fix for multi columns + b := appendColumns(nil, table.Alias, table.PKs) + b = append(b, " DESC"...) + return q.OrderExpr(internal.BytesToString(b)).Limit(1).Select() +} + +// Select selects the model. +func (q *Query) Select(values ...interface{}) error { + if q.stickyErr != nil { + return q.stickyErr + } + + model, err := q.newModel(values) + if err != nil { + return err + } + + res, err := q.query(q.ctx, model, NewSelectQuery(q)) + if err != nil { + return err + } + + if res.RowsReturned() > 0 { + if q.tableModel != nil { + if err := q.selectJoins(q.tableModel.GetJoins()); err != nil { + return err + } + } + } + + if err := model.AfterSelect(q.ctx); err != nil { + return err + } + + return nil +} + +func (q *Query) newModel(values []interface{}) (Model, error) { + if len(values) > 0 { + return newScanModel(values) + } + return q.tableModel, nil +} + +func (q *Query) query(ctx context.Context, model Model, query interface{}) (Result, error) { + if _, ok := model.(useQueryOne); ok { + return q.db.QueryOneContext(ctx, model, query, q.tableModel) + } + return q.db.QueryContext(ctx, model, query, q.tableModel) +} + +// SelectAndCount runs Select and Count in two goroutines, +// waits for them to finish and returns the result. If query limit is -1 +// it does not select any data and only counts the results. +func (q *Query) SelectAndCount(values ...interface{}) (count int, firstErr error) { + if q.stickyErr != nil { + return 0, q.stickyErr + } + + var wg sync.WaitGroup + var mu sync.Mutex + + if q.limit >= 0 { + wg.Add(1) + go func() { + defer wg.Done() + err := q.Select(values...) + if err != nil { + mu.Lock() + if firstErr == nil { + firstErr = err + } + mu.Unlock() + } + }() + } + + wg.Add(1) + go func() { + defer wg.Done() + var err error + count, err = q.Count() + if err != nil { + mu.Lock() + if firstErr == nil { + firstErr = err + } + mu.Unlock() + } + }() + + wg.Wait() + return count, firstErr +} + +// SelectAndCountEstimate runs Select and CountEstimate in two goroutines, +// waits for them to finish and returns the result. If query limit is -1 +// it does not select any data and only counts the results. +func (q *Query) SelectAndCountEstimate(threshold int, values ...interface{}) (count int, firstErr error) { + if q.stickyErr != nil { + return 0, q.stickyErr + } + + var wg sync.WaitGroup + var mu sync.Mutex + + if q.limit >= 0 { + wg.Add(1) + go func() { + defer wg.Done() + err := q.Select(values...) + if err != nil { + mu.Lock() + if firstErr == nil { + firstErr = err + } + mu.Unlock() + } + }() + } + + wg.Add(1) + go func() { + defer wg.Done() + var err error + count, err = q.CountEstimate(threshold) + if err != nil { + mu.Lock() + if firstErr == nil { + firstErr = err + } + mu.Unlock() + } + }() + + wg.Wait() + return count, firstErr +} + +// ForEach calls the function for each row returned by the query +// without loading all rows into the memory. +// +// Function can accept a struct, a pointer to a struct, an orm.Model, +// or values for the columns in a row. Function must return an error. +func (q *Query) ForEach(fn interface{}) error { + m := newFuncModel(fn) + return q.Select(m) +} + +func (q *Query) forEachHasOneJoin(fn func(*join) error) error { + if q.tableModel == nil { + return nil + } + return q._forEachHasOneJoin(fn, q.tableModel.GetJoins()) +} + +func (q *Query) _forEachHasOneJoin(fn func(*join) error, joins []join) error { + for i := range joins { + j := &joins[i] + switch j.Rel.Type { + case HasOneRelation, BelongsToRelation: + err := fn(j) + if err != nil { + return err + } + + err = q._forEachHasOneJoin(fn, j.JoinModel.GetJoins()) + if err != nil { + return err + } + } + } + return nil +} + +func (q *Query) selectJoins(joins []join) error { + var err error + for i := range joins { + j := &joins[i] + if j.Rel.Type == HasOneRelation || j.Rel.Type == BelongsToRelation { + err = q.selectJoins(j.JoinModel.GetJoins()) + } else { + err = j.Select(q.db.Formatter(), q.New()) + } + if err != nil { + return err + } + } + return nil +} + +// Insert inserts the model. +func (q *Query) Insert(values ...interface{}) (Result, error) { + if q.stickyErr != nil { + return nil, q.stickyErr + } + + model, err := q.newModel(values) + if err != nil { + return nil, err + } + + ctx := q.ctx + + if q.tableModel != nil && q.tableModel.Table().hasFlag(beforeInsertHookFlag) { + ctx, err = q.tableModel.BeforeInsert(ctx) + if err != nil { + return nil, err + } + } + + query := NewInsertQuery(q) + res, err := q.returningQuery(ctx, model, query) + if err != nil { + return nil, err + } + + if q.tableModel != nil { + if err := q.tableModel.AfterInsert(ctx); err != nil { + return nil, err + } + } + + return res, nil +} + +// SelectOrInsert selects the model inserting one if it does not exist. +// It returns true when model was inserted. +func (q *Query) SelectOrInsert(values ...interface{}) (inserted bool, _ error) { + if q.stickyErr != nil { + return false, q.stickyErr + } + + var insertq *Query + var insertErr error + for i := 0; i < 5; i++ { + if i >= 2 { + dur := internal.RetryBackoff(i-2, 250*time.Millisecond, 5*time.Second) + if err := internal.Sleep(q.ctx, dur); err != nil { + return false, err + } + } + + err := q.Select(values...) + if err == nil { + return false, nil + } + if err != internal.ErrNoRows { + return false, err + } + + if insertq == nil { + insertq = q + if len(insertq.columns) > 0 { + insertq = insertq.Clone() + insertq.columns = nil + } + } + + res, err := insertq.Insert(values...) + if err != nil { + insertErr = err + if err == internal.ErrNoRows { + continue + } + if pgErr, ok := err.(internal.PGError); ok { + if pgErr.IntegrityViolation() { + continue + } + if pgErr.Field('C') == "55000" { + // Retry on "#55000 attempted to delete invisible tuple". + continue + } + } + return false, err + } + if res.RowsAffected() == 1 { + return true, nil + } + } + + err := fmt.Errorf( + "pg: SelectOrInsert: select returns no rows (insert fails with err=%q)", + insertErr) + return false, err +} + +// Update updates the model. +func (q *Query) Update(scan ...interface{}) (Result, error) { + return q.update(scan, false) +} + +// Update updates the model omitting fields with zero values such as: +// - empty string, +// - 0, +// - zero time, +// - empty map or slice, +// - byte array with all zeroes, +// - nil ptr, +// - types with method `IsZero() == true`. +func (q *Query) UpdateNotZero(scan ...interface{}) (Result, error) { + return q.update(scan, true) +} + +func (q *Query) update(values []interface{}, omitZero bool) (Result, error) { + if q.stickyErr != nil { + return nil, q.stickyErr + } + + model, err := q.newModel(values) + if err != nil { + return nil, err + } + + c := q.ctx + + if q.tableModel != nil { + c, err = q.tableModel.BeforeUpdate(c) + if err != nil { + return nil, err + } + } + + query := NewUpdateQuery(q, omitZero) + res, err := q.returningQuery(c, model, query) + if err != nil { + return nil, err + } + + if q.tableModel != nil { + err = q.tableModel.AfterUpdate(c) + if err != nil { + return nil, err + } + } + + return res, nil +} + +func (q *Query) returningQuery(c context.Context, model Model, query interface{}) (Result, error) { + if !q.hasReturning() { + return q.db.QueryContext(c, model, query, q.tableModel) + } + if _, ok := model.(useQueryOne); ok { + return q.db.QueryOneContext(c, model, query, q.tableModel) + } + return q.db.QueryContext(c, model, query, q.tableModel) +} + +// Delete deletes the model. When model has deleted_at column the row +// is soft deleted instead. +func (q *Query) Delete(values ...interface{}) (Result, error) { + if q.tableModel == nil { + return q.ForceDelete(values...) + } + + table := q.tableModel.Table() + if table.SoftDeleteField == nil { + return q.ForceDelete(values...) + } + + clone := q.Clone() + if q.tableModel.IsNil() { + if table.SoftDeleteField.SQLType == pgTypeBigint { + clone = clone.Set("? = ?", table.SoftDeleteField.Column, time.Now().UnixNano()) + } else { + clone = clone.Set("? = ?", table.SoftDeleteField.Column, time.Now()) + } + } else { + if err := clone.tableModel.setSoftDeleteField(); err != nil { + return nil, err + } + clone = clone.Column(table.SoftDeleteField.SQLName) + } + return clone.Update(values...) +} + +// Delete forces delete of the model with deleted_at column. +func (q *Query) ForceDelete(values ...interface{}) (Result, error) { + if q.stickyErr != nil { + return nil, q.stickyErr + } + if q.tableModel == nil { + return nil, errModelNil + } + q = q.withFlag(deletedFlag) + + model, err := q.newModel(values) + if err != nil { + return nil, err + } + + ctx := q.ctx + + if q.tableModel != nil { + ctx, err = q.tableModel.BeforeDelete(ctx) + if err != nil { + return nil, err + } + } + + res, err := q.returningQuery(ctx, model, NewDeleteQuery(q)) + if err != nil { + return nil, err + } + + if q.tableModel != nil { + if err := q.tableModel.AfterDelete(ctx); err != nil { + return nil, err + } + } + + return res, nil +} + +func (q *Query) CreateTable(opt *CreateTableOptions) error { + _, err := q.db.ExecContext(q.ctx, NewCreateTableQuery(q, opt)) + return err +} + +func (q *Query) DropTable(opt *DropTableOptions) error { + _, err := q.db.ExecContext(q.ctx, NewDropTableQuery(q, opt)) + return err +} + +func (q *Query) CreateComposite(opt *CreateCompositeOptions) error { + _, err := q.db.ExecContext(q.ctx, NewCreateCompositeQuery(q, opt)) + return err +} + +func (q *Query) DropComposite(opt *DropCompositeOptions) error { + _, err := q.db.ExecContext(q.ctx, NewDropCompositeQuery(q, opt)) + return err +} + +// Exec is an alias for DB.Exec. +func (q *Query) Exec(query interface{}, params ...interface{}) (Result, error) { + params = append(params, q.tableModel) + return q.db.ExecContext(q.ctx, query, params...) +} + +// ExecOne is an alias for DB.ExecOne. +func (q *Query) ExecOne(query interface{}, params ...interface{}) (Result, error) { + params = append(params, q.tableModel) + return q.db.ExecOneContext(q.ctx, query, params...) +} + +// Query is an alias for DB.Query. +func (q *Query) Query(model, query interface{}, params ...interface{}) (Result, error) { + params = append(params, q.tableModel) + return q.db.QueryContext(q.ctx, model, query, params...) +} + +// QueryOne is an alias for DB.QueryOne. +func (q *Query) QueryOne(model, query interface{}, params ...interface{}) (Result, error) { + params = append(params, q.tableModel) + return q.db.QueryOneContext(q.ctx, model, query, params...) +} + +// CopyFrom is an alias from DB.CopyFrom. +func (q *Query) CopyFrom(r io.Reader, query interface{}, params ...interface{}) (Result, error) { + params = append(params, q.tableModel) + return q.db.CopyFrom(r, query, params...) +} + +// CopyTo is an alias from DB.CopyTo. +func (q *Query) CopyTo(w io.Writer, query interface{}, params ...interface{}) (Result, error) { + params = append(params, q.tableModel) + return q.db.CopyTo(w, query, params...) +} + +var _ QueryAppender = (*Query)(nil) + +func (q *Query) AppendQuery(fmter QueryFormatter, b []byte) ([]byte, error) { + return NewSelectQuery(q).AppendQuery(fmter, b) +} + +// Exists returns true or false depending if there are any rows matching the query. +func (q *Query) Exists() (bool, error) { + q = q.Clone() // copy to not change original query + q.columns = []QueryAppender{SafeQuery("1")} + q.order = nil + q.limit = 1 + res, err := q.db.ExecContext(q.ctx, NewSelectQuery(q)) + if err != nil { + return false, err + } + return res.RowsAffected() > 0, nil +} + +func (q *Query) hasTableModel() bool { + return q.tableModel != nil && !q.tableModel.IsNil() +} + +func (q *Query) hasExplicitTableModel() bool { + return q.tableModel != nil && !q.hasFlag(implicitModelFlag) +} + +func (q *Query) modelHasTableName() bool { + return q.hasExplicitTableModel() && q.tableModel.Table().SQLName != "" +} + +func (q *Query) modelHasTableAlias() bool { + return q.hasExplicitTableModel() && q.tableModel.Table().Alias != "" +} + +func (q *Query) hasTables() bool { + return q.modelHasTableName() || len(q.tables) > 0 +} + +func (q *Query) appendFirstTable(fmter QueryFormatter, b []byte) ([]byte, error) { + if q.modelHasTableName() { + return fmter.FormatQuery(b, string(q.tableModel.Table().SQLName)), nil + } + if len(q.tables) > 0 { + return q.tables[0].AppendQuery(fmter, b) + } + return b, nil +} + +func (q *Query) appendFirstTableWithAlias(fmter QueryFormatter, b []byte) (_ []byte, err error) { + if q.modelHasTableName() { + table := q.tableModel.Table() + b = fmter.FormatQuery(b, string(table.SQLName)) + if table.Alias != table.SQLName { + b = append(b, " AS "...) + b = append(b, table.Alias...) + } + return b, nil + } + + if len(q.tables) > 0 { + b, err = q.tables[0].AppendQuery(fmter, b) + if err != nil { + return nil, err + } + if q.modelHasTableAlias() { + table := q.tableModel.Table() + if table.Alias != table.SQLName { + b = append(b, " AS "...) + b = append(b, table.Alias...) + } + } + } + + return b, nil +} + +func (q *Query) hasMultiTables() bool { + if q.modelHasTableName() { + return len(q.tables) >= 1 + } + return len(q.tables) >= 2 +} + +func (q *Query) appendOtherTables(fmter QueryFormatter, b []byte) (_ []byte, err error) { + tables := q.tables + if !q.modelHasTableName() { + tables = tables[1:] + } + for i, f := range tables { + if i > 0 { + b = append(b, ", "...) + } + b, err = f.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + return b, nil +} + +func (q *Query) appendColumns(fmter QueryFormatter, b []byte) (_ []byte, err error) { + for i, f := range q.columns { + if i > 0 { + b = append(b, ", "...) + } + b, err = f.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + return b, nil +} + +func (q *Query) mustAppendWhere(fmter QueryFormatter, b []byte) ([]byte, error) { + if len(q.where) == 0 { + err := errors.New( + "pg: Update and Delete queries require Where clause (try WherePK)") + return nil, err + } + return q.appendWhere(fmter, b) +} + +func (q *Query) appendWhere(fmter QueryFormatter, b []byte) (_ []byte, err error) { + isSoftDelete := q.isSoftDelete() + + if len(q.where) > 0 { + if isSoftDelete { + b = append(b, '(') + } + + b, err = q._appendWhere(fmter, b, q.where) + if err != nil { + return nil, err + } + + if isSoftDelete { + b = append(b, ')') + } + } + + if isSoftDelete { + if len(q.where) > 0 { + b = append(b, " AND "...) + } + b = append(b, q.tableModel.Table().Alias...) + b = q.appendSoftDelete(b) + } + + return b, nil +} + +func (q *Query) appendSoftDelete(b []byte) []byte { + b = append(b, '.') + b = append(b, q.tableModel.Table().SoftDeleteField.Column...) + if q.hasFlag(deletedFlag) { + b = append(b, " IS NOT NULL"...) + } else { + b = append(b, " IS NULL"...) + } + return b +} + +func (q *Query) appendUpdWhere(fmter QueryFormatter, b []byte) ([]byte, error) { + return q._appendWhere(fmter, b, q.updWhere) +} + +func (q *Query) _appendWhere( + fmter QueryFormatter, b []byte, where []queryWithSepAppender, +) (_ []byte, err error) { + for i, f := range where { + start := len(b) + + if i > 0 { + b = f.AppendSep(b) + } + + before := len(b) + + b, err = f.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + + if len(b) == before { + b = b[:start] + } + } + return b, nil +} + +func (q *Query) appendSet(fmter QueryFormatter, b []byte) (_ []byte, err error) { + b = append(b, " SET "...) + for i, f := range q.set { + if i > 0 { + b = append(b, ", "...) + } + b, err = f.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + return b, nil +} + +func (q *Query) hasReturning() bool { + if len(q.returning) == 0 { + return false + } + if len(q.returning) == 1 { + switch q.returning[0].query { + case "null", "NULL": + return false + } + } + return true +} + +func (q *Query) appendReturning(fmter QueryFormatter, b []byte) (_ []byte, err error) { + if !q.hasReturning() { + return b, nil + } + + b = append(b, " RETURNING "...) + for i, f := range q.returning { + if i > 0 { + b = append(b, ", "...) + } + b, err = f.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + return b, nil +} + +func (q *Query) appendWith(fmter QueryFormatter, b []byte) (_ []byte, err error) { + b = append(b, "WITH "...) + for i, with := range q.with { + if i > 0 { + b = append(b, ", "...) + } + b = types.AppendIdent(b, with.name, 1) + b = append(b, " AS ("...) + + b, err = with.query.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + + b = append(b, ')') + } + b = append(b, ' ') + return b, nil +} + +func (q *Query) isSliceModelWithData() bool { + if !q.hasTableModel() { + return false + } + m, ok := q.tableModel.(*sliceTableModel) + return ok && m.sliceLen > 0 +} + +//------------------------------------------------------------------------------ + +type wherePKStructQuery struct { + q *Query +} + +var _ queryWithSepAppender = (*wherePKStructQuery)(nil) + +func (wherePKStructQuery) AppendSep(b []byte) []byte { + return append(b, " AND "...) +} + +func (q wherePKStructQuery) AppendQuery(fmter QueryFormatter, b []byte) ([]byte, error) { + table := q.q.tableModel.Table() + value := q.q.tableModel.Value() + return appendColumnAndValue(fmter, b, value, table.Alias, table.PKs), nil +} + +func appendColumnAndValue( + fmter QueryFormatter, b []byte, v reflect.Value, alias types.Safe, fields []*Field, +) []byte { + isPlaceholder := isTemplateFormatter(fmter) + for i, f := range fields { + if i > 0 { + b = append(b, " AND "...) + } + b = append(b, alias...) + b = append(b, '.') + b = append(b, f.Column...) + b = append(b, " = "...) + if isPlaceholder { + b = append(b, '?') + } else { + b = f.AppendValue(b, v, 1) + } + } + return b +} + +//------------------------------------------------------------------------------ + +type wherePKSliceQuery struct { + q *Query +} + +var _ queryWithSepAppender = (*wherePKSliceQuery)(nil) + +func (wherePKSliceQuery) AppendSep(b []byte) []byte { + return append(b, " AND "...) +} + +func (q wherePKSliceQuery) AppendQuery(fmter QueryFormatter, b []byte) ([]byte, error) { + table := q.q.tableModel.Table() + + for i, f := range table.PKs { + if i > 0 { + b = append(b, " AND "...) + } + b = append(b, table.Alias...) + b = append(b, '.') + b = append(b, f.Column...) + b = append(b, " = "...) + b = append(b, `"_data".`...) + b = append(b, f.Column...) + } + + return b, nil +} + +type joinPKSliceQuery struct { + q *Query +} + +var _ QueryAppender = (*joinPKSliceQuery)(nil) + +func (q joinPKSliceQuery) AppendQuery(fmter QueryFormatter, b []byte) ([]byte, error) { + table := q.q.tableModel.Table() + slice := q.q.tableModel.Value() + + b = append(b, " JOIN (VALUES "...) + + sliceLen := slice.Len() + for i := 0; i < sliceLen; i++ { + if i > 0 { + b = append(b, ", "...) + } + + el := indirect(slice.Index(i)) + + b = append(b, '(') + for i, f := range table.PKs { + if i > 0 { + b = append(b, ", "...) + } + + b = f.AppendValue(b, el, 1) + + if f.UserSQLType != "" { + b = append(b, "::"...) + b = append(b, f.SQLType...) + } + } + + b = append(b, ", "...) + b = strconv.AppendInt(b, int64(i), 10) + + b = append(b, ')') + } + + b = append(b, `) AS "_data" (`...) + + for i, f := range table.PKs { + if i > 0 { + b = append(b, ", "...) + } + b = append(b, f.Column...) + } + + b = append(b, ", "...) + b = append(b, `"ordering"`...) + b = append(b, ") ON TRUE"...) + + return b, nil +} diff --git a/vendor/github.com/go-pg/pg/v10/orm/relation.go b/vendor/github.com/go-pg/pg/v10/orm/relation.go new file mode 100644 index 000000000..28d915bcd --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/orm/relation.go @@ -0,0 +1,33 @@ +package orm + +import ( + "fmt" + + "github.com/go-pg/pg/v10/types" +) + +const ( + InvalidRelation = iota + HasOneRelation + BelongsToRelation + HasManyRelation + Many2ManyRelation +) + +type Relation struct { + Type int + Field *Field + JoinTable *Table + BaseFKs []*Field + JoinFKs []*Field + Polymorphic *Field + + M2MTableName types.Safe + M2MTableAlias types.Safe + M2MBaseFKs []string + M2MJoinFKs []string +} + +func (r *Relation) String() string { + return fmt.Sprintf("relation=%s", r.Field.GoName) +} diff --git a/vendor/github.com/go-pg/pg/v10/orm/result.go b/vendor/github.com/go-pg/pg/v10/orm/result.go new file mode 100644 index 000000000..9d82815ef --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/orm/result.go @@ -0,0 +1,14 @@ +package orm + +// Result summarizes an executed SQL command. +type Result interface { + Model() Model + + // RowsAffected returns the number of rows affected by SELECT, INSERT, UPDATE, + // or DELETE queries. It returns -1 if query can't possibly affect any rows, + // e.g. in case of CREATE or SHOW queries. + RowsAffected() int + + // RowsReturned returns the number of rows returned by the query. + RowsReturned() int +} diff --git a/vendor/github.com/go-pg/pg/v10/orm/select.go b/vendor/github.com/go-pg/pg/v10/orm/select.go new file mode 100644 index 000000000..d3b38742d --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/orm/select.go @@ -0,0 +1,346 @@ +package orm + +import ( + "bytes" + "fmt" + "strconv" + "strings" + + "github.com/go-pg/pg/v10/types" +) + +type SelectQuery struct { + q *Query + count string +} + +var ( + _ QueryAppender = (*SelectQuery)(nil) + _ QueryCommand = (*SelectQuery)(nil) +) + +func NewSelectQuery(q *Query) *SelectQuery { + return &SelectQuery{ + q: q, + } +} + +func (q *SelectQuery) String() string { + b, err := q.AppendQuery(defaultFmter, nil) + if err != nil { + panic(err) + } + return string(b) +} + +func (q *SelectQuery) Operation() QueryOp { + return SelectOp +} + +func (q *SelectQuery) Clone() QueryCommand { + return &SelectQuery{ + q: q.q.Clone(), + count: q.count, + } +} + +func (q *SelectQuery) Query() *Query { + return q.q +} + +func (q *SelectQuery) AppendTemplate(b []byte) ([]byte, error) { + return q.AppendQuery(dummyFormatter{}, b) +} + +func (q *SelectQuery) AppendQuery(fmter QueryFormatter, b []byte) (_ []byte, err error) { //nolint:gocyclo + if q.q.stickyErr != nil { + return nil, q.q.stickyErr + } + + cteCount := q.count != "" && (len(q.q.group) > 0 || q.isDistinct()) + if cteCount { + b = append(b, `WITH "_count_wrapper" AS (`...) + } + + if len(q.q.with) > 0 { + b, err = q.q.appendWith(fmter, b) + if err != nil { + return nil, err + } + } + + if len(q.q.union) > 0 { + b = append(b, '(') + } + + b = append(b, "SELECT "...) + + if len(q.q.distinctOn) > 0 { + b = append(b, "DISTINCT ON ("...) + for i, app := range q.q.distinctOn { + if i > 0 { + b = append(b, ", "...) + } + b, err = app.AppendQuery(fmter, b) + } + b = append(b, ") "...) + } else if q.q.distinctOn != nil { + b = append(b, "DISTINCT "...) + } + + if q.count != "" && !cteCount { + b = append(b, q.count...) + } else { + b, err = q.appendColumns(fmter, b) + if err != nil { + return nil, err + } + } + + if q.q.hasTables() { + b = append(b, " FROM "...) + b, err = q.appendTables(fmter, b) + if err != nil { + return nil, err + } + } + + err = q.q.forEachHasOneJoin(func(j *join) error { + b = append(b, ' ') + b, err = j.appendHasOneJoin(fmter, b, q.q) + return err + }) + if err != nil { + return nil, err + } + + for _, j := range q.q.joins { + b, err = j.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + + if len(q.q.where) > 0 || q.q.isSoftDelete() { + b = append(b, " WHERE "...) + b, err = q.q.appendWhere(fmter, b) + if err != nil { + return nil, err + } + } + + if len(q.q.group) > 0 { + b = append(b, " GROUP BY "...) + for i, f := range q.q.group { + if i > 0 { + b = append(b, ", "...) + } + b, err = f.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + } + + if len(q.q.having) > 0 { + b = append(b, " HAVING "...) + for i, f := range q.q.having { + if i > 0 { + b = append(b, " AND "...) + } + b = append(b, '(') + b, err = f.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + b = append(b, ')') + } + } + + if q.count == "" { + if len(q.q.order) > 0 { + b = append(b, " ORDER BY "...) + for i, f := range q.q.order { + if i > 0 { + b = append(b, ", "...) + } + b, err = f.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + } + + if q.q.limit != 0 { + b = append(b, " LIMIT "...) + b = strconv.AppendInt(b, int64(q.q.limit), 10) + } + + if q.q.offset != 0 { + b = append(b, " OFFSET "...) + b = strconv.AppendInt(b, int64(q.q.offset), 10) + } + + if q.q.selFor != nil { + b = append(b, " FOR "...) + b, err = q.q.selFor.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + } else if cteCount { + b = append(b, `) SELECT `...) + b = append(b, q.count...) + b = append(b, ` FROM "_count_wrapper"`...) + } + + if len(q.q.union) > 0 { + b = append(b, ")"...) + + for _, u := range q.q.union { + b = append(b, u.expr...) + b = append(b, '(') + b, err = u.query.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + b = append(b, ")"...) + } + } + + return b, q.q.stickyErr +} + +func (q SelectQuery) appendColumns(fmter QueryFormatter, b []byte) (_ []byte, err error) { + start := len(b) + + switch { + case q.q.columns != nil: + b, err = q.q.appendColumns(fmter, b) + if err != nil { + return nil, err + } + case q.q.hasExplicitTableModel(): + table := q.q.tableModel.Table() + if len(table.Fields) > 10 && isTemplateFormatter(fmter) { + b = append(b, table.Alias...) + b = append(b, '.') + b = types.AppendString(b, fmt.Sprintf("%d columns", len(table.Fields)), 2) + } else { + b = appendColumns(b, table.Alias, table.Fields) + } + default: + b = append(b, '*') + } + + err = q.q.forEachHasOneJoin(func(j *join) error { + if len(b) != start { + b = append(b, ", "...) + start = len(b) + } + + b = j.appendHasOneColumns(b) + return nil + }) + if err != nil { + return nil, err + } + + b = bytes.TrimSuffix(b, []byte(", ")) + + return b, nil +} + +func (q *SelectQuery) isDistinct() bool { + if q.q.distinctOn != nil { + return true + } + for _, column := range q.q.columns { + column, ok := column.(*SafeQueryAppender) + if ok { + if strings.Contains(column.query, "DISTINCT") || + strings.Contains(column.query, "distinct") { + return true + } + } + } + return false +} + +func (q *SelectQuery) appendTables(fmter QueryFormatter, b []byte) (_ []byte, err error) { + tables := q.q.tables + + if q.q.modelHasTableName() { + table := q.q.tableModel.Table() + b = fmter.FormatQuery(b, string(table.SQLNameForSelects)) + if table.Alias != "" { + b = append(b, " AS "...) + b = append(b, table.Alias...) + } + + if len(tables) > 0 { + b = append(b, ", "...) + } + } else if len(tables) > 0 { + b, err = tables[0].AppendQuery(fmter, b) + if err != nil { + return nil, err + } + if q.q.modelHasTableAlias() { + b = append(b, " AS "...) + b = append(b, q.q.tableModel.Table().Alias...) + } + + tables = tables[1:] + if len(tables) > 0 { + b = append(b, ", "...) + } + } + + for i, f := range tables { + if i > 0 { + b = append(b, ", "...) + } + b, err = f.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + + return b, nil +} + +//------------------------------------------------------------------------------ + +type joinQuery struct { + join *SafeQueryAppender + on []*condAppender +} + +func (j *joinQuery) AppendOn(app *condAppender) { + j.on = append(j.on, app) +} + +func (j *joinQuery) AppendQuery(fmter QueryFormatter, b []byte) (_ []byte, err error) { + b = append(b, ' ') + + b, err = j.join.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + + if len(j.on) > 0 { + b = append(b, " ON "...) + for i, on := range j.on { + if i > 0 { + b = on.AppendSep(b) + } + b, err = on.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + } + + return b, nil +} diff --git a/vendor/github.com/go-pg/pg/v10/orm/table.go b/vendor/github.com/go-pg/pg/v10/orm/table.go new file mode 100644 index 000000000..8b57bbfc0 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/orm/table.go @@ -0,0 +1,1560 @@ +package orm + +import ( + "database/sql" + "encoding/json" + "fmt" + "net" + "reflect" + "strconv" + "strings" + "sync" + "time" + + "github.com/jinzhu/inflection" + "github.com/vmihailenco/tagparser" + + "github.com/go-pg/pg/v10/internal" + "github.com/go-pg/pg/v10/internal/pool" + "github.com/go-pg/pg/v10/pgjson" + "github.com/go-pg/pg/v10/types" + "github.com/go-pg/zerochecker" +) + +const ( + beforeScanHookFlag = uint16(1) << iota + afterScanHookFlag + afterSelectHookFlag + beforeInsertHookFlag + afterInsertHookFlag + beforeUpdateHookFlag + afterUpdateHookFlag + beforeDeleteHookFlag + afterDeleteHookFlag + discardUnknownColumnsFlag +) + +var ( + timeType = reflect.TypeOf((*time.Time)(nil)).Elem() + nullTimeType = reflect.TypeOf((*types.NullTime)(nil)).Elem() + sqlNullTimeType = reflect.TypeOf((*sql.NullTime)(nil)).Elem() + ipType = reflect.TypeOf((*net.IP)(nil)).Elem() + ipNetType = reflect.TypeOf((*net.IPNet)(nil)).Elem() + scannerType = reflect.TypeOf((*sql.Scanner)(nil)).Elem() + nullBoolType = reflect.TypeOf((*sql.NullBool)(nil)).Elem() + nullFloatType = reflect.TypeOf((*sql.NullFloat64)(nil)).Elem() + nullIntType = reflect.TypeOf((*sql.NullInt64)(nil)).Elem() + nullStringType = reflect.TypeOf((*sql.NullString)(nil)).Elem() + jsonRawMessageType = reflect.TypeOf((*json.RawMessage)(nil)).Elem() +) + +var tableNameInflector = inflection.Plural + +// SetTableNameInflector overrides the default func that pluralizes +// model name to get table name, e.g. my_article becomes my_articles. +func SetTableNameInflector(fn func(string) string) { + tableNameInflector = fn +} + +// Table represents a SQL table created from Go struct. +type Table struct { + Type reflect.Type + zeroStruct reflect.Value + + TypeName string + Alias types.Safe + ModelName string + + SQLName types.Safe + SQLNameForSelects types.Safe + + Tablespace types.Safe + + PartitionBy string + + allFields []*Field // read only + skippedFields []*Field + + Fields []*Field // PKs + DataFields + PKs []*Field + DataFields []*Field + fieldsMapMu sync.RWMutex + FieldsMap map[string]*Field + + Methods map[string]*Method + Relations map[string]*Relation + Unique map[string][]*Field + + SoftDeleteField *Field + SetSoftDeleteField func(fv reflect.Value) error + + flags uint16 +} + +func newTable(typ reflect.Type) *Table { + t := new(Table) + t.Type = typ + t.zeroStruct = reflect.New(t.Type).Elem() + t.TypeName = internal.ToExported(t.Type.Name()) + t.ModelName = internal.Underscore(t.Type.Name()) + tableName := tableNameInflector(t.ModelName) + t.setName(quoteIdent(tableName)) + t.Alias = quoteIdent(t.ModelName) + + typ = reflect.PtrTo(t.Type) + if typ.Implements(beforeScanHookType) { + t.setFlag(beforeScanHookFlag) + } + if typ.Implements(afterScanHookType) { + t.setFlag(afterScanHookFlag) + } + if typ.Implements(afterSelectHookType) { + t.setFlag(afterSelectHookFlag) + } + if typ.Implements(beforeInsertHookType) { + t.setFlag(beforeInsertHookFlag) + } + if typ.Implements(afterInsertHookType) { + t.setFlag(afterInsertHookFlag) + } + if typ.Implements(beforeUpdateHookType) { + t.setFlag(beforeUpdateHookFlag) + } + if typ.Implements(afterUpdateHookType) { + t.setFlag(afterUpdateHookFlag) + } + if typ.Implements(beforeDeleteHookType) { + t.setFlag(beforeDeleteHookFlag) + } + if typ.Implements(afterDeleteHookType) { + t.setFlag(afterDeleteHookFlag) + } + + return t +} + +func (t *Table) init1() { + t.initFields() + t.initMethods() +} + +func (t *Table) init2() { + t.initInlines() + t.initRelations() + t.skippedFields = nil +} + +func (t *Table) setName(name types.Safe) { + t.SQLName = name + t.SQLNameForSelects = name + if t.Alias == "" { + t.Alias = name + } +} + +func (t *Table) String() string { + return "model=" + t.TypeName +} + +func (t *Table) setFlag(flag uint16) { + t.flags |= flag +} + +func (t *Table) hasFlag(flag uint16) bool { + if t == nil { + return false + } + return t.flags&flag != 0 +} + +func (t *Table) checkPKs() error { + if len(t.PKs) == 0 { + return fmt.Errorf("pg: %s does not have primary keys", t) + } + return nil +} + +func (t *Table) mustSoftDelete() error { + if t.SoftDeleteField == nil { + return fmt.Errorf("pg: %s does not support soft deletes", t) + } + return nil +} + +func (t *Table) AddField(field *Field) { + t.Fields = append(t.Fields, field) + if field.hasFlag(PrimaryKeyFlag) { + t.PKs = append(t.PKs, field) + } else { + t.DataFields = append(t.DataFields, field) + } + t.FieldsMap[field.SQLName] = field +} + +func (t *Table) RemoveField(field *Field) { + t.Fields = removeField(t.Fields, field) + if field.hasFlag(PrimaryKeyFlag) { + t.PKs = removeField(t.PKs, field) + } else { + t.DataFields = removeField(t.DataFields, field) + } + delete(t.FieldsMap, field.SQLName) +} + +func removeField(fields []*Field, field *Field) []*Field { + for i, f := range fields { + if f == field { + fields = append(fields[:i], fields[i+1:]...) + } + } + return fields +} + +func (t *Table) getField(name string) *Field { + t.fieldsMapMu.RLock() + field := t.FieldsMap[name] + t.fieldsMapMu.RUnlock() + return field +} + +func (t *Table) HasField(name string) bool { + _, ok := t.FieldsMap[name] + return ok +} + +func (t *Table) GetField(name string) (*Field, error) { + field, ok := t.FieldsMap[name] + if !ok { + return nil, fmt.Errorf("pg: %s does not have column=%s", t, name) + } + return field, nil +} + +func (t *Table) AppendParam(b []byte, strct reflect.Value, name string) ([]byte, bool) { + field, ok := t.FieldsMap[name] + if ok { + b = field.AppendValue(b, strct, 1) + return b, true + } + + method, ok := t.Methods[name] + if ok { + b = method.AppendValue(b, strct.Addr(), 1) + return b, true + } + + return b, false +} + +func (t *Table) initFields() { + t.Fields = make([]*Field, 0, t.Type.NumField()) + t.FieldsMap = make(map[string]*Field, t.Type.NumField()) + t.addFields(t.Type, nil) +} + +func (t *Table) addFields(typ reflect.Type, baseIndex []int) { + for i := 0; i < typ.NumField(); i++ { + f := typ.Field(i) + + // Make a copy so slice is not shared between fields. + index := make([]int, len(baseIndex)) + copy(index, baseIndex) + + if f.Anonymous { + if f.Tag.Get("sql") == "-" || f.Tag.Get("pg") == "-" { + continue + } + + fieldType := indirectType(f.Type) + if fieldType.Kind() != reflect.Struct { + continue + } + t.addFields(fieldType, append(index, f.Index...)) + + pgTag := tagparser.Parse(f.Tag.Get("pg")) + if _, inherit := pgTag.Options["inherit"]; inherit { + embeddedTable := _tables.get(fieldType, true) + t.TypeName = embeddedTable.TypeName + t.SQLName = embeddedTable.SQLName + t.SQLNameForSelects = embeddedTable.SQLNameForSelects + t.Alias = embeddedTable.Alias + t.ModelName = embeddedTable.ModelName + } + + continue + } + + field := t.newField(f, index) + if field != nil { + t.AddField(field) + } + } +} + +//nolint +func (t *Table) newField(f reflect.StructField, index []int) *Field { + pgTag := tagparser.Parse(f.Tag.Get("pg")) + + switch f.Name { + case "tableName": + if len(index) > 0 { + return nil + } + + if isKnownTableOption(pgTag.Name) { + internal.Warn.Printf( + "%s.%s tag name %q is also an option name; is it a mistake?", + t.TypeName, f.Name, pgTag.Name, + ) + } + + for name := range pgTag.Options { + if !isKnownTableOption(name) { + internal.Warn.Printf("%s.%s has unknown tag option: %q", t.TypeName, f.Name, name) + } + } + + if tableSpace, ok := pgTag.Options["tablespace"]; ok { + s, _ := tagparser.Unquote(tableSpace) + t.Tablespace = quoteIdent(s) + } + + partitionBy, ok := pgTag.Options["partition_by"] + if !ok { + partitionBy, ok = pgTag.Options["partitionBy"] + if ok { + internal.Deprecated.Printf("partitionBy is renamed to partition_by") + } + } + if ok { + s, _ := tagparser.Unquote(partitionBy) + t.PartitionBy = s + } + + if pgTag.Name == "_" { + t.setName("") + } else if pgTag.Name != "" { + s, _ := tagparser.Unquote(pgTag.Name) + t.setName(types.Safe(quoteTableName(s))) + } + + if s, ok := pgTag.Options["select"]; ok { + s, _ = tagparser.Unquote(s) + t.SQLNameForSelects = types.Safe(quoteTableName(s)) + } + + if v, ok := pgTag.Options["alias"]; ok { + v, _ = tagparser.Unquote(v) + t.Alias = quoteIdent(v) + } + + pgTag := tagparser.Parse(f.Tag.Get("pg")) + if _, ok := pgTag.Options["discard_unknown_columns"]; ok { + t.setFlag(discardUnknownColumnsFlag) + } + + return nil + } + + if f.PkgPath != "" { + return nil + } + + sqlName := internal.Underscore(f.Name) + + if pgTag.Name != sqlName && isKnownFieldOption(pgTag.Name) { + internal.Warn.Printf( + "%s.%s tag name %q is also an option name; is it a mistake?", + t.TypeName, f.Name, pgTag.Name, + ) + } + + for name := range pgTag.Options { + if !isKnownFieldOption(name) { + internal.Warn.Printf("%s.%s has unknown tag option: %q", t.TypeName, f.Name, name) + } + } + + skip := pgTag.Name == "-" + if !skip && pgTag.Name != "" { + sqlName = pgTag.Name + } + + index = append(index, f.Index...) + if field := t.getField(sqlName); field != nil { + if indexEqual(field.Index, index) { + return field + } + t.RemoveField(field) + } + + field := &Field{ + Field: f, + Type: indirectType(f.Type), + + GoName: f.Name, + SQLName: sqlName, + Column: quoteIdent(sqlName), + + Index: index, + } + + if _, ok := pgTag.Options["notnull"]; ok { + field.setFlag(NotNullFlag) + } + if v, ok := pgTag.Options["unique"]; ok { + if v == "" { + field.setFlag(UniqueFlag) + } + // Split the value by comma, this will allow multiple names to be specified. + // We can use this to create multiple named unique constraints where a single column + // might be included in multiple constraints. + v, _ = tagparser.Unquote(v) + for _, uniqueName := range strings.Split(v, ",") { + if t.Unique == nil { + t.Unique = make(map[string][]*Field) + } + t.Unique[uniqueName] = append(t.Unique[uniqueName], field) + } + } + if v, ok := pgTag.Options["default"]; ok { + v, ok = tagparser.Unquote(v) + if ok { + field.Default = types.Safe(types.AppendString(nil, v, 1)) + } else { + field.Default = types.Safe(v) + } + } + + //nolint + if _, ok := pgTag.Options["pk"]; ok { + field.setFlag(PrimaryKeyFlag) + } else if strings.HasSuffix(field.SQLName, "_id") || + strings.HasSuffix(field.SQLName, "_uuid") { + field.setFlag(ForeignKeyFlag) + } else if strings.HasPrefix(field.SQLName, "fk_") { + field.setFlag(ForeignKeyFlag) + } else if len(t.PKs) == 0 && !pgTag.HasOption("nopk") { + switch field.SQLName { + case "id", "uuid", "pk_" + t.ModelName: + field.setFlag(PrimaryKeyFlag) + } + } + + if _, ok := pgTag.Options["use_zero"]; ok { + field.setFlag(UseZeroFlag) + } + if _, ok := pgTag.Options["array"]; ok { + field.setFlag(ArrayFlag) + } + + field.SQLType = fieldSQLType(field, pgTag) + if strings.HasSuffix(field.SQLType, "[]") { + field.setFlag(ArrayFlag) + } + + if v, ok := pgTag.Options["on_delete"]; ok { + field.OnDelete = v + } + + if v, ok := pgTag.Options["on_update"]; ok { + field.OnUpdate = v + } + + if _, ok := pgTag.Options["composite"]; ok { + field.append = compositeAppender(f.Type) + field.scan = compositeScanner(f.Type) + } else if _, ok := pgTag.Options["json_use_number"]; ok { + field.append = types.Appender(f.Type) + field.scan = scanJSONValue + } else if field.hasFlag(ArrayFlag) { + field.append = types.ArrayAppender(f.Type) + field.scan = types.ArrayScanner(f.Type) + } else if _, ok := pgTag.Options["hstore"]; ok { + field.append = types.HstoreAppender(f.Type) + field.scan = types.HstoreScanner(f.Type) + } else if field.SQLType == pgTypeBigint && field.Type.Kind() == reflect.Uint64 { + if f.Type.Kind() == reflect.Ptr { + field.append = appendUintPtrAsInt + } else { + field.append = appendUintAsInt + } + field.scan = types.Scanner(f.Type) + } else if _, ok := pgTag.Options["msgpack"]; ok { + field.append = msgpackAppender(f.Type) + field.scan = msgpackScanner(f.Type) + } else { + field.append = types.Appender(f.Type) + field.scan = types.Scanner(f.Type) + } + field.isZero = zerochecker.Checker(f.Type) + + if v, ok := pgTag.Options["alias"]; ok { + v, _ = tagparser.Unquote(v) + t.FieldsMap[v] = field + } + + t.allFields = append(t.allFields, field) + if skip { + t.skippedFields = append(t.skippedFields, field) + t.FieldsMap[field.SQLName] = field + return nil + } + + if _, ok := pgTag.Options["soft_delete"]; ok { + t.SetSoftDeleteField = setSoftDeleteFieldFunc(f.Type) + if t.SetSoftDeleteField == nil { + err := fmt.Errorf( + "pg: soft_delete is only supported for time.Time, pg.NullTime, sql.NullInt64, and int64 (or implement ValueScanner that scans time)") + panic(err) + } + t.SoftDeleteField = field + } + + return field +} + +func (t *Table) initMethods() { + t.Methods = make(map[string]*Method) + typ := reflect.PtrTo(t.Type) + for i := 0; i < typ.NumMethod(); i++ { + m := typ.Method(i) + if m.PkgPath != "" { + continue + } + if m.Type.NumIn() > 1 { + continue + } + if m.Type.NumOut() != 1 { + continue + } + + retType := m.Type.Out(0) + t.Methods[m.Name] = &Method{ + Index: m.Index, + + appender: types.Appender(retType), + } + } +} + +func (t *Table) initInlines() { + for _, f := range t.skippedFields { + if f.Type.Kind() == reflect.Struct { + t.inlineFields(f, nil) + } + } +} + +func (t *Table) initRelations() { + for i := 0; i < len(t.Fields); { + f := t.Fields[i] + if t.tryRelation(f) { + t.Fields = removeField(t.Fields, f) + t.DataFields = removeField(t.DataFields, f) + } else { + i++ + } + + if f.Type.Kind() == reflect.Struct { + t.inlineFields(f, nil) + } + } +} + +func (t *Table) tryRelation(field *Field) bool { + pgTag := tagparser.Parse(field.Field.Tag.Get("pg")) + + if rel, ok := pgTag.Options["rel"]; ok { + return t.tryRelationType(field, rel, pgTag) + } + if _, ok := pgTag.Options["many2many"]; ok { + return t.tryRelationType(field, "many2many", pgTag) + } + + if field.UserSQLType != "" || isScanner(field.Type) { + return false + } + + switch field.Type.Kind() { + case reflect.Slice: + return t.tryRelationSlice(field, pgTag) + case reflect.Struct: + return t.tryRelationStruct(field, pgTag) + } + return false +} + +func (t *Table) tryRelationType(field *Field, rel string, pgTag *tagparser.Tag) bool { + switch rel { + case "has-one": + return t.mustHasOneRelation(field, pgTag) + case "belongs-to": + return t.mustBelongsToRelation(field, pgTag) + case "has-many": + return t.mustHasManyRelation(field, pgTag) + case "many2many": + return t.mustM2MRelation(field, pgTag) + default: + panic(fmt.Errorf("pg: unknown relation=%s on field=%s", rel, field.GoName)) + } +} + +func (t *Table) mustHasOneRelation(field *Field, pgTag *tagparser.Tag) bool { + joinTable := _tables.get(field.Type, true) + if err := joinTable.checkPKs(); err != nil { + panic(err) + } + fkPrefix, fkOK := pgTag.Options["fk"] + + if fkOK && len(joinTable.PKs) == 1 { + fk := t.getField(fkPrefix) + if fk == nil { + panic(fmt.Errorf( + "pg: %s has-one %s: %s must have column %s "+ + "(use fk:custom_column tag on %s field to specify custom column)", + t.TypeName, field.GoName, t.TypeName, fkPrefix, field.GoName, + )) + } + + t.addRelation(&Relation{ + Type: HasOneRelation, + Field: field, + JoinTable: joinTable, + BaseFKs: []*Field{fk}, + JoinFKs: joinTable.PKs, + }) + return true + } + + if !fkOK { + fkPrefix = internal.Underscore(field.GoName) + "_" + } + fks := make([]*Field, 0, len(joinTable.PKs)) + + for _, joinPK := range joinTable.PKs { + fkName := fkPrefix + joinPK.SQLName + if fk := t.getField(fkName); fk != nil { + fks = append(fks, fk) + continue + } + + if fk := t.getField(joinPK.SQLName); fk != nil { + fks = append(fks, fk) + continue + } + + panic(fmt.Errorf( + "pg: %s has-one %s: %s must have column %s "+ + "(use fk:custom_column tag on %s field to specify custom column)", + t.TypeName, field.GoName, t.TypeName, fkName, field.GoName, + )) + } + + t.addRelation(&Relation{ + Type: HasOneRelation, + Field: field, + JoinTable: joinTable, + BaseFKs: fks, + JoinFKs: joinTable.PKs, + }) + return true +} + +func (t *Table) mustBelongsToRelation(field *Field, pgTag *tagparser.Tag) bool { + if err := t.checkPKs(); err != nil { + panic(err) + } + joinTable := _tables.get(field.Type, true) + fkPrefix, fkOK := pgTag.Options["join_fk"] + + if fkOK && len(t.PKs) == 1 { + fk := joinTable.getField(fkPrefix) + if fk == nil { + panic(fmt.Errorf( + "pg: %s belongs-to %s: %s must have column %s "+ + "(use join_fk:custom_column tag on %s field to specify custom column)", + field.GoName, t.TypeName, joinTable.TypeName, fkPrefix, field.GoName, + )) + } + + t.addRelation(&Relation{ + Type: BelongsToRelation, + Field: field, + JoinTable: joinTable, + BaseFKs: t.PKs, + JoinFKs: []*Field{fk}, + }) + return true + } + + if !fkOK { + fkPrefix = internal.Underscore(t.ModelName) + "_" + } + fks := make([]*Field, 0, len(t.PKs)) + + for _, pk := range t.PKs { + fkName := fkPrefix + pk.SQLName + if fk := joinTable.getField(fkName); fk != nil { + fks = append(fks, fk) + continue + } + + if fk := joinTable.getField(pk.SQLName); fk != nil { + fks = append(fks, fk) + continue + } + + panic(fmt.Errorf( + "pg: %s belongs-to %s: %s must have column %s "+ + "(use join_fk:custom_column tag on %s field to specify custom column)", + field.GoName, t.TypeName, joinTable.TypeName, fkName, field.GoName, + )) + } + + t.addRelation(&Relation{ + Type: BelongsToRelation, + Field: field, + JoinTable: joinTable, + BaseFKs: t.PKs, + JoinFKs: fks, + }) + return true +} + +func (t *Table) mustHasManyRelation(field *Field, pgTag *tagparser.Tag) bool { + if err := t.checkPKs(); err != nil { + panic(err) + } + if field.Type.Kind() != reflect.Slice { + panic(fmt.Errorf( + "pg: %s.%s has-many relation requires slice, got %q", + t.TypeName, field.GoName, field.Type.Kind(), + )) + } + + joinTable := _tables.get(indirectType(field.Type.Elem()), true) + fkPrefix, fkOK := pgTag.Options["join_fk"] + _, polymorphic := pgTag.Options["polymorphic"] + + if fkOK && !polymorphic && len(t.PKs) == 1 { + fk := joinTable.getField(fkPrefix) + if fk == nil { + panic(fmt.Errorf( + "pg: %s has-many %s: %s must have column %s "+ + "(use join_fk:custom_column tag on %s field to specify custom column)", + t.TypeName, field.GoName, joinTable.TypeName, fkPrefix, field.GoName, + )) + } + + t.addRelation(&Relation{ + Type: HasManyRelation, + Field: field, + JoinTable: joinTable, + BaseFKs: t.PKs, + JoinFKs: []*Field{fk}, + }) + return true + } + + if !fkOK { + fkPrefix = internal.Underscore(t.ModelName) + "_" + } + fks := make([]*Field, 0, len(t.PKs)) + + for _, pk := range t.PKs { + fkName := fkPrefix + pk.SQLName + if fk := joinTable.getField(fkName); fk != nil { + fks = append(fks, fk) + continue + } + + if fk := joinTable.getField(pk.SQLName); fk != nil { + fks = append(fks, fk) + continue + } + + panic(fmt.Errorf( + "pg: %s has-many %s: %s must have column %s "+ + "(use join_fk:custom_column tag on %s field to specify custom column)", + t.TypeName, field.GoName, joinTable.TypeName, fkName, field.GoName, + )) + } + + var typeField *Field + + if polymorphic { + typeFieldName := fkPrefix + "type" + typeField = joinTable.getField(typeFieldName) + if typeField == nil { + panic(fmt.Errorf( + "pg: %s has-many %s: %s must have polymorphic column %s", + t.TypeName, field.GoName, joinTable.TypeName, typeFieldName, + )) + } + } + + t.addRelation(&Relation{ + Type: HasManyRelation, + Field: field, + JoinTable: joinTable, + BaseFKs: t.PKs, + JoinFKs: fks, + Polymorphic: typeField, + }) + return true +} + +func (t *Table) mustM2MRelation(field *Field, pgTag *tagparser.Tag) bool { + if field.Type.Kind() != reflect.Slice { + panic(fmt.Errorf( + "pg: %s.%s many2many relation requires slice, got %q", + t.TypeName, field.GoName, field.Type.Kind(), + )) + } + joinTable := _tables.get(indirectType(field.Type.Elem()), true) + + if err := t.checkPKs(); err != nil { + panic(err) + } + if err := joinTable.checkPKs(); err != nil { + panic(err) + } + + m2mTableNameString, ok := pgTag.Options["many2many"] + if !ok { + panic(fmt.Errorf("pg: %s must have many2many tag option", field.GoName)) + } + m2mTableName := quoteTableName(m2mTableNameString) + + m2mTable := _tables.getByName(m2mTableName) + if m2mTable == nil { + panic(fmt.Errorf( + "pg: can't find %s table (use orm.RegisterTable to register the model)", + m2mTableName, + )) + } + + var baseFKs []string + var joinFKs []string + + { + fkPrefix, ok := pgTag.Options["fk"] + if !ok { + fkPrefix = internal.Underscore(t.ModelName) + "_" + } + + if ok && len(t.PKs) == 1 { + if m2mTable.getField(fkPrefix) == nil { + panic(fmt.Errorf( + "pg: %s many2many %s: %s must have column %s "+ + "(use fk:custom_column tag on %s field to specify custom column)", + t.TypeName, field.GoName, m2mTable.TypeName, fkPrefix, field.GoName, + )) + } + baseFKs = []string{fkPrefix} + } else { + for _, pk := range t.PKs { + fkName := fkPrefix + pk.SQLName + if m2mTable.getField(fkName) != nil { + baseFKs = append(baseFKs, fkName) + continue + } + + if m2mTable.getField(pk.SQLName) != nil { + baseFKs = append(baseFKs, pk.SQLName) + continue + } + + panic(fmt.Errorf( + "pg: %s many2many %s: %s must have column %s "+ + "(use fk:custom_column tag on %s field to specify custom column)", + t.TypeName, field.GoName, m2mTable.TypeName, fkName, field.GoName, + )) + } + } + } + + { + joinFKPrefix, ok := pgTag.Options["join_fk"] + if !ok { + joinFKPrefix = internal.Underscore(joinTable.ModelName) + "_" + } + + if ok && len(joinTable.PKs) == 1 { + if m2mTable.getField(joinFKPrefix) == nil { + panic(fmt.Errorf( + "pg: %s many2many %s: %s must have column %s "+ + "(use join_fk:custom_column tag on %s field to specify custom column)", + joinTable.TypeName, field.GoName, m2mTable.TypeName, joinFKPrefix, field.GoName, + )) + } + joinFKs = []string{joinFKPrefix} + } else { + for _, joinPK := range joinTable.PKs { + fkName := joinFKPrefix + joinPK.SQLName + if m2mTable.getField(fkName) != nil { + joinFKs = append(joinFKs, fkName) + continue + } + + if m2mTable.getField(joinPK.SQLName) != nil { + joinFKs = append(joinFKs, joinPK.SQLName) + continue + } + + panic(fmt.Errorf( + "pg: %s many2many %s: %s must have column %s "+ + "(use join_fk:custom_column tag on %s field to specify custom column)", + t.TypeName, field.GoName, m2mTable.TypeName, fkName, field.GoName, + )) + } + } + } + + t.addRelation(&Relation{ + Type: Many2ManyRelation, + Field: field, + JoinTable: joinTable, + M2MTableName: m2mTableName, + M2MTableAlias: m2mTable.Alias, + M2MBaseFKs: baseFKs, + M2MJoinFKs: joinFKs, + }) + return true +} + +//nolint +func (t *Table) tryRelationSlice(field *Field, pgTag *tagparser.Tag) bool { + if t.tryM2MRelation(field, pgTag) { + internal.Deprecated.Printf( + `add pg:"rel:many2many" to %s.%s field tag`, t.TypeName, field.GoName) + return true + } + if t.tryHasManyRelation(field, pgTag) { + internal.Deprecated.Printf( + `add pg:"rel:has-many" to %s.%s field tag`, t.TypeName, field.GoName) + return true + } + return false +} + +func (t *Table) tryM2MRelation(field *Field, pgTag *tagparser.Tag) bool { + elemType := indirectType(field.Type.Elem()) + if elemType.Kind() != reflect.Struct { + return false + } + + joinTable := _tables.get(elemType, true) + + fk, fkOK := pgTag.Options["fk"] + if fkOK { + if fk == "-" { + return false + } + fk = tryUnderscorePrefix(fk) + } + + m2mTableName := pgTag.Options["many2many"] + if m2mTableName == "" { + return false + } + + m2mTable := _tables.getByName(quoteIdent(m2mTableName)) + + var m2mTableAlias types.Safe + if m2mTable != nil { + m2mTableAlias = m2mTable.Alias + } else if ind := strings.IndexByte(m2mTableName, '.'); ind >= 0 { + m2mTableAlias = quoteIdent(m2mTableName[ind+1:]) + } else { + m2mTableAlias = quoteIdent(m2mTableName) + } + + var fks []string + if !fkOK { + fk = t.ModelName + "_" + } + if m2mTable != nil { + keys := foreignKeys(t, m2mTable, fk, fkOK) + if len(keys) == 0 { + return false + } + for _, fk := range keys { + fks = append(fks, fk.SQLName) + } + } else { + if fkOK && len(t.PKs) == 1 { + fks = append(fks, fk) + } else { + for _, pk := range t.PKs { + fks = append(fks, fk+pk.SQLName) + } + } + } + + joinFK, joinFKOk := pgTag.Options["join_fk"] + if !joinFKOk { + joinFK, joinFKOk = pgTag.Options["joinFK"] + if joinFKOk { + internal.Deprecated.Printf("joinFK is renamed to join_fk") + } + } + if joinFKOk { + joinFK = tryUnderscorePrefix(joinFK) + } else { + joinFK = joinTable.ModelName + "_" + } + + var joinFKs []string + if m2mTable != nil { + keys := foreignKeys(joinTable, m2mTable, joinFK, joinFKOk) + if len(keys) == 0 { + return false + } + for _, fk := range keys { + joinFKs = append(joinFKs, fk.SQLName) + } + } else { + if joinFKOk && len(joinTable.PKs) == 1 { + joinFKs = append(joinFKs, joinFK) + } else { + for _, pk := range joinTable.PKs { + joinFKs = append(joinFKs, joinFK+pk.SQLName) + } + } + } + + t.addRelation(&Relation{ + Type: Many2ManyRelation, + Field: field, + JoinTable: joinTable, + M2MTableName: quoteIdent(m2mTableName), + M2MTableAlias: m2mTableAlias, + M2MBaseFKs: fks, + M2MJoinFKs: joinFKs, + }) + return true +} + +func (t *Table) tryHasManyRelation(field *Field, pgTag *tagparser.Tag) bool { + elemType := indirectType(field.Type.Elem()) + if elemType.Kind() != reflect.Struct { + return false + } + + joinTable := _tables.get(elemType, true) + + fk, fkOK := pgTag.Options["fk"] + if fkOK { + if fk == "-" { + return false + } + fk = tryUnderscorePrefix(fk) + } + + s, polymorphic := pgTag.Options["polymorphic"] + var typeField *Field + if polymorphic { + fk = tryUnderscorePrefix(s) + + typeField = joinTable.getField(fk + "type") + if typeField == nil { + return false + } + } else if !fkOK { + fk = t.ModelName + "_" + } + + fks := foreignKeys(t, joinTable, fk, fkOK || polymorphic) + if len(fks) == 0 { + return false + } + + var fkValues []*Field + fkValue, ok := pgTag.Options["fk_value"] + if ok { + if len(fks) > 1 { + panic(fmt.Errorf("got fk_value, but there are %d fks", len(fks))) + } + + f := t.getField(fkValue) + if f == nil { + panic(fmt.Errorf("fk_value=%q not found in %s", fkValue, t)) + } + fkValues = append(fkValues, f) + } else { + fkValues = t.PKs + } + + if len(fks) != len(fkValues) { + panic("len(fks) != len(fkValues)") + } + + if len(fks) > 0 { + t.addRelation(&Relation{ + Type: HasManyRelation, + Field: field, + JoinTable: joinTable, + BaseFKs: fkValues, + JoinFKs: fks, + Polymorphic: typeField, + }) + return true + } + + return false +} + +func (t *Table) tryRelationStruct(field *Field, pgTag *tagparser.Tag) bool { + joinTable := _tables.get(field.Type, true) + + if len(joinTable.allFields) == 0 { + return false + } + + if t.tryHasOne(joinTable, field, pgTag) { + internal.Deprecated.Printf( + `add pg:"rel:has-one" to %s.%s field tag`, t.TypeName, field.GoName) + t.inlineFields(field, nil) + return true + } + + if t.tryBelongsToOne(joinTable, field, pgTag) { + internal.Deprecated.Printf( + `add pg:"rel:belongs-to" to %s.%s field tag`, t.TypeName, field.GoName) + t.inlineFields(field, nil) + return true + } + + t.inlineFields(field, nil) + return false +} + +func (t *Table) inlineFields(strct *Field, path map[reflect.Type]struct{}) { + if path == nil { + path = map[reflect.Type]struct{}{ + t.Type: {}, + } + } + + if _, ok := path[strct.Type]; ok { + return + } + path[strct.Type] = struct{}{} + + joinTable := _tables.get(strct.Type, true) + for _, f := range joinTable.allFields { + f = f.Clone() + f.GoName = strct.GoName + "_" + f.GoName + f.SQLName = strct.SQLName + "__" + f.SQLName + f.Column = quoteIdent(f.SQLName) + f.Index = appendNew(strct.Index, f.Index...) + + t.fieldsMapMu.Lock() + if _, ok := t.FieldsMap[f.SQLName]; !ok { + t.FieldsMap[f.SQLName] = f + } + t.fieldsMapMu.Unlock() + + if f.Type.Kind() != reflect.Struct { + continue + } + + if _, ok := path[f.Type]; !ok { + t.inlineFields(f, path) + } + } +} + +func appendNew(dst []int, src ...int) []int { + cp := make([]int, len(dst)+len(src)) + copy(cp, dst) + copy(cp[len(dst):], src) + return cp +} + +func isScanner(typ reflect.Type) bool { + return typ.Implements(scannerType) || reflect.PtrTo(typ).Implements(scannerType) +} + +func fieldSQLType(field *Field, pgTag *tagparser.Tag) string { + if typ, ok := pgTag.Options["type"]; ok { + typ, _ = tagparser.Unquote(typ) + field.UserSQLType = typ + typ = normalizeSQLType(typ) + return typ + } + + if typ, ok := pgTag.Options["composite"]; ok { + typ, _ = tagparser.Unquote(typ) + return typ + } + + if _, ok := pgTag.Options["hstore"]; ok { + return "hstore" + } else if _, ok := pgTag.Options["hstore"]; ok { + return "hstore" + } + + if field.hasFlag(ArrayFlag) { + switch field.Type.Kind() { + case reflect.Slice, reflect.Array: + sqlType := sqlType(field.Type.Elem()) + return sqlType + "[]" + } + } + + sqlType := sqlType(field.Type) + return sqlType +} + +func sqlType(typ reflect.Type) string { + switch typ { + case timeType, nullTimeType, sqlNullTimeType: + return pgTypeTimestampTz + case ipType: + return pgTypeInet + case ipNetType: + return pgTypeCidr + case nullBoolType: + return pgTypeBoolean + case nullFloatType: + return pgTypeDoublePrecision + case nullIntType: + return pgTypeBigint + case nullStringType: + return pgTypeText + case jsonRawMessageType: + return pgTypeJSONB + } + + switch typ.Kind() { + case reflect.Int8, reflect.Uint8, reflect.Int16: + return pgTypeSmallint + case reflect.Uint16, reflect.Int32: + return pgTypeInteger + case reflect.Uint32, reflect.Int64, reflect.Int: + return pgTypeBigint + case reflect.Uint, reflect.Uint64: + // Unsigned bigint is not supported - use bigint. + return pgTypeBigint + case reflect.Float32: + return pgTypeReal + case reflect.Float64: + return pgTypeDoublePrecision + case reflect.Bool: + return pgTypeBoolean + case reflect.String: + return pgTypeText + case reflect.Map, reflect.Struct: + return pgTypeJSONB + case reflect.Array, reflect.Slice: + if typ.Elem().Kind() == reflect.Uint8 { + return pgTypeBytea + } + return pgTypeJSONB + default: + return typ.Kind().String() + } +} + +func normalizeSQLType(s string) string { + switch s { + case "int2": + return pgTypeSmallint + case "int4", "int", "serial": + return pgTypeInteger + case "int8", pgTypeBigserial: + return pgTypeBigint + case "float4": + return pgTypeReal + case "float8": + return pgTypeDoublePrecision + } + return s +} + +func sqlTypeEqual(a, b string) bool { + return a == b +} + +func (t *Table) tryHasOne(joinTable *Table, field *Field, pgTag *tagparser.Tag) bool { + fk, fkOK := pgTag.Options["fk"] + if fkOK { + if fk == "-" { + return false + } + fk = tryUnderscorePrefix(fk) + } else { + fk = internal.Underscore(field.GoName) + "_" + } + + fks := foreignKeys(joinTable, t, fk, fkOK) + if len(fks) > 0 { + t.addRelation(&Relation{ + Type: HasOneRelation, + Field: field, + JoinTable: joinTable, + BaseFKs: fks, + JoinFKs: joinTable.PKs, + }) + return true + } + return false +} + +func (t *Table) tryBelongsToOne(joinTable *Table, field *Field, pgTag *tagparser.Tag) bool { + fk, fkOK := pgTag.Options["fk"] + if fkOK { + if fk == "-" { + return false + } + fk = tryUnderscorePrefix(fk) + } else { + fk = internal.Underscore(t.TypeName) + "_" + } + + fks := foreignKeys(t, joinTable, fk, fkOK) + if len(fks) > 0 { + t.addRelation(&Relation{ + Type: BelongsToRelation, + Field: field, + JoinTable: joinTable, + BaseFKs: t.PKs, + JoinFKs: fks, + }) + return true + } + return false +} + +func (t *Table) addRelation(rel *Relation) { + if t.Relations == nil { + t.Relations = make(map[string]*Relation) + } + _, ok := t.Relations[rel.Field.GoName] + if ok { + panic(fmt.Errorf("%s already has %s", t, rel)) + } + t.Relations[rel.Field.GoName] = rel +} + +func foreignKeys(base, join *Table, fk string, tryFK bool) []*Field { + var fks []*Field + + for _, pk := range base.PKs { + fkName := fk + pk.SQLName + f := join.getField(fkName) + if f != nil && sqlTypeEqual(pk.SQLType, f.SQLType) { + fks = append(fks, f) + continue + } + + if strings.IndexByte(pk.SQLName, '_') == -1 { + continue + } + + f = join.getField(pk.SQLName) + if f != nil && sqlTypeEqual(pk.SQLType, f.SQLType) { + fks = append(fks, f) + continue + } + } + if len(fks) > 0 && len(fks) == len(base.PKs) { + return fks + } + + fks = nil + for _, pk := range base.PKs { + if !strings.HasPrefix(pk.SQLName, "pk_") { + continue + } + fkName := "fk_" + pk.SQLName[3:] + f := join.getField(fkName) + if f != nil && sqlTypeEqual(pk.SQLType, f.SQLType) { + fks = append(fks, f) + } + } + if len(fks) > 0 && len(fks) == len(base.PKs) { + return fks + } + + if fk == "" || len(base.PKs) != 1 { + return nil + } + + if tryFK { + f := join.getField(fk) + if f != nil && sqlTypeEqual(base.PKs[0].SQLType, f.SQLType) { + return []*Field{f} + } + } + + for _, suffix := range []string{"id", "uuid"} { + f := join.getField(fk + suffix) + if f != nil && sqlTypeEqual(base.PKs[0].SQLType, f.SQLType) { + return []*Field{f} + } + } + + return nil +} + +func scanJSONValue(v reflect.Value, rd types.Reader, n int) error { + // Zero value so it works with SelectOrInsert. + // TODO: better handle slices + v.Set(reflect.New(v.Type()).Elem()) + + if n == -1 { + return nil + } + + dec := pgjson.NewDecoder(rd) + dec.UseNumber() + return dec.Decode(v.Addr().Interface()) +} + +func appendUintAsInt(b []byte, v reflect.Value, _ int) []byte { + return strconv.AppendInt(b, int64(v.Uint()), 10) +} + +func appendUintPtrAsInt(b []byte, v reflect.Value, _ int) []byte { + return strconv.AppendInt(b, int64(v.Elem().Uint()), 10) +} + +func tryUnderscorePrefix(s string) string { + if s == "" { + return s + } + if c := s[0]; internal.IsUpper(c) { + return internal.Underscore(s) + "_" + } + return s +} + +func quoteTableName(s string) types.Safe { + // Don't quote if table name contains placeholder (?) or parentheses. + if strings.IndexByte(s, '?') >= 0 || + strings.IndexByte(s, '(') >= 0 && strings.IndexByte(s, ')') >= 0 { + return types.Safe(s) + } + return quoteIdent(s) +} + +func quoteIdent(s string) types.Safe { + return types.Safe(types.AppendIdent(nil, s, 1)) +} + +func setSoftDeleteFieldFunc(typ reflect.Type) func(fv reflect.Value) error { + switch typ { + case timeType: + return func(fv reflect.Value) error { + ptr := fv.Addr().Interface().(*time.Time) + *ptr = time.Now() + return nil + } + case nullTimeType: + return func(fv reflect.Value) error { + ptr := fv.Addr().Interface().(*types.NullTime) + *ptr = types.NullTime{Time: time.Now()} + return nil + } + case nullIntType: + return func(fv reflect.Value) error { + ptr := fv.Addr().Interface().(*sql.NullInt64) + *ptr = sql.NullInt64{Int64: time.Now().UnixNano()} + return nil + } + } + + switch typ.Kind() { + case reflect.Int64: + return func(fv reflect.Value) error { + ptr := fv.Addr().Interface().(*int64) + *ptr = time.Now().UnixNano() + return nil + } + case reflect.Ptr: + break + default: + return setSoftDeleteFallbackFunc(typ) + } + + originalType := typ + typ = typ.Elem() + + switch typ { //nolint:gocritic + case timeType: + return func(fv reflect.Value) error { + now := time.Now() + fv.Set(reflect.ValueOf(&now)) + return nil + } + } + + switch typ.Kind() { //nolint:gocritic + case reflect.Int64: + return func(fv reflect.Value) error { + utime := time.Now().UnixNano() + fv.Set(reflect.ValueOf(&utime)) + return nil + } + } + + return setSoftDeleteFallbackFunc(originalType) +} + +func setSoftDeleteFallbackFunc(typ reflect.Type) func(fv reflect.Value) error { + scanner := types.Scanner(typ) + if scanner == nil { + return nil + } + + return func(fv reflect.Value) error { + var flags int + b := types.AppendTime(nil, time.Now(), flags) + return scanner(fv, pool.NewBytesReader(b), len(b)) + } +} + +func isKnownTableOption(name string) bool { + switch name { + case "alias", + "select", + "tablespace", + "partition_by", + "discard_unknown_columns": + return true + } + return false +} + +func isKnownFieldOption(name string) bool { + switch name { + case "alias", + "type", + "array", + "hstore", + "composite", + "json_use_number", + "msgpack", + "notnull", + "use_zero", + "default", + "unique", + "soft_delete", + "on_delete", + "on_update", + + "pk", + "nopk", + "rel", + "fk", + "join_fk", + "many2many", + "polymorphic": + return true + } + return false +} diff --git a/vendor/github.com/go-pg/pg/v10/orm/table_create.go b/vendor/github.com/go-pg/pg/v10/orm/table_create.go new file mode 100644 index 000000000..384c729de --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/orm/table_create.go @@ -0,0 +1,248 @@ +package orm + +import ( + "sort" + "strconv" + + "github.com/go-pg/pg/v10/types" +) + +type CreateTableOptions struct { + Varchar int // replaces PostgreSQL data type `text` with `varchar(n)` + Temp bool + IfNotExists bool + + // FKConstraints causes CreateTable to create foreign key constraints + // for has one relations. ON DELETE hook can be added using tag + // `pg:"on_delete:RESTRICT"` on foreign key field. ON UPDATE hook can be added using tag + // `pg:"on_update:CASCADE"` + FKConstraints bool +} + +type CreateTableQuery struct { + q *Query + opt *CreateTableOptions +} + +var ( + _ QueryAppender = (*CreateTableQuery)(nil) + _ QueryCommand = (*CreateTableQuery)(nil) +) + +func NewCreateTableQuery(q *Query, opt *CreateTableOptions) *CreateTableQuery { + return &CreateTableQuery{ + q: q, + opt: opt, + } +} + +func (q *CreateTableQuery) String() string { + b, err := q.AppendQuery(defaultFmter, nil) + if err != nil { + panic(err) + } + return string(b) +} + +func (q *CreateTableQuery) Operation() QueryOp { + return CreateTableOp +} + +func (q *CreateTableQuery) Clone() QueryCommand { + return &CreateTableQuery{ + q: q.q.Clone(), + opt: q.opt, + } +} + +func (q *CreateTableQuery) Query() *Query { + return q.q +} + +func (q *CreateTableQuery) AppendTemplate(b []byte) ([]byte, error) { + return q.AppendQuery(dummyFormatter{}, b) +} + +func (q *CreateTableQuery) AppendQuery(fmter QueryFormatter, b []byte) (_ []byte, err error) { + if q.q.stickyErr != nil { + return nil, q.q.stickyErr + } + if q.q.tableModel == nil { + return nil, errModelNil + } + + table := q.q.tableModel.Table() + + b = append(b, "CREATE "...) + if q.opt != nil && q.opt.Temp { + b = append(b, "TEMP "...) + } + b = append(b, "TABLE "...) + if q.opt != nil && q.opt.IfNotExists { + b = append(b, "IF NOT EXISTS "...) + } + b, err = q.q.appendFirstTable(fmter, b) + if err != nil { + return nil, err + } + b = append(b, " ("...) + + for i, field := range table.Fields { + if i > 0 { + b = append(b, ", "...) + } + + b = append(b, field.Column...) + b = append(b, " "...) + b = q.appendSQLType(b, field) + if field.hasFlag(NotNullFlag) { + b = append(b, " NOT NULL"...) + } + if field.hasFlag(UniqueFlag) { + b = append(b, " UNIQUE"...) + } + if field.Default != "" { + b = append(b, " DEFAULT "...) + b = append(b, field.Default...) + } + } + + b = appendPKConstraint(b, table.PKs) + b = appendUniqueConstraints(b, table) + + if q.opt != nil && q.opt.FKConstraints { + for _, rel := range table.Relations { + b = q.appendFKConstraint(fmter, b, rel) + } + } + + b = append(b, ")"...) + + if table.PartitionBy != "" { + b = append(b, " PARTITION BY "...) + b = append(b, table.PartitionBy...) + } + + if table.Tablespace != "" { + b = q.appendTablespace(b, table.Tablespace) + } + + return b, q.q.stickyErr +} + +func (q *CreateTableQuery) appendSQLType(b []byte, field *Field) []byte { + if field.UserSQLType != "" { + return append(b, field.UserSQLType...) + } + if q.opt != nil && q.opt.Varchar > 0 && + field.SQLType == "text" { + b = append(b, "varchar("...) + b = strconv.AppendInt(b, int64(q.opt.Varchar), 10) + b = append(b, ")"...) + return b + } + if field.hasFlag(PrimaryKeyFlag) { + return append(b, pkSQLType(field.SQLType)...) + } + return append(b, field.SQLType...) +} + +func pkSQLType(s string) string { + switch s { + case pgTypeSmallint: + return pgTypeSmallserial + case pgTypeInteger: + return pgTypeSerial + case pgTypeBigint: + return pgTypeBigserial + } + return s +} + +func appendPKConstraint(b []byte, pks []*Field) []byte { + if len(pks) == 0 { + return b + } + + b = append(b, ", PRIMARY KEY ("...) + b = appendColumns(b, "", pks) + b = append(b, ")"...) + return b +} + +func appendUniqueConstraints(b []byte, table *Table) []byte { + keys := make([]string, 0, len(table.Unique)) + for key := range table.Unique { + keys = append(keys, key) + } + sort.Strings(keys) + + for _, key := range keys { + b = appendUnique(b, table.Unique[key]) + } + + return b +} + +func appendUnique(b []byte, fields []*Field) []byte { + b = append(b, ", UNIQUE ("...) + b = appendColumns(b, "", fields) + b = append(b, ")"...) + return b +} + +func (q *CreateTableQuery) appendFKConstraint(fmter QueryFormatter, b []byte, rel *Relation) []byte { + if rel.Type != HasOneRelation { + return b + } + + b = append(b, ", FOREIGN KEY ("...) + b = appendColumns(b, "", rel.BaseFKs) + b = append(b, ")"...) + + b = append(b, " REFERENCES "...) + b = fmter.FormatQuery(b, string(rel.JoinTable.SQLName)) + b = append(b, " ("...) + b = appendColumns(b, "", rel.JoinFKs) + b = append(b, ")"...) + + if s := onDelete(rel.BaseFKs); s != "" { + b = append(b, " ON DELETE "...) + b = append(b, s...) + } + + if s := onUpdate(rel.BaseFKs); s != "" { + b = append(b, " ON UPDATE "...) + b = append(b, s...) + } + + return b +} + +func (q *CreateTableQuery) appendTablespace(b []byte, tableSpace types.Safe) []byte { + b = append(b, " TABLESPACE "...) + b = append(b, tableSpace...) + return b +} + +func onDelete(fks []*Field) string { + var onDelete string + for _, f := range fks { + if f.OnDelete != "" { + onDelete = f.OnDelete + break + } + } + return onDelete +} + +func onUpdate(fks []*Field) string { + var onUpdate string + for _, f := range fks { + if f.OnUpdate != "" { + onUpdate = f.OnUpdate + break + } + } + return onUpdate +} diff --git a/vendor/github.com/go-pg/pg/v10/orm/table_drop.go b/vendor/github.com/go-pg/pg/v10/orm/table_drop.go new file mode 100644 index 000000000..599ac3952 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/orm/table_drop.go @@ -0,0 +1,73 @@ +package orm + +type DropTableOptions struct { + IfExists bool + Cascade bool +} + +type DropTableQuery struct { + q *Query + opt *DropTableOptions +} + +var ( + _ QueryAppender = (*DropTableQuery)(nil) + _ QueryCommand = (*DropTableQuery)(nil) +) + +func NewDropTableQuery(q *Query, opt *DropTableOptions) *DropTableQuery { + return &DropTableQuery{ + q: q, + opt: opt, + } +} + +func (q *DropTableQuery) String() string { + b, err := q.AppendQuery(defaultFmter, nil) + if err != nil { + panic(err) + } + return string(b) +} + +func (q *DropTableQuery) Operation() QueryOp { + return DropTableOp +} + +func (q *DropTableQuery) Clone() QueryCommand { + return &DropTableQuery{ + q: q.q.Clone(), + opt: q.opt, + } +} + +func (q *DropTableQuery) Query() *Query { + return q.q +} + +func (q *DropTableQuery) AppendTemplate(b []byte) ([]byte, error) { + return q.AppendQuery(dummyFormatter{}, b) +} + +func (q *DropTableQuery) AppendQuery(fmter QueryFormatter, b []byte) (_ []byte, err error) { + if q.q.stickyErr != nil { + return nil, q.q.stickyErr + } + if q.q.tableModel == nil { + return nil, errModelNil + } + + b = append(b, "DROP TABLE "...) + if q.opt != nil && q.opt.IfExists { + b = append(b, "IF EXISTS "...) + } + b, err = q.q.appendFirstTable(fmter, b) + if err != nil { + return nil, err + } + if q.opt != nil && q.opt.Cascade { + b = append(b, " CASCADE"...) + } + + return b, q.q.stickyErr +} diff --git a/vendor/github.com/go-pg/pg/v10/orm/table_params.go b/vendor/github.com/go-pg/pg/v10/orm/table_params.go new file mode 100644 index 000000000..46d8e064a --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/orm/table_params.go @@ -0,0 +1,29 @@ +package orm + +import "reflect" + +type tableParams struct { + table *Table + strct reflect.Value +} + +func newTableParams(strct interface{}) (*tableParams, bool) { + v := reflect.ValueOf(strct) + if !v.IsValid() { + return nil, false + } + + v = reflect.Indirect(v) + if v.Kind() != reflect.Struct { + return nil, false + } + + return &tableParams{ + table: GetTable(v.Type()), + strct: v, + }, true +} + +func (m *tableParams) AppendParam(fmter QueryFormatter, b []byte, name string) ([]byte, bool) { + return m.table.AppendParam(b, m.strct, name) +} diff --git a/vendor/github.com/go-pg/pg/v10/orm/tables.go b/vendor/github.com/go-pg/pg/v10/orm/tables.go new file mode 100644 index 000000000..fa937a54e --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/orm/tables.go @@ -0,0 +1,136 @@ +package orm + +import ( + "fmt" + "reflect" + "sync" + + "github.com/go-pg/pg/v10/types" +) + +var _tables = newTables() + +type tableInProgress struct { + table *Table + + init1Once sync.Once + init2Once sync.Once +} + +func newTableInProgress(table *Table) *tableInProgress { + return &tableInProgress{ + table: table, + } +} + +func (inp *tableInProgress) init1() bool { + var inited bool + inp.init1Once.Do(func() { + inp.table.init1() + inited = true + }) + return inited +} + +func (inp *tableInProgress) init2() bool { + var inited bool + inp.init2Once.Do(func() { + inp.table.init2() + inited = true + }) + return inited +} + +// GetTable returns a Table for a struct type. +func GetTable(typ reflect.Type) *Table { + return _tables.Get(typ) +} + +// RegisterTable registers a struct as SQL table. +// It is usually used to register intermediate table +// in many to many relationship. +func RegisterTable(strct interface{}) { + _tables.Register(strct) +} + +type tables struct { + tables sync.Map + + mu sync.RWMutex + inProgress map[reflect.Type]*tableInProgress +} + +func newTables() *tables { + return &tables{ + inProgress: make(map[reflect.Type]*tableInProgress), + } +} + +func (t *tables) Register(strct interface{}) { + typ := reflect.TypeOf(strct) + if typ.Kind() == reflect.Ptr { + typ = typ.Elem() + } + _ = t.Get(typ) +} + +func (t *tables) get(typ reflect.Type, allowInProgress bool) *Table { + if typ.Kind() != reflect.Struct { + panic(fmt.Errorf("got %s, wanted %s", typ.Kind(), reflect.Struct)) + } + + if v, ok := t.tables.Load(typ); ok { + return v.(*Table) + } + + t.mu.Lock() + + if v, ok := t.tables.Load(typ); ok { + t.mu.Unlock() + return v.(*Table) + } + + var table *Table + + inProgress := t.inProgress[typ] + if inProgress == nil { + table = newTable(typ) + inProgress = newTableInProgress(table) + t.inProgress[typ] = inProgress + } else { + table = inProgress.table + } + + t.mu.Unlock() + + inProgress.init1() + if allowInProgress { + return table + } + + if inProgress.init2() { + t.mu.Lock() + delete(t.inProgress, typ) + t.tables.Store(typ, table) + t.mu.Unlock() + } + + return table +} + +func (t *tables) Get(typ reflect.Type) *Table { + return t.get(typ, false) +} + +func (t *tables) getByName(name types.Safe) *Table { + var found *Table + t.tables.Range(func(key, value interface{}) bool { + t := value.(*Table) + if t.SQLName == name { + found = t + return false + } + return true + }) + return found +} diff --git a/vendor/github.com/go-pg/pg/v10/orm/types.go b/vendor/github.com/go-pg/pg/v10/orm/types.go new file mode 100644 index 000000000..c8e9ec375 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/orm/types.go @@ -0,0 +1,48 @@ +package orm + +//nolint +const ( + // Date / Time + pgTypeTimestamp = "timestamp" // Timestamp without a time zone + pgTypeTimestampTz = "timestamptz" // Timestamp with a time zone + pgTypeDate = "date" // Date + pgTypeTime = "time" // Time without a time zone + pgTypeTimeTz = "time with time zone" // Time with a time zone + pgTypeInterval = "interval" // Time Interval + + // Network Addresses + pgTypeInet = "inet" // IPv4 or IPv6 hosts and networks + pgTypeCidr = "cidr" // IPv4 or IPv6 networks + pgTypeMacaddr = "macaddr" // MAC addresses + + // Boolean + pgTypeBoolean = "boolean" + + // Numeric Types + + // Floating Point Types + pgTypeReal = "real" // 4 byte floating point (6 digit precision) + pgTypeDoublePrecision = "double precision" // 8 byte floating point (15 digit precision) + + // Integer Types + pgTypeSmallint = "smallint" // 2 byte integer + pgTypeInteger = "integer" // 4 byte integer + pgTypeBigint = "bigint" // 8 byte integer + + // Serial Types + pgTypeSmallserial = "smallserial" // 2 byte autoincrementing integer + pgTypeSerial = "serial" // 4 byte autoincrementing integer + pgTypeBigserial = "bigserial" // 8 byte autoincrementing integer + + // Character Types + pgTypeVarchar = "varchar" // variable length string with limit + pgTypeChar = "char" // fixed length string (blank padded) + pgTypeText = "text" // variable length string without limit + + // JSON Types + pgTypeJSON = "json" // text representation of json data + pgTypeJSONB = "jsonb" // binary representation of json data + + // Binary Data Types + pgTypeBytea = "bytea" // binary string +) diff --git a/vendor/github.com/go-pg/pg/v10/orm/update.go b/vendor/github.com/go-pg/pg/v10/orm/update.go new file mode 100644 index 000000000..ce6396fd3 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/orm/update.go @@ -0,0 +1,378 @@ +package orm + +import ( + "fmt" + "reflect" + "sort" + + "github.com/go-pg/pg/v10/types" +) + +type UpdateQuery struct { + q *Query + omitZero bool + placeholder bool +} + +var ( + _ QueryAppender = (*UpdateQuery)(nil) + _ QueryCommand = (*UpdateQuery)(nil) +) + +func NewUpdateQuery(q *Query, omitZero bool) *UpdateQuery { + return &UpdateQuery{ + q: q, + omitZero: omitZero, + } +} + +func (q *UpdateQuery) String() string { + b, err := q.AppendQuery(defaultFmter, nil) + if err != nil { + panic(err) + } + return string(b) +} + +func (q *UpdateQuery) Operation() QueryOp { + return UpdateOp +} + +func (q *UpdateQuery) Clone() QueryCommand { + return &UpdateQuery{ + q: q.q.Clone(), + omitZero: q.omitZero, + placeholder: q.placeholder, + } +} + +func (q *UpdateQuery) Query() *Query { + return q.q +} + +func (q *UpdateQuery) AppendTemplate(b []byte) ([]byte, error) { + cp := q.Clone().(*UpdateQuery) + cp.placeholder = true + return cp.AppendQuery(dummyFormatter{}, b) +} + +func (q *UpdateQuery) AppendQuery(fmter QueryFormatter, b []byte) (_ []byte, err error) { + if q.q.stickyErr != nil { + return nil, q.q.stickyErr + } + + if len(q.q.with) > 0 { + b, err = q.q.appendWith(fmter, b) + if err != nil { + return nil, err + } + } + + b = append(b, "UPDATE "...) + + b, err = q.q.appendFirstTableWithAlias(fmter, b) + if err != nil { + return nil, err + } + + b, err = q.mustAppendSet(fmter, b) + if err != nil { + return nil, err + } + + isSliceModelWithData := q.q.isSliceModelWithData() + if isSliceModelWithData || q.q.hasMultiTables() { + b = append(b, " FROM "...) + b, err = q.q.appendOtherTables(fmter, b) + if err != nil { + return nil, err + } + + if isSliceModelWithData { + b, err = q.appendSliceModelData(fmter, b) + if err != nil { + return nil, err + } + } + } + + b, err = q.mustAppendWhere(fmter, b, isSliceModelWithData) + if err != nil { + return nil, err + } + + if len(q.q.returning) > 0 { + b, err = q.q.appendReturning(fmter, b) + if err != nil { + return nil, err + } + } + + return b, q.q.stickyErr +} + +func (q *UpdateQuery) mustAppendWhere( + fmter QueryFormatter, b []byte, isSliceModelWithData bool, +) (_ []byte, err error) { + b = append(b, " WHERE "...) + + if !isSliceModelWithData { + return q.q.mustAppendWhere(fmter, b) + } + + if len(q.q.where) > 0 { + return q.q.appendWhere(fmter, b) + } + + table := q.q.tableModel.Table() + err = table.checkPKs() + if err != nil { + return nil, err + } + + b = appendWhereColumnAndColumn(b, table.Alias, table.PKs) + return b, nil +} + +func (q *UpdateQuery) mustAppendSet(fmter QueryFormatter, b []byte) (_ []byte, err error) { + if len(q.q.set) > 0 { + return q.q.appendSet(fmter, b) + } + + b = append(b, " SET "...) + + if m, ok := q.q.model.(*mapModel); ok { + return q.appendMapSet(b, m.m), nil + } + + if !q.q.hasTableModel() { + return nil, errModelNil + } + + value := q.q.tableModel.Value() + if value.Kind() == reflect.Struct { + b, err = q.appendSetStruct(fmter, b, value) + } else { + if value.Len() > 0 { + b, err = q.appendSetSlice(b) + } else { + err = fmt.Errorf("pg: can't bulk-update empty slice %s", value.Type()) + } + } + if err != nil { + return nil, err + } + + return b, nil +} + +func (q *UpdateQuery) appendMapSet(b []byte, m map[string]interface{}) []byte { + keys := make([]string, 0, len(m)) + + for k := range m { + keys = append(keys, k) + } + sort.Strings(keys) + + for i, k := range keys { + if i > 0 { + b = append(b, ", "...) + } + + b = types.AppendIdent(b, k, 1) + b = append(b, " = "...) + if q.placeholder { + b = append(b, '?') + } else { + b = types.Append(b, m[k], 1) + } + } + + return b +} + +func (q *UpdateQuery) appendSetStruct(fmter QueryFormatter, b []byte, strct reflect.Value) ([]byte, error) { + fields, err := q.q.getFields() + if err != nil { + return nil, err + } + + if len(fields) == 0 { + fields = q.q.tableModel.Table().DataFields + } + + pos := len(b) + for _, f := range fields { + if q.omitZero && f.NullZero() && f.HasZeroValue(strct) { + continue + } + + if len(b) != pos { + b = append(b, ", "...) + pos = len(b) + } + + b = append(b, f.Column...) + b = append(b, " = "...) + + if q.placeholder { + b = append(b, '?') + continue + } + + app, ok := q.q.modelValues[f.SQLName] + if ok { + b, err = app.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } else { + b = f.AppendValue(b, strct, 1) + } + } + + for i, v := range q.q.extraValues { + if i > 0 || len(fields) > 0 { + b = append(b, ", "...) + } + + b = append(b, v.column...) + b = append(b, " = "...) + + b, err = v.value.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + + return b, nil +} + +func (q *UpdateQuery) appendSetSlice(b []byte) ([]byte, error) { + fields, err := q.q.getFields() + if err != nil { + return nil, err + } + + if len(fields) == 0 { + fields = q.q.tableModel.Table().DataFields + } + + var table *Table + if q.omitZero { + table = q.q.tableModel.Table() + } + + for i, f := range fields { + if i > 0 { + b = append(b, ", "...) + } + + b = append(b, f.Column...) + b = append(b, " = "...) + if q.omitZero && table != nil { + b = append(b, "COALESCE("...) + } + b = append(b, "_data."...) + b = append(b, f.Column...) + if q.omitZero && table != nil { + b = append(b, ", "...) + if table.Alias != table.SQLName { + b = append(b, table.Alias...) + b = append(b, '.') + } + b = append(b, f.Column...) + b = append(b, ")"...) + } + } + + return b, nil +} + +func (q *UpdateQuery) appendSliceModelData(fmter QueryFormatter, b []byte) ([]byte, error) { + columns, err := q.q.getDataFields() + if err != nil { + return nil, err + } + + if len(columns) > 0 { + columns = append(columns, q.q.tableModel.Table().PKs...) + } else { + columns = q.q.tableModel.Table().Fields + } + + return q.appendSliceValues(fmter, b, columns, q.q.tableModel.Value()) +} + +func (q *UpdateQuery) appendSliceValues( + fmter QueryFormatter, b []byte, fields []*Field, slice reflect.Value, +) (_ []byte, err error) { + b = append(b, "(VALUES ("...) + + if q.placeholder { + b, err = q.appendValues(fmter, b, fields, reflect.Value{}) + if err != nil { + return nil, err + } + } else { + sliceLen := slice.Len() + for i := 0; i < sliceLen; i++ { + if i > 0 { + b = append(b, "), ("...) + } + b, err = q.appendValues(fmter, b, fields, slice.Index(i)) + if err != nil { + return nil, err + } + } + } + + b = append(b, ")) AS _data("...) + b = appendColumns(b, "", fields) + b = append(b, ")"...) + + return b, nil +} + +func (q *UpdateQuery) appendValues( + fmter QueryFormatter, b []byte, fields []*Field, strct reflect.Value, +) (_ []byte, err error) { + for i, f := range fields { + if i > 0 { + b = append(b, ", "...) + } + + app, ok := q.q.modelValues[f.SQLName] + if ok { + b, err = app.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + continue + } + + if q.placeholder { + b = append(b, '?') + } else { + b = f.AppendValue(b, indirect(strct), 1) + } + + b = append(b, "::"...) + b = append(b, f.SQLType...) + } + return b, nil +} + +func appendWhereColumnAndColumn(b []byte, alias types.Safe, fields []*Field) []byte { + for i, f := range fields { + if i > 0 { + b = append(b, " AND "...) + } + b = append(b, alias...) + b = append(b, '.') + b = append(b, f.Column...) + b = append(b, " = _data."...) + b = append(b, f.Column...) + } + return b +} diff --git a/vendor/github.com/go-pg/pg/v10/orm/util.go b/vendor/github.com/go-pg/pg/v10/orm/util.go new file mode 100644 index 000000000..b7963ba0b --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/orm/util.go @@ -0,0 +1,151 @@ +package orm + +import ( + "reflect" + + "github.com/go-pg/pg/v10/types" +) + +func indirect(v reflect.Value) reflect.Value { + switch v.Kind() { + case reflect.Interface: + return indirect(v.Elem()) + case reflect.Ptr: + return v.Elem() + default: + return v + } +} + +func indirectType(t reflect.Type) reflect.Type { + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + return t +} + +func sliceElemType(v reflect.Value) reflect.Type { + elemType := v.Type().Elem() + if elemType.Kind() == reflect.Interface && v.Len() > 0 { + return indirect(v.Index(0).Elem()).Type() + } + return indirectType(elemType) +} + +func typeByIndex(t reflect.Type, index []int) reflect.Type { + for _, x := range index { + switch t.Kind() { + case reflect.Ptr: + t = t.Elem() + case reflect.Slice: + t = indirectType(t.Elem()) + } + t = t.Field(x).Type + } + return indirectType(t) +} + +func fieldByIndex(v reflect.Value, index []int) (_ reflect.Value, ok bool) { + if len(index) == 1 { + return v.Field(index[0]), true + } + + for i, idx := range index { + if i > 0 { + if v.Kind() == reflect.Ptr { + if v.IsNil() { + return v, false + } + v = v.Elem() + } + } + v = v.Field(idx) + } + return v, true +} + +func fieldByIndexAlloc(v reflect.Value, index []int) reflect.Value { + if len(index) == 1 { + return v.Field(index[0]) + } + + for i, idx := range index { + if i > 0 { + v = indirectNil(v) + } + v = v.Field(idx) + } + return v +} + +func indirectNil(v reflect.Value) reflect.Value { + if v.Kind() == reflect.Ptr { + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } + v = v.Elem() + } + return v +} + +func walk(v reflect.Value, index []int, fn func(reflect.Value)) { + v = reflect.Indirect(v) + switch v.Kind() { + case reflect.Slice: + sliceLen := v.Len() + for i := 0; i < sliceLen; i++ { + visitField(v.Index(i), index, fn) + } + default: + visitField(v, index, fn) + } +} + +func visitField(v reflect.Value, index []int, fn func(reflect.Value)) { + v = reflect.Indirect(v) + if len(index) > 0 { + v = v.Field(index[0]) + if v.Kind() == reflect.Ptr && v.IsNil() { + return + } + walk(v, index[1:], fn) + } else { + fn(v) + } +} + +func dstValues(model TableModel, fields []*Field) map[string][]reflect.Value { + fieldIndex := model.Relation().Field.Index + m := make(map[string][]reflect.Value) + var id []byte + walk(model.Root(), model.ParentIndex(), func(v reflect.Value) { + id = modelID(id[:0], v, fields) + m[string(id)] = append(m[string(id)], v.FieldByIndex(fieldIndex)) + }) + return m +} + +func modelID(b []byte, v reflect.Value, fields []*Field) []byte { + for i, f := range fields { + if i > 0 { + b = append(b, ',') + } + b = f.AppendValue(b, v, 0) + } + return b +} + +func appendColumns(b []byte, table types.Safe, fields []*Field) []byte { + for i, f := range fields { + if i > 0 { + b = append(b, ", "...) + } + + if len(table) > 0 { + b = append(b, table...) + b = append(b, '.') + } + b = append(b, f.Column...) + } + return b +} diff --git a/vendor/github.com/go-pg/pg/v10/pg.go b/vendor/github.com/go-pg/pg/v10/pg.go new file mode 100644 index 000000000..923ef6bef --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/pg.go @@ -0,0 +1,274 @@ +package pg + +import ( + "context" + "io" + "strconv" + + "github.com/go-pg/pg/v10/internal" + "github.com/go-pg/pg/v10/orm" + "github.com/go-pg/pg/v10/types" +) + +// Discard is used with Query and QueryOne to discard rows. +var Discard orm.Discard + +// NullTime is a time.Time wrapper that marshals zero time as JSON null and +// PostgreSQL NULL. +type NullTime = types.NullTime + +// Scan returns ColumnScanner that copies the columns in the +// row into the values. +func Scan(values ...interface{}) orm.ColumnScanner { + return orm.Scan(values...) +} + +// Safe represents a safe SQL query. +type Safe = types.Safe + +// Ident represents a SQL identifier, e.g. table or column name. +type Ident = types.Ident + +// SafeQuery replaces any placeholders found in the query. +func SafeQuery(query string, params ...interface{}) *orm.SafeQueryAppender { + return orm.SafeQuery(query, params...) +} + +// In accepts a slice and returns a wrapper that can be used with PostgreSQL +// IN operator: +// +// Where("id IN (?)", pg.In([]int{1, 2, 3, 4})) +// +// produces +// +// WHERE id IN (1, 2, 3, 4) +func In(slice interface{}) types.ValueAppender { + return types.In(slice) +} + +// InMulti accepts multiple values and returns a wrapper that can be used +// with PostgreSQL IN operator: +// +// Where("(id1, id2) IN (?)", pg.InMulti([]int{1, 2}, []int{3, 4})) +// +// produces +// +// WHERE (id1, id2) IN ((1, 2), (3, 4)) +func InMulti(values ...interface{}) types.ValueAppender { + return types.InMulti(values...) +} + +// Array accepts a slice and returns a wrapper for working with PostgreSQL +// array data type. +// +// For struct fields you can use array tag: +// +// Emails []string `pg:",array"` +func Array(v interface{}) *types.Array { + return types.NewArray(v) +} + +// Hstore accepts a map and returns a wrapper for working with hstore data type. +// Supported map types are: +// - map[string]string +// +// For struct fields you can use hstore tag: +// +// Attrs map[string]string `pg:",hstore"` +func Hstore(v interface{}) *types.Hstore { + return types.NewHstore(v) +} + +// SetLogger sets the logger to the given one. +func SetLogger(logger internal.Logging) { + internal.Logger = logger +} + +//------------------------------------------------------------------------------ + +type Query = orm.Query + +// Model returns a new query for the optional model. +func Model(model ...interface{}) *Query { + return orm.NewQuery(nil, model...) +} + +// ModelContext returns a new query for the optional model with a context. +func ModelContext(c context.Context, model ...interface{}) *Query { + return orm.NewQueryContext(c, nil, model...) +} + +// DBI is a DB interface implemented by *DB and *Tx. +type DBI interface { + Model(model ...interface{}) *Query + ModelContext(c context.Context, model ...interface{}) *Query + + Exec(query interface{}, params ...interface{}) (Result, error) + ExecContext(c context.Context, query interface{}, params ...interface{}) (Result, error) + ExecOne(query interface{}, params ...interface{}) (Result, error) + ExecOneContext(c context.Context, query interface{}, params ...interface{}) (Result, error) + Query(model, query interface{}, params ...interface{}) (Result, error) + QueryContext(c context.Context, model, query interface{}, params ...interface{}) (Result, error) + QueryOne(model, query interface{}, params ...interface{}) (Result, error) + QueryOneContext(c context.Context, model, query interface{}, params ...interface{}) (Result, error) + + Begin() (*Tx, error) + RunInTransaction(ctx context.Context, fn func(*Tx) error) error + + CopyFrom(r io.Reader, query interface{}, params ...interface{}) (Result, error) + CopyTo(w io.Writer, query interface{}, params ...interface{}) (Result, error) +} + +var ( + _ DBI = (*DB)(nil) + _ DBI = (*Tx)(nil) +) + +//------------------------------------------------------------------------------ + +// Strings is a type alias for a slice of strings. +type Strings []string + +var ( + _ orm.HooklessModel = (*Strings)(nil) + _ types.ValueAppender = (*Strings)(nil) +) + +// Init initializes the Strings slice. +func (strings *Strings) Init() error { + if s := *strings; len(s) > 0 { + *strings = s[:0] + } + return nil +} + +// NextColumnScanner ... +func (strings *Strings) NextColumnScanner() orm.ColumnScanner { + return strings +} + +// AddColumnScanner ... +func (Strings) AddColumnScanner(_ orm.ColumnScanner) error { + return nil +} + +// ScanColumn scans the columns and appends them to `strings`. +func (strings *Strings) ScanColumn(col types.ColumnInfo, rd types.Reader, n int) error { + b := make([]byte, n) + _, err := io.ReadFull(rd, b) + if err != nil { + return err + } + + *strings = append(*strings, internal.BytesToString(b)) + return nil +} + +// AppendValue appends the values from `strings` to the given byte slice. +func (strings Strings) AppendValue(dst []byte, quote int) ([]byte, error) { + if len(strings) == 0 { + return dst, nil + } + + for _, s := range strings { + dst = types.AppendString(dst, s, 1) + dst = append(dst, ',') + } + dst = dst[:len(dst)-1] + return dst, nil +} + +//------------------------------------------------------------------------------ + +// Ints is a type alias for a slice of int64 values. +type Ints []int64 + +var ( + _ orm.HooklessModel = (*Ints)(nil) + _ types.ValueAppender = (*Ints)(nil) +) + +// Init initializes the Int slice. +func (ints *Ints) Init() error { + if s := *ints; len(s) > 0 { + *ints = s[:0] + } + return nil +} + +// NewColumnScanner ... +func (ints *Ints) NextColumnScanner() orm.ColumnScanner { + return ints +} + +// AddColumnScanner ... +func (Ints) AddColumnScanner(_ orm.ColumnScanner) error { + return nil +} + +// ScanColumn scans the columns and appends them to `ints`. +func (ints *Ints) ScanColumn(col types.ColumnInfo, rd types.Reader, n int) error { + num, err := types.ScanInt64(rd, n) + if err != nil { + return err + } + + *ints = append(*ints, num) + return nil +} + +// AppendValue appends the values from `ints` to the given byte slice. +func (ints Ints) AppendValue(dst []byte, quote int) ([]byte, error) { + if len(ints) == 0 { + return dst, nil + } + + for _, v := range ints { + dst = strconv.AppendInt(dst, v, 10) + dst = append(dst, ',') + } + dst = dst[:len(dst)-1] + return dst, nil +} + +//------------------------------------------------------------------------------ + +// IntSet is a set of int64 values. +type IntSet map[int64]struct{} + +var _ orm.HooklessModel = (*IntSet)(nil) + +// Init initializes the IntSet. +func (set *IntSet) Init() error { + if len(*set) > 0 { + *set = make(map[int64]struct{}) + } + return nil +} + +// NextColumnScanner ... +func (set *IntSet) NextColumnScanner() orm.ColumnScanner { + return set +} + +// AddColumnScanner ... +func (IntSet) AddColumnScanner(_ orm.ColumnScanner) error { + return nil +} + +// ScanColumn scans the columns and appends them to `IntSet`. +func (set *IntSet) ScanColumn(col types.ColumnInfo, rd types.Reader, n int) error { + num, err := types.ScanInt64(rd, n) + if err != nil { + return err + } + + setVal := *set + if setVal == nil { + *set = make(IntSet) + setVal = *set + } + + setVal[num] = struct{}{} + return nil +} diff --git a/vendor/github.com/go-pg/pg/v10/pgjson/json.go b/vendor/github.com/go-pg/pg/v10/pgjson/json.go new file mode 100644 index 000000000..c401dc946 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/pgjson/json.go @@ -0,0 +1,26 @@ +package pgjson + +import ( + "encoding/json" + "io" +) + +var _ Provider = (*StdProvider)(nil) + +type StdProvider struct{} + +func (StdProvider) Marshal(v interface{}) ([]byte, error) { + return json.Marshal(v) +} + +func (StdProvider) Unmarshal(data []byte, v interface{}) error { + return json.Unmarshal(data, v) +} + +func (StdProvider) NewEncoder(w io.Writer) Encoder { + return json.NewEncoder(w) +} + +func (StdProvider) NewDecoder(r io.Reader) Decoder { + return json.NewDecoder(r) +} diff --git a/vendor/github.com/go-pg/pg/v10/pgjson/provider.go b/vendor/github.com/go-pg/pg/v10/pgjson/provider.go new file mode 100644 index 000000000..a4b663ce4 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/pgjson/provider.go @@ -0,0 +1,43 @@ +package pgjson + +import ( + "io" +) + +var provider Provider = StdProvider{} + +func SetProvider(p Provider) { + provider = p +} + +type Provider interface { + Marshal(v interface{}) ([]byte, error) + Unmarshal(data []byte, v interface{}) error + NewEncoder(w io.Writer) Encoder + NewDecoder(r io.Reader) Decoder +} + +type Decoder interface { + Decode(v interface{}) error + UseNumber() +} + +type Encoder interface { + Encode(v interface{}) error +} + +func Marshal(v interface{}) ([]byte, error) { + return provider.Marshal(v) +} + +func Unmarshal(data []byte, v interface{}) error { + return provider.Unmarshal(data, v) +} + +func NewEncoder(w io.Writer) Encoder { + return provider.NewEncoder(w) +} + +func NewDecoder(r io.Reader) Decoder { + return provider.NewDecoder(r) +} diff --git a/vendor/github.com/go-pg/pg/v10/result.go b/vendor/github.com/go-pg/pg/v10/result.go new file mode 100644 index 000000000..b8d8d9e45 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/result.go @@ -0,0 +1,53 @@ +package pg + +import ( + "bytes" + "strconv" + + "github.com/go-pg/pg/v10/internal" + "github.com/go-pg/pg/v10/orm" +) + +// Result summarizes an executed SQL command. +type Result = orm.Result + +// A result summarizes an executed SQL command. +type result struct { + model orm.Model + + affected int + returned int +} + +var _ Result = (*result)(nil) + +//nolint +func (res *result) parse(b []byte) error { + res.affected = -1 + + ind := bytes.LastIndexByte(b, ' ') + if ind == -1 { + return nil + } + + s := internal.BytesToString(b[ind+1 : len(b)-1]) + + affected, err := strconv.Atoi(s) + if err == nil { + res.affected = affected + } + + return nil +} + +func (res *result) Model() orm.Model { + return res.model +} + +func (res *result) RowsAffected() int { + return res.affected +} + +func (res *result) RowsReturned() int { + return res.returned +} diff --git a/vendor/github.com/go-pg/pg/v10/stmt.go b/vendor/github.com/go-pg/pg/v10/stmt.go new file mode 100644 index 000000000..528788379 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/stmt.go @@ -0,0 +1,282 @@ +package pg + +import ( + "context" + "errors" + + "github.com/go-pg/pg/v10/internal" + "github.com/go-pg/pg/v10/internal/pool" + "github.com/go-pg/pg/v10/orm" + "github.com/go-pg/pg/v10/types" +) + +var errStmtClosed = errors.New("pg: statement is closed") + +// Stmt is a prepared statement. Stmt is safe for concurrent use by +// multiple goroutines. +type Stmt struct { + db *baseDB + stickyErr error + + q string + name string + columns []types.ColumnInfo +} + +func prepareStmt(db *baseDB, q string) (*Stmt, error) { + stmt := &Stmt{ + db: db, + + q: q, + } + + err := stmt.prepare(context.TODO(), q) + if err != nil { + _ = stmt.Close() + return nil, err + } + return stmt, nil +} + +func (stmt *Stmt) prepare(ctx context.Context, q string) error { + var lastErr error + for attempt := 0; attempt <= stmt.db.opt.MaxRetries; attempt++ { + if attempt > 0 { + if err := internal.Sleep(ctx, stmt.db.retryBackoff(attempt-1)); err != nil { + return err + } + + err := stmt.db.pool.(*pool.StickyConnPool).Reset(ctx) + if err != nil { + return err + } + } + + lastErr = stmt.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error { + var err error + stmt.name, stmt.columns, err = stmt.db.prepare(ctx, cn, q) + return err + }) + if !stmt.db.shouldRetry(lastErr) { + break + } + } + return lastErr +} + +func (stmt *Stmt) withConn(c context.Context, fn func(context.Context, *pool.Conn) error) error { + if stmt.stickyErr != nil { + return stmt.stickyErr + } + err := stmt.db.withConn(c, fn) + if err == pool.ErrClosed { + return errStmtClosed + } + return err +} + +// Exec executes a prepared statement with the given parameters. +func (stmt *Stmt) Exec(params ...interface{}) (Result, error) { + return stmt.exec(context.TODO(), params...) +} + +// ExecContext executes a prepared statement with the given parameters. +func (stmt *Stmt) ExecContext(c context.Context, params ...interface{}) (Result, error) { + return stmt.exec(c, params...) +} + +func (stmt *Stmt) exec(ctx context.Context, params ...interface{}) (Result, error) { + ctx, evt, err := stmt.db.beforeQuery(ctx, stmt.db.db, nil, stmt.q, params, nil) + if err != nil { + return nil, err + } + + var res Result + var lastErr error + for attempt := 0; attempt <= stmt.db.opt.MaxRetries; attempt++ { + if attempt > 0 { + lastErr = internal.Sleep(ctx, stmt.db.retryBackoff(attempt-1)) + if lastErr != nil { + break + } + } + + lastErr = stmt.withConn(ctx, func(c context.Context, cn *pool.Conn) error { + res, err = stmt.extQuery(ctx, cn, stmt.name, params...) + return err + }) + if !stmt.db.shouldRetry(lastErr) { + break + } + } + + if err := stmt.db.afterQuery(ctx, evt, res, lastErr); err != nil { + return nil, err + } + return res, lastErr +} + +// ExecOne acts like Exec, but query must affect only one row. It +// returns ErrNoRows error when query returns zero rows or +// ErrMultiRows when query returns multiple rows. +func (stmt *Stmt) ExecOne(params ...interface{}) (Result, error) { + return stmt.execOne(context.Background(), params...) +} + +// ExecOneContext acts like ExecOne but additionally receives a context. +func (stmt *Stmt) ExecOneContext(c context.Context, params ...interface{}) (Result, error) { + return stmt.execOne(c, params...) +} + +func (stmt *Stmt) execOne(c context.Context, params ...interface{}) (Result, error) { + res, err := stmt.ExecContext(c, params...) + if err != nil { + return nil, err + } + + if err := internal.AssertOneRow(res.RowsAffected()); err != nil { + return nil, err + } + return res, nil +} + +// Query executes a prepared query statement with the given parameters. +func (stmt *Stmt) Query(model interface{}, params ...interface{}) (Result, error) { + return stmt.query(context.Background(), model, params...) +} + +// QueryContext acts like Query but additionally receives a context. +func (stmt *Stmt) QueryContext(c context.Context, model interface{}, params ...interface{}) (Result, error) { + return stmt.query(c, model, params...) +} + +func (stmt *Stmt) query(ctx context.Context, model interface{}, params ...interface{}) (Result, error) { + ctx, evt, err := stmt.db.beforeQuery(ctx, stmt.db.db, model, stmt.q, params, nil) + if err != nil { + return nil, err + } + + var res Result + var lastErr error + for attempt := 0; attempt <= stmt.db.opt.MaxRetries; attempt++ { + if attempt > 0 { + lastErr = internal.Sleep(ctx, stmt.db.retryBackoff(attempt-1)) + if lastErr != nil { + break + } + } + + lastErr = stmt.withConn(ctx, func(c context.Context, cn *pool.Conn) error { + res, err = stmt.extQueryData(ctx, cn, stmt.name, model, stmt.columns, params...) + return err + }) + if !stmt.db.shouldRetry(lastErr) { + break + } + } + + if err := stmt.db.afterQuery(ctx, evt, res, lastErr); err != nil { + return nil, err + } + return res, lastErr +} + +// QueryOne acts like Query, but query must return only one row. It +// returns ErrNoRows error when query returns zero rows or +// ErrMultiRows when query returns multiple rows. +func (stmt *Stmt) QueryOne(model interface{}, params ...interface{}) (Result, error) { + return stmt.queryOne(context.Background(), model, params...) +} + +// QueryOneContext acts like QueryOne but additionally receives a context. +func (stmt *Stmt) QueryOneContext(c context.Context, model interface{}, params ...interface{}) (Result, error) { + return stmt.queryOne(c, model, params...) +} + +func (stmt *Stmt) queryOne(c context.Context, model interface{}, params ...interface{}) (Result, error) { + mod, err := orm.NewModel(model) + if err != nil { + return nil, err + } + + res, err := stmt.QueryContext(c, mod, params...) + if err != nil { + return nil, err + } + + if err := internal.AssertOneRow(res.RowsAffected()); err != nil { + return nil, err + } + return res, nil +} + +// Close closes the statement. +func (stmt *Stmt) Close() error { + var firstErr error + + if stmt.name != "" { + firstErr = stmt.closeStmt() + } + + err := stmt.db.Close() + if err != nil && firstErr == nil { + firstErr = err + } + + return firstErr +} + +func (stmt *Stmt) extQuery( + c context.Context, cn *pool.Conn, name string, params ...interface{}, +) (Result, error) { + err := cn.WithWriter(c, stmt.db.opt.WriteTimeout, func(wb *pool.WriteBuffer) error { + return writeBindExecuteMsg(wb, name, params...) + }) + if err != nil { + return nil, err + } + + var res Result + err = cn.WithReader(c, stmt.db.opt.ReadTimeout, func(rd *pool.ReaderContext) error { + res, err = readExtQuery(rd) + return err + }) + if err != nil { + return nil, err + } + + return res, nil +} + +func (stmt *Stmt) extQueryData( + c context.Context, + cn *pool.Conn, + name string, + model interface{}, + columns []types.ColumnInfo, + params ...interface{}, +) (Result, error) { + err := cn.WithWriter(c, stmt.db.opt.WriteTimeout, func(wb *pool.WriteBuffer) error { + return writeBindExecuteMsg(wb, name, params...) + }) + if err != nil { + return nil, err + } + + var res *result + err = cn.WithReader(c, stmt.db.opt.ReadTimeout, func(rd *pool.ReaderContext) error { + res, err = readExtQueryData(c, rd, model, columns) + return err + }) + if err != nil { + return nil, err + } + + return res, nil +} + +func (stmt *Stmt) closeStmt() error { + return stmt.withConn(context.TODO(), func(c context.Context, cn *pool.Conn) error { + return stmt.db.closeStmt(c, cn, stmt.name) + }) +} diff --git a/vendor/github.com/go-pg/pg/v10/tx.go b/vendor/github.com/go-pg/pg/v10/tx.go new file mode 100644 index 000000000..db444ff65 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/tx.go @@ -0,0 +1,388 @@ +package pg + +import ( + "context" + "errors" + "io" + "sync" + "sync/atomic" + + "github.com/go-pg/pg/v10/internal" + "github.com/go-pg/pg/v10/internal/pool" + "github.com/go-pg/pg/v10/orm" +) + +// ErrTxDone is returned by any operation that is performed on a transaction +// that has already been committed or rolled back. +var ErrTxDone = errors.New("pg: transaction has already been committed or rolled back") + +// Tx is an in-progress database transaction. It is safe for concurrent use +// by multiple goroutines. +// +// A transaction must end with a call to Commit or Rollback. +// +// After a call to Commit or Rollback, all operations on the transaction fail +// with ErrTxDone. +// +// The statements prepared for a transaction by calling the transaction's +// Prepare or Stmt methods are closed by the call to Commit or Rollback. +type Tx struct { + db *baseDB + ctx context.Context + + stmtsMu sync.Mutex + stmts []*Stmt + + _closed int32 +} + +var _ orm.DB = (*Tx)(nil) + +// Context returns the context.Context of the transaction. +func (tx *Tx) Context() context.Context { + return tx.ctx +} + +// Begin starts a transaction. Most callers should use RunInTransaction instead. +func (db *baseDB) Begin() (*Tx, error) { + return db.BeginContext(db.db.Context()) +} + +func (db *baseDB) BeginContext(ctx context.Context) (*Tx, error) { + tx := &Tx{ + db: db.withPool(pool.NewStickyConnPool(db.pool)), + ctx: ctx, + } + + err := tx.begin(ctx) + if err != nil { + tx.close() + return nil, err + } + + return tx, nil +} + +// RunInTransaction runs a function in a transaction. If function +// returns an error transaction is rolled back, otherwise transaction +// is committed. +func (db *baseDB) RunInTransaction(ctx context.Context, fn func(*Tx) error) error { + tx, err := db.BeginContext(ctx) + if err != nil { + return err + } + return tx.RunInTransaction(ctx, fn) +} + +// Begin returns current transaction. It does not start new transaction. +func (tx *Tx) Begin() (*Tx, error) { + return tx, nil +} + +// RunInTransaction runs a function in the transaction. If function +// returns an error transaction is rolled back, otherwise transaction +// is committed. +func (tx *Tx) RunInTransaction(ctx context.Context, fn func(*Tx) error) error { + defer func() { + if err := recover(); err != nil { + if err := tx.RollbackContext(ctx); err != nil { + internal.Logger.Printf(ctx, "tx.Rollback panicked: %s", err) + } + panic(err) + } + }() + + if err := fn(tx); err != nil { + if err := tx.RollbackContext(ctx); err != nil { + internal.Logger.Printf(ctx, "tx.Rollback failed: %s", err) + } + return err + } + return tx.CommitContext(ctx) +} + +func (tx *Tx) withConn(c context.Context, fn func(context.Context, *pool.Conn) error) error { + err := tx.db.withConn(c, fn) + if tx.closed() && err == pool.ErrClosed { + return ErrTxDone + } + return err +} + +// Stmt returns a transaction-specific prepared statement +// from an existing statement. +func (tx *Tx) Stmt(stmt *Stmt) *Stmt { + stmt, err := tx.Prepare(stmt.q) + if err != nil { + return &Stmt{stickyErr: err} + } + return stmt +} + +// Prepare creates a prepared statement for use within a transaction. +// +// The returned statement operates within the transaction and can no longer +// be used once the transaction has been committed or rolled back. +// +// To use an existing prepared statement on this transaction, see Tx.Stmt. +func (tx *Tx) Prepare(q string) (*Stmt, error) { + tx.stmtsMu.Lock() + defer tx.stmtsMu.Unlock() + + db := tx.db.withPool(pool.NewStickyConnPool(tx.db.pool)) + stmt, err := prepareStmt(db, q) + if err != nil { + return nil, err + } + tx.stmts = append(tx.stmts, stmt) + + return stmt, nil +} + +// Exec is an alias for DB.Exec. +func (tx *Tx) Exec(query interface{}, params ...interface{}) (Result, error) { + return tx.exec(tx.ctx, query, params...) +} + +// ExecContext acts like Exec but additionally receives a context. +func (tx *Tx) ExecContext(c context.Context, query interface{}, params ...interface{}) (Result, error) { + return tx.exec(c, query, params...) +} + +func (tx *Tx) exec(ctx context.Context, query interface{}, params ...interface{}) (Result, error) { + wb := pool.GetWriteBuffer() + defer pool.PutWriteBuffer(wb) + + if err := writeQueryMsg(wb, tx.db.fmter, query, params...); err != nil { + return nil, err + } + + ctx, evt, err := tx.db.beforeQuery(ctx, tx, nil, query, params, wb.Query()) + if err != nil { + return nil, err + } + + var res Result + lastErr := tx.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error { + res, err = tx.db.simpleQuery(ctx, cn, wb) + return err + }) + + if err := tx.db.afterQuery(ctx, evt, res, lastErr); err != nil { + return nil, err + } + return res, lastErr +} + +// ExecOne is an alias for DB.ExecOne. +func (tx *Tx) ExecOne(query interface{}, params ...interface{}) (Result, error) { + return tx.execOne(tx.ctx, query, params...) +} + +// ExecOneContext acts like ExecOne but additionally receives a context. +func (tx *Tx) ExecOneContext(c context.Context, query interface{}, params ...interface{}) (Result, error) { + return tx.execOne(c, query, params...) +} + +func (tx *Tx) execOne(c context.Context, query interface{}, params ...interface{}) (Result, error) { + res, err := tx.ExecContext(c, query, params...) + if err != nil { + return nil, err + } + + if err := internal.AssertOneRow(res.RowsAffected()); err != nil { + return nil, err + } + return res, nil +} + +// Query is an alias for DB.Query. +func (tx *Tx) Query(model interface{}, query interface{}, params ...interface{}) (Result, error) { + return tx.query(tx.ctx, model, query, params...) +} + +// QueryContext acts like Query but additionally receives a context. +func (tx *Tx) QueryContext( + c context.Context, + model interface{}, + query interface{}, + params ...interface{}, +) (Result, error) { + return tx.query(c, model, query, params...) +} + +func (tx *Tx) query( + ctx context.Context, + model interface{}, + query interface{}, + params ...interface{}, +) (Result, error) { + wb := pool.GetWriteBuffer() + defer pool.PutWriteBuffer(wb) + + if err := writeQueryMsg(wb, tx.db.fmter, query, params...); err != nil { + return nil, err + } + + ctx, evt, err := tx.db.beforeQuery(ctx, tx, model, query, params, wb.Query()) + if err != nil { + return nil, err + } + + var res *result + lastErr := tx.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error { + res, err = tx.db.simpleQueryData(ctx, cn, model, wb) + return err + }) + + if err := tx.db.afterQuery(ctx, evt, res, err); err != nil { + return nil, err + } + return res, lastErr +} + +// QueryOne is an alias for DB.QueryOne. +func (tx *Tx) QueryOne(model interface{}, query interface{}, params ...interface{}) (Result, error) { + return tx.queryOne(tx.ctx, model, query, params...) +} + +// QueryOneContext acts like QueryOne but additionally receives a context. +func (tx *Tx) QueryOneContext( + c context.Context, + model interface{}, + query interface{}, + params ...interface{}, +) (Result, error) { + return tx.queryOne(c, model, query, params...) +} + +func (tx *Tx) queryOne( + c context.Context, + model interface{}, + query interface{}, + params ...interface{}, +) (Result, error) { + mod, err := orm.NewModel(model) + if err != nil { + return nil, err + } + + res, err := tx.QueryContext(c, mod, query, params...) + if err != nil { + return nil, err + } + + if err := internal.AssertOneRow(res.RowsAffected()); err != nil { + return nil, err + } + return res, nil +} + +// Model is an alias for DB.Model. +func (tx *Tx) Model(model ...interface{}) *Query { + return orm.NewQuery(tx, model...) +} + +// ModelContext acts like Model but additionally receives a context. +func (tx *Tx) ModelContext(c context.Context, model ...interface{}) *Query { + return orm.NewQueryContext(c, tx, model...) +} + +// CopyFrom is an alias for DB.CopyFrom. +func (tx *Tx) CopyFrom(r io.Reader, query interface{}, params ...interface{}) (res Result, err error) { + err = tx.withConn(tx.ctx, func(c context.Context, cn *pool.Conn) error { + res, err = tx.db.copyFrom(c, cn, r, query, params...) + return err + }) + return res, err +} + +// CopyTo is an alias for DB.CopyTo. +func (tx *Tx) CopyTo(w io.Writer, query interface{}, params ...interface{}) (res Result, err error) { + err = tx.withConn(tx.ctx, func(c context.Context, cn *pool.Conn) error { + res, err = tx.db.copyTo(c, cn, w, query, params...) + return err + }) + return res, err +} + +// Formatter is an alias for DB.Formatter. +func (tx *Tx) Formatter() orm.QueryFormatter { + return tx.db.Formatter() +} + +func (tx *Tx) begin(ctx context.Context) error { + var lastErr error + for attempt := 0; attempt <= tx.db.opt.MaxRetries; attempt++ { + if attempt > 0 { + if err := internal.Sleep(ctx, tx.db.retryBackoff(attempt-1)); err != nil { + return err + } + + err := tx.db.pool.(*pool.StickyConnPool).Reset(ctx) + if err != nil { + return err + } + } + + _, lastErr = tx.ExecContext(ctx, "BEGIN") + if !tx.db.shouldRetry(lastErr) { + break + } + } + return lastErr +} + +func (tx *Tx) Commit() error { + return tx.CommitContext(tx.ctx) +} + +// Commit commits the transaction. +func (tx *Tx) CommitContext(ctx context.Context) error { + _, err := tx.ExecContext(internal.UndoContext(ctx), "COMMIT") + tx.close() + return err +} + +func (tx *Tx) Rollback() error { + return tx.RollbackContext(tx.ctx) +} + +// Rollback aborts the transaction. +func (tx *Tx) RollbackContext(ctx context.Context) error { + _, err := tx.ExecContext(internal.UndoContext(ctx), "ROLLBACK") + tx.close() + return err +} + +func (tx *Tx) Close() error { + return tx.CloseContext(tx.ctx) +} + +// Close calls Rollback if the tx has not already been committed or rolled back. +func (tx *Tx) CloseContext(ctx context.Context) error { + if tx.closed() { + return nil + } + return tx.RollbackContext(ctx) +} + +func (tx *Tx) close() { + if !atomic.CompareAndSwapInt32(&tx._closed, 0, 1) { + return + } + + tx.stmtsMu.Lock() + defer tx.stmtsMu.Unlock() + + for _, stmt := range tx.stmts { + _ = stmt.Close() + } + tx.stmts = nil + + _ = tx.db.Close() +} + +func (tx *Tx) closed() bool { + return atomic.LoadInt32(&tx._closed) == 1 +} diff --git a/vendor/github.com/go-pg/pg/v10/types/append.go b/vendor/github.com/go-pg/pg/v10/types/append.go new file mode 100644 index 000000000..05be2a0fa --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/types/append.go @@ -0,0 +1,201 @@ +package types + +import ( + "math" + "reflect" + "strconv" + "time" + "unicode/utf8" + + "github.com/tmthrgd/go-hex" +) + +func Append(b []byte, v interface{}, flags int) []byte { + switch v := v.(type) { + case nil: + return AppendNull(b, flags) + case bool: + return appendBool(b, v) + case int32: + return strconv.AppendInt(b, int64(v), 10) + case int64: + return strconv.AppendInt(b, v, 10) + case int: + return strconv.AppendInt(b, int64(v), 10) + case float32: + return appendFloat(b, float64(v), flags, 32) + case float64: + return appendFloat(b, v, flags, 64) + case string: + return AppendString(b, v, flags) + case time.Time: + return AppendTime(b, v, flags) + case []byte: + return AppendBytes(b, v, flags) + case ValueAppender: + return appendAppender(b, v, flags) + default: + return appendValue(b, reflect.ValueOf(v), flags) + } +} + +func AppendError(b []byte, err error) []byte { + b = append(b, "?!("...) + b = append(b, err.Error()...) + b = append(b, ')') + return b +} + +func AppendNull(b []byte, flags int) []byte { + if hasFlag(flags, quoteFlag) { + return append(b, "NULL"...) + } + return nil +} + +func appendBool(dst []byte, v bool) []byte { + if v { + return append(dst, "TRUE"...) + } + return append(dst, "FALSE"...) +} + +func appendFloat(dst []byte, v float64, flags int, bitSize int) []byte { + if hasFlag(flags, arrayFlag) { + return appendFloat2(dst, v, flags) + } + + switch { + case math.IsNaN(v): + if hasFlag(flags, quoteFlag) { + return append(dst, "'NaN'"...) + } + return append(dst, "NaN"...) + case math.IsInf(v, 1): + if hasFlag(flags, quoteFlag) { + return append(dst, "'Infinity'"...) + } + return append(dst, "Infinity"...) + case math.IsInf(v, -1): + if hasFlag(flags, quoteFlag) { + return append(dst, "'-Infinity'"...) + } + return append(dst, "-Infinity"...) + default: + return strconv.AppendFloat(dst, v, 'f', -1, bitSize) + } +} + +func appendFloat2(dst []byte, v float64, _ int) []byte { + switch { + case math.IsNaN(v): + return append(dst, "NaN"...) + case math.IsInf(v, 1): + return append(dst, "Infinity"...) + case math.IsInf(v, -1): + return append(dst, "-Infinity"...) + default: + return strconv.AppendFloat(dst, v, 'f', -1, 64) + } +} + +func AppendString(b []byte, s string, flags int) []byte { + if hasFlag(flags, arrayFlag) { + return appendString2(b, s, flags) + } + + if hasFlag(flags, quoteFlag) { + b = append(b, '\'') + for _, c := range s { + if c == '\000' { + continue + } + + if c == '\'' { + b = append(b, '\'', '\'') + } else { + b = appendRune(b, c) + } + } + b = append(b, '\'') + return b + } + + for _, c := range s { + if c != '\000' { + b = appendRune(b, c) + } + } + return b +} + +func appendString2(b []byte, s string, flags int) []byte { + b = append(b, '"') + for _, c := range s { + if c == '\000' { + continue + } + + switch c { + case '\'': + if hasFlag(flags, quoteFlag) { + b = append(b, '\'') + } + b = append(b, '\'') + case '"': + b = append(b, '\\', '"') + case '\\': + b = append(b, '\\', '\\') + default: + b = appendRune(b, c) + } + } + b = append(b, '"') + return b +} + +func appendRune(b []byte, r rune) []byte { + if r < utf8.RuneSelf { + return append(b, byte(r)) + } + l := len(b) + if cap(b)-l < utf8.UTFMax { + b = append(b, make([]byte, utf8.UTFMax)...) + } + n := utf8.EncodeRune(b[l:l+utf8.UTFMax], r) + return b[:l+n] +} + +func AppendBytes(b []byte, bytes []byte, flags int) []byte { + if bytes == nil { + return AppendNull(b, flags) + } + + if hasFlag(flags, arrayFlag) { + b = append(b, `"\`...) + } else if hasFlag(flags, quoteFlag) { + b = append(b, '\'') + } + + b = append(b, `\x`...) + + s := len(b) + b = append(b, make([]byte, hex.EncodedLen(len(bytes)))...) + hex.Encode(b[s:], bytes) + + if hasFlag(flags, arrayFlag) { + b = append(b, '"') + } else if hasFlag(flags, quoteFlag) { + b = append(b, '\'') + } + + return b +} + +func appendAppender(b []byte, v ValueAppender, flags int) []byte { + bb, err := v.AppendValue(b, flags) + if err != nil { + return AppendError(b, err) + } + return bb +} diff --git a/vendor/github.com/go-pg/pg/v10/types/append_ident.go b/vendor/github.com/go-pg/pg/v10/types/append_ident.go new file mode 100644 index 000000000..60b9d6784 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/types/append_ident.go @@ -0,0 +1,46 @@ +package types + +import "github.com/go-pg/pg/v10/internal" + +func AppendIdent(b []byte, field string, flags int) []byte { + return appendIdent(b, internal.StringToBytes(field), flags) +} + +func AppendIdentBytes(b []byte, field []byte, flags int) []byte { + return appendIdent(b, field, flags) +} + +func appendIdent(b, src []byte, flags int) []byte { + var quoted bool +loop: + for _, c := range src { + switch c { + case '*': + if !quoted { + b = append(b, '*') + continue loop + } + case '.': + if quoted && hasFlag(flags, quoteFlag) { + b = append(b, '"') + quoted = false + } + b = append(b, '.') + continue loop + } + + if !quoted && hasFlag(flags, quoteFlag) { + b = append(b, '"') + quoted = true + } + if c == '"' { + b = append(b, '"', '"') + } else { + b = append(b, c) + } + } + if quoted && hasFlag(flags, quoteFlag) { + b = append(b, '"') + } + return b +} diff --git a/vendor/github.com/go-pg/pg/v10/types/append_jsonb.go b/vendor/github.com/go-pg/pg/v10/types/append_jsonb.go new file mode 100644 index 000000000..ffe221825 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/types/append_jsonb.go @@ -0,0 +1,49 @@ +package types + +import "github.com/go-pg/pg/v10/internal/parser" + +func AppendJSONB(b, jsonb []byte, flags int) []byte { + if hasFlag(flags, arrayFlag) { + b = append(b, '"') + } else if hasFlag(flags, quoteFlag) { + b = append(b, '\'') + } + + p := parser.New(jsonb) + for p.Valid() { + c := p.Read() + switch c { + case '"': + if hasFlag(flags, arrayFlag) { + b = append(b, '\\') + } + b = append(b, '"') + case '\'': + if hasFlag(flags, quoteFlag) { + b = append(b, '\'') + } + b = append(b, '\'') + case '\000': + continue + case '\\': + if p.SkipBytes([]byte("u0000")) { + b = append(b, "\\\\u0000"...) + } else { + b = append(b, '\\') + if p.Valid() { + b = append(b, p.Read()) + } + } + default: + b = append(b, c) + } + } + + if hasFlag(flags, arrayFlag) { + b = append(b, '"') + } else if hasFlag(flags, quoteFlag) { + b = append(b, '\'') + } + + return b +} diff --git a/vendor/github.com/go-pg/pg/v10/types/append_value.go b/vendor/github.com/go-pg/pg/v10/types/append_value.go new file mode 100644 index 000000000..f12fc564f --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/types/append_value.go @@ -0,0 +1,248 @@ +package types + +import ( + "database/sql/driver" + "fmt" + "net" + "reflect" + "strconv" + "sync" + "time" + + "github.com/vmihailenco/bufpool" + + "github.com/go-pg/pg/v10/internal" + "github.com/go-pg/pg/v10/pgjson" +) + +var ( + driverValuerType = reflect.TypeOf((*driver.Valuer)(nil)).Elem() + appenderType = reflect.TypeOf((*ValueAppender)(nil)).Elem() +) + +type AppenderFunc func([]byte, reflect.Value, int) []byte + +var appenders []AppenderFunc + +//nolint +func init() { + appenders = []AppenderFunc{ + reflect.Bool: appendBoolValue, + reflect.Int: appendIntValue, + reflect.Int8: appendIntValue, + reflect.Int16: appendIntValue, + reflect.Int32: appendIntValue, + reflect.Int64: appendIntValue, + reflect.Uint: appendUintValue, + reflect.Uint8: appendUintValue, + reflect.Uint16: appendUintValue, + reflect.Uint32: appendUintValue, + reflect.Uint64: appendUintValue, + reflect.Uintptr: nil, + reflect.Float32: appendFloat32Value, + reflect.Float64: appendFloat64Value, + reflect.Complex64: nil, + reflect.Complex128: nil, + reflect.Array: appendJSONValue, + reflect.Chan: nil, + reflect.Func: nil, + reflect.Interface: appendIfaceValue, + reflect.Map: appendJSONValue, + reflect.Ptr: nil, + reflect.Slice: appendJSONValue, + reflect.String: appendStringValue, + reflect.Struct: appendStructValue, + reflect.UnsafePointer: nil, + } +} + +var appendersMap sync.Map + +// RegisterAppender registers an appender func for the value type. +// Expecting to be used only during initialization, it panics +// if there is already a registered appender for the given type. +func RegisterAppender(value interface{}, fn AppenderFunc) { + registerAppender(reflect.TypeOf(value), fn) +} + +func registerAppender(typ reflect.Type, fn AppenderFunc) { + _, loaded := appendersMap.LoadOrStore(typ, fn) + if loaded { + err := fmt.Errorf("pg: appender for the type=%s is already registered", + typ.String()) + panic(err) + } +} + +func Appender(typ reflect.Type) AppenderFunc { + if v, ok := appendersMap.Load(typ); ok { + return v.(AppenderFunc) + } + fn := appender(typ, false) + _, _ = appendersMap.LoadOrStore(typ, fn) + return fn +} + +func appender(typ reflect.Type, pgArray bool) AppenderFunc { + switch typ { + case timeType: + return appendTimeValue + case ipType: + return appendIPValue + case ipNetType: + return appendIPNetValue + case jsonRawMessageType: + return appendJSONRawMessageValue + } + + if typ.Implements(appenderType) { + return appendAppenderValue + } + if typ.Implements(driverValuerType) { + return appendDriverValuerValue + } + + kind := typ.Kind() + switch kind { + case reflect.Ptr: + return ptrAppenderFunc(typ) + case reflect.Slice: + if typ.Elem().Kind() == reflect.Uint8 { + return appendBytesValue + } + if pgArray { + return ArrayAppender(typ) + } + case reflect.Array: + if typ.Elem().Kind() == reflect.Uint8 { + return appendArrayBytesValue + } + } + return appenders[kind] +} + +func ptrAppenderFunc(typ reflect.Type) AppenderFunc { + appender := Appender(typ.Elem()) + return func(b []byte, v reflect.Value, flags int) []byte { + if v.IsNil() { + return AppendNull(b, flags) + } + return appender(b, v.Elem(), flags) + } +} + +func appendValue(b []byte, v reflect.Value, flags int) []byte { + if v.Kind() == reflect.Ptr && v.IsNil() { + return AppendNull(b, flags) + } + appender := Appender(v.Type()) + return appender(b, v, flags) +} + +func appendIfaceValue(b []byte, v reflect.Value, flags int) []byte { + return Append(b, v.Interface(), flags) +} + +func appendBoolValue(b []byte, v reflect.Value, _ int) []byte { + return appendBool(b, v.Bool()) +} + +func appendIntValue(b []byte, v reflect.Value, _ int) []byte { + return strconv.AppendInt(b, v.Int(), 10) +} + +func appendUintValue(b []byte, v reflect.Value, _ int) []byte { + return strconv.AppendUint(b, v.Uint(), 10) +} + +func appendFloat32Value(b []byte, v reflect.Value, flags int) []byte { + return appendFloat(b, v.Float(), flags, 32) +} + +func appendFloat64Value(b []byte, v reflect.Value, flags int) []byte { + return appendFloat(b, v.Float(), flags, 64) +} + +func appendBytesValue(b []byte, v reflect.Value, flags int) []byte { + return AppendBytes(b, v.Bytes(), flags) +} + +func appendArrayBytesValue(b []byte, v reflect.Value, flags int) []byte { + if v.CanAddr() { + return AppendBytes(b, v.Slice(0, v.Len()).Bytes(), flags) + } + + buf := bufpool.Get(v.Len()) + + tmp := buf.Bytes() + reflect.Copy(reflect.ValueOf(tmp), v) + b = AppendBytes(b, tmp, flags) + + bufpool.Put(buf) + + return b +} + +func appendStringValue(b []byte, v reflect.Value, flags int) []byte { + return AppendString(b, v.String(), flags) +} + +func appendStructValue(b []byte, v reflect.Value, flags int) []byte { + if v.Type() == timeType { + return appendTimeValue(b, v, flags) + } + return appendJSONValue(b, v, flags) +} + +var jsonPool bufpool.Pool + +func appendJSONValue(b []byte, v reflect.Value, flags int) []byte { + buf := jsonPool.Get() + defer jsonPool.Put(buf) + + if err := pgjson.NewEncoder(buf).Encode(v.Interface()); err != nil { + return AppendError(b, err) + } + + bb := buf.Bytes() + if len(bb) > 0 && bb[len(bb)-1] == '\n' { + bb = bb[:len(bb)-1] + } + + return AppendJSONB(b, bb, flags) +} + +func appendTimeValue(b []byte, v reflect.Value, flags int) []byte { + tm := v.Interface().(time.Time) + return AppendTime(b, tm, flags) +} + +func appendIPValue(b []byte, v reflect.Value, flags int) []byte { + ip := v.Interface().(net.IP) + return AppendString(b, ip.String(), flags) +} + +func appendIPNetValue(b []byte, v reflect.Value, flags int) []byte { + ipnet := v.Interface().(net.IPNet) + return AppendString(b, ipnet.String(), flags) +} + +func appendJSONRawMessageValue(b []byte, v reflect.Value, flags int) []byte { + return AppendString(b, internal.BytesToString(v.Bytes()), flags) +} + +func appendAppenderValue(b []byte, v reflect.Value, flags int) []byte { + return appendAppender(b, v.Interface().(ValueAppender), flags) +} + +func appendDriverValuerValue(b []byte, v reflect.Value, flags int) []byte { + return appendDriverValuer(b, v.Interface().(driver.Valuer), flags) +} + +func appendDriverValuer(b []byte, v driver.Valuer, flags int) []byte { + value, err := v.Value() + if err != nil { + return AppendError(b, err) + } + return Append(b, value, flags) +} diff --git a/vendor/github.com/go-pg/pg/v10/types/array.go b/vendor/github.com/go-pg/pg/v10/types/array.go new file mode 100644 index 000000000..fb70c1f50 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/types/array.go @@ -0,0 +1,58 @@ +package types + +import ( + "fmt" + "reflect" +) + +type Array struct { + v reflect.Value + + append AppenderFunc + scan ScannerFunc +} + +var ( + _ ValueAppender = (*Array)(nil) + _ ValueScanner = (*Array)(nil) +) + +func NewArray(vi interface{}) *Array { + v := reflect.ValueOf(vi) + if !v.IsValid() { + panic(fmt.Errorf("pg: Array(nil)")) + } + + return &Array{ + v: v, + + append: ArrayAppender(v.Type()), + scan: ArrayScanner(v.Type()), + } +} + +func (a *Array) AppendValue(b []byte, flags int) ([]byte, error) { + if a.append == nil { + panic(fmt.Errorf("pg: Array(unsupported %s)", a.v.Type())) + } + return a.append(b, a.v, flags), nil +} + +func (a *Array) ScanValue(rd Reader, n int) error { + if a.scan == nil { + return fmt.Errorf("pg: Array(unsupported %s)", a.v.Type()) + } + + if a.v.Kind() != reflect.Ptr { + return fmt.Errorf("pg: Array(non-pointer %s)", a.v.Type()) + } + + return a.scan(a.v.Elem(), rd, n) +} + +func (a *Array) Value() interface{} { + if a.v.IsValid() { + return a.v.Interface() + } + return nil +} diff --git a/vendor/github.com/go-pg/pg/v10/types/array_append.go b/vendor/github.com/go-pg/pg/v10/types/array_append.go new file mode 100644 index 000000000..a4132eb61 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/types/array_append.go @@ -0,0 +1,236 @@ +package types + +import ( + "reflect" + "strconv" + "sync" +) + +var ( + stringType = reflect.TypeOf((*string)(nil)).Elem() + sliceStringType = reflect.TypeOf([]string(nil)) + + intType = reflect.TypeOf((*int)(nil)).Elem() + sliceIntType = reflect.TypeOf([]int(nil)) + + int64Type = reflect.TypeOf((*int64)(nil)).Elem() + sliceInt64Type = reflect.TypeOf([]int64(nil)) + + float64Type = reflect.TypeOf((*float64)(nil)).Elem() + sliceFloat64Type = reflect.TypeOf([]float64(nil)) +) + +var arrayAppendersMap sync.Map + +func ArrayAppender(typ reflect.Type) AppenderFunc { + if v, ok := arrayAppendersMap.Load(typ); ok { + return v.(AppenderFunc) + } + fn := arrayAppender(typ) + arrayAppendersMap.Store(typ, fn) + return fn +} + +func arrayAppender(typ reflect.Type) AppenderFunc { + kind := typ.Kind() + if kind == reflect.Ptr { + typ = typ.Elem() + kind = typ.Kind() + } + + switch kind { + case reflect.Slice, reflect.Array: + // ok: + default: + return nil + } + + elemType := typ.Elem() + + if kind == reflect.Slice { + switch elemType { + case stringType: + return appendSliceStringValue + case intType: + return appendSliceIntValue + case int64Type: + return appendSliceInt64Value + case float64Type: + return appendSliceFloat64Value + } + } + + appendElem := appender(elemType, true) + return func(b []byte, v reflect.Value, flags int) []byte { + flags |= arrayFlag + + kind := v.Kind() + switch kind { + case reflect.Ptr, reflect.Slice: + if v.IsNil() { + return AppendNull(b, flags) + } + } + + if kind == reflect.Ptr { + v = v.Elem() + } + + quote := shouldQuoteArray(flags) + if quote { + b = append(b, '\'') + } + + flags |= subArrayFlag + + b = append(b, '{') + for i := 0; i < v.Len(); i++ { + elem := v.Index(i) + b = appendElem(b, elem, flags) + b = append(b, ',') + } + if v.Len() > 0 { + b[len(b)-1] = '}' // Replace trailing comma. + } else { + b = append(b, '}') + } + + if quote { + b = append(b, '\'') + } + + return b + } +} + +func appendSliceStringValue(b []byte, v reflect.Value, flags int) []byte { + ss := v.Convert(sliceStringType).Interface().([]string) + return appendSliceString(b, ss, flags) +} + +func appendSliceString(b []byte, ss []string, flags int) []byte { + if ss == nil { + return AppendNull(b, flags) + } + + quote := shouldQuoteArray(flags) + if quote { + b = append(b, '\'') + } + + b = append(b, '{') + for _, s := range ss { + b = appendString2(b, s, flags) + b = append(b, ',') + } + if len(ss) > 0 { + b[len(b)-1] = '}' // Replace trailing comma. + } else { + b = append(b, '}') + } + + if quote { + b = append(b, '\'') + } + + return b +} + +func appendSliceIntValue(b []byte, v reflect.Value, flags int) []byte { + ints := v.Convert(sliceIntType).Interface().([]int) + return appendSliceInt(b, ints, flags) +} + +func appendSliceInt(b []byte, ints []int, flags int) []byte { + if ints == nil { + return AppendNull(b, flags) + } + + quote := shouldQuoteArray(flags) + if quote { + b = append(b, '\'') + } + + b = append(b, '{') + for _, n := range ints { + b = strconv.AppendInt(b, int64(n), 10) + b = append(b, ',') + } + if len(ints) > 0 { + b[len(b)-1] = '}' // Replace trailing comma. + } else { + b = append(b, '}') + } + + if quote { + b = append(b, '\'') + } + + return b +} + +func appendSliceInt64Value(b []byte, v reflect.Value, flags int) []byte { + ints := v.Convert(sliceInt64Type).Interface().([]int64) + return appendSliceInt64(b, ints, flags) +} + +func appendSliceInt64(b []byte, ints []int64, flags int) []byte { + if ints == nil { + return AppendNull(b, flags) + } + + quote := shouldQuoteArray(flags) + if quote { + b = append(b, '\'') + } + + b = append(b, '{') + for _, n := range ints { + b = strconv.AppendInt(b, n, 10) + b = append(b, ',') + } + if len(ints) > 0 { + b[len(b)-1] = '}' // Replace trailing comma. + } else { + b = append(b, '}') + } + + if quote { + b = append(b, '\'') + } + + return b +} + +func appendSliceFloat64Value(b []byte, v reflect.Value, flags int) []byte { + floats := v.Convert(sliceFloat64Type).Interface().([]float64) + return appendSliceFloat64(b, floats, flags) +} + +func appendSliceFloat64(b []byte, floats []float64, flags int) []byte { + if floats == nil { + return AppendNull(b, flags) + } + + quote := shouldQuoteArray(flags) + if quote { + b = append(b, '\'') + } + + b = append(b, '{') + for _, n := range floats { + b = appendFloat2(b, n, flags) + b = append(b, ',') + } + if len(floats) > 0 { + b[len(b)-1] = '}' // Replace trailing comma. + } else { + b = append(b, '}') + } + + if quote { + b = append(b, '\'') + } + + return b +} diff --git a/vendor/github.com/go-pg/pg/v10/types/array_parser.go b/vendor/github.com/go-pg/pg/v10/types/array_parser.go new file mode 100644 index 000000000..0870a6568 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/types/array_parser.go @@ -0,0 +1,170 @@ +package types + +import ( + "bufio" + "bytes" + "errors" + "fmt" + "io" + + "github.com/go-pg/pg/v10/internal/parser" +) + +var errEndOfArray = errors.New("pg: end of array") + +type arrayParser struct { + p parser.StreamingParser + + stickyErr error + buf []byte +} + +func newArrayParserErr(err error) *arrayParser { + return &arrayParser{ + stickyErr: err, + buf: make([]byte, 32), + } +} + +func newArrayParser(rd Reader) *arrayParser { + p := parser.NewStreamingParser(rd) + err := p.SkipByte('{') + if err != nil { + return newArrayParserErr(err) + } + return &arrayParser{ + p: p, + } +} + +func (p *arrayParser) NextElem() ([]byte, error) { + if p.stickyErr != nil { + return nil, p.stickyErr + } + + c, err := p.p.ReadByte() + if err != nil { + if err == io.EOF { + return nil, errEndOfArray + } + return nil, err + } + + switch c { + case '"': + b, err := p.p.ReadSubstring(p.buf[:0]) + if err != nil { + return nil, err + } + p.buf = b + + err = p.readCommaBrace() + if err != nil { + return nil, err + } + + return b, nil + case '{': + b, err := p.readSubArray(p.buf[:0]) + if err != nil { + return nil, err + } + p.buf = b + + err = p.readCommaBrace() + if err != nil { + return nil, err + } + + return b, nil + case '}': + return nil, errEndOfArray + default: + err = p.p.UnreadByte() + if err != nil { + return nil, err + } + + b, err := p.readSimple(p.buf[:0]) + if err != nil { + return nil, err + } + p.buf = b + + if bytes.Equal(b, []byte("NULL")) { + return nil, nil + } + return b, nil + } +} + +func (p *arrayParser) readSimple(b []byte) ([]byte, error) { + for { + tmp, err := p.p.ReadSlice(',') + if err == nil { + b = append(b, tmp...) + b = b[:len(b)-1] + break + } + b = append(b, tmp...) + if err == bufio.ErrBufferFull { + continue + } + if err == io.EOF { + if b[len(b)-1] == '}' { + b = b[:len(b)-1] + break + } + } + return nil, err + } + return b, nil +} + +func (p *arrayParser) readSubArray(b []byte) ([]byte, error) { + b = append(b, '{') + for { + c, err := p.p.ReadByte() + if err != nil { + return nil, err + } + + if c == '}' { + b = append(b, '}') + return b, nil + } + + if c == '"' { + b = append(b, '"') + for { + tmp, err := p.p.ReadSlice('"') + b = append(b, tmp...) + if err != nil { + if err == bufio.ErrBufferFull { + continue + } + return nil, err + } + if len(b) > 1 && b[len(b)-2] != '\\' { + break + } + } + continue + } + + b = append(b, c) + } +} + +func (p *arrayParser) readCommaBrace() error { + c, err := p.p.ReadByte() + if err != nil { + return err + } + switch c { + case ',', '}': + return nil + default: + return fmt.Errorf("pg: got %q, wanted ',' or '}'", c) + } +} diff --git a/vendor/github.com/go-pg/pg/v10/types/array_scan.go b/vendor/github.com/go-pg/pg/v10/types/array_scan.go new file mode 100644 index 000000000..dbccafc06 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/types/array_scan.go @@ -0,0 +1,334 @@ +package types + +import ( + "fmt" + "reflect" + + "github.com/go-pg/pg/v10/internal" + "github.com/go-pg/pg/v10/internal/pool" +) + +var arrayValueScannerType = reflect.TypeOf((*ArrayValueScanner)(nil)).Elem() + +type ArrayValueScanner interface { + BeforeScanArrayValue(rd Reader, n int) error + ScanArrayValue(rd Reader, n int) error + AfterScanArrayValue() error +} + +func ArrayScanner(typ reflect.Type) ScannerFunc { + if typ.Implements(arrayValueScannerType) { + return scanArrayValueScannerValue + } + + kind := typ.Kind() + if kind == reflect.Ptr { + typ = typ.Elem() + kind = typ.Kind() + } + + switch kind { + case reflect.Slice, reflect.Array: + // ok: + default: + return nil + } + + elemType := typ.Elem() + + if kind == reflect.Slice { + switch elemType { + case stringType: + return scanStringArrayValue + case intType: + return scanIntArrayValue + case int64Type: + return scanInt64ArrayValue + case float64Type: + return scanFloat64ArrayValue + } + } + + scanElem := scanner(elemType, true) + return func(v reflect.Value, rd Reader, n int) error { + v = reflect.Indirect(v) + if !v.CanSet() { + return fmt.Errorf("pg: Scan(non-settable %s)", v.Type()) + } + + kind := v.Kind() + + if n == -1 { + if kind != reflect.Slice || !v.IsNil() { + v.Set(reflect.Zero(v.Type())) + } + return nil + } + + if kind == reflect.Slice { + if v.IsNil() { + v.Set(reflect.MakeSlice(v.Type(), 0, 0)) + } else if v.Len() > 0 { + v.Set(v.Slice(0, 0)) + } + } + + p := newArrayParser(rd) + nextValue := internal.MakeSliceNextElemFunc(v) + var elemRd *pool.BytesReader + + for { + elem, err := p.NextElem() + if err != nil { + if err == errEndOfArray { + break + } + return err + } + + if elemRd == nil { + elemRd = pool.NewBytesReader(elem) + } else { + elemRd.Reset(elem) + } + + var elemN int + if elem == nil { + elemN = -1 + } else { + elemN = len(elem) + } + + elemValue := nextValue() + err = scanElem(elemValue, elemRd, elemN) + if err != nil { + return err + } + } + + return nil + } +} + +func scanStringArrayValue(v reflect.Value, rd Reader, n int) error { + v = reflect.Indirect(v) + if !v.CanSet() { + return fmt.Errorf("pg: Scan(non-settable %s)", v.Type()) + } + + strings, err := scanStringArray(rd, n) + if err != nil { + return err + } + + v.Set(reflect.ValueOf(strings)) + return nil +} + +func scanStringArray(rd Reader, n int) ([]string, error) { + if n == -1 { + return nil, nil + } + + p := newArrayParser(rd) + slice := make([]string, 0) + for { + elem, err := p.NextElem() + if err != nil { + if err == errEndOfArray { + break + } + return nil, err + } + + slice = append(slice, string(elem)) + } + + return slice, nil +} + +func scanIntArrayValue(v reflect.Value, rd Reader, n int) error { + v = reflect.Indirect(v) + if !v.CanSet() { + return fmt.Errorf("pg: Scan(non-settable %s)", v.Type()) + } + + slice, err := decodeSliceInt(rd, n) + if err != nil { + return err + } + + v.Set(reflect.ValueOf(slice)) + return nil +} + +func decodeSliceInt(rd Reader, n int) ([]int, error) { + if n == -1 { + return nil, nil + } + + p := newArrayParser(rd) + slice := make([]int, 0) + for { + elem, err := p.NextElem() + if err != nil { + if err == errEndOfArray { + break + } + return nil, err + } + + if elem == nil { + slice = append(slice, 0) + continue + } + + n, err := internal.Atoi(elem) + if err != nil { + return nil, err + } + + slice = append(slice, n) + } + + return slice, nil +} + +func scanInt64ArrayValue(v reflect.Value, rd Reader, n int) error { + v = reflect.Indirect(v) + if !v.CanSet() { + return fmt.Errorf("pg: Scan(non-settable %s)", v.Type()) + } + + slice, err := scanInt64Array(rd, n) + if err != nil { + return err + } + + v.Set(reflect.ValueOf(slice)) + return nil +} + +func scanInt64Array(rd Reader, n int) ([]int64, error) { + if n == -1 { + return nil, nil + } + + p := newArrayParser(rd) + slice := make([]int64, 0) + for { + elem, err := p.NextElem() + if err != nil { + if err == errEndOfArray { + break + } + return nil, err + } + + if elem == nil { + slice = append(slice, 0) + continue + } + + n, err := internal.ParseInt(elem, 10, 64) + if err != nil { + return nil, err + } + + slice = append(slice, n) + } + + return slice, nil +} + +func scanFloat64ArrayValue(v reflect.Value, rd Reader, n int) error { + v = reflect.Indirect(v) + if !v.CanSet() { + return fmt.Errorf("pg: Scan(non-settable %s)", v.Type()) + } + + slice, err := scanFloat64Array(rd, n) + if err != nil { + return err + } + + v.Set(reflect.ValueOf(slice)) + return nil +} + +func scanFloat64Array(rd Reader, n int) ([]float64, error) { + if n == -1 { + return nil, nil + } + + p := newArrayParser(rd) + slice := make([]float64, 0) + for { + elem, err := p.NextElem() + if err != nil { + if err == errEndOfArray { + break + } + return nil, err + } + + if elem == nil { + slice = append(slice, 0) + continue + } + + n, err := internal.ParseFloat(elem, 64) + if err != nil { + return nil, err + } + + slice = append(slice, n) + } + + return slice, nil +} + +func scanArrayValueScannerValue(v reflect.Value, rd Reader, n int) error { + if n == -1 { + return nil + } + + scanner := v.Addr().Interface().(ArrayValueScanner) + + err := scanner.BeforeScanArrayValue(rd, n) + if err != nil { + return err + } + + p := newArrayParser(rd) + var elemRd *pool.BytesReader + for { + elem, err := p.NextElem() + if err != nil { + if err == errEndOfArray { + break + } + return err + } + + if elemRd == nil { + elemRd = pool.NewBytesReader(elem) + } else { + elemRd.Reset(elem) + } + + var elemN int + if elem == nil { + elemN = -1 + } else { + elemN = len(elem) + } + + err = scanner.ScanArrayValue(elemRd, elemN) + if err != nil { + return err + } + } + + return scanner.AfterScanArrayValue() +} diff --git a/vendor/github.com/go-pg/pg/v10/types/column.go b/vendor/github.com/go-pg/pg/v10/types/column.go new file mode 100644 index 000000000..e3470f3eb --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/types/column.go @@ -0,0 +1,113 @@ +package types + +import ( + "encoding/json" + + "github.com/go-pg/pg/v10/internal/pool" + "github.com/go-pg/pg/v10/pgjson" +) + +const ( + pgBool = 16 + + pgInt2 = 21 + pgInt4 = 23 + pgInt8 = 20 + + pgFloat4 = 700 + pgFloat8 = 701 + + pgText = 25 + pgVarchar = 1043 + pgBytea = 17 + pgJSON = 114 + pgJSONB = 3802 + + pgTimestamp = 1114 + pgTimestamptz = 1184 + + // pgInt2Array = 1005 + pgInt32Array = 1007 + pgInt8Array = 1016 + pgFloat8Array = 1022 + pgStringArray = 1009 + + pgUUID = 2950 +) + +type ColumnInfo = pool.ColumnInfo + +type RawValue struct { + Type int32 + Value string +} + +func (v RawValue) AppendValue(b []byte, flags int) ([]byte, error) { + return AppendString(b, v.Value, flags), nil +} + +func (v RawValue) MarshalJSON() ([]byte, error) { + return pgjson.Marshal(v.Value) +} + +func ReadColumnValue(col ColumnInfo, rd Reader, n int) (interface{}, error) { + switch col.DataType { + case pgBool: + return ScanBool(rd, n) + + case pgInt2: + n, err := scanInt64(rd, n, 16) + if err != nil { + return nil, err + } + return int16(n), nil + case pgInt4: + n, err := scanInt64(rd, n, 32) + if err != nil { + return nil, err + } + return int32(n), nil + case pgInt8: + return ScanInt64(rd, n) + + case pgFloat4: + return ScanFloat32(rd, n) + case pgFloat8: + return ScanFloat64(rd, n) + + case pgBytea: + return ScanBytes(rd, n) + case pgText, pgVarchar, pgUUID: + return ScanString(rd, n) + case pgJSON, pgJSONB: + s, err := ScanString(rd, n) + if err != nil { + return nil, err + } + return json.RawMessage(s), nil + + case pgTimestamp: + return ScanTime(rd, n) + case pgTimestamptz: + return ScanTime(rd, n) + + case pgInt32Array: + return scanInt64Array(rd, n) + case pgInt8Array: + return scanInt64Array(rd, n) + case pgFloat8Array: + return scanFloat64Array(rd, n) + case pgStringArray: + return scanStringArray(rd, n) + + default: + s, err := ScanString(rd, n) + if err != nil { + return nil, err + } + return RawValue{ + Type: col.DataType, + Value: s, + }, nil + } +} diff --git a/vendor/github.com/go-pg/pg/v10/types/doc.go b/vendor/github.com/go-pg/pg/v10/types/doc.go new file mode 100644 index 000000000..890ef3c08 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/types/doc.go @@ -0,0 +1,4 @@ +/* +The API in this package is not stable and may change without any notice. +*/ +package types diff --git a/vendor/github.com/go-pg/pg/v10/types/flags.go b/vendor/github.com/go-pg/pg/v10/types/flags.go new file mode 100644 index 000000000..10e415f14 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/types/flags.go @@ -0,0 +1,25 @@ +package types + +import "reflect" + +const ( + quoteFlag = 1 << iota + arrayFlag + subArrayFlag +) + +func hasFlag(flags, flag int) bool { + return flags&flag == flag +} + +func shouldQuoteArray(flags int) bool { + return hasFlag(flags, quoteFlag) && !hasFlag(flags, subArrayFlag) +} + +func nilable(v reflect.Value) bool { + switch v.Kind() { + case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice: + return true + } + return false +} diff --git a/vendor/github.com/go-pg/pg/v10/types/hex.go b/vendor/github.com/go-pg/pg/v10/types/hex.go new file mode 100644 index 000000000..8ae6469b9 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/types/hex.go @@ -0,0 +1,81 @@ +package types + +import ( + "bytes" + "encoding/hex" + "fmt" + "io" + + fasthex "github.com/tmthrgd/go-hex" +) + +type HexEncoder struct { + b []byte + flags int + written bool +} + +func NewHexEncoder(b []byte, flags int) *HexEncoder { + return &HexEncoder{ + b: b, + flags: flags, + } +} + +func (enc *HexEncoder) Bytes() []byte { + return enc.b +} + +func (enc *HexEncoder) Write(b []byte) (int, error) { + if !enc.written { + if hasFlag(enc.flags, arrayFlag) { + enc.b = append(enc.b, `"\`...) + } else if hasFlag(enc.flags, quoteFlag) { + enc.b = append(enc.b, '\'') + } + enc.b = append(enc.b, `\x`...) + enc.written = true + } + + i := len(enc.b) + enc.b = append(enc.b, make([]byte, fasthex.EncodedLen(len(b)))...) + fasthex.Encode(enc.b[i:], b) + + return len(b), nil +} + +func (enc *HexEncoder) Close() error { + if enc.written { + if hasFlag(enc.flags, arrayFlag) { + enc.b = append(enc.b, '"') + } else if hasFlag(enc.flags, quoteFlag) { + enc.b = append(enc.b, '\'') + } + } else { + enc.b = AppendNull(enc.b, enc.flags) + } + return nil +} + +//------------------------------------------------------------------------------ + +func NewHexDecoder(rd Reader, n int) (io.Reader, error) { + if n <= 0 { + var rd bytes.Reader + return &rd, nil + } + + if c, err := rd.ReadByte(); err != nil { + return nil, err + } else if c != '\\' { + return nil, fmt.Errorf("got %q, wanted %q", c, '\\') + } + + if c, err := rd.ReadByte(); err != nil { + return nil, err + } else if c != 'x' { + return nil, fmt.Errorf("got %q, wanted %q", c, 'x') + } + + return hex.NewDecoder(rd), nil +} diff --git a/vendor/github.com/go-pg/pg/v10/types/hstore.go b/vendor/github.com/go-pg/pg/v10/types/hstore.go new file mode 100644 index 000000000..58c214ac6 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/types/hstore.go @@ -0,0 +1,59 @@ +package types + +import ( + "fmt" + "reflect" +) + +type Hstore struct { + v reflect.Value + + append AppenderFunc + scan ScannerFunc +} + +var ( + _ ValueAppender = (*Hstore)(nil) + _ ValueScanner = (*Hstore)(nil) +) + +func NewHstore(vi interface{}) *Hstore { + v := reflect.ValueOf(vi) + if !v.IsValid() { + panic(fmt.Errorf("pg.Hstore(nil)")) + } + + typ := v.Type() + if typ.Kind() == reflect.Ptr { + typ = typ.Elem() + } + if typ.Kind() != reflect.Map { + panic(fmt.Errorf("pg.Hstore(unsupported %s)", typ)) + } + + return &Hstore{ + v: v, + + append: HstoreAppender(typ), + scan: HstoreScanner(typ), + } +} + +func (h *Hstore) Value() interface{} { + if h.v.IsValid() { + return h.v.Interface() + } + return nil +} + +func (h *Hstore) AppendValue(b []byte, flags int) ([]byte, error) { + return h.append(b, h.v, flags), nil +} + +func (h *Hstore) ScanValue(rd Reader, n int) error { + if h.v.Kind() != reflect.Ptr { + return fmt.Errorf("pg: Hstore(non-pointer %s)", h.v.Type()) + } + + return h.scan(h.v.Elem(), rd, n) +} diff --git a/vendor/github.com/go-pg/pg/v10/types/hstore_append.go b/vendor/github.com/go-pg/pg/v10/types/hstore_append.go new file mode 100644 index 000000000..e27292afa --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/types/hstore_append.go @@ -0,0 +1,50 @@ +package types + +import ( + "fmt" + "reflect" +) + +var mapStringStringType = reflect.TypeOf(map[string]string(nil)) + +func HstoreAppender(typ reflect.Type) AppenderFunc { + if typ.Key() == stringType && typ.Elem() == stringType { + return appendMapStringStringValue + } + + return func(b []byte, v reflect.Value, flags int) []byte { + err := fmt.Errorf("pg.Hstore(unsupported %s)", v.Type()) + return AppendError(b, err) + } +} + +func appendMapStringString(b []byte, m map[string]string, flags int) []byte { + if m == nil { + return AppendNull(b, flags) + } + + if hasFlag(flags, quoteFlag) { + b = append(b, '\'') + } + + for key, value := range m { + b = appendString2(b, key, flags) + b = append(b, '=', '>') + b = appendString2(b, value, flags) + b = append(b, ',') + } + if len(m) > 0 { + b = b[:len(b)-1] // Strip trailing comma. + } + + if hasFlag(flags, quoteFlag) { + b = append(b, '\'') + } + + return b +} + +func appendMapStringStringValue(b []byte, v reflect.Value, flags int) []byte { + m := v.Convert(mapStringStringType).Interface().(map[string]string) + return appendMapStringString(b, m, flags) +} diff --git a/vendor/github.com/go-pg/pg/v10/types/hstore_parser.go b/vendor/github.com/go-pg/pg/v10/types/hstore_parser.go new file mode 100644 index 000000000..79cd41eda --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/types/hstore_parser.go @@ -0,0 +1,65 @@ +package types + +import ( + "errors" + "io" + + "github.com/go-pg/pg/v10/internal/parser" +) + +var errEndOfHstore = errors.New("pg: end of hstore") + +type hstoreParser struct { + p parser.StreamingParser +} + +func newHstoreParser(rd Reader) *hstoreParser { + return &hstoreParser{ + p: parser.NewStreamingParser(rd), + } +} + +func (p *hstoreParser) NextKey() ([]byte, error) { + err := p.p.SkipByte('"') + if err != nil { + if err == io.EOF { + return nil, errEndOfHstore + } + return nil, err + } + + key, err := p.p.ReadSubstring(nil) + if err != nil { + return nil, err + } + + err = p.p.SkipByte('=') + if err != nil { + return nil, err + } + err = p.p.SkipByte('>') + if err != nil { + return nil, err + } + + return key, nil +} + +func (p *hstoreParser) NextValue() ([]byte, error) { + err := p.p.SkipByte('"') + if err != nil { + return nil, err + } + + value, err := p.p.ReadSubstring(nil) + if err != nil { + return nil, err + } + + err = p.p.SkipByte(',') + if err == nil { + _ = p.p.SkipByte(' ') + } + + return value, nil +} diff --git a/vendor/github.com/go-pg/pg/v10/types/hstore_scan.go b/vendor/github.com/go-pg/pg/v10/types/hstore_scan.go new file mode 100644 index 000000000..2061c6163 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/types/hstore_scan.go @@ -0,0 +1,51 @@ +package types + +import ( + "fmt" + "reflect" +) + +func HstoreScanner(typ reflect.Type) ScannerFunc { + if typ.Key() == stringType && typ.Elem() == stringType { + return scanMapStringStringValue + } + return func(v reflect.Value, rd Reader, n int) error { + return fmt.Errorf("pg.Hstore(unsupported %s)", v.Type()) + } +} + +func scanMapStringStringValue(v reflect.Value, rd Reader, n int) error { + m, err := scanMapStringString(rd, n) + if err != nil { + return err + } + + v.Set(reflect.ValueOf(m)) + return nil +} + +func scanMapStringString(rd Reader, n int) (map[string]string, error) { + if n == -1 { + return nil, nil + } + + p := newHstoreParser(rd) + m := make(map[string]string) + for { + key, err := p.NextKey() + if err != nil { + if err == errEndOfHstore { + break + } + return nil, err + } + + value, err := p.NextValue() + if err != nil { + return nil, err + } + + m[string(key)] = string(value) + } + return m, nil +} diff --git a/vendor/github.com/go-pg/pg/v10/types/in_op.go b/vendor/github.com/go-pg/pg/v10/types/in_op.go new file mode 100644 index 000000000..472b986d8 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/types/in_op.go @@ -0,0 +1,62 @@ +package types + +import ( + "fmt" + "reflect" +) + +type inOp struct { + slice reflect.Value + stickyErr error +} + +var _ ValueAppender = (*inOp)(nil) + +func InMulti(values ...interface{}) ValueAppender { + return &inOp{ + slice: reflect.ValueOf(values), + } +} + +func In(slice interface{}) ValueAppender { + v := reflect.ValueOf(slice) + if v.Kind() != reflect.Slice { + return &inOp{ + stickyErr: fmt.Errorf("pg: In(non-slice %T)", slice), + } + } + + return &inOp{ + slice: v, + } +} + +func (in *inOp) AppendValue(b []byte, flags int) ([]byte, error) { + if in.stickyErr != nil { + return nil, in.stickyErr + } + return appendIn(b, in.slice, flags), nil +} + +func appendIn(b []byte, slice reflect.Value, flags int) []byte { + sliceLen := slice.Len() + for i := 0; i < sliceLen; i++ { + if i > 0 { + b = append(b, ',') + } + + elem := slice.Index(i) + if elem.Kind() == reflect.Interface { + elem = elem.Elem() + } + + if elem.Kind() == reflect.Slice { + b = append(b, '(') + b = appendIn(b, elem, flags) + b = append(b, ')') + } else { + b = appendValue(b, elem, flags) + } + } + return b +} diff --git a/vendor/github.com/go-pg/pg/v10/types/null_time.go b/vendor/github.com/go-pg/pg/v10/types/null_time.go new file mode 100644 index 000000000..3c3f1f79a --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/types/null_time.go @@ -0,0 +1,58 @@ +package types + +import ( + "bytes" + "database/sql" + "encoding/json" + "time" +) + +var jsonNull = []byte("null") + +// NullTime is a time.Time wrapper that marshals zero time as JSON null and +// PostgreSQL NULL. +type NullTime struct { + time.Time +} + +var ( + _ json.Marshaler = (*NullTime)(nil) + _ json.Unmarshaler = (*NullTime)(nil) + _ sql.Scanner = (*NullTime)(nil) + _ ValueAppender = (*NullTime)(nil) +) + +func (tm NullTime) MarshalJSON() ([]byte, error) { + if tm.IsZero() { + return jsonNull, nil + } + return tm.Time.MarshalJSON() +} + +func (tm *NullTime) UnmarshalJSON(b []byte) error { + if bytes.Equal(b, jsonNull) { + tm.Time = time.Time{} + return nil + } + return tm.Time.UnmarshalJSON(b) +} + +func (tm NullTime) AppendValue(b []byte, flags int) ([]byte, error) { + if tm.IsZero() { + return AppendNull(b, flags), nil + } + return AppendTime(b, tm.Time, flags), nil +} + +func (tm *NullTime) Scan(b interface{}) error { + if b == nil { + tm.Time = time.Time{} + return nil + } + newtm, err := ParseTime(b.([]byte)) + if err != nil { + return err + } + tm.Time = newtm + return nil +} diff --git a/vendor/github.com/go-pg/pg/v10/types/scan.go b/vendor/github.com/go-pg/pg/v10/types/scan.go new file mode 100644 index 000000000..2e9c0cc85 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/types/scan.go @@ -0,0 +1,244 @@ +package types + +import ( + "errors" + "fmt" + "reflect" + "time" + + "github.com/tmthrgd/go-hex" + + "github.com/go-pg/pg/v10/internal" +) + +func Scan(v interface{}, rd Reader, n int) error { + var err error + switch v := v.(type) { + case *string: + *v, err = ScanString(rd, n) + return err + case *[]byte: + *v, err = ScanBytes(rd, n) + return err + case *int: + *v, err = ScanInt(rd, n) + return err + case *int64: + *v, err = ScanInt64(rd, n) + return err + case *float32: + *v, err = ScanFloat32(rd, n) + return err + case *float64: + *v, err = ScanFloat64(rd, n) + return err + case *time.Time: + *v, err = ScanTime(rd, n) + return err + } + + vv := reflect.ValueOf(v) + if !vv.IsValid() { + return errors.New("pg: Scan(nil)") + } + + if vv.Kind() != reflect.Ptr { + return fmt.Errorf("pg: Scan(non-pointer %T)", v) + } + if vv.IsNil() { + return fmt.Errorf("pg: Scan(non-settable %T)", v) + } + + vv = vv.Elem() + if vv.Kind() == reflect.Interface { + if vv.IsNil() { + return errors.New("pg: Scan(nil)") + } + + vv = vv.Elem() + if vv.Kind() != reflect.Ptr { + return fmt.Errorf("pg: Decode(non-pointer %s)", vv.Type().String()) + } + } + + return ScanValue(vv, rd, n) +} + +func ScanString(rd Reader, n int) (string, error) { + if n <= 0 { + return "", nil + } + + b, err := rd.ReadFull() + if err != nil { + return "", err + } + + return internal.BytesToString(b), nil +} + +func ScanBytes(rd Reader, n int) ([]byte, error) { + if n == -1 { + return nil, nil + } + if n == 0 { + return []byte{}, nil + } + + b := make([]byte, hex.DecodedLen(n-2)) + if err := ReadBytes(rd, b); err != nil { + return nil, err + } + return b, nil +} + +func ReadBytes(rd Reader, b []byte) error { + tmp, err := rd.ReadFullTemp() + if err != nil { + return err + } + + if len(tmp) < 2 { + return fmt.Errorf("pg: can't parse bytea: %q", tmp) + } + + if tmp[0] != '\\' || tmp[1] != 'x' { + return fmt.Errorf("pg: can't parse bytea: %q", tmp) + } + tmp = tmp[2:] // Trim off "\\x". + + if len(b) != hex.DecodedLen(len(tmp)) { + return fmt.Errorf("pg: too small buf to decode hex") + } + + if _, err := hex.Decode(b, tmp); err != nil { + return err + } + + return nil +} + +func ScanInt(rd Reader, n int) (int, error) { + if n <= 0 { + return 0, nil + } + + tmp, err := rd.ReadFullTemp() + if err != nil { + return 0, err + } + + num, err := internal.Atoi(tmp) + if err != nil { + return 0, err + } + + return num, nil +} + +func ScanInt64(rd Reader, n int) (int64, error) { + return scanInt64(rd, n, 64) +} + +func scanInt64(rd Reader, n int, bitSize int) (int64, error) { + if n <= 0 { + return 0, nil + } + + tmp, err := rd.ReadFullTemp() + if err != nil { + return 0, err + } + + num, err := internal.ParseInt(tmp, 10, bitSize) + if err != nil { + return 0, err + } + + return num, nil +} + +func ScanUint64(rd Reader, n int) (uint64, error) { + if n <= 0 { + return 0, nil + } + + tmp, err := rd.ReadFullTemp() + if err != nil { + return 0, err + } + + // PostgreSQL does not natively support uint64 - only int64. + // Be nice and accept negative int64. + if len(tmp) > 0 && tmp[0] == '-' { + num, err := internal.ParseInt(tmp, 10, 64) + if err != nil { + return 0, err + } + return uint64(num), nil + } + + num, err := internal.ParseUint(tmp, 10, 64) + if err != nil { + return 0, err + } + + return num, nil +} + +func ScanFloat32(rd Reader, n int) (float32, error) { + if n <= 0 { + return 0, nil + } + + tmp, err := rd.ReadFullTemp() + if err != nil { + return 0, err + } + + num, err := internal.ParseFloat(tmp, 32) + if err != nil { + return 0, err + } + + return float32(num), nil +} + +func ScanFloat64(rd Reader, n int) (float64, error) { + if n <= 0 { + return 0, nil + } + + tmp, err := rd.ReadFullTemp() + if err != nil { + return 0, err + } + + num, err := internal.ParseFloat(tmp, 64) + if err != nil { + return 0, err + } + + return num, nil +} + +func ScanTime(rd Reader, n int) (time.Time, error) { + if n <= 0 { + return time.Time{}, nil + } + + tmp, err := rd.ReadFullTemp() + if err != nil { + return time.Time{}, err + } + + return ParseTime(tmp) +} + +func ScanBool(rd Reader, n int) (bool, error) { + tmp, err := rd.ReadFullTemp() + if err != nil { + return false, err + } + return len(tmp) == 1 && (tmp[0] == 't' || tmp[0] == '1'), nil +} diff --git a/vendor/github.com/go-pg/pg/v10/types/scan_value.go b/vendor/github.com/go-pg/pg/v10/types/scan_value.go new file mode 100644 index 000000000..9f5a7bb6e --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/types/scan_value.go @@ -0,0 +1,418 @@ +package types + +import ( + "database/sql" + "encoding/json" + "errors" + "fmt" + "net" + "reflect" + "sync" + "time" + + "github.com/go-pg/pg/v10/internal" + "github.com/go-pg/pg/v10/pgjson" +) + +var ( + valueScannerType = reflect.TypeOf((*ValueScanner)(nil)).Elem() + sqlScannerType = reflect.TypeOf((*sql.Scanner)(nil)).Elem() + timeType = reflect.TypeOf((*time.Time)(nil)).Elem() + ipType = reflect.TypeOf((*net.IP)(nil)).Elem() + ipNetType = reflect.TypeOf((*net.IPNet)(nil)).Elem() + jsonRawMessageType = reflect.TypeOf((*json.RawMessage)(nil)).Elem() +) + +type ScannerFunc func(reflect.Value, Reader, int) error + +var valueScanners []ScannerFunc + +//nolint +func init() { + valueScanners = []ScannerFunc{ + reflect.Bool: scanBoolValue, + reflect.Int: scanInt64Value, + reflect.Int8: scanInt64Value, + reflect.Int16: scanInt64Value, + reflect.Int32: scanInt64Value, + reflect.Int64: scanInt64Value, + reflect.Uint: scanUint64Value, + reflect.Uint8: scanUint64Value, + reflect.Uint16: scanUint64Value, + reflect.Uint32: scanUint64Value, + reflect.Uint64: scanUint64Value, + reflect.Uintptr: nil, + reflect.Float32: scanFloat32Value, + reflect.Float64: scanFloat64Value, + reflect.Complex64: nil, + reflect.Complex128: nil, + reflect.Array: scanJSONValue, + reflect.Chan: nil, + reflect.Func: nil, + reflect.Interface: scanIfaceValue, + reflect.Map: scanJSONValue, + reflect.Ptr: nil, + reflect.Slice: scanJSONValue, + reflect.String: scanStringValue, + reflect.Struct: scanJSONValue, + reflect.UnsafePointer: nil, + } +} + +var scannersMap sync.Map + +// RegisterScanner registers an scanner func for the type. +// Expecting to be used only during initialization, it panics +// if there is already a registered scanner for the given type. +func RegisterScanner(value interface{}, fn ScannerFunc) { + registerScanner(reflect.TypeOf(value), fn) +} + +func registerScanner(typ reflect.Type, fn ScannerFunc) { + _, loaded := scannersMap.LoadOrStore(typ, fn) + if loaded { + err := fmt.Errorf("pg: scanner for the type=%s is already registered", + typ.String()) + panic(err) + } +} + +func Scanner(typ reflect.Type) ScannerFunc { + if v, ok := scannersMap.Load(typ); ok { + return v.(ScannerFunc) + } + fn := scanner(typ, false) + _, _ = scannersMap.LoadOrStore(typ, fn) + return fn +} + +func scanner(typ reflect.Type, pgArray bool) ScannerFunc { + switch typ { + case timeType: + return scanTimeValue + case ipType: + return scanIPValue + case ipNetType: + return scanIPNetValue + case jsonRawMessageType: + return scanJSONRawMessageValue + } + + if typ.Implements(valueScannerType) { + return scanValueScannerValue + } + if reflect.PtrTo(typ).Implements(valueScannerType) { + return scanValueScannerAddrValue + } + + if typ.Implements(sqlScannerType) { + return scanSQLScannerValue + } + if reflect.PtrTo(typ).Implements(sqlScannerType) { + return scanSQLScannerAddrValue + } + + kind := typ.Kind() + switch kind { + case reflect.Ptr: + return ptrScannerFunc(typ) + case reflect.Slice: + if typ.Elem().Kind() == reflect.Uint8 { + return scanBytesValue + } + if pgArray { + return ArrayScanner(typ) + } + case reflect.Array: + if typ.Elem().Kind() == reflect.Uint8 { + return scanArrayBytesValue + } + } + return valueScanners[kind] +} + +func ptrScannerFunc(typ reflect.Type) ScannerFunc { + scanner := Scanner(typ.Elem()) + return func(v reflect.Value, rd Reader, n int) error { + if scanner == nil { + return fmt.Errorf("pg: Scan(unsupported %s)", v.Type()) + } + + if n == -1 { + if v.IsNil() { + return nil + } + if !v.CanSet() { + return fmt.Errorf("pg: Scan(non-settable %s)", v.Type()) + } + v.Set(reflect.Zero(v.Type())) + return nil + } + + if v.IsNil() { + if !v.CanSet() { + return fmt.Errorf("pg: Scan(non-settable %s)", v.Type()) + } + v.Set(reflect.New(v.Type().Elem())) + } + + return scanner(v.Elem(), rd, n) + } +} + +func scanIfaceValue(v reflect.Value, rd Reader, n int) error { + if v.IsNil() { + return scanJSONValue(v, rd, n) + } + return ScanValue(v.Elem(), rd, n) +} + +func ScanValue(v reflect.Value, rd Reader, n int) error { + if !v.IsValid() { + return errors.New("pg: Scan(nil)") + } + + scanner := Scanner(v.Type()) + if scanner != nil { + return scanner(v, rd, n) + } + + if v.Kind() == reflect.Interface { + return errors.New("pg: Scan(nil)") + } + return fmt.Errorf("pg: Scan(unsupported %s)", v.Type()) +} + +func scanBoolValue(v reflect.Value, rd Reader, n int) error { + if n == -1 { + v.SetBool(false) + return nil + } + + flag, err := ScanBool(rd, n) + if err != nil { + return err + } + + v.SetBool(flag) + return nil +} + +func scanInt64Value(v reflect.Value, rd Reader, n int) error { + num, err := ScanInt64(rd, n) + if err != nil { + return err + } + + v.SetInt(num) + return nil +} + +func scanUint64Value(v reflect.Value, rd Reader, n int) error { + num, err := ScanUint64(rd, n) + if err != nil { + return err + } + + v.SetUint(num) + return nil +} + +func scanFloat32Value(v reflect.Value, rd Reader, n int) error { + num, err := ScanFloat32(rd, n) + if err != nil { + return err + } + + v.SetFloat(float64(num)) + return nil +} + +func scanFloat64Value(v reflect.Value, rd Reader, n int) error { + num, err := ScanFloat64(rd, n) + if err != nil { + return err + } + + v.SetFloat(num) + return nil +} + +func scanStringValue(v reflect.Value, rd Reader, n int) error { + s, err := ScanString(rd, n) + if err != nil { + return err + } + + v.SetString(s) + return nil +} + +func scanJSONValue(v reflect.Value, rd Reader, n int) error { + // Zero value so it works with SelectOrInsert. + // TODO: better handle slices + v.Set(reflect.New(v.Type()).Elem()) + + if n == -1 { + return nil + } + + dec := pgjson.NewDecoder(rd) + return dec.Decode(v.Addr().Interface()) +} + +func scanTimeValue(v reflect.Value, rd Reader, n int) error { + tm, err := ScanTime(rd, n) + if err != nil { + return err + } + + ptr := v.Addr().Interface().(*time.Time) + *ptr = tm + + return nil +} + +func scanIPValue(v reflect.Value, rd Reader, n int) error { + if n == -1 { + return nil + } + + tmp, err := rd.ReadFullTemp() + if err != nil { + return err + } + + ip := net.ParseIP(internal.BytesToString(tmp)) + if ip == nil { + return fmt.Errorf("pg: invalid ip=%q", tmp) + } + + ptr := v.Addr().Interface().(*net.IP) + *ptr = ip + + return nil +} + +var zeroIPNetValue = reflect.ValueOf(net.IPNet{}) + +func scanIPNetValue(v reflect.Value, rd Reader, n int) error { + if n == -1 { + v.Set(zeroIPNetValue) + return nil + } + + tmp, err := rd.ReadFullTemp() + if err != nil { + return err + } + + _, ipnet, err := net.ParseCIDR(internal.BytesToString(tmp)) + if err != nil { + return err + } + + ptr := v.Addr().Interface().(*net.IPNet) + *ptr = *ipnet + + return nil +} + +func scanJSONRawMessageValue(v reflect.Value, rd Reader, n int) error { + if n == -1 { + v.SetBytes(nil) + return nil + } + + b, err := rd.ReadFull() + if err != nil { + return err + } + + v.SetBytes(b) + return nil +} + +func scanBytesValue(v reflect.Value, rd Reader, n int) error { + if n == -1 { + v.SetBytes(nil) + return nil + } + + b, err := ScanBytes(rd, n) + if err != nil { + return err + } + + v.SetBytes(b) + return nil +} + +func scanArrayBytesValue(v reflect.Value, rd Reader, n int) error { + b := v.Slice(0, v.Len()).Bytes() + + if n == -1 { + for i := range b { + b[i] = 0 + } + return nil + } + + return ReadBytes(rd, b) +} + +func scanValueScannerValue(v reflect.Value, rd Reader, n int) error { + if n == -1 { + if v.IsNil() { + return nil + } + return v.Interface().(ValueScanner).ScanValue(rd, n) + } + + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } + + return v.Interface().(ValueScanner).ScanValue(rd, n) +} + +func scanValueScannerAddrValue(v reflect.Value, rd Reader, n int) error { + if !v.CanAddr() { + return fmt.Errorf("pg: Scan(non-settable %s)", v.Type()) + } + return v.Addr().Interface().(ValueScanner).ScanValue(rd, n) +} + +func scanSQLScannerValue(v reflect.Value, rd Reader, n int) error { + if n == -1 { + if nilable(v) && v.IsNil() { + return nil + } + return scanSQLScanner(v.Interface().(sql.Scanner), rd, n) + } + + if nilable(v) && v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } + + return scanSQLScanner(v.Interface().(sql.Scanner), rd, n) +} + +func scanSQLScannerAddrValue(v reflect.Value, rd Reader, n int) error { + if !v.CanAddr() { + return fmt.Errorf("pg: Scan(non-settable %s)", v.Type()) + } + return scanSQLScanner(v.Addr().Interface().(sql.Scanner), rd, n) +} + +func scanSQLScanner(scanner sql.Scanner, rd Reader, n int) error { + if n == -1 { + return scanner.Scan(nil) + } + + tmp, err := rd.ReadFullTemp() + if err != nil { + return err + } + return scanner.Scan(tmp) +} diff --git a/vendor/github.com/go-pg/pg/v10/types/time.go b/vendor/github.com/go-pg/pg/v10/types/time.go new file mode 100644 index 000000000..e68a7a19a --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/types/time.go @@ -0,0 +1,56 @@ +package types + +import ( + "time" + + "github.com/go-pg/pg/v10/internal" +) + +const ( + dateFormat = "2006-01-02" + timeFormat = "15:04:05.999999999" + timestampFormat = "2006-01-02 15:04:05.999999999" + timestamptzFormat = "2006-01-02 15:04:05.999999999-07:00:00" + timestamptzFormat2 = "2006-01-02 15:04:05.999999999-07:00" + timestamptzFormat3 = "2006-01-02 15:04:05.999999999-07" +) + +func ParseTime(b []byte) (time.Time, error) { + s := internal.BytesToString(b) + return ParseTimeString(s) +} + +func ParseTimeString(s string) (time.Time, error) { + switch l := len(s); { + case l <= len(timeFormat): + if s[2] == ':' { + return time.ParseInLocation(timeFormat, s, time.UTC) + } + return time.ParseInLocation(dateFormat, s, time.UTC) + default: + if s[10] == 'T' { + return time.Parse(time.RFC3339Nano, s) + } + if c := s[l-9]; c == '+' || c == '-' { + return time.Parse(timestamptzFormat, s) + } + if c := s[l-6]; c == '+' || c == '-' { + return time.Parse(timestamptzFormat2, s) + } + if c := s[l-3]; c == '+' || c == '-' { + return time.Parse(timestamptzFormat3, s) + } + return time.ParseInLocation(timestampFormat, s, time.UTC) + } +} + +func AppendTime(b []byte, tm time.Time, flags int) []byte { + if flags == 1 { + b = append(b, '\'') + } + b = tm.UTC().AppendFormat(b, timestamptzFormat) + if flags == 1 { + b = append(b, '\'') + } + return b +} diff --git a/vendor/github.com/go-pg/pg/v10/types/types.go b/vendor/github.com/go-pg/pg/v10/types/types.go new file mode 100644 index 000000000..718ac2933 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/types/types.go @@ -0,0 +1,37 @@ +package types + +import ( + "github.com/go-pg/pg/v10/internal/pool" +) + +type Reader = pool.Reader + +type ValueScanner interface { + ScanValue(rd Reader, n int) error +} + +type ValueAppender interface { + AppendValue(b []byte, flags int) ([]byte, error) +} + +//------------------------------------------------------------------------------ + +// Safe represents a safe SQL query. +type Safe string + +var _ ValueAppender = (*Safe)(nil) + +func (q Safe) AppendValue(b []byte, flags int) ([]byte, error) { + return append(b, q...), nil +} + +//------------------------------------------------------------------------------ + +// Ident represents a SQL identifier, e.g. table or column name. +type Ident string + +var _ ValueAppender = (*Ident)(nil) + +func (f Ident) AppendValue(b []byte, flags int) ([]byte, error) { + return AppendIdent(b, string(f), flags), nil +} |