diff options
Diffstat (limited to 'vendor/github.com/go-pg/pg/v10/orm')
37 files changed, 8158 insertions, 0 deletions
diff --git a/vendor/github.com/go-pg/pg/v10/orm/composite.go b/vendor/github.com/go-pg/pg/v10/orm/composite.go new file mode 100644 index 000000000..d2e48a8b3 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/orm/composite.go @@ -0,0 +1,100 @@ +package orm + +import ( + "fmt" + "reflect" + + "github.com/go-pg/pg/v10/internal/pool" + "github.com/go-pg/pg/v10/types" +) + +func compositeScanner(typ reflect.Type) types.ScannerFunc { + if typ.Kind() == reflect.Ptr { + typ = typ.Elem() + } + + var table *Table + return func(v reflect.Value, rd types.Reader, n int) error { + if n == -1 { + v.Set(reflect.Zero(v.Type())) + return nil + } + + if table == nil { + table = GetTable(typ) + } + if v.Kind() == reflect.Ptr { + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } + v = v.Elem() + } + + p := newCompositeParser(rd) + var elemReader *pool.BytesReader + + var firstErr error + for i := 0; ; i++ { + elem, err := p.NextElem() + if err != nil { + if err == errEndOfComposite { + break + } + return err + } + + if i >= len(table.Fields) { + if firstErr == nil { + firstErr = fmt.Errorf( + "pg: %s has %d fields, but composite requires at least %d values", + table, len(table.Fields), i) + } + continue + } + + if elemReader == nil { + elemReader = pool.NewBytesReader(elem) + } else { + elemReader.Reset(elem) + } + + field := table.Fields[i] + if elem == nil { + err = field.ScanValue(v, elemReader, -1) + } else { + err = field.ScanValue(v, elemReader, len(elem)) + } + if err != nil && firstErr == nil { + firstErr = err + } + } + + return firstErr + } +} + +func compositeAppender(typ reflect.Type) types.AppenderFunc { + if typ.Kind() == reflect.Ptr { + typ = typ.Elem() + } + + var table *Table + return func(b []byte, v reflect.Value, quote int) []byte { + if table == nil { + table = GetTable(typ) + } + if v.Kind() == reflect.Ptr { + v = v.Elem() + } + + b = append(b, "ROW("...) + for i, f := range table.Fields { + if i > 0 { + b = append(b, ',') + } + b = f.AppendValue(b, v, quote) + } + b = append(b, ')') + return b + } +} diff --git a/vendor/github.com/go-pg/pg/v10/orm/composite_create.go b/vendor/github.com/go-pg/pg/v10/orm/composite_create.go new file mode 100644 index 000000000..fd60a94e4 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/orm/composite_create.go @@ -0,0 +1,89 @@ +package orm + +import ( + "strconv" +) + +type CreateCompositeOptions struct { + Varchar int // replaces PostgreSQL data type `text` with `varchar(n)` +} + +type CreateCompositeQuery struct { + q *Query + opt *CreateCompositeOptions +} + +var ( + _ QueryAppender = (*CreateCompositeQuery)(nil) + _ QueryCommand = (*CreateCompositeQuery)(nil) +) + +func NewCreateCompositeQuery(q *Query, opt *CreateCompositeOptions) *CreateCompositeQuery { + return &CreateCompositeQuery{ + q: q, + opt: opt, + } +} + +func (q *CreateCompositeQuery) String() string { + b, err := q.AppendQuery(defaultFmter, nil) + if err != nil { + panic(err) + } + return string(b) +} + +func (q *CreateCompositeQuery) Operation() QueryOp { + return CreateCompositeOp +} + +func (q *CreateCompositeQuery) Clone() QueryCommand { + return &CreateCompositeQuery{ + q: q.q.Clone(), + opt: q.opt, + } +} + +func (q *CreateCompositeQuery) Query() *Query { + return q.q +} + +func (q *CreateCompositeQuery) AppendTemplate(b []byte) ([]byte, error) { + return q.AppendQuery(dummyFormatter{}, b) +} + +func (q *CreateCompositeQuery) AppendQuery(fmter QueryFormatter, b []byte) ([]byte, error) { + if q.q.stickyErr != nil { + return nil, q.q.stickyErr + } + if q.q.tableModel == nil { + return nil, errModelNil + } + + table := q.q.tableModel.Table() + + b = append(b, "CREATE TYPE "...) + b = append(b, table.Alias...) + b = append(b, " AS ("...) + + for i, field := range table.Fields { + if i > 0 { + b = append(b, ", "...) + } + + b = append(b, field.Column...) + b = append(b, " "...) + if field.UserSQLType == "" && q.opt != nil && q.opt.Varchar > 0 && + field.SQLType == "text" { + b = append(b, "varchar("...) + b = strconv.AppendInt(b, int64(q.opt.Varchar), 10) + b = append(b, ")"...) + } else { + b = append(b, field.SQLType...) + } + } + + b = append(b, ")"...) + + return b, q.q.stickyErr +} diff --git a/vendor/github.com/go-pg/pg/v10/orm/composite_drop.go b/vendor/github.com/go-pg/pg/v10/orm/composite_drop.go new file mode 100644 index 000000000..2a169b07a --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/orm/composite_drop.go @@ -0,0 +1,70 @@ +package orm + +type DropCompositeOptions struct { + IfExists bool + Cascade bool +} + +type DropCompositeQuery struct { + q *Query + opt *DropCompositeOptions +} + +var ( + _ QueryAppender = (*DropCompositeQuery)(nil) + _ QueryCommand = (*DropCompositeQuery)(nil) +) + +func NewDropCompositeQuery(q *Query, opt *DropCompositeOptions) *DropCompositeQuery { + return &DropCompositeQuery{ + q: q, + opt: opt, + } +} + +func (q *DropCompositeQuery) String() string { + b, err := q.AppendQuery(defaultFmter, nil) + if err != nil { + panic(err) + } + return string(b) +} + +func (q *DropCompositeQuery) Operation() QueryOp { + return DropCompositeOp +} + +func (q *DropCompositeQuery) Clone() QueryCommand { + return &DropCompositeQuery{ + q: q.q.Clone(), + opt: q.opt, + } +} + +func (q *DropCompositeQuery) Query() *Query { + return q.q +} + +func (q *DropCompositeQuery) AppendTemplate(b []byte) ([]byte, error) { + return q.AppendQuery(dummyFormatter{}, b) +} + +func (q *DropCompositeQuery) AppendQuery(fmter QueryFormatter, b []byte) ([]byte, error) { + if q.q.stickyErr != nil { + return nil, q.q.stickyErr + } + if q.q.tableModel == nil { + return nil, errModelNil + } + + b = append(b, "DROP TYPE "...) + if q.opt != nil && q.opt.IfExists { + b = append(b, "IF EXISTS "...) + } + b = append(b, q.q.tableModel.Table().Alias...) + if q.opt != nil && q.opt.Cascade { + b = append(b, " CASCADE"...) + } + + return b, q.q.stickyErr +} diff --git a/vendor/github.com/go-pg/pg/v10/orm/composite_parser.go b/vendor/github.com/go-pg/pg/v10/orm/composite_parser.go new file mode 100644 index 000000000..29e500444 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/orm/composite_parser.go @@ -0,0 +1,140 @@ +package orm + +import ( + "bufio" + "errors" + "fmt" + "io" + + "github.com/go-pg/pg/v10/internal/parser" + "github.com/go-pg/pg/v10/types" +) + +var errEndOfComposite = errors.New("pg: end of composite") + +type compositeParser struct { + p parser.StreamingParser + + stickyErr error +} + +func newCompositeParserErr(err error) *compositeParser { + return &compositeParser{ + stickyErr: err, + } +} + +func newCompositeParser(rd types.Reader) *compositeParser { + p := parser.NewStreamingParser(rd) + err := p.SkipByte('(') + if err != nil { + return newCompositeParserErr(err) + } + return &compositeParser{ + p: p, + } +} + +func (p *compositeParser) NextElem() ([]byte, error) { + if p.stickyErr != nil { + return nil, p.stickyErr + } + + c, err := p.p.ReadByte() + if err != nil { + if err == io.EOF { + return nil, errEndOfComposite + } + return nil, err + } + + switch c { + case '"': + return p.readQuoted() + case ',': + return nil, nil + case ')': + return nil, errEndOfComposite + default: + _ = p.p.UnreadByte() + } + + var b []byte + for { + tmp, err := p.p.ReadSlice(',') + if err == nil { + if b == nil { + b = tmp + } else { + b = append(b, tmp...) + } + b = b[:len(b)-1] + break + } + b = append(b, tmp...) + if err == bufio.ErrBufferFull { + continue + } + if err == io.EOF { + if b[len(b)-1] == ')' { + b = b[:len(b)-1] + break + } + } + return nil, err + } + + if len(b) == 0 { // NULL + return nil, nil + } + return b, nil +} + +func (p *compositeParser) readQuoted() ([]byte, error) { + var b []byte + + c, err := p.p.ReadByte() + if err != nil { + return nil, err + } + + for { + next, err := p.p.ReadByte() + if err != nil { + return nil, err + } + + if c == '\\' || c == '\'' { + if next == c { + b = append(b, c) + c, err = p.p.ReadByte() + if err != nil { + return nil, err + } + } else { + b = append(b, c) + c = next + } + continue + } + + if c == '"' { + switch next { + case '"': + b = append(b, '"') + c, err = p.p.ReadByte() + if err != nil { + return nil, err + } + case ',', ')': + return b, nil + default: + return nil, fmt.Errorf("pg: got %q, wanted ',' or ')'", c) + } + continue + } + + b = append(b, c) + c = next + } +} diff --git a/vendor/github.com/go-pg/pg/v10/orm/count_estimate.go b/vendor/github.com/go-pg/pg/v10/orm/count_estimate.go new file mode 100644 index 000000000..bfa664a72 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/orm/count_estimate.go @@ -0,0 +1,90 @@ +package orm + +import ( + "fmt" + + "github.com/go-pg/pg/v10/internal" +) + +// Placeholder that is replaced with count(*). +const placeholder = `'_go_pg_placeholder'` + +// https://wiki.postgresql.org/wiki/Count_estimate +//nolint +var pgCountEstimateFunc = fmt.Sprintf(` +CREATE OR REPLACE FUNCTION _go_pg_count_estimate_v2(query text, threshold int) +RETURNS int AS $$ +DECLARE + rec record; + nrows int; +BEGIN + FOR rec IN EXECUTE 'EXPLAIN ' || query LOOP + nrows := substring(rec."QUERY PLAN" FROM ' rows=(\d+)'); + EXIT WHEN nrows IS NOT NULL; + END LOOP; + + -- Return the estimation if there are too many rows. + IF nrows > threshold THEN + RETURN nrows; + END IF; + + -- Otherwise execute real count query. + query := replace(query, 'SELECT '%s'', 'SELECT count(*)'); + EXECUTE query INTO nrows; + + IF nrows IS NULL THEN + nrows := 0; + END IF; + + RETURN nrows; +END; +$$ LANGUAGE plpgsql; +`, placeholder) + +// CountEstimate uses EXPLAIN to get estimated number of rows returned the query. +// If that number is bigger than the threshold it returns the estimation. +// Otherwise it executes another query using count aggregate function and +// returns the result. +// +// Based on https://wiki.postgresql.org/wiki/Count_estimate +func (q *Query) CountEstimate(threshold int) (int, error) { + if q.stickyErr != nil { + return 0, q.stickyErr + } + + query, err := q.countSelectQuery(placeholder).AppendQuery(q.db.Formatter(), nil) + if err != nil { + return 0, err + } + + for i := 0; i < 3; i++ { + var count int + _, err = q.db.QueryOneContext( + q.ctx, + Scan(&count), + "SELECT _go_pg_count_estimate_v2(?, ?)", + string(query), threshold, + ) + if err != nil { + if pgerr, ok := err.(internal.PGError); ok && pgerr.Field('C') == "42883" { + // undefined_function + err = q.createCountEstimateFunc() + if err != nil { + pgerr, ok := err.(internal.PGError) + if !ok || !pgerr.IntegrityViolation() { + return 0, err + } + } + continue + } + } + return count, err + } + + return 0, err +} + +func (q *Query) createCountEstimateFunc() error { + _, err := q.db.ExecContext(q.ctx, pgCountEstimateFunc) + return err +} diff --git a/vendor/github.com/go-pg/pg/v10/orm/delete.go b/vendor/github.com/go-pg/pg/v10/orm/delete.go new file mode 100644 index 000000000..c54cd10f8 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/orm/delete.go @@ -0,0 +1,158 @@ +package orm + +import ( + "reflect" + + "github.com/go-pg/pg/v10/types" +) + +type DeleteQuery struct { + q *Query + placeholder bool +} + +var ( + _ QueryAppender = (*DeleteQuery)(nil) + _ QueryCommand = (*DeleteQuery)(nil) +) + +func NewDeleteQuery(q *Query) *DeleteQuery { + return &DeleteQuery{ + q: q, + } +} + +func (q *DeleteQuery) String() string { + b, err := q.AppendQuery(defaultFmter, nil) + if err != nil { + panic(err) + } + return string(b) +} + +func (q *DeleteQuery) Operation() QueryOp { + return DeleteOp +} + +func (q *DeleteQuery) Clone() QueryCommand { + return &DeleteQuery{ + q: q.q.Clone(), + placeholder: q.placeholder, + } +} + +func (q *DeleteQuery) Query() *Query { + return q.q +} + +func (q *DeleteQuery) AppendTemplate(b []byte) ([]byte, error) { + cp := q.Clone().(*DeleteQuery) + cp.placeholder = true + return cp.AppendQuery(dummyFormatter{}, b) +} + +func (q *DeleteQuery) AppendQuery(fmter QueryFormatter, b []byte) (_ []byte, err error) { + if q.q.stickyErr != nil { + return nil, q.q.stickyErr + } + + if len(q.q.with) > 0 { + b, err = q.q.appendWith(fmter, b) + if err != nil { + return nil, err + } + } + + b = append(b, "DELETE FROM "...) + b, err = q.q.appendFirstTableWithAlias(fmter, b) + if err != nil { + return nil, err + } + + if q.q.hasMultiTables() { + b = append(b, " USING "...) + b, err = q.q.appendOtherTables(fmter, b) + if err != nil { + return nil, err + } + } + + b = append(b, " WHERE "...) + value := q.q.tableModel.Value() + + if q.q.isSliceModelWithData() { + if len(q.q.where) > 0 { + b, err = q.q.appendWhere(fmter, b) + if err != nil { + return nil, err + } + } else { + table := q.q.tableModel.Table() + err = table.checkPKs() + if err != nil { + return nil, err + } + + b = appendColumnAndSliceValue(fmter, b, value, table.Alias, table.PKs) + } + } else { + b, err = q.q.mustAppendWhere(fmter, b) + if err != nil { + return nil, err + } + } + + if len(q.q.returning) > 0 { + b, err = q.q.appendReturning(fmter, b) + if err != nil { + return nil, err + } + } + + return b, q.q.stickyErr +} + +func appendColumnAndSliceValue( + fmter QueryFormatter, b []byte, slice reflect.Value, alias types.Safe, fields []*Field, +) []byte { + if len(fields) > 1 { + b = append(b, '(') + } + b = appendColumns(b, alias, fields) + if len(fields) > 1 { + b = append(b, ')') + } + + b = append(b, " IN ("...) + + isPlaceholder := isTemplateFormatter(fmter) + sliceLen := slice.Len() + for i := 0; i < sliceLen; i++ { + if i > 0 { + b = append(b, ", "...) + } + + el := indirect(slice.Index(i)) + + if len(fields) > 1 { + b = append(b, '(') + } + for i, f := range fields { + if i > 0 { + b = append(b, ", "...) + } + if isPlaceholder { + b = append(b, '?') + } else { + b = f.AppendValue(b, el, 1) + } + } + if len(fields) > 1 { + b = append(b, ')') + } + } + + b = append(b, ')') + + return b +} diff --git a/vendor/github.com/go-pg/pg/v10/orm/field.go b/vendor/github.com/go-pg/pg/v10/orm/field.go new file mode 100644 index 000000000..fe9b4abea --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/orm/field.go @@ -0,0 +1,146 @@ +package orm + +import ( + "fmt" + "reflect" + + "github.com/go-pg/pg/v10/types" + "github.com/go-pg/zerochecker" +) + +const ( + PrimaryKeyFlag = uint8(1) << iota + ForeignKeyFlag + NotNullFlag + UseZeroFlag + UniqueFlag + ArrayFlag +) + +type Field struct { + Field reflect.StructField + Type reflect.Type + Index []int + + GoName string // struct field name, e.g. Id + SQLName string // SQL name, .e.g. id + Column types.Safe // escaped SQL name, e.g. "id" + SQLType string + UserSQLType string + Default types.Safe + OnDelete string + OnUpdate string + + flags uint8 + + append types.AppenderFunc + scan types.ScannerFunc + + isZero zerochecker.Func +} + +func indexEqual(ind1, ind2 []int) bool { + if len(ind1) != len(ind2) { + return false + } + for i, ind := range ind1 { + if ind != ind2[i] { + return false + } + } + return true +} + +func (f *Field) Clone() *Field { + cp := *f + cp.Index = cp.Index[:len(f.Index):len(f.Index)] + return &cp +} + +func (f *Field) setFlag(flag uint8) { + f.flags |= flag +} + +func (f *Field) hasFlag(flag uint8) bool { + return f.flags&flag != 0 +} + +func (f *Field) Value(strct reflect.Value) reflect.Value { + return fieldByIndexAlloc(strct, f.Index) +} + +func (f *Field) HasZeroValue(strct reflect.Value) bool { + return f.hasZeroValue(strct, f.Index) +} + +func (f *Field) hasZeroValue(v reflect.Value, index []int) bool { + for _, idx := range index { + if v.Kind() == reflect.Ptr { + if v.IsNil() { + return true + } + v = v.Elem() + } + v = v.Field(idx) + } + return f.isZero(v) +} + +func (f *Field) NullZero() bool { + return !f.hasFlag(UseZeroFlag) +} + +func (f *Field) AppendValue(b []byte, strct reflect.Value, quote int) []byte { + fv, ok := fieldByIndex(strct, f.Index) + if !ok { + return types.AppendNull(b, quote) + } + + if f.NullZero() && f.isZero(fv) { + return types.AppendNull(b, quote) + } + if f.append == nil { + panic(fmt.Errorf("pg: AppendValue(unsupported %s)", fv.Type())) + } + return f.append(b, fv, quote) +} + +func (f *Field) ScanValue(strct reflect.Value, rd types.Reader, n int) error { + if f.scan == nil { + return fmt.Errorf("pg: ScanValue(unsupported %s)", f.Type) + } + + var fv reflect.Value + if n == -1 { + var ok bool + fv, ok = fieldByIndex(strct, f.Index) + if !ok { + return nil + } + } else { + fv = fieldByIndexAlloc(strct, f.Index) + } + + return f.scan(fv, rd, n) +} + +type Method struct { + Index int + + flags int8 + + appender func([]byte, reflect.Value, int) []byte +} + +func (m *Method) Has(flag int8) bool { + return m.flags&flag != 0 +} + +func (m *Method) Value(strct reflect.Value) reflect.Value { + return strct.Method(m.Index).Call(nil)[0] +} + +func (m *Method) AppendValue(dst []byte, strct reflect.Value, quote int) []byte { + mv := m.Value(strct) + return m.appender(dst, mv, quote) +} diff --git a/vendor/github.com/go-pg/pg/v10/orm/format.go b/vendor/github.com/go-pg/pg/v10/orm/format.go new file mode 100644 index 000000000..9945f6e1d --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/orm/format.go @@ -0,0 +1,333 @@ +package orm + +import ( + "bytes" + "fmt" + "sort" + "strconv" + "strings" + + "github.com/go-pg/pg/v10/internal" + "github.com/go-pg/pg/v10/internal/parser" + "github.com/go-pg/pg/v10/types" +) + +var defaultFmter = NewFormatter() + +type queryWithSepAppender interface { + QueryAppender + AppendSep([]byte) []byte +} + +//------------------------------------------------------------------------------ + +type SafeQueryAppender struct { + query string + params []interface{} +} + +var ( + _ QueryAppender = (*SafeQueryAppender)(nil) + _ types.ValueAppender = (*SafeQueryAppender)(nil) +) + +//nolint +func SafeQuery(query string, params ...interface{}) *SafeQueryAppender { + return &SafeQueryAppender{query, params} +} + +func (q *SafeQueryAppender) AppendQuery(fmter QueryFormatter, b []byte) ([]byte, error) { + return fmter.FormatQuery(b, q.query, q.params...), nil +} + +func (q *SafeQueryAppender) AppendValue(b []byte, quote int) ([]byte, error) { + return q.AppendQuery(defaultFmter, b) +} + +func (q *SafeQueryAppender) Value() types.Safe { + b, err := q.AppendValue(nil, 1) + if err != nil { + return types.Safe(err.Error()) + } + return types.Safe(internal.BytesToString(b)) +} + +//------------------------------------------------------------------------------ + +type condGroupAppender struct { + sep string + cond []queryWithSepAppender +} + +var ( + _ QueryAppender = (*condAppender)(nil) + _ queryWithSepAppender = (*condAppender)(nil) +) + +func (q *condGroupAppender) AppendSep(b []byte) []byte { + return append(b, q.sep...) +} + +func (q *condGroupAppender) AppendQuery(fmter QueryFormatter, b []byte) (_ []byte, err error) { + b = append(b, '(') + for i, app := range q.cond { + if i > 0 { + b = app.AppendSep(b) + } + b, err = app.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + b = append(b, ')') + return b, nil +} + +//------------------------------------------------------------------------------ + +type condAppender struct { + sep string + cond string + params []interface{} +} + +var ( + _ QueryAppender = (*condAppender)(nil) + _ queryWithSepAppender = (*condAppender)(nil) +) + +func (q *condAppender) AppendSep(b []byte) []byte { + return append(b, q.sep...) +} + +func (q *condAppender) AppendQuery(fmter QueryFormatter, b []byte) ([]byte, error) { + b = append(b, '(') + b = fmter.FormatQuery(b, q.cond, q.params...) + b = append(b, ')') + return b, nil +} + +//------------------------------------------------------------------------------ + +type fieldAppender struct { + field string +} + +var _ QueryAppender = (*fieldAppender)(nil) + +func (a fieldAppender) AppendQuery(fmter QueryFormatter, b []byte) ([]byte, error) { + return types.AppendIdent(b, a.field, 1), nil +} + +//------------------------------------------------------------------------------ + +type dummyFormatter struct{} + +func (f dummyFormatter) FormatQuery(b []byte, query string, params ...interface{}) []byte { + return append(b, query...) +} + +func isTemplateFormatter(fmter QueryFormatter) bool { + _, ok := fmter.(dummyFormatter) + return ok +} + +//------------------------------------------------------------------------------ + +type QueryFormatter interface { + FormatQuery(b []byte, query string, params ...interface{}) []byte +} + +type Formatter struct { + namedParams map[string]interface{} + model TableModel +} + +var _ QueryFormatter = (*Formatter)(nil) + +func NewFormatter() *Formatter { + return new(Formatter) +} + +func (f *Formatter) String() string { + if len(f.namedParams) == 0 { + return "" + } + + keys := make([]string, len(f.namedParams)) + index := 0 + for k := range f.namedParams { + keys[index] = k + index++ + } + + sort.Strings(keys) + + ss := make([]string, len(keys)) + for i, k := range keys { + ss[i] = fmt.Sprintf("%s=%v", k, f.namedParams[k]) + } + return " " + strings.Join(ss, " ") +} + +func (f *Formatter) clone() *Formatter { + cp := NewFormatter() + + cp.model = f.model + if len(f.namedParams) > 0 { + cp.namedParams = make(map[string]interface{}, len(f.namedParams)) + } + for param, value := range f.namedParams { + cp.setParam(param, value) + } + + return cp +} + +func (f *Formatter) WithTableModel(model TableModel) *Formatter { + cp := f.clone() + cp.model = model + return cp +} + +func (f *Formatter) WithModel(model interface{}) *Formatter { + switch model := model.(type) { + case TableModel: + return f.WithTableModel(model) + case *Query: + return f.WithTableModel(model.tableModel) + case QueryCommand: + return f.WithTableModel(model.Query().tableModel) + default: + panic(fmt.Errorf("pg: unsupported model %T", model)) + } +} + +func (f *Formatter) setParam(param string, value interface{}) { + if f.namedParams == nil { + f.namedParams = make(map[string]interface{}) + } + f.namedParams[param] = value +} + +func (f *Formatter) WithParam(param string, value interface{}) *Formatter { + cp := f.clone() + cp.setParam(param, value) + return cp +} + +func (f *Formatter) Param(param string) interface{} { + return f.namedParams[param] +} + +func (f *Formatter) hasParams() bool { + return len(f.namedParams) > 0 || f.model != nil +} + +func (f *Formatter) FormatQueryBytes(dst, query []byte, params ...interface{}) []byte { + if (params == nil && !f.hasParams()) || bytes.IndexByte(query, '?') == -1 { + return append(dst, query...) + } + return f.append(dst, parser.New(query), params) +} + +func (f *Formatter) FormatQuery(dst []byte, query string, params ...interface{}) []byte { + if (params == nil && !f.hasParams()) || strings.IndexByte(query, '?') == -1 { + return append(dst, query...) + } + return f.append(dst, parser.NewString(query), params) +} + +func (f *Formatter) append(dst []byte, p *parser.Parser, params []interface{}) []byte { + var paramsIndex int + var namedParamsOnce bool + var tableParams *tableParams + + for p.Valid() { + b, ok := p.ReadSep('?') + if !ok { + dst = append(dst, b...) + continue + } + if len(b) > 0 && b[len(b)-1] == '\\' { + dst = append(dst, b[:len(b)-1]...) + dst = append(dst, '?') + continue + } + dst = append(dst, b...) + + id, numeric := p.ReadIdentifier() + if id != "" { + if numeric { + idx, err := strconv.Atoi(id) + if err != nil { + goto restore_param + } + + if idx >= len(params) { + goto restore_param + } + + dst = f.appendParam(dst, params[idx]) + continue + } + + if f.namedParams != nil { + param, paramOK := f.namedParams[id] + if paramOK { + dst = f.appendParam(dst, param) + continue + } + } + + if !namedParamsOnce && len(params) > 0 { + namedParamsOnce = true + tableParams, _ = newTableParams(params[len(params)-1]) + } + + if tableParams != nil { + dst, ok = tableParams.AppendParam(f, dst, id) + if ok { + continue + } + } + + if f.model != nil { + dst, ok = f.model.AppendParam(f, dst, id) + if ok { + continue + } + } + + restore_param: + dst = append(dst, '?') + dst = append(dst, id...) + continue + } + + if paramsIndex >= len(params) { + dst = append(dst, '?') + continue + } + + param := params[paramsIndex] + paramsIndex++ + + dst = f.appendParam(dst, param) + } + + return dst +} + +func (f *Formatter) appendParam(b []byte, param interface{}) []byte { + switch param := param.(type) { + case QueryAppender: + bb, err := param.AppendQuery(f, b) + if err != nil { + return types.AppendError(b, err) + } + return bb + default: + return types.Append(b, param, 1) + } +} diff --git a/vendor/github.com/go-pg/pg/v10/orm/hook.go b/vendor/github.com/go-pg/pg/v10/orm/hook.go new file mode 100644 index 000000000..78bd10310 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/orm/hook.go @@ -0,0 +1,248 @@ +package orm + +import ( + "context" + "reflect" +) + +type hookStubs struct{} + +var ( + _ AfterScanHook = (*hookStubs)(nil) + _ AfterSelectHook = (*hookStubs)(nil) + _ BeforeInsertHook = (*hookStubs)(nil) + _ AfterInsertHook = (*hookStubs)(nil) + _ BeforeUpdateHook = (*hookStubs)(nil) + _ AfterUpdateHook = (*hookStubs)(nil) + _ BeforeDeleteHook = (*hookStubs)(nil) + _ AfterDeleteHook = (*hookStubs)(nil) +) + +func (hookStubs) AfterScan(ctx context.Context) error { + return nil +} + +func (hookStubs) AfterSelect(ctx context.Context) error { + return nil +} + +func (hookStubs) BeforeInsert(ctx context.Context) (context.Context, error) { + return ctx, nil +} + +func (hookStubs) AfterInsert(ctx context.Context) error { + return nil +} + +func (hookStubs) BeforeUpdate(ctx context.Context) (context.Context, error) { + return ctx, nil +} + +func (hookStubs) AfterUpdate(ctx context.Context) error { + return nil +} + +func (hookStubs) BeforeDelete(ctx context.Context) (context.Context, error) { + return ctx, nil +} + +func (hookStubs) AfterDelete(ctx context.Context) error { + return nil +} + +func callHookSlice( + ctx context.Context, + slice reflect.Value, + ptr bool, + hook func(context.Context, reflect.Value) (context.Context, error), +) (context.Context, error) { + var firstErr error + sliceLen := slice.Len() + for i := 0; i < sliceLen; i++ { + v := slice.Index(i) + if !ptr { + v = v.Addr() + } + + var err error + ctx, err = hook(ctx, v) + if err != nil && firstErr == nil { + firstErr = err + } + } + return ctx, firstErr +} + +func callHookSlice2( + ctx context.Context, + slice reflect.Value, + ptr bool, + hook func(context.Context, reflect.Value) error, +) error { + var firstErr error + if slice.IsValid() { + sliceLen := slice.Len() + for i := 0; i < sliceLen; i++ { + v := slice.Index(i) + if !ptr { + v = v.Addr() + } + + err := hook(ctx, v) + if err != nil && firstErr == nil { + firstErr = err + } + } + } + return firstErr +} + +//------------------------------------------------------------------------------ + +type BeforeScanHook interface { + BeforeScan(context.Context) error +} + +var beforeScanHookType = reflect.TypeOf((*BeforeScanHook)(nil)).Elem() + +func callBeforeScanHook(ctx context.Context, v reflect.Value) error { + return v.Interface().(BeforeScanHook).BeforeScan(ctx) +} + +//------------------------------------------------------------------------------ + +type AfterScanHook interface { + AfterScan(context.Context) error +} + +var afterScanHookType = reflect.TypeOf((*AfterScanHook)(nil)).Elem() + +func callAfterScanHook(ctx context.Context, v reflect.Value) error { + return v.Interface().(AfterScanHook).AfterScan(ctx) +} + +//------------------------------------------------------------------------------ + +type AfterSelectHook interface { + AfterSelect(context.Context) error +} + +var afterSelectHookType = reflect.TypeOf((*AfterSelectHook)(nil)).Elem() + +func callAfterSelectHook(ctx context.Context, v reflect.Value) error { + return v.Interface().(AfterSelectHook).AfterSelect(ctx) +} + +func callAfterSelectHookSlice( + ctx context.Context, slice reflect.Value, ptr bool, +) error { + return callHookSlice2(ctx, slice, ptr, callAfterSelectHook) +} + +//------------------------------------------------------------------------------ + +type BeforeInsertHook interface { + BeforeInsert(context.Context) (context.Context, error) +} + +var beforeInsertHookType = reflect.TypeOf((*BeforeInsertHook)(nil)).Elem() + +func callBeforeInsertHook(ctx context.Context, v reflect.Value) (context.Context, error) { + return v.Interface().(BeforeInsertHook).BeforeInsert(ctx) +} + +func callBeforeInsertHookSlice( + ctx context.Context, slice reflect.Value, ptr bool, +) (context.Context, error) { + return callHookSlice(ctx, slice, ptr, callBeforeInsertHook) +} + +//------------------------------------------------------------------------------ + +type AfterInsertHook interface { + AfterInsert(context.Context) error +} + +var afterInsertHookType = reflect.TypeOf((*AfterInsertHook)(nil)).Elem() + +func callAfterInsertHook(ctx context.Context, v reflect.Value) error { + return v.Interface().(AfterInsertHook).AfterInsert(ctx) +} + +func callAfterInsertHookSlice( + ctx context.Context, slice reflect.Value, ptr bool, +) error { + return callHookSlice2(ctx, slice, ptr, callAfterInsertHook) +} + +//------------------------------------------------------------------------------ + +type BeforeUpdateHook interface { + BeforeUpdate(context.Context) (context.Context, error) +} + +var beforeUpdateHookType = reflect.TypeOf((*BeforeUpdateHook)(nil)).Elem() + +func callBeforeUpdateHook(ctx context.Context, v reflect.Value) (context.Context, error) { + return v.Interface().(BeforeUpdateHook).BeforeUpdate(ctx) +} + +func callBeforeUpdateHookSlice( + ctx context.Context, slice reflect.Value, ptr bool, +) (context.Context, error) { + return callHookSlice(ctx, slice, ptr, callBeforeUpdateHook) +} + +//------------------------------------------------------------------------------ + +type AfterUpdateHook interface { + AfterUpdate(context.Context) error +} + +var afterUpdateHookType = reflect.TypeOf((*AfterUpdateHook)(nil)).Elem() + +func callAfterUpdateHook(ctx context.Context, v reflect.Value) error { + return v.Interface().(AfterUpdateHook).AfterUpdate(ctx) +} + +func callAfterUpdateHookSlice( + ctx context.Context, slice reflect.Value, ptr bool, +) error { + return callHookSlice2(ctx, slice, ptr, callAfterUpdateHook) +} + +//------------------------------------------------------------------------------ + +type BeforeDeleteHook interface { + BeforeDelete(context.Context) (context.Context, error) +} + +var beforeDeleteHookType = reflect.TypeOf((*BeforeDeleteHook)(nil)).Elem() + +func callBeforeDeleteHook(ctx context.Context, v reflect.Value) (context.Context, error) { + return v.Interface().(BeforeDeleteHook).BeforeDelete(ctx) +} + +func callBeforeDeleteHookSlice( + ctx context.Context, slice reflect.Value, ptr bool, +) (context.Context, error) { + return callHookSlice(ctx, slice, ptr, callBeforeDeleteHook) +} + +//------------------------------------------------------------------------------ + +type AfterDeleteHook interface { + AfterDelete(context.Context) error +} + +var afterDeleteHookType = reflect.TypeOf((*AfterDeleteHook)(nil)).Elem() + +func callAfterDeleteHook(ctx context.Context, v reflect.Value) error { + return v.Interface().(AfterDeleteHook).AfterDelete(ctx) +} + +func callAfterDeleteHookSlice( + ctx context.Context, slice reflect.Value, ptr bool, +) error { + return callHookSlice2(ctx, slice, ptr, callAfterDeleteHook) +} diff --git a/vendor/github.com/go-pg/pg/v10/orm/insert.go b/vendor/github.com/go-pg/pg/v10/orm/insert.go new file mode 100644 index 000000000..a7a543576 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/orm/insert.go @@ -0,0 +1,345 @@ +package orm + +import ( + "fmt" + "reflect" + "sort" + + "github.com/go-pg/pg/v10/types" +) + +type InsertQuery struct { + q *Query + returningFields []*Field + placeholder bool +} + +var _ QueryCommand = (*InsertQuery)(nil) + +func NewInsertQuery(q *Query) *InsertQuery { + return &InsertQuery{ + q: q, + } +} + +func (q *InsertQuery) String() string { + b, err := q.AppendQuery(defaultFmter, nil) + if err != nil { + panic(err) + } + return string(b) +} + +func (q *InsertQuery) Operation() QueryOp { + return InsertOp +} + +func (q *InsertQuery) Clone() QueryCommand { + return &InsertQuery{ + q: q.q.Clone(), + placeholder: q.placeholder, + } +} + +func (q *InsertQuery) Query() *Query { + return q.q +} + +var _ TemplateAppender = (*InsertQuery)(nil) + +func (q *InsertQuery) AppendTemplate(b []byte) ([]byte, error) { + cp := q.Clone().(*InsertQuery) + cp.placeholder = true + return cp.AppendQuery(dummyFormatter{}, b) +} + +var _ QueryAppender = (*InsertQuery)(nil) + +func (q *InsertQuery) AppendQuery(fmter QueryFormatter, b []byte) (_ []byte, err error) { + if q.q.stickyErr != nil { + return nil, q.q.stickyErr + } + + if len(q.q.with) > 0 { + b, err = q.q.appendWith(fmter, b) + if err != nil { + return nil, err + } + } + + b = append(b, "INSERT INTO "...) + if q.q.onConflict != nil { + b, err = q.q.appendFirstTableWithAlias(fmter, b) + } else { + b, err = q.q.appendFirstTable(fmter, b) + } + if err != nil { + return nil, err + } + + b, err = q.appendColumnsValues(fmter, b) + if err != nil { + return nil, err + } + + if q.q.onConflict != nil { + b = append(b, " ON CONFLICT "...) + b, err = q.q.onConflict.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + + if q.q.onConflictDoUpdate() { + if len(q.q.set) > 0 { + b, err = q.q.appendSet(fmter, b) + if err != nil { + return nil, err + } + } else { + fields, err := q.q.getDataFields() + if err != nil { + return nil, err + } + + if len(fields) == 0 { + fields = q.q.tableModel.Table().DataFields + } + + b = q.appendSetExcluded(b, fields) + } + + if len(q.q.updWhere) > 0 { + b = append(b, " WHERE "...) + b, err = q.q.appendUpdWhere(fmter, b) + if err != nil { + return nil, err + } + } + } + } + + if len(q.q.returning) > 0 { + b, err = q.q.appendReturning(fmter, b) + if err != nil { + return nil, err + } + } else if len(q.returningFields) > 0 { + b = appendReturningFields(b, q.returningFields) + } + + return b, q.q.stickyErr +} + +func (q *InsertQuery) appendColumnsValues(fmter QueryFormatter, b []byte) (_ []byte, err error) { + if q.q.hasMultiTables() { + if q.q.columns != nil { + b = append(b, " ("...) + b, err = q.q.appendColumns(fmter, b) + if err != nil { + return nil, err + } + b = append(b, ")"...) + } + + b = append(b, " SELECT * FROM "...) + b, err = q.q.appendOtherTables(fmter, b) + if err != nil { + return nil, err + } + + return b, nil + } + + if m, ok := q.q.model.(*mapModel); ok { + return q.appendMapColumnsValues(b, m.m), nil + } + + if !q.q.hasTableModel() { + return nil, errModelNil + } + + fields, err := q.q.getFields() + if err != nil { + return nil, err + } + + if len(fields) == 0 { + fields = q.q.tableModel.Table().Fields + } + value := q.q.tableModel.Value() + + b = append(b, " ("...) + b = q.appendColumns(b, fields) + b = append(b, ") VALUES ("...) + if m, ok := q.q.tableModel.(*sliceTableModel); ok { + if m.sliceLen == 0 { + err = fmt.Errorf("pg: can't bulk-insert empty slice %s", value.Type()) + return nil, err + } + b, err = q.appendSliceValues(fmter, b, fields, value) + if err != nil { + return nil, err + } + } else { + b, err = q.appendValues(fmter, b, fields, value) + if err != nil { + return nil, err + } + } + b = append(b, ")"...) + + return b, nil +} + +func (q *InsertQuery) appendMapColumnsValues(b []byte, m map[string]interface{}) []byte { + keys := make([]string, 0, len(m)) + + for k := range m { + keys = append(keys, k) + } + sort.Strings(keys) + + b = append(b, " ("...) + + for i, k := range keys { + if i > 0 { + b = append(b, ", "...) + } + b = types.AppendIdent(b, k, 1) + } + + b = append(b, ") VALUES ("...) + + for i, k := range keys { + if i > 0 { + b = append(b, ", "...) + } + if q.placeholder { + b = append(b, '?') + } else { + b = types.Append(b, m[k], 1) + } + } + + b = append(b, ")"...) + + return b +} + +func (q *InsertQuery) appendValues( + fmter QueryFormatter, b []byte, fields []*Field, strct reflect.Value, +) (_ []byte, err error) { + for i, f := range fields { + if i > 0 { + b = append(b, ", "...) + } + + app, ok := q.q.modelValues[f.SQLName] + if ok { + b, err = app.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + q.addReturningField(f) + continue + } + + switch { + case q.placeholder: + b = append(b, '?') + case (f.Default != "" || f.NullZero()) && f.HasZeroValue(strct): + b = append(b, "DEFAULT"...) + q.addReturningField(f) + default: + b = f.AppendValue(b, strct, 1) + } + } + + for i, v := range q.q.extraValues { + if i > 0 || len(fields) > 0 { + b = append(b, ", "...) + } + + b, err = v.value.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + + return b, nil +} + +func (q *InsertQuery) appendSliceValues( + fmter QueryFormatter, b []byte, fields []*Field, slice reflect.Value, +) (_ []byte, err error) { + if q.placeholder { + return q.appendValues(fmter, b, fields, reflect.Value{}) + } + + sliceLen := slice.Len() + for i := 0; i < sliceLen; i++ { + if i > 0 { + b = append(b, "), ("...) + } + el := indirect(slice.Index(i)) + b, err = q.appendValues(fmter, b, fields, el) + if err != nil { + return nil, err + } + } + + for i, v := range q.q.extraValues { + if i > 0 || len(fields) > 0 { + b = append(b, ", "...) + } + + b, err = v.value.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + + return b, nil +} + +func (q *InsertQuery) addReturningField(field *Field) { + if len(q.q.returning) > 0 { + return + } + for _, f := range q.returningFields { + if f == field { + return + } + } + q.returningFields = append(q.returningFields, field) +} + +func (q *InsertQuery) appendSetExcluded(b []byte, fields []*Field) []byte { + b = append(b, " SET "...) + for i, f := range fields { + if i > 0 { + b = append(b, ", "...) + } + b = append(b, f.Column...) + b = append(b, " = EXCLUDED."...) + b = append(b, f.Column...) + } + return b +} + +func (q *InsertQuery) appendColumns(b []byte, fields []*Field) []byte { + b = appendColumns(b, "", fields) + for i, v := range q.q.extraValues { + if i > 0 || len(fields) > 0 { + b = append(b, ", "...) + } + b = types.AppendIdent(b, v.column, 1) + } + return b +} + +func appendReturningFields(b []byte, fields []*Field) []byte { + b = append(b, " RETURNING "...) + b = appendColumns(b, "", fields) + return b +} diff --git a/vendor/github.com/go-pg/pg/v10/orm/join.go b/vendor/github.com/go-pg/pg/v10/orm/join.go new file mode 100644 index 000000000..2b64ba1b8 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/orm/join.go @@ -0,0 +1,351 @@ +package orm + +import ( + "reflect" + + "github.com/go-pg/pg/v10/internal" + "github.com/go-pg/pg/v10/types" +) + +type join struct { + Parent *join + BaseModel TableModel + JoinModel TableModel + Rel *Relation + + ApplyQuery func(*Query) (*Query, error) + Columns []string + on []*condAppender +} + +func (j *join) AppendOn(app *condAppender) { + j.on = append(j.on, app) +} + +func (j *join) Select(fmter QueryFormatter, q *Query) error { + switch j.Rel.Type { + case HasManyRelation: + return j.selectMany(fmter, q) + case Many2ManyRelation: + return j.selectM2M(fmter, q) + } + panic("not reached") +} + +func (j *join) selectMany(_ QueryFormatter, q *Query) error { + q, err := j.manyQuery(q) + if err != nil { + return err + } + if q == nil { + return nil + } + return q.Select() +} + +func (j *join) manyQuery(q *Query) (*Query, error) { + manyModel := newManyModel(j) + if manyModel == nil { + return nil, nil + } + + q = q.Model(manyModel) + if j.ApplyQuery != nil { + var err error + q, err = j.ApplyQuery(q) + if err != nil { + return nil, err + } + } + + if len(q.columns) == 0 { + q.columns = append(q.columns, &hasManyColumnsAppender{j}) + } + + baseTable := j.BaseModel.Table() + var where []byte + if len(j.Rel.JoinFKs) > 1 { + where = append(where, '(') + } + where = appendColumns(where, j.JoinModel.Table().Alias, j.Rel.JoinFKs) + if len(j.Rel.JoinFKs) > 1 { + where = append(where, ')') + } + where = append(where, " IN ("...) + where = appendChildValues( + where, j.JoinModel.Root(), j.JoinModel.ParentIndex(), j.Rel.BaseFKs) + where = append(where, ")"...) + q = q.Where(internal.BytesToString(where)) + + if j.Rel.Polymorphic != nil { + q = q.Where(`? IN (?, ?)`, + j.Rel.Polymorphic.Column, + baseTable.ModelName, baseTable.TypeName) + } + + return q, nil +} + +func (j *join) selectM2M(fmter QueryFormatter, q *Query) error { + q, err := j.m2mQuery(fmter, q) + if err != nil { + return err + } + if q == nil { + return nil + } + return q.Select() +} + +func (j *join) m2mQuery(fmter QueryFormatter, q *Query) (*Query, error) { + m2mModel := newM2MModel(j) + if m2mModel == nil { + return nil, nil + } + + q = q.Model(m2mModel) + if j.ApplyQuery != nil { + var err error + q, err = j.ApplyQuery(q) + if err != nil { + return nil, err + } + } + + if len(q.columns) == 0 { + q.columns = append(q.columns, &hasManyColumnsAppender{j}) + } + + index := j.JoinModel.ParentIndex() + baseTable := j.BaseModel.Table() + + //nolint + var join []byte + join = append(join, "JOIN "...) + join = fmter.FormatQuery(join, string(j.Rel.M2MTableName)) + join = append(join, " AS "...) + join = append(join, j.Rel.M2MTableAlias...) + join = append(join, " ON ("...) + for i, col := range j.Rel.M2MBaseFKs { + if i > 0 { + join = append(join, ", "...) + } + join = append(join, j.Rel.M2MTableAlias...) + join = append(join, '.') + join = types.AppendIdent(join, col, 1) + } + join = append(join, ") IN ("...) + join = appendChildValues(join, j.BaseModel.Root(), index, baseTable.PKs) + join = append(join, ")"...) + q = q.Join(internal.BytesToString(join)) + + joinTable := j.JoinModel.Table() + for i, col := range j.Rel.M2MJoinFKs { + pk := joinTable.PKs[i] + q = q.Where("?.? = ?.?", + joinTable.Alias, pk.Column, + j.Rel.M2MTableAlias, types.Ident(col)) + } + + return q, nil +} + +func (j *join) hasParent() bool { + if j.Parent != nil { + switch j.Parent.Rel.Type { + case HasOneRelation, BelongsToRelation: + return true + } + } + return false +} + +func (j *join) appendAlias(b []byte) []byte { + b = append(b, '"') + b = appendAlias(b, j) + b = append(b, '"') + return b +} + +func (j *join) appendAliasColumn(b []byte, column string) []byte { + b = append(b, '"') + b = appendAlias(b, j) + b = append(b, "__"...) + b = append(b, column...) + b = append(b, '"') + return b +} + +func (j *join) appendBaseAlias(b []byte) []byte { + if j.hasParent() { + b = append(b, '"') + b = appendAlias(b, j.Parent) + b = append(b, '"') + return b + } + return append(b, j.BaseModel.Table().Alias...) +} + +func (j *join) appendSoftDelete(b []byte, flags queryFlag) []byte { + b = append(b, '.') + b = append(b, j.JoinModel.Table().SoftDeleteField.Column...) + if hasFlag(flags, deletedFlag) { + b = append(b, " IS NOT NULL"...) + } else { + b = append(b, " IS NULL"...) + } + return b +} + +func appendAlias(b []byte, j *join) []byte { + if j.hasParent() { + b = appendAlias(b, j.Parent) + b = append(b, "__"...) + } + b = append(b, j.Rel.Field.SQLName...) + return b +} + +func (j *join) appendHasOneColumns(b []byte) []byte { + if j.Columns == nil { + for i, f := range j.JoinModel.Table().Fields { + if i > 0 { + b = append(b, ", "...) + } + b = j.appendAlias(b) + b = append(b, '.') + b = append(b, f.Column...) + b = append(b, " AS "...) + b = j.appendAliasColumn(b, f.SQLName) + } + return b + } + + for i, column := range j.Columns { + if i > 0 { + b = append(b, ", "...) + } + b = j.appendAlias(b) + b = append(b, '.') + b = types.AppendIdent(b, column, 1) + b = append(b, " AS "...) + b = j.appendAliasColumn(b, column) + } + + return b +} + +func (j *join) appendHasOneJoin(fmter QueryFormatter, b []byte, q *Query) (_ []byte, err error) { + isSoftDelete := j.JoinModel.Table().SoftDeleteField != nil && !q.hasFlag(allWithDeletedFlag) + + b = append(b, "LEFT JOIN "...) + b = fmter.FormatQuery(b, string(j.JoinModel.Table().SQLNameForSelects)) + b = append(b, " AS "...) + b = j.appendAlias(b) + + b = append(b, " ON "...) + + if isSoftDelete { + b = append(b, '(') + } + + if len(j.Rel.BaseFKs) > 1 { + b = append(b, '(') + } + for i, baseFK := range j.Rel.BaseFKs { + if i > 0 { + b = append(b, " AND "...) + } + b = j.appendAlias(b) + b = append(b, '.') + b = append(b, j.Rel.JoinFKs[i].Column...) + b = append(b, " = "...) + b = j.appendBaseAlias(b) + b = append(b, '.') + b = append(b, baseFK.Column...) + } + if len(j.Rel.BaseFKs) > 1 { + b = append(b, ')') + } + + for _, on := range j.on { + b = on.AppendSep(b) + b, err = on.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + + if isSoftDelete { + b = append(b, ')') + } + + if isSoftDelete { + b = append(b, " AND "...) + b = j.appendAlias(b) + b = j.appendSoftDelete(b, q.flags) + } + + return b, nil +} + +type hasManyColumnsAppender struct { + *join +} + +var _ QueryAppender = (*hasManyColumnsAppender)(nil) + +func (q *hasManyColumnsAppender) AppendQuery(fmter QueryFormatter, b []byte) ([]byte, error) { + if q.Rel.M2MTableAlias != "" { + b = append(b, q.Rel.M2MTableAlias...) + b = append(b, ".*, "...) + } + + joinTable := q.JoinModel.Table() + + if q.Columns != nil { + for i, column := range q.Columns { + if i > 0 { + b = append(b, ", "...) + } + b = append(b, joinTable.Alias...) + b = append(b, '.') + b = types.AppendIdent(b, column, 1) + } + return b, nil + } + + b = appendColumns(b, joinTable.Alias, joinTable.Fields) + return b, nil +} + +func appendChildValues(b []byte, v reflect.Value, index []int, fields []*Field) []byte { + seen := make(map[string]struct{}) + walk(v, index, func(v reflect.Value) { + start := len(b) + + if len(fields) > 1 { + b = append(b, '(') + } + for i, f := range fields { + if i > 0 { + b = append(b, ", "...) + } + b = f.AppendValue(b, v, 1) + } + if len(fields) > 1 { + b = append(b, ')') + } + b = append(b, ", "...) + + if _, ok := seen[string(b[start:])]; ok { + b = b[:start] + } else { + seen[string(b[start:])] = struct{}{} + } + }) + if len(seen) > 0 { + b = b[:len(b)-2] // trim ", " + } + return b +} diff --git a/vendor/github.com/go-pg/pg/v10/orm/model.go b/vendor/github.com/go-pg/pg/v10/orm/model.go new file mode 100644 index 000000000..333a90dd7 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/orm/model.go @@ -0,0 +1,150 @@ +package orm + +import ( + "database/sql" + "errors" + "fmt" + "reflect" + + "github.com/go-pg/pg/v10/types" +) + +var errModelNil = errors.New("pg: Model(nil)") + +type useQueryOne interface { + useQueryOne() bool +} + +type HooklessModel interface { + // Init is responsible to initialize/reset model state. + // It is called only once no matter how many rows were returned. + Init() error + + // NextColumnScanner returns a ColumnScanner that is used to scan columns + // from the current row. It is called once for every row. + NextColumnScanner() ColumnScanner + + // AddColumnScanner adds the ColumnScanner to the model. + AddColumnScanner(ColumnScanner) error +} + +type Model interface { + HooklessModel + + AfterScanHook + AfterSelectHook + + BeforeInsertHook + AfterInsertHook + + BeforeUpdateHook + AfterUpdateHook + + BeforeDeleteHook + AfterDeleteHook +} + +func NewModel(value interface{}) (Model, error) { + return newModel(value, false) +} + +func newScanModel(values []interface{}) (Model, error) { + if len(values) > 1 { + return Scan(values...), nil + } + return newModel(values[0], true) +} + +func newModel(value interface{}, scan bool) (Model, error) { + switch value := value.(type) { + case Model: + return value, nil + case HooklessModel: + return newModelWithHookStubs(value), nil + case types.ValueScanner, sql.Scanner: + if !scan { + return nil, fmt.Errorf("pg: Model(unsupported %T)", value) + } + return Scan(value), nil + } + + v := reflect.ValueOf(value) + if !v.IsValid() { + return nil, errModelNil + } + if v.Kind() != reflect.Ptr { + return nil, fmt.Errorf("pg: Model(non-pointer %T)", value) + } + + if v.IsNil() { + typ := v.Type().Elem() + if typ.Kind() == reflect.Struct { + return newStructTableModel(GetTable(typ)), nil + } + return nil, errModelNil + } + + v = v.Elem() + + if v.Kind() == reflect.Interface { + if !v.IsNil() { + v = v.Elem() + if v.Kind() != reflect.Ptr { + return nil, fmt.Errorf("pg: Model(non-pointer %s)", v.Type().String()) + } + } + } + + switch v.Kind() { + case reflect.Struct: + if v.Type() != timeType { + return newStructTableModelValue(v), nil + } + case reflect.Slice: + elemType := sliceElemType(v) + switch elemType.Kind() { + case reflect.Struct: + if elemType != timeType { + return newSliceTableModel(v, elemType), nil + } + case reflect.Map: + if err := validMap(elemType); err != nil { + return nil, err + } + slicePtr := v.Addr().Interface().(*[]map[string]interface{}) + return newMapSliceModel(slicePtr), nil + } + return newSliceModel(v, elemType), nil + case reflect.Map: + typ := v.Type() + if err := validMap(typ); err != nil { + return nil, err + } + mapPtr := v.Addr().Interface().(*map[string]interface{}) + return newMapModel(mapPtr), nil + } + + if !scan { + return nil, fmt.Errorf("pg: Model(unsupported %T)", value) + } + return Scan(value), nil +} + +type modelWithHookStubs struct { + hookStubs + HooklessModel +} + +func newModelWithHookStubs(m HooklessModel) Model { + return modelWithHookStubs{ + HooklessModel: m, + } +} + +func validMap(typ reflect.Type) error { + if typ.Key().Kind() != reflect.String || typ.Elem().Kind() != reflect.Interface { + return fmt.Errorf("pg: Model(unsupported %s, expected *map[string]interface{})", + typ.String()) + } + return nil +} diff --git a/vendor/github.com/go-pg/pg/v10/orm/model_discard.go b/vendor/github.com/go-pg/pg/v10/orm/model_discard.go new file mode 100644 index 000000000..92e5c566c --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/orm/model_discard.go @@ -0,0 +1,27 @@ +package orm + +import ( + "github.com/go-pg/pg/v10/types" +) + +type Discard struct { + hookStubs +} + +var _ Model = (*Discard)(nil) + +func (Discard) Init() error { + return nil +} + +func (m Discard) NextColumnScanner() ColumnScanner { + return m +} + +func (m Discard) AddColumnScanner(ColumnScanner) error { + return nil +} + +func (m Discard) ScanColumn(col types.ColumnInfo, rd types.Reader, n int) error { + return nil +} diff --git a/vendor/github.com/go-pg/pg/v10/orm/model_func.go b/vendor/github.com/go-pg/pg/v10/orm/model_func.go new file mode 100644 index 000000000..8427bdea2 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/orm/model_func.go @@ -0,0 +1,89 @@ +package orm + +import ( + "fmt" + "reflect" +) + +var errorType = reflect.TypeOf((*error)(nil)).Elem() + +type funcModel struct { + Model + fnv reflect.Value + fnIn []reflect.Value +} + +var _ Model = (*funcModel)(nil) + +func newFuncModel(fn interface{}) *funcModel { + m := &funcModel{ + fnv: reflect.ValueOf(fn), + } + + fnt := m.fnv.Type() + if fnt.Kind() != reflect.Func { + panic(fmt.Errorf("ForEach expects a %s, got a %s", + reflect.Func, fnt.Kind())) + } + + if fnt.NumIn() < 1 { + panic(fmt.Errorf("ForEach expects at least 1 arg, got %d", fnt.NumIn())) + } + + if fnt.NumOut() != 1 { + panic(fmt.Errorf("ForEach must return 1 error value, got %d", fnt.NumOut())) + } + if fnt.Out(0) != errorType { + panic(fmt.Errorf("ForEach must return an error, got %T", fnt.Out(0))) + } + + if fnt.NumIn() > 1 { + initFuncModelScan(m, fnt) + return m + } + + t0 := fnt.In(0) + var v0 reflect.Value + if t0.Kind() == reflect.Ptr { + t0 = t0.Elem() + v0 = reflect.New(t0) + } else { + v0 = reflect.New(t0).Elem() + } + + m.fnIn = []reflect.Value{v0} + + model, ok := v0.Interface().(Model) + if ok { + m.Model = model + return m + } + + if v0.Kind() == reflect.Ptr { + v0 = v0.Elem() + } + if v0.Kind() != reflect.Struct { + panic(fmt.Errorf("ForEach accepts a %s, got %s", + reflect.Struct, v0.Kind())) + } + m.Model = newStructTableModelValue(v0) + + return m +} + +func initFuncModelScan(m *funcModel, fnt reflect.Type) { + m.fnIn = make([]reflect.Value, fnt.NumIn()) + for i := 0; i < fnt.NumIn(); i++ { + m.fnIn[i] = reflect.New(fnt.In(i)).Elem() + } + m.Model = scanReflectValues(m.fnIn) +} + +func (m *funcModel) AddColumnScanner(_ ColumnScanner) error { + out := m.fnv.Call(m.fnIn) + errv := out[0] + if !errv.IsNil() { + return errv.Interface().(error) + } + return nil +} diff --git a/vendor/github.com/go-pg/pg/v10/orm/model_map.go b/vendor/github.com/go-pg/pg/v10/orm/model_map.go new file mode 100644 index 000000000..24533d43c --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/orm/model_map.go @@ -0,0 +1,53 @@ +package orm + +import ( + "github.com/go-pg/pg/v10/types" +) + +type mapModel struct { + hookStubs + ptr *map[string]interface{} + m map[string]interface{} +} + +var _ Model = (*mapModel)(nil) + +func newMapModel(ptr *map[string]interface{}) *mapModel { + model := &mapModel{ + ptr: ptr, + } + if ptr != nil { + model.m = *ptr + } + return model +} + +func (m *mapModel) Init() error { + return nil +} + +func (m *mapModel) NextColumnScanner() ColumnScanner { + if m.m == nil { + m.m = make(map[string]interface{}) + *m.ptr = m.m + } + return m +} + +func (m mapModel) AddColumnScanner(ColumnScanner) error { + return nil +} + +func (m *mapModel) ScanColumn(col types.ColumnInfo, rd types.Reader, n int) error { + val, err := types.ReadColumnValue(col, rd, n) + if err != nil { + return err + } + + m.m[col.Name] = val + return nil +} + +func (mapModel) useQueryOne() bool { + return true +} diff --git a/vendor/github.com/go-pg/pg/v10/orm/model_map_slice.go b/vendor/github.com/go-pg/pg/v10/orm/model_map_slice.go new file mode 100644 index 000000000..ea14c9b6b --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/orm/model_map_slice.go @@ -0,0 +1,45 @@ +package orm + +type mapSliceModel struct { + mapModel + slice *[]map[string]interface{} +} + +var _ Model = (*mapSliceModel)(nil) + +func newMapSliceModel(ptr *[]map[string]interface{}) *mapSliceModel { + return &mapSliceModel{ + slice: ptr, + } +} + +func (m *mapSliceModel) Init() error { + slice := *m.slice + if len(slice) > 0 { + *m.slice = slice[:0] + } + return nil +} + +func (m *mapSliceModel) NextColumnScanner() ColumnScanner { + slice := *m.slice + if len(slice) == cap(slice) { + m.mapModel.m = make(map[string]interface{}) + *m.slice = append(slice, m.mapModel.m) //nolint:gocritic + return m + } + + slice = slice[:len(slice)+1] + el := slice[len(slice)-1] + if el != nil { + m.mapModel.m = el + } else { + el = make(map[string]interface{}) + slice[len(slice)-1] = el + m.mapModel.m = el + } + *m.slice = slice + return m +} + +func (mapSliceModel) useQueryOne() {} //nolint:unused diff --git a/vendor/github.com/go-pg/pg/v10/orm/model_scan.go b/vendor/github.com/go-pg/pg/v10/orm/model_scan.go new file mode 100644 index 000000000..08f66beba --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/orm/model_scan.go @@ -0,0 +1,69 @@ +package orm + +import ( + "fmt" + "reflect" + + "github.com/go-pg/pg/v10/types" +) + +type scanValuesModel struct { + Discard + values []interface{} +} + +var _ Model = scanValuesModel{} + +//nolint +func Scan(values ...interface{}) scanValuesModel { + return scanValuesModel{ + values: values, + } +} + +func (scanValuesModel) useQueryOne() bool { + return true +} + +func (m scanValuesModel) NextColumnScanner() ColumnScanner { + return m +} + +func (m scanValuesModel) ScanColumn(col types.ColumnInfo, rd types.Reader, n int) error { + if int(col.Index) >= len(m.values) { + return fmt.Errorf("pg: no Scan var for column index=%d name=%q", + col.Index, col.Name) + } + return types.Scan(m.values[col.Index], rd, n) +} + +//------------------------------------------------------------------------------ + +type scanReflectValuesModel struct { + Discard + values []reflect.Value +} + +var _ Model = scanReflectValuesModel{} + +func scanReflectValues(values []reflect.Value) scanReflectValuesModel { + return scanReflectValuesModel{ + values: values, + } +} + +func (scanReflectValuesModel) useQueryOne() bool { + return true +} + +func (m scanReflectValuesModel) NextColumnScanner() ColumnScanner { + return m +} + +func (m scanReflectValuesModel) ScanColumn(col types.ColumnInfo, rd types.Reader, n int) error { + if int(col.Index) >= len(m.values) { + return fmt.Errorf("pg: no Scan var for column index=%d name=%q", + col.Index, col.Name) + } + return types.ScanValue(m.values[col.Index], rd, n) +} diff --git a/vendor/github.com/go-pg/pg/v10/orm/model_slice.go b/vendor/github.com/go-pg/pg/v10/orm/model_slice.go new file mode 100644 index 000000000..1e163629e --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/orm/model_slice.go @@ -0,0 +1,43 @@ +package orm + +import ( + "reflect" + + "github.com/go-pg/pg/v10/internal" + "github.com/go-pg/pg/v10/types" +) + +type sliceModel struct { + Discard + slice reflect.Value + nextElem func() reflect.Value + scan func(reflect.Value, types.Reader, int) error +} + +var _ Model = (*sliceModel)(nil) + +func newSliceModel(slice reflect.Value, elemType reflect.Type) *sliceModel { + return &sliceModel{ + slice: slice, + scan: types.Scanner(elemType), + } +} + +func (m *sliceModel) Init() error { + if m.slice.IsValid() && m.slice.Len() > 0 { + m.slice.Set(m.slice.Slice(0, 0)) + } + return nil +} + +func (m *sliceModel) NextColumnScanner() ColumnScanner { + return m +} + +func (m *sliceModel) ScanColumn(col types.ColumnInfo, rd types.Reader, n int) error { + if m.nextElem == nil { + m.nextElem = internal.MakeSliceNextElemFunc(m.slice) + } + v := m.nextElem() + return m.scan(v, rd, n) +} diff --git a/vendor/github.com/go-pg/pg/v10/orm/model_table.go b/vendor/github.com/go-pg/pg/v10/orm/model_table.go new file mode 100644 index 000000000..afdc15ccc --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/orm/model_table.go @@ -0,0 +1,65 @@ +package orm + +import ( + "fmt" + "reflect" + + "github.com/go-pg/pg/v10/types" +) + +type TableModel interface { + Model + + IsNil() bool + Table() *Table + Relation() *Relation + AppendParam(QueryFormatter, []byte, string) ([]byte, bool) + + Join(string, func(*Query) (*Query, error)) *join + GetJoin(string) *join + GetJoins() []join + AddJoin(join) *join + + Root() reflect.Value + Index() []int + ParentIndex() []int + Mount(reflect.Value) + Kind() reflect.Kind + Value() reflect.Value + + setSoftDeleteField() error + scanColumn(types.ColumnInfo, types.Reader, int) (bool, error) +} + +func newTableModelIndex(typ reflect.Type, root reflect.Value, index []int, rel *Relation) (TableModel, error) { + typ = typeByIndex(typ, index) + + if typ.Kind() == reflect.Struct { + return &structTableModel{ + table: GetTable(typ), + rel: rel, + + root: root, + index: index, + }, nil + } + + if typ.Kind() == reflect.Slice { + structType := indirectType(typ.Elem()) + if structType.Kind() == reflect.Struct { + m := sliceTableModel{ + structTableModel: structTableModel{ + table: GetTable(structType), + rel: rel, + + root: root, + index: index, + }, + } + m.init(typ) + return &m, nil + } + } + + return nil, fmt.Errorf("pg: NewModel(%s)", typ) +} diff --git a/vendor/github.com/go-pg/pg/v10/orm/model_table_m2m.go b/vendor/github.com/go-pg/pg/v10/orm/model_table_m2m.go new file mode 100644 index 000000000..83ac73bde --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/orm/model_table_m2m.go @@ -0,0 +1,111 @@ +package orm + +import ( + "fmt" + "reflect" + + "github.com/go-pg/pg/v10/internal/pool" + "github.com/go-pg/pg/v10/types" +) + +type m2mModel struct { + *sliceTableModel + baseTable *Table + rel *Relation + + buf []byte + dstValues map[string][]reflect.Value + columns map[string]string +} + +var _ TableModel = (*m2mModel)(nil) + +func newM2MModel(j *join) *m2mModel { + baseTable := j.BaseModel.Table() + joinModel := j.JoinModel.(*sliceTableModel) + dstValues := dstValues(joinModel, baseTable.PKs) + if len(dstValues) == 0 { + return nil + } + m := &m2mModel{ + sliceTableModel: joinModel, + baseTable: baseTable, + rel: j.Rel, + + dstValues: dstValues, + columns: make(map[string]string), + } + if !m.sliceOfPtr { + m.strct = reflect.New(m.table.Type).Elem() + } + return m +} + +func (m *m2mModel) NextColumnScanner() ColumnScanner { + if m.sliceOfPtr { + m.strct = reflect.New(m.table.Type).Elem() + } else { + m.strct.Set(m.table.zeroStruct) + } + m.structInited = false + return m +} + +func (m *m2mModel) AddColumnScanner(_ ColumnScanner) error { + buf, err := m.modelIDMap(m.buf[:0]) + if err != nil { + return err + } + m.buf = buf + + dstValues, ok := m.dstValues[string(buf)] + if !ok { + return fmt.Errorf( + "pg: relation=%q does not have base %s with id=%q (check join conditions)", + m.rel.Field.GoName, m.baseTable, buf) + } + + for _, v := range dstValues { + if m.sliceOfPtr { + v.Set(reflect.Append(v, m.strct.Addr())) + } else { + v.Set(reflect.Append(v, m.strct)) + } + } + + return nil +} + +func (m *m2mModel) modelIDMap(b []byte) ([]byte, error) { + for i, col := range m.rel.M2MBaseFKs { + if i > 0 { + b = append(b, ',') + } + if s, ok := m.columns[col]; ok { + b = append(b, s...) + } else { + return nil, fmt.Errorf("pg: %s does not have column=%q", + m.sliceTableModel, col) + } + } + return b, nil +} + +func (m *m2mModel) ScanColumn(col types.ColumnInfo, rd types.Reader, n int) error { + if n > 0 { + b, err := rd.ReadFullTemp() + if err != nil { + return err + } + + m.columns[col.Name] = string(b) + rd = pool.NewBytesReader(b) + } else { + m.columns[col.Name] = "" + } + + if ok, err := m.sliceTableModel.scanColumn(col, rd, n); ok { + return err + } + return nil +} diff --git a/vendor/github.com/go-pg/pg/v10/orm/model_table_many.go b/vendor/github.com/go-pg/pg/v10/orm/model_table_many.go new file mode 100644 index 000000000..561384bba --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/orm/model_table_many.go @@ -0,0 +1,75 @@ +package orm + +import ( + "fmt" + "reflect" +) + +type manyModel struct { + *sliceTableModel + baseTable *Table + rel *Relation + + buf []byte + dstValues map[string][]reflect.Value +} + +var _ TableModel = (*manyModel)(nil) + +func newManyModel(j *join) *manyModel { + baseTable := j.BaseModel.Table() + joinModel := j.JoinModel.(*sliceTableModel) + dstValues := dstValues(joinModel, j.Rel.BaseFKs) + if len(dstValues) == 0 { + return nil + } + m := manyModel{ + sliceTableModel: joinModel, + baseTable: baseTable, + rel: j.Rel, + + dstValues: dstValues, + } + if !m.sliceOfPtr { + m.strct = reflect.New(m.table.Type).Elem() + } + return &m +} + +func (m *manyModel) NextColumnScanner() ColumnScanner { + if m.sliceOfPtr { + m.strct = reflect.New(m.table.Type).Elem() + } else { + m.strct.Set(m.table.zeroStruct) + } + m.structInited = false + return m +} + +func (m *manyModel) AddColumnScanner(model ColumnScanner) error { + m.buf = modelID(m.buf[:0], m.strct, m.rel.JoinFKs) + dstValues, ok := m.dstValues[string(m.buf)] + if !ok { + return fmt.Errorf( + "pg: relation=%q does not have base %s with id=%q (check join conditions)", + m.rel.Field.GoName, m.baseTable, m.buf) + } + + for i, v := range dstValues { + if !m.sliceOfPtr { + v.Set(reflect.Append(v, m.strct)) + continue + } + + if i == 0 { + v.Set(reflect.Append(v, m.strct.Addr())) + continue + } + + clone := reflect.New(m.strct.Type()).Elem() + clone.Set(m.strct) + v.Set(reflect.Append(v, clone.Addr())) + } + + return nil +} diff --git a/vendor/github.com/go-pg/pg/v10/orm/model_table_slice.go b/vendor/github.com/go-pg/pg/v10/orm/model_table_slice.go new file mode 100644 index 000000000..c50be8252 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/orm/model_table_slice.go @@ -0,0 +1,156 @@ +package orm + +import ( + "context" + "reflect" + + "github.com/go-pg/pg/v10/internal" +) + +type sliceTableModel struct { + structTableModel + + slice reflect.Value + sliceLen int + sliceOfPtr bool + nextElem func() reflect.Value +} + +var _ TableModel = (*sliceTableModel)(nil) + +func newSliceTableModel(slice reflect.Value, elemType reflect.Type) *sliceTableModel { + m := &sliceTableModel{ + structTableModel: structTableModel{ + table: GetTable(elemType), + root: slice, + }, + slice: slice, + sliceLen: slice.Len(), + nextElem: internal.MakeSliceNextElemFunc(slice), + } + m.init(slice.Type()) + return m +} + +func (m *sliceTableModel) init(sliceType reflect.Type) { + switch sliceType.Elem().Kind() { + case reflect.Ptr, reflect.Interface: + m.sliceOfPtr = true + } +} + +//nolint +func (*sliceTableModel) useQueryOne() {} + +func (m *sliceTableModel) IsNil() bool { + return false +} + +func (m *sliceTableModel) AppendParam(fmter QueryFormatter, b []byte, name string) ([]byte, bool) { + if field, ok := m.table.FieldsMap[name]; ok { + b = append(b, "_data."...) + b = append(b, field.Column...) + return b, true + } + return m.structTableModel.AppendParam(fmter, b, name) +} + +func (m *sliceTableModel) Join(name string, apply func(*Query) (*Query, error)) *join { + return m.join(m.Value(), name, apply) +} + +func (m *sliceTableModel) Bind(bind reflect.Value) { + m.slice = bind.Field(m.index[len(m.index)-1]) +} + +func (m *sliceTableModel) Kind() reflect.Kind { + return reflect.Slice +} + +func (m *sliceTableModel) Value() reflect.Value { + return m.slice +} + +func (m *sliceTableModel) Init() error { + if m.slice.IsValid() && m.slice.Len() > 0 { + m.slice.Set(m.slice.Slice(0, 0)) + } + return nil +} + +func (m *sliceTableModel) NextColumnScanner() ColumnScanner { + m.strct = m.nextElem() + m.structInited = false + return m +} + +func (m *sliceTableModel) AddColumnScanner(_ ColumnScanner) error { + return nil +} + +// Inherit these hooks from structTableModel. +var ( + _ BeforeScanHook = (*sliceTableModel)(nil) + _ AfterScanHook = (*sliceTableModel)(nil) +) + +func (m *sliceTableModel) AfterSelect(ctx context.Context) error { + if m.table.hasFlag(afterSelectHookFlag) { + return callAfterSelectHookSlice(ctx, m.slice, m.sliceOfPtr) + } + return nil +} + +func (m *sliceTableModel) BeforeInsert(ctx context.Context) (context.Context, error) { + if m.table.hasFlag(beforeInsertHookFlag) { + return callBeforeInsertHookSlice(ctx, m.slice, m.sliceOfPtr) + } + return ctx, nil +} + +func (m *sliceTableModel) AfterInsert(ctx context.Context) error { + if m.table.hasFlag(afterInsertHookFlag) { + return callAfterInsertHookSlice(ctx, m.slice, m.sliceOfPtr) + } + return nil +} + +func (m *sliceTableModel) BeforeUpdate(ctx context.Context) (context.Context, error) { + if m.table.hasFlag(beforeUpdateHookFlag) && !m.IsNil() { + return callBeforeUpdateHookSlice(ctx, m.slice, m.sliceOfPtr) + } + return ctx, nil +} + +func (m *sliceTableModel) AfterUpdate(ctx context.Context) error { + if m.table.hasFlag(afterUpdateHookFlag) { + return callAfterUpdateHookSlice(ctx, m.slice, m.sliceOfPtr) + } + return nil +} + +func (m *sliceTableModel) BeforeDelete(ctx context.Context) (context.Context, error) { + if m.table.hasFlag(beforeDeleteHookFlag) && !m.IsNil() { + return callBeforeDeleteHookSlice(ctx, m.slice, m.sliceOfPtr) + } + return ctx, nil +} + +func (m *sliceTableModel) AfterDelete(ctx context.Context) error { + if m.table.hasFlag(afterDeleteHookFlag) && !m.IsNil() { + return callAfterDeleteHookSlice(ctx, m.slice, m.sliceOfPtr) + } + return nil +} + +func (m *sliceTableModel) setSoftDeleteField() error { + sliceLen := m.slice.Len() + for i := 0; i < sliceLen; i++ { + strct := indirect(m.slice.Index(i)) + fv := m.table.SoftDeleteField.Value(strct) + if err := m.table.SetSoftDeleteField(fv); err != nil { + return err + } + } + return nil +} diff --git a/vendor/github.com/go-pg/pg/v10/orm/model_table_struct.go b/vendor/github.com/go-pg/pg/v10/orm/model_table_struct.go new file mode 100644 index 000000000..fce7cc6b7 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/orm/model_table_struct.go @@ -0,0 +1,399 @@ +package orm + +import ( + "context" + "fmt" + "reflect" + "strings" + + "github.com/go-pg/pg/v10/types" +) + +type structTableModel struct { + table *Table + rel *Relation + joins []join + + root reflect.Value + index []int + + strct reflect.Value + structInited bool + structInitErr error +} + +var _ TableModel = (*structTableModel)(nil) + +func newStructTableModel(table *Table) *structTableModel { + return &structTableModel{ + table: table, + } +} + +func newStructTableModelValue(v reflect.Value) *structTableModel { + return &structTableModel{ + table: GetTable(v.Type()), + root: v, + strct: v, + } +} + +func (*structTableModel) useQueryOne() bool { + return true +} + +func (m *structTableModel) String() string { + return m.table.String() +} + +func (m *structTableModel) IsNil() bool { + return !m.strct.IsValid() +} + +func (m *structTableModel) Table() *Table { + return m.table +} + +func (m *structTableModel) Relation() *Relation { + return m.rel +} + +func (m *structTableModel) AppendParam(fmter QueryFormatter, b []byte, name string) ([]byte, bool) { + b, ok := m.table.AppendParam(b, m.strct, name) + if ok { + return b, true + } + + switch name { + case "TableName": + b = fmter.FormatQuery(b, string(m.table.SQLName)) + return b, true + case "TableAlias": + b = append(b, m.table.Alias...) + return b, true + case "TableColumns": + b = appendColumns(b, m.table.Alias, m.table.Fields) + return b, true + case "Columns": + b = appendColumns(b, "", m.table.Fields) + return b, true + case "TablePKs": + b = appendColumns(b, m.table.Alias, m.table.PKs) + return b, true + case "PKs": + b = appendColumns(b, "", m.table.PKs) + return b, true + } + + return b, false +} + +func (m *structTableModel) Root() reflect.Value { + return m.root +} + +func (m *structTableModel) Index() []int { + return m.index +} + +func (m *structTableModel) ParentIndex() []int { + return m.index[:len(m.index)-len(m.rel.Field.Index)] +} + +func (m *structTableModel) Kind() reflect.Kind { + return reflect.Struct +} + +func (m *structTableModel) Value() reflect.Value { + return m.strct +} + +func (m *structTableModel) Mount(host reflect.Value) { + m.strct = host.FieldByIndex(m.rel.Field.Index) + m.structInited = false +} + +func (m *structTableModel) initStruct() error { + if m.structInited { + return m.structInitErr + } + m.structInited = true + + switch m.strct.Kind() { + case reflect.Invalid: + m.structInitErr = errModelNil + return m.structInitErr + case reflect.Interface: + m.strct = m.strct.Elem() + } + + if m.strct.Kind() == reflect.Ptr { + if m.strct.IsNil() { + m.strct.Set(reflect.New(m.strct.Type().Elem())) + m.strct = m.strct.Elem() + } else { + m.strct = m.strct.Elem() + } + } + + m.mountJoins() + + return nil +} + +func (m *structTableModel) mountJoins() { + for i := range m.joins { + j := &m.joins[i] + switch j.Rel.Type { + case HasOneRelation, BelongsToRelation: + j.JoinModel.Mount(m.strct) + } + } +} + +func (structTableModel) Init() error { + return nil +} + +func (m *structTableModel) NextColumnScanner() ColumnScanner { + return m +} + +func (m *structTableModel) AddColumnScanner(_ ColumnScanner) error { + return nil +} + +var _ BeforeScanHook = (*structTableModel)(nil) + +func (m *structTableModel) BeforeScan(ctx context.Context) error { + if !m.table.hasFlag(beforeScanHookFlag) { + return nil + } + return callBeforeScanHook(ctx, m.strct.Addr()) +} + +var _ AfterScanHook = (*structTableModel)(nil) + +func (m *structTableModel) AfterScan(ctx context.Context) error { + if !m.table.hasFlag(afterScanHookFlag) || !m.structInited { + return nil + } + + var firstErr error + + if err := callAfterScanHook(ctx, m.strct.Addr()); err != nil && firstErr == nil { + firstErr = err + } + + for _, j := range m.joins { + switch j.Rel.Type { + case HasOneRelation, BelongsToRelation: + if err := j.JoinModel.AfterScan(ctx); err != nil && firstErr == nil { + firstErr = err + } + } + } + + return firstErr +} + +func (m *structTableModel) AfterSelect(ctx context.Context) error { + if m.table.hasFlag(afterSelectHookFlag) { + return callAfterSelectHook(ctx, m.strct.Addr()) + } + return nil +} + +func (m *structTableModel) BeforeInsert(ctx context.Context) (context.Context, error) { + if m.table.hasFlag(beforeInsertHookFlag) { + return callBeforeInsertHook(ctx, m.strct.Addr()) + } + return ctx, nil +} + +func (m *structTableModel) AfterInsert(ctx context.Context) error { + if m.table.hasFlag(afterInsertHookFlag) { + return callAfterInsertHook(ctx, m.strct.Addr()) + } + return nil +} + +func (m *structTableModel) BeforeUpdate(ctx context.Context) (context.Context, error) { + if m.table.hasFlag(beforeUpdateHookFlag) && !m.IsNil() { + return callBeforeUpdateHook(ctx, m.strct.Addr()) + } + return ctx, nil +} + +func (m *structTableModel) AfterUpdate(ctx context.Context) error { + if m.table.hasFlag(afterUpdateHookFlag) && !m.IsNil() { + return callAfterUpdateHook(ctx, m.strct.Addr()) + } + return nil +} + +func (m *structTableModel) BeforeDelete(ctx context.Context) (context.Context, error) { + if m.table.hasFlag(beforeDeleteHookFlag) && !m.IsNil() { + return callBeforeDeleteHook(ctx, m.strct.Addr()) + } + return ctx, nil +} + +func (m *structTableModel) AfterDelete(ctx context.Context) error { + if m.table.hasFlag(afterDeleteHookFlag) && !m.IsNil() { + return callAfterDeleteHook(ctx, m.strct.Addr()) + } + return nil +} + +func (m *structTableModel) ScanColumn( + col types.ColumnInfo, rd types.Reader, n int, +) error { + ok, err := m.scanColumn(col, rd, n) + if ok { + return err + } + if m.table.hasFlag(discardUnknownColumnsFlag) || col.Name[0] == '_' { + return nil + } + return fmt.Errorf( + "pg: can't find column=%s in %s "+ + "(prefix the column with underscore or use discard_unknown_columns)", + col.Name, m.table, + ) +} + +func (m *structTableModel) scanColumn(col types.ColumnInfo, rd types.Reader, n int) (bool, error) { + // Don't init nil struct if value is NULL. + if n == -1 && + !m.structInited && + m.strct.Kind() == reflect.Ptr && + m.strct.IsNil() { + return true, nil + } + + if err := m.initStruct(); err != nil { + return true, err + } + + joinName, fieldName := splitColumn(col.Name) + if joinName != "" { + if join := m.GetJoin(joinName); join != nil { + joinCol := col + joinCol.Name = fieldName + return join.JoinModel.scanColumn(joinCol, rd, n) + } + if m.table.ModelName == joinName { + joinCol := col + joinCol.Name = fieldName + return m.scanColumn(joinCol, rd, n) + } + } + + field, ok := m.table.FieldsMap[col.Name] + if !ok { + return false, nil + } + + return true, field.ScanValue(m.strct, rd, n) +} + +func (m *structTableModel) GetJoin(name string) *join { + for i := range m.joins { + j := &m.joins[i] + if j.Rel.Field.GoName == name || j.Rel.Field.SQLName == name { + return j + } + } + return nil +} + +func (m *structTableModel) GetJoins() []join { + return m.joins +} + +func (m *structTableModel) AddJoin(j join) *join { + m.joins = append(m.joins, j) + return &m.joins[len(m.joins)-1] +} + +func (m *structTableModel) Join(name string, apply func(*Query) (*Query, error)) *join { + return m.join(m.Value(), name, apply) +} + +func (m *structTableModel) join( + bind reflect.Value, name string, apply func(*Query) (*Query, error), +) *join { + path := strings.Split(name, ".") + index := make([]int, 0, len(path)) + + currJoin := join{ + BaseModel: m, + JoinModel: m, + } + var lastJoin *join + var hasColumnName bool + + for _, name := range path { + rel, ok := currJoin.JoinModel.Table().Relations[name] + if !ok { + hasColumnName = true + break + } + + currJoin.Rel = rel + index = append(index, rel.Field.Index...) + + if j := currJoin.JoinModel.GetJoin(name); j != nil { + currJoin.BaseModel = j.BaseModel + currJoin.JoinModel = j.JoinModel + + lastJoin = j + } else { + model, err := newTableModelIndex(m.table.Type, bind, index, rel) + if err != nil { + return nil + } + + currJoin.Parent = lastJoin + currJoin.BaseModel = currJoin.JoinModel + currJoin.JoinModel = model + + lastJoin = currJoin.BaseModel.AddJoin(currJoin) + } + } + + // No joins with such name. + if lastJoin == nil { + return nil + } + if apply != nil { + lastJoin.ApplyQuery = apply + } + + if hasColumnName { + column := path[len(path)-1] + if column == "_" { + if lastJoin.Columns == nil { + lastJoin.Columns = make([]string, 0) + } + } else { + lastJoin.Columns = append(lastJoin.Columns, column) + } + } + + return lastJoin +} + +func (m *structTableModel) setSoftDeleteField() error { + fv := m.table.SoftDeleteField.Value(m.strct) + return m.table.SetSoftDeleteField(fv) +} + +func splitColumn(s string) (string, string) { + ind := strings.Index(s, "__") + if ind == -1 { + return "", s + } + return s[:ind], s[ind+2:] +} diff --git a/vendor/github.com/go-pg/pg/v10/orm/msgpack.go b/vendor/github.com/go-pg/pg/v10/orm/msgpack.go new file mode 100644 index 000000000..56c88a23e --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/orm/msgpack.go @@ -0,0 +1,52 @@ +package orm + +import ( + "reflect" + + "github.com/vmihailenco/msgpack/v5" + + "github.com/go-pg/pg/v10/types" +) + +func msgpackAppender(_ reflect.Type) types.AppenderFunc { + return func(b []byte, v reflect.Value, flags int) []byte { + hexEnc := types.NewHexEncoder(b, flags) + + enc := msgpack.GetEncoder() + defer msgpack.PutEncoder(enc) + + enc.Reset(hexEnc) + if err := enc.EncodeValue(v); err != nil { + return types.AppendError(b, err) + } + + if err := hexEnc.Close(); err != nil { + return types.AppendError(b, err) + } + + return hexEnc.Bytes() + } +} + +func msgpackScanner(_ reflect.Type) types.ScannerFunc { + return func(v reflect.Value, rd types.Reader, n int) error { + if n <= 0 { + return nil + } + + hexDec, err := types.NewHexDecoder(rd, n) + if err != nil { + return err + } + + dec := msgpack.GetDecoder() + defer msgpack.PutDecoder(dec) + + dec.Reset(hexDec) + if err := dec.DecodeValue(v); err != nil { + return err + } + + return nil + } +} diff --git a/vendor/github.com/go-pg/pg/v10/orm/orm.go b/vendor/github.com/go-pg/pg/v10/orm/orm.go new file mode 100644 index 000000000..d18993d2d --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/orm/orm.go @@ -0,0 +1,58 @@ +/* +The API in this package is not stable and may change without any notice. +*/ +package orm + +import ( + "context" + "io" + + "github.com/go-pg/pg/v10/types" +) + +// ColumnScanner is used to scan column values. +type ColumnScanner interface { + // Scan assigns a column value from a row. + // + // An error should be returned if the value can not be stored + // without loss of information. + ScanColumn(col types.ColumnInfo, rd types.Reader, n int) error +} + +type QueryAppender interface { + AppendQuery(fmter QueryFormatter, b []byte) ([]byte, error) +} + +type TemplateAppender interface { + AppendTemplate(b []byte) ([]byte, error) +} + +type QueryCommand interface { + QueryAppender + TemplateAppender + String() string + Operation() QueryOp + Clone() QueryCommand + Query() *Query +} + +// DB is a common interface for pg.DB and pg.Tx types. +type DB interface { + Model(model ...interface{}) *Query + ModelContext(c context.Context, model ...interface{}) *Query + + Exec(query interface{}, params ...interface{}) (Result, error) + ExecContext(c context.Context, query interface{}, params ...interface{}) (Result, error) + ExecOne(query interface{}, params ...interface{}) (Result, error) + ExecOneContext(c context.Context, query interface{}, params ...interface{}) (Result, error) + Query(model, query interface{}, params ...interface{}) (Result, error) + QueryContext(c context.Context, model, query interface{}, params ...interface{}) (Result, error) + QueryOne(model, query interface{}, params ...interface{}) (Result, error) + QueryOneContext(c context.Context, model, query interface{}, params ...interface{}) (Result, error) + + CopyFrom(r io.Reader, query interface{}, params ...interface{}) (Result, error) + CopyTo(w io.Writer, query interface{}, params ...interface{}) (Result, error) + + Context() context.Context + Formatter() QueryFormatter +} diff --git a/vendor/github.com/go-pg/pg/v10/orm/query.go b/vendor/github.com/go-pg/pg/v10/orm/query.go new file mode 100644 index 000000000..8a9231f65 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/orm/query.go @@ -0,0 +1,1680 @@ +package orm + +import ( + "context" + "errors" + "fmt" + "io" + "reflect" + "strconv" + "strings" + "sync" + "time" + + "github.com/go-pg/pg/v10/internal" + "github.com/go-pg/pg/v10/types" +) + +type QueryOp string + +const ( + SelectOp QueryOp = "SELECT" + InsertOp QueryOp = "INSERT" + UpdateOp QueryOp = "UPDATE" + DeleteOp QueryOp = "DELETE" + CreateTableOp QueryOp = "CREATE TABLE" + DropTableOp QueryOp = "DROP TABLE" + CreateCompositeOp QueryOp = "CREATE COMPOSITE" + DropCompositeOp QueryOp = "DROP COMPOSITE" +) + +type queryFlag uint8 + +const ( + implicitModelFlag queryFlag = 1 << iota + deletedFlag + allWithDeletedFlag +) + +type withQuery struct { + name string + query QueryAppender +} + +type columnValue struct { + column string + value *SafeQueryAppender +} + +type union struct { + expr string + query *Query +} + +type Query struct { + ctx context.Context + db DB + stickyErr error + + model Model + tableModel TableModel + flags queryFlag + + with []withQuery + tables []QueryAppender + distinctOn []*SafeQueryAppender + columns []QueryAppender + set []QueryAppender + modelValues map[string]*SafeQueryAppender + extraValues []*columnValue + where []queryWithSepAppender + updWhere []queryWithSepAppender + group []QueryAppender + having []*SafeQueryAppender + union []*union + joins []QueryAppender + joinAppendOn func(app *condAppender) + order []QueryAppender + limit int + offset int + selFor *SafeQueryAppender + + onConflict *SafeQueryAppender + returning []*SafeQueryAppender +} + +func NewQuery(db DB, model ...interface{}) *Query { + ctx := context.Background() + if db != nil { + ctx = db.Context() + } + q := &Query{ctx: ctx} + return q.DB(db).Model(model...) +} + +func NewQueryContext(ctx context.Context, db DB, model ...interface{}) *Query { + return NewQuery(db, model...).Context(ctx) +} + +// New returns new zero Query bound to the current db. +func (q *Query) New() *Query { + clone := &Query{ + ctx: q.ctx, + db: q.db, + + model: q.model, + tableModel: cloneTableModelJoins(q.tableModel), + flags: q.flags, + } + return clone.withFlag(implicitModelFlag) +} + +// Clone clones the Query. +func (q *Query) Clone() *Query { + var modelValues map[string]*SafeQueryAppender + if len(q.modelValues) > 0 { + modelValues = make(map[string]*SafeQueryAppender, len(q.modelValues)) + for k, v := range q.modelValues { + modelValues[k] = v + } + } + + clone := &Query{ + ctx: q.ctx, + db: q.db, + stickyErr: q.stickyErr, + + model: q.model, + tableModel: cloneTableModelJoins(q.tableModel), + flags: q.flags, + + with: q.with[:len(q.with):len(q.with)], + tables: q.tables[:len(q.tables):len(q.tables)], + distinctOn: q.distinctOn[:len(q.distinctOn):len(q.distinctOn)], + columns: q.columns[:len(q.columns):len(q.columns)], + set: q.set[:len(q.set):len(q.set)], + modelValues: modelValues, + extraValues: q.extraValues[:len(q.extraValues):len(q.extraValues)], + where: q.where[:len(q.where):len(q.where)], + updWhere: q.updWhere[:len(q.updWhere):len(q.updWhere)], + joins: q.joins[:len(q.joins):len(q.joins)], + group: q.group[:len(q.group):len(q.group)], + having: q.having[:len(q.having):len(q.having)], + union: q.union[:len(q.union):len(q.union)], + order: q.order[:len(q.order):len(q.order)], + limit: q.limit, + offset: q.offset, + selFor: q.selFor, + + onConflict: q.onConflict, + returning: q.returning[:len(q.returning):len(q.returning)], + } + + return clone +} + +func cloneTableModelJoins(tm TableModel) TableModel { + switch tm := tm.(type) { + case *structTableModel: + if len(tm.joins) == 0 { + return tm + } + clone := *tm + clone.joins = clone.joins[:len(clone.joins):len(clone.joins)] + return &clone + case *sliceTableModel: + if len(tm.joins) == 0 { + return tm + } + clone := *tm + clone.joins = clone.joins[:len(clone.joins):len(clone.joins)] + return &clone + } + return tm +} + +func (q *Query) err(err error) *Query { + if q.stickyErr == nil { + q.stickyErr = err + } + return q +} + +func (q *Query) hasFlag(flag queryFlag) bool { + return hasFlag(q.flags, flag) +} + +func hasFlag(flags, flag queryFlag) bool { + return flags&flag != 0 +} + +func (q *Query) withFlag(flag queryFlag) *Query { + q.flags |= flag + return q +} + +func (q *Query) withoutFlag(flag queryFlag) *Query { + q.flags &= ^flag + return q +} + +func (q *Query) Context(c context.Context) *Query { + q.ctx = c + return q +} + +func (q *Query) DB(db DB) *Query { + q.db = db + return q +} + +func (q *Query) Model(model ...interface{}) *Query { + var err error + switch l := len(model); { + case l == 0: + q.model = nil + case l == 1: + q.model, err = NewModel(model[0]) + case l > 1: + q.model, err = NewModel(&model) + default: + panic("not reached") + } + if err != nil { + q = q.err(err) + } + + q.tableModel, _ = q.model.(TableModel) + + return q.withoutFlag(implicitModelFlag) +} + +func (q *Query) TableModel() TableModel { + return q.tableModel +} + +func (q *Query) isSoftDelete() bool { + if q.tableModel != nil { + return q.tableModel.Table().SoftDeleteField != nil && !q.hasFlag(allWithDeletedFlag) + } + return false +} + +// Deleted adds `WHERE deleted_at IS NOT NULL` clause for soft deleted models. +func (q *Query) Deleted() *Query { + if q.tableModel != nil { + if err := q.tableModel.Table().mustSoftDelete(); err != nil { + return q.err(err) + } + } + return q.withFlag(deletedFlag).withoutFlag(allWithDeletedFlag) +} + +// AllWithDeleted changes query to return all rows including soft deleted ones. +func (q *Query) AllWithDeleted() *Query { + if q.tableModel != nil { + if err := q.tableModel.Table().mustSoftDelete(); err != nil { + return q.err(err) + } + } + return q.withFlag(allWithDeletedFlag).withoutFlag(deletedFlag) +} + +// With adds subq as common table expression with the given name. +func (q *Query) With(name string, subq *Query) *Query { + return q._with(name, NewSelectQuery(subq)) +} + +func (q *Query) WithInsert(name string, subq *Query) *Query { + return q._with(name, NewInsertQuery(subq)) +} + +func (q *Query) WithUpdate(name string, subq *Query) *Query { + return q._with(name, NewUpdateQuery(subq, false)) +} + +func (q *Query) WithDelete(name string, subq *Query) *Query { + return q._with(name, NewDeleteQuery(subq)) +} + +func (q *Query) _with(name string, subq QueryAppender) *Query { + q.with = append(q.with, withQuery{ + name: name, + query: subq, + }) + return q +} + +// WrapWith creates new Query and adds to it current query as +// common table expression with the given name. +func (q *Query) WrapWith(name string) *Query { + wrapper := q.New() + wrapper.with = q.with + q.with = nil + wrapper = wrapper.With(name, q) + return wrapper +} + +func (q *Query) Table(tables ...string) *Query { + for _, table := range tables { + q.tables = append(q.tables, fieldAppender{table}) + } + return q +} + +func (q *Query) TableExpr(expr string, params ...interface{}) *Query { + q.tables = append(q.tables, SafeQuery(expr, params...)) + return q +} + +func (q *Query) Distinct() *Query { + q.distinctOn = make([]*SafeQueryAppender, 0) + return q +} + +func (q *Query) DistinctOn(expr string, params ...interface{}) *Query { + q.distinctOn = append(q.distinctOn, SafeQuery(expr, params...)) + return q +} + +// Column adds a column to the Query quoting it according to PostgreSQL rules. +// Does not expand params like ?TableAlias etc. +// ColumnExpr can be used to bypass quoting restriction or for params expansion. +// Column name can be: +// - column_name, +// - table_alias.column_name, +// - table_alias.*. +func (q *Query) Column(columns ...string) *Query { + for _, column := range columns { + if column == "_" { + if q.columns == nil { + q.columns = make([]QueryAppender, 0) + } + continue + } + + q.columns = append(q.columns, fieldAppender{column}) + } + return q +} + +// ColumnExpr adds column expression to the Query. +func (q *Query) ColumnExpr(expr string, params ...interface{}) *Query { + q.columns = append(q.columns, SafeQuery(expr, params...)) + return q +} + +// ExcludeColumn excludes a column from the list of to be selected columns. +func (q *Query) ExcludeColumn(columns ...string) *Query { + if q.columns == nil { + for _, f := range q.tableModel.Table().Fields { + q.columns = append(q.columns, fieldAppender{f.SQLName}) + } + } + + for _, col := range columns { + if !q.excludeColumn(col) { + return q.err(fmt.Errorf("pg: can't find column=%q", col)) + } + } + return q +} + +func (q *Query) excludeColumn(column string) bool { + for i := 0; i < len(q.columns); i++ { + app, ok := q.columns[i].(fieldAppender) + if ok && app.field == column { + q.columns = append(q.columns[:i], q.columns[i+1:]...) + return true + } + } + return false +} + +func (q *Query) getFields() ([]*Field, error) { + return q._getFields(false) +} + +func (q *Query) getDataFields() ([]*Field, error) { + return q._getFields(true) +} + +func (q *Query) _getFields(omitPKs bool) ([]*Field, error) { + table := q.tableModel.Table() + columns := make([]*Field, 0, len(q.columns)) + for _, col := range q.columns { + f, ok := col.(fieldAppender) + if !ok { + continue + } + + field, err := table.GetField(f.field) + if err != nil { + return nil, err + } + + if omitPKs && field.hasFlag(PrimaryKeyFlag) { + continue + } + + columns = append(columns, field) + } + return columns, nil +} + +// Relation adds a relation to the query. Relation name can be: +// - RelationName to select all columns, +// - RelationName.column_name, +// - RelationName._ to join relation without selecting relation columns. +func (q *Query) Relation(name string, apply ...func(*Query) (*Query, error)) *Query { + var fn func(*Query) (*Query, error) + if len(apply) == 1 { + fn = apply[0] + } else if len(apply) > 1 { + panic("only one apply function is supported") + } + + join := q.tableModel.Join(name, fn) + if join == nil { + return q.err(fmt.Errorf("%s does not have relation=%q", + q.tableModel.Table(), name)) + } + + if fn == nil { + return q + } + + switch join.Rel.Type { + case HasOneRelation, BelongsToRelation: + q.joinAppendOn = join.AppendOn + return q.Apply(fn) + default: + q.joinAppendOn = nil + return q + } +} + +func (q *Query) Set(set string, params ...interface{}) *Query { + q.set = append(q.set, SafeQuery(set, params...)) + return q +} + +// Value overwrites model value for the column in INSERT and UPDATE queries. +func (q *Query) Value(column string, value string, params ...interface{}) *Query { + if !q.hasTableModel() { + q.err(errModelNil) + return q + } + + table := q.tableModel.Table() + if _, ok := table.FieldsMap[column]; ok { + if q.modelValues == nil { + q.modelValues = make(map[string]*SafeQueryAppender) + } + q.modelValues[column] = SafeQuery(value, params...) + } else { + q.extraValues = append(q.extraValues, &columnValue{ + column: column, + value: SafeQuery(value, params...), + }) + } + + return q +} + +func (q *Query) Where(condition string, params ...interface{}) *Query { + q.addWhere(&condAppender{ + sep: " AND ", + cond: condition, + params: params, + }) + return q +} + +func (q *Query) WhereOr(condition string, params ...interface{}) *Query { + q.addWhere(&condAppender{ + sep: " OR ", + cond: condition, + params: params, + }) + return q +} + +// WhereGroup encloses conditions added in the function in parentheses. +// +// q.Where("TRUE"). +// WhereGroup(func(q *pg.Query) (*pg.Query, error) { +// q = q.WhereOr("FALSE").WhereOr("TRUE"). +// return q, nil +// }) +// +// generates +// +// WHERE TRUE AND (FALSE OR TRUE) +func (q *Query) WhereGroup(fn func(*Query) (*Query, error)) *Query { + return q.whereGroup(" AND ", fn) +} + +// WhereGroup encloses conditions added in the function in parentheses. +// +// q.Where("TRUE"). +// WhereNotGroup(func(q *pg.Query) (*pg.Query, error) { +// q = q.WhereOr("FALSE").WhereOr("TRUE"). +// return q, nil +// }) +// +// generates +// +// WHERE TRUE AND NOT (FALSE OR TRUE) +func (q *Query) WhereNotGroup(fn func(*Query) (*Query, error)) *Query { + return q.whereGroup(" AND NOT ", fn) +} + +// WhereOrGroup encloses conditions added in the function in parentheses. +// +// q.Where("TRUE"). +// WhereOrGroup(func(q *pg.Query) (*pg.Query, error) { +// q = q.Where("FALSE").Where("TRUE"). +// return q, nil +// }) +// +// generates +// +// WHERE TRUE OR (FALSE AND TRUE) +func (q *Query) WhereOrGroup(fn func(*Query) (*Query, error)) *Query { + return q.whereGroup(" OR ", fn) +} + +// WhereOrGroup encloses conditions added in the function in parentheses. +// +// q.Where("TRUE"). +// WhereOrGroup(func(q *pg.Query) (*pg.Query, error) { +// q = q.Where("FALSE").Where("TRUE"). +// return q, nil +// }) +// +// generates +// +// WHERE TRUE OR NOT (FALSE AND TRUE) +func (q *Query) WhereOrNotGroup(fn func(*Query) (*Query, error)) *Query { + return q.whereGroup(" OR NOT ", fn) +} + +func (q *Query) whereGroup(conj string, fn func(*Query) (*Query, error)) *Query { + saved := q.where + q.where = nil + + newq, err := fn(q) + if err != nil { + q.err(err) + return q + } + + if len(newq.where) == 0 { + newq.where = saved + return newq + } + + f := &condGroupAppender{ + sep: conj, + cond: newq.where, + } + newq.where = saved + newq.addWhere(f) + + return newq +} + +// WhereIn is a shortcut for Where and pg.In. +func (q *Query) WhereIn(where string, slice interface{}) *Query { + return q.Where(where, types.In(slice)) +} + +// WhereInMulti is a shortcut for Where and pg.InMulti. +func (q *Query) WhereInMulti(where string, values ...interface{}) *Query { + return q.Where(where, types.InMulti(values...)) +} + +func (q *Query) addWhere(f queryWithSepAppender) { + if q.onConflictDoUpdate() { + q.updWhere = append(q.updWhere, f) + } else { + q.where = append(q.where, f) + } +} + +// WherePK adds condition based on the model primary keys. +// Usually it is the same as: +// +// Where("id = ?id") +func (q *Query) WherePK() *Query { + if !q.hasTableModel() { + q.err(errModelNil) + return q + } + + if err := q.tableModel.Table().checkPKs(); err != nil { + q.err(err) + return q + } + + switch q.tableModel.Kind() { + case reflect.Struct: + q.where = append(q.where, wherePKStructQuery{q}) + return q + case reflect.Slice: + q.joins = append(q.joins, joinPKSliceQuery{q: q}) + q.where = append(q.where, wherePKSliceQuery{q: q}) + q = q.OrderExpr(`"_data"."ordering" ASC`) + return q + } + + panic("not reached") +} + +func (q *Query) Join(join string, params ...interface{}) *Query { + j := &joinQuery{ + join: SafeQuery(join, params...), + } + q.joins = append(q.joins, j) + q.joinAppendOn = j.AppendOn + return q +} + +// JoinOn appends join condition to the last join. +func (q *Query) JoinOn(condition string, params ...interface{}) *Query { + if q.joinAppendOn == nil { + q.err(errors.New("pg: no joins to apply JoinOn")) + return q + } + q.joinAppendOn(&condAppender{ + sep: " AND ", + cond: condition, + params: params, + }) + return q +} + +func (q *Query) JoinOnOr(condition string, params ...interface{}) *Query { + if q.joinAppendOn == nil { + q.err(errors.New("pg: no joins to apply JoinOn")) + return q + } + q.joinAppendOn(&condAppender{ + sep: " OR ", + cond: condition, + params: params, + }) + return q +} + +func (q *Query) Group(columns ...string) *Query { + for _, column := range columns { + q.group = append(q.group, fieldAppender{column}) + } + return q +} + +func (q *Query) GroupExpr(group string, params ...interface{}) *Query { + q.group = append(q.group, SafeQuery(group, params...)) + return q +} + +func (q *Query) Having(having string, params ...interface{}) *Query { + q.having = append(q.having, SafeQuery(having, params...)) + return q +} + +func (q *Query) Union(other *Query) *Query { + return q.addUnion(" UNION ", other) +} + +func (q *Query) UnionAll(other *Query) *Query { + return q.addUnion(" UNION ALL ", other) +} + +func (q *Query) Intersect(other *Query) *Query { + return q.addUnion(" INTERSECT ", other) +} + +func (q *Query) IntersectAll(other *Query) *Query { + return q.addUnion(" INTERSECT ALL ", other) +} + +func (q *Query) Except(other *Query) *Query { + return q.addUnion(" EXCEPT ", other) +} + +func (q *Query) ExceptAll(other *Query) *Query { + return q.addUnion(" EXCEPT ALL ", other) +} + +func (q *Query) addUnion(expr string, other *Query) *Query { + q.union = append(q.union, &union{ + expr: expr, + query: other, + }) + return q +} + +// Order adds sort order to the Query quoting column name. Does not expand params like ?TableAlias etc. +// OrderExpr can be used to bypass quoting restriction or for params expansion. +func (q *Query) Order(orders ...string) *Query { +loop: + for _, order := range orders { + if order == "" { + continue + } + ind := strings.Index(order, " ") + if ind != -1 { + field := order[:ind] + sort := order[ind+1:] + switch internal.UpperString(sort) { + case "ASC", "DESC", "ASC NULLS FIRST", "DESC NULLS FIRST", + "ASC NULLS LAST", "DESC NULLS LAST": + q = q.OrderExpr("? ?", types.Ident(field), types.Safe(sort)) + continue loop + } + } + + q.order = append(q.order, fieldAppender{order}) + } + return q +} + +// Order adds sort order to the Query. +func (q *Query) OrderExpr(order string, params ...interface{}) *Query { + if order != "" { + q.order = append(q.order, SafeQuery(order, params...)) + } + return q +} + +func (q *Query) Limit(n int) *Query { + q.limit = n + return q +} + +func (q *Query) Offset(n int) *Query { + q.offset = n + return q +} + +func (q *Query) OnConflict(s string, params ...interface{}) *Query { + q.onConflict = SafeQuery(s, params...) + return q +} + +func (q *Query) onConflictDoUpdate() bool { + return q.onConflict != nil && + strings.HasSuffix(internal.UpperString(q.onConflict.query), "DO UPDATE") +} + +// Returning adds a RETURNING clause to the query. +// +// `Returning("NULL")` can be used to suppress default returning clause +// generated by go-pg for INSERT queries to get values for null columns. +func (q *Query) Returning(s string, params ...interface{}) *Query { + q.returning = append(q.returning, SafeQuery(s, params...)) + return q +} + +func (q *Query) For(s string, params ...interface{}) *Query { + q.selFor = SafeQuery(s, params...) + return q +} + +// Apply calls the fn passing the Query as an argument. +func (q *Query) Apply(fn func(*Query) (*Query, error)) *Query { + qq, err := fn(q) + if err != nil { + q.err(err) + return q + } + return qq +} + +// Count returns number of rows matching the query using count aggregate function. +func (q *Query) Count() (int, error) { + if q.stickyErr != nil { + return 0, q.stickyErr + } + + var count int + _, err := q.db.QueryOneContext( + q.ctx, Scan(&count), q.countSelectQuery("count(*)"), q.tableModel) + return count, err +} + +func (q *Query) countSelectQuery(column string) *SelectQuery { + return &SelectQuery{ + q: q, + count: column, + } +} + +// First sorts rows by primary key and selects the first row. +// It is a shortcut for: +// +// q.OrderExpr("id ASC").Limit(1) +func (q *Query) First() error { + table := q.tableModel.Table() + + if err := table.checkPKs(); err != nil { + return err + } + + b := appendColumns(nil, table.Alias, table.PKs) + return q.OrderExpr(internal.BytesToString(b)).Limit(1).Select() +} + +// Last sorts rows by primary key and selects the last row. +// It is a shortcut for: +// +// q.OrderExpr("id DESC").Limit(1) +func (q *Query) Last() error { + table := q.tableModel.Table() + + if err := table.checkPKs(); err != nil { + return err + } + + // TODO: fix for multi columns + b := appendColumns(nil, table.Alias, table.PKs) + b = append(b, " DESC"...) + return q.OrderExpr(internal.BytesToString(b)).Limit(1).Select() +} + +// Select selects the model. +func (q *Query) Select(values ...interface{}) error { + if q.stickyErr != nil { + return q.stickyErr + } + + model, err := q.newModel(values) + if err != nil { + return err + } + + res, err := q.query(q.ctx, model, NewSelectQuery(q)) + if err != nil { + return err + } + + if res.RowsReturned() > 0 { + if q.tableModel != nil { + if err := q.selectJoins(q.tableModel.GetJoins()); err != nil { + return err + } + } + } + + if err := model.AfterSelect(q.ctx); err != nil { + return err + } + + return nil +} + +func (q *Query) newModel(values []interface{}) (Model, error) { + if len(values) > 0 { + return newScanModel(values) + } + return q.tableModel, nil +} + +func (q *Query) query(ctx context.Context, model Model, query interface{}) (Result, error) { + if _, ok := model.(useQueryOne); ok { + return q.db.QueryOneContext(ctx, model, query, q.tableModel) + } + return q.db.QueryContext(ctx, model, query, q.tableModel) +} + +// SelectAndCount runs Select and Count in two goroutines, +// waits for them to finish and returns the result. If query limit is -1 +// it does not select any data and only counts the results. +func (q *Query) SelectAndCount(values ...interface{}) (count int, firstErr error) { + if q.stickyErr != nil { + return 0, q.stickyErr + } + + var wg sync.WaitGroup + var mu sync.Mutex + + if q.limit >= 0 { + wg.Add(1) + go func() { + defer wg.Done() + err := q.Select(values...) + if err != nil { + mu.Lock() + if firstErr == nil { + firstErr = err + } + mu.Unlock() + } + }() + } + + wg.Add(1) + go func() { + defer wg.Done() + var err error + count, err = q.Count() + if err != nil { + mu.Lock() + if firstErr == nil { + firstErr = err + } + mu.Unlock() + } + }() + + wg.Wait() + return count, firstErr +} + +// SelectAndCountEstimate runs Select and CountEstimate in two goroutines, +// waits for them to finish and returns the result. If query limit is -1 +// it does not select any data and only counts the results. +func (q *Query) SelectAndCountEstimate(threshold int, values ...interface{}) (count int, firstErr error) { + if q.stickyErr != nil { + return 0, q.stickyErr + } + + var wg sync.WaitGroup + var mu sync.Mutex + + if q.limit >= 0 { + wg.Add(1) + go func() { + defer wg.Done() + err := q.Select(values...) + if err != nil { + mu.Lock() + if firstErr == nil { + firstErr = err + } + mu.Unlock() + } + }() + } + + wg.Add(1) + go func() { + defer wg.Done() + var err error + count, err = q.CountEstimate(threshold) + if err != nil { + mu.Lock() + if firstErr == nil { + firstErr = err + } + mu.Unlock() + } + }() + + wg.Wait() + return count, firstErr +} + +// ForEach calls the function for each row returned by the query +// without loading all rows into the memory. +// +// Function can accept a struct, a pointer to a struct, an orm.Model, +// or values for the columns in a row. Function must return an error. +func (q *Query) ForEach(fn interface{}) error { + m := newFuncModel(fn) + return q.Select(m) +} + +func (q *Query) forEachHasOneJoin(fn func(*join) error) error { + if q.tableModel == nil { + return nil + } + return q._forEachHasOneJoin(fn, q.tableModel.GetJoins()) +} + +func (q *Query) _forEachHasOneJoin(fn func(*join) error, joins []join) error { + for i := range joins { + j := &joins[i] + switch j.Rel.Type { + case HasOneRelation, BelongsToRelation: + err := fn(j) + if err != nil { + return err + } + + err = q._forEachHasOneJoin(fn, j.JoinModel.GetJoins()) + if err != nil { + return err + } + } + } + return nil +} + +func (q *Query) selectJoins(joins []join) error { + var err error + for i := range joins { + j := &joins[i] + if j.Rel.Type == HasOneRelation || j.Rel.Type == BelongsToRelation { + err = q.selectJoins(j.JoinModel.GetJoins()) + } else { + err = j.Select(q.db.Formatter(), q.New()) + } + if err != nil { + return err + } + } + return nil +} + +// Insert inserts the model. +func (q *Query) Insert(values ...interface{}) (Result, error) { + if q.stickyErr != nil { + return nil, q.stickyErr + } + + model, err := q.newModel(values) + if err != nil { + return nil, err + } + + ctx := q.ctx + + if q.tableModel != nil && q.tableModel.Table().hasFlag(beforeInsertHookFlag) { + ctx, err = q.tableModel.BeforeInsert(ctx) + if err != nil { + return nil, err + } + } + + query := NewInsertQuery(q) + res, err := q.returningQuery(ctx, model, query) + if err != nil { + return nil, err + } + + if q.tableModel != nil { + if err := q.tableModel.AfterInsert(ctx); err != nil { + return nil, err + } + } + + return res, nil +} + +// SelectOrInsert selects the model inserting one if it does not exist. +// It returns true when model was inserted. +func (q *Query) SelectOrInsert(values ...interface{}) (inserted bool, _ error) { + if q.stickyErr != nil { + return false, q.stickyErr + } + + var insertq *Query + var insertErr error + for i := 0; i < 5; i++ { + if i >= 2 { + dur := internal.RetryBackoff(i-2, 250*time.Millisecond, 5*time.Second) + if err := internal.Sleep(q.ctx, dur); err != nil { + return false, err + } + } + + err := q.Select(values...) + if err == nil { + return false, nil + } + if err != internal.ErrNoRows { + return false, err + } + + if insertq == nil { + insertq = q + if len(insertq.columns) > 0 { + insertq = insertq.Clone() + insertq.columns = nil + } + } + + res, err := insertq.Insert(values...) + if err != nil { + insertErr = err + if err == internal.ErrNoRows { + continue + } + if pgErr, ok := err.(internal.PGError); ok { + if pgErr.IntegrityViolation() { + continue + } + if pgErr.Field('C') == "55000" { + // Retry on "#55000 attempted to delete invisible tuple". + continue + } + } + return false, err + } + if res.RowsAffected() == 1 { + return true, nil + } + } + + err := fmt.Errorf( + "pg: SelectOrInsert: select returns no rows (insert fails with err=%q)", + insertErr) + return false, err +} + +// Update updates the model. +func (q *Query) Update(scan ...interface{}) (Result, error) { + return q.update(scan, false) +} + +// Update updates the model omitting fields with zero values such as: +// - empty string, +// - 0, +// - zero time, +// - empty map or slice, +// - byte array with all zeroes, +// - nil ptr, +// - types with method `IsZero() == true`. +func (q *Query) UpdateNotZero(scan ...interface{}) (Result, error) { + return q.update(scan, true) +} + +func (q *Query) update(values []interface{}, omitZero bool) (Result, error) { + if q.stickyErr != nil { + return nil, q.stickyErr + } + + model, err := q.newModel(values) + if err != nil { + return nil, err + } + + c := q.ctx + + if q.tableModel != nil { + c, err = q.tableModel.BeforeUpdate(c) + if err != nil { + return nil, err + } + } + + query := NewUpdateQuery(q, omitZero) + res, err := q.returningQuery(c, model, query) + if err != nil { + return nil, err + } + + if q.tableModel != nil { + err = q.tableModel.AfterUpdate(c) + if err != nil { + return nil, err + } + } + + return res, nil +} + +func (q *Query) returningQuery(c context.Context, model Model, query interface{}) (Result, error) { + if !q.hasReturning() { + return q.db.QueryContext(c, model, query, q.tableModel) + } + if _, ok := model.(useQueryOne); ok { + return q.db.QueryOneContext(c, model, query, q.tableModel) + } + return q.db.QueryContext(c, model, query, q.tableModel) +} + +// Delete deletes the model. When model has deleted_at column the row +// is soft deleted instead. +func (q *Query) Delete(values ...interface{}) (Result, error) { + if q.tableModel == nil { + return q.ForceDelete(values...) + } + + table := q.tableModel.Table() + if table.SoftDeleteField == nil { + return q.ForceDelete(values...) + } + + clone := q.Clone() + if q.tableModel.IsNil() { + if table.SoftDeleteField.SQLType == pgTypeBigint { + clone = clone.Set("? = ?", table.SoftDeleteField.Column, time.Now().UnixNano()) + } else { + clone = clone.Set("? = ?", table.SoftDeleteField.Column, time.Now()) + } + } else { + if err := clone.tableModel.setSoftDeleteField(); err != nil { + return nil, err + } + clone = clone.Column(table.SoftDeleteField.SQLName) + } + return clone.Update(values...) +} + +// Delete forces delete of the model with deleted_at column. +func (q *Query) ForceDelete(values ...interface{}) (Result, error) { + if q.stickyErr != nil { + return nil, q.stickyErr + } + if q.tableModel == nil { + return nil, errModelNil + } + q = q.withFlag(deletedFlag) + + model, err := q.newModel(values) + if err != nil { + return nil, err + } + + ctx := q.ctx + + if q.tableModel != nil { + ctx, err = q.tableModel.BeforeDelete(ctx) + if err != nil { + return nil, err + } + } + + res, err := q.returningQuery(ctx, model, NewDeleteQuery(q)) + if err != nil { + return nil, err + } + + if q.tableModel != nil { + if err := q.tableModel.AfterDelete(ctx); err != nil { + return nil, err + } + } + + return res, nil +} + +func (q *Query) CreateTable(opt *CreateTableOptions) error { + _, err := q.db.ExecContext(q.ctx, NewCreateTableQuery(q, opt)) + return err +} + +func (q *Query) DropTable(opt *DropTableOptions) error { + _, err := q.db.ExecContext(q.ctx, NewDropTableQuery(q, opt)) + return err +} + +func (q *Query) CreateComposite(opt *CreateCompositeOptions) error { + _, err := q.db.ExecContext(q.ctx, NewCreateCompositeQuery(q, opt)) + return err +} + +func (q *Query) DropComposite(opt *DropCompositeOptions) error { + _, err := q.db.ExecContext(q.ctx, NewDropCompositeQuery(q, opt)) + return err +} + +// Exec is an alias for DB.Exec. +func (q *Query) Exec(query interface{}, params ...interface{}) (Result, error) { + params = append(params, q.tableModel) + return q.db.ExecContext(q.ctx, query, params...) +} + +// ExecOne is an alias for DB.ExecOne. +func (q *Query) ExecOne(query interface{}, params ...interface{}) (Result, error) { + params = append(params, q.tableModel) + return q.db.ExecOneContext(q.ctx, query, params...) +} + +// Query is an alias for DB.Query. +func (q *Query) Query(model, query interface{}, params ...interface{}) (Result, error) { + params = append(params, q.tableModel) + return q.db.QueryContext(q.ctx, model, query, params...) +} + +// QueryOne is an alias for DB.QueryOne. +func (q *Query) QueryOne(model, query interface{}, params ...interface{}) (Result, error) { + params = append(params, q.tableModel) + return q.db.QueryOneContext(q.ctx, model, query, params...) +} + +// CopyFrom is an alias from DB.CopyFrom. +func (q *Query) CopyFrom(r io.Reader, query interface{}, params ...interface{}) (Result, error) { + params = append(params, q.tableModel) + return q.db.CopyFrom(r, query, params...) +} + +// CopyTo is an alias from DB.CopyTo. +func (q *Query) CopyTo(w io.Writer, query interface{}, params ...interface{}) (Result, error) { + params = append(params, q.tableModel) + return q.db.CopyTo(w, query, params...) +} + +var _ QueryAppender = (*Query)(nil) + +func (q *Query) AppendQuery(fmter QueryFormatter, b []byte) ([]byte, error) { + return NewSelectQuery(q).AppendQuery(fmter, b) +} + +// Exists returns true or false depending if there are any rows matching the query. +func (q *Query) Exists() (bool, error) { + q = q.Clone() // copy to not change original query + q.columns = []QueryAppender{SafeQuery("1")} + q.order = nil + q.limit = 1 + res, err := q.db.ExecContext(q.ctx, NewSelectQuery(q)) + if err != nil { + return false, err + } + return res.RowsAffected() > 0, nil +} + +func (q *Query) hasTableModel() bool { + return q.tableModel != nil && !q.tableModel.IsNil() +} + +func (q *Query) hasExplicitTableModel() bool { + return q.tableModel != nil && !q.hasFlag(implicitModelFlag) +} + +func (q *Query) modelHasTableName() bool { + return q.hasExplicitTableModel() && q.tableModel.Table().SQLName != "" +} + +func (q *Query) modelHasTableAlias() bool { + return q.hasExplicitTableModel() && q.tableModel.Table().Alias != "" +} + +func (q *Query) hasTables() bool { + return q.modelHasTableName() || len(q.tables) > 0 +} + +func (q *Query) appendFirstTable(fmter QueryFormatter, b []byte) ([]byte, error) { + if q.modelHasTableName() { + return fmter.FormatQuery(b, string(q.tableModel.Table().SQLName)), nil + } + if len(q.tables) > 0 { + return q.tables[0].AppendQuery(fmter, b) + } + return b, nil +} + +func (q *Query) appendFirstTableWithAlias(fmter QueryFormatter, b []byte) (_ []byte, err error) { + if q.modelHasTableName() { + table := q.tableModel.Table() + b = fmter.FormatQuery(b, string(table.SQLName)) + if table.Alias != table.SQLName { + b = append(b, " AS "...) + b = append(b, table.Alias...) + } + return b, nil + } + + if len(q.tables) > 0 { + b, err = q.tables[0].AppendQuery(fmter, b) + if err != nil { + return nil, err + } + if q.modelHasTableAlias() { + table := q.tableModel.Table() + if table.Alias != table.SQLName { + b = append(b, " AS "...) + b = append(b, table.Alias...) + } + } + } + + return b, nil +} + +func (q *Query) hasMultiTables() bool { + if q.modelHasTableName() { + return len(q.tables) >= 1 + } + return len(q.tables) >= 2 +} + +func (q *Query) appendOtherTables(fmter QueryFormatter, b []byte) (_ []byte, err error) { + tables := q.tables + if !q.modelHasTableName() { + tables = tables[1:] + } + for i, f := range tables { + if i > 0 { + b = append(b, ", "...) + } + b, err = f.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + return b, nil +} + +func (q *Query) appendColumns(fmter QueryFormatter, b []byte) (_ []byte, err error) { + for i, f := range q.columns { + if i > 0 { + b = append(b, ", "...) + } + b, err = f.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + return b, nil +} + +func (q *Query) mustAppendWhere(fmter QueryFormatter, b []byte) ([]byte, error) { + if len(q.where) == 0 { + err := errors.New( + "pg: Update and Delete queries require Where clause (try WherePK)") + return nil, err + } + return q.appendWhere(fmter, b) +} + +func (q *Query) appendWhere(fmter QueryFormatter, b []byte) (_ []byte, err error) { + isSoftDelete := q.isSoftDelete() + + if len(q.where) > 0 { + if isSoftDelete { + b = append(b, '(') + } + + b, err = q._appendWhere(fmter, b, q.where) + if err != nil { + return nil, err + } + + if isSoftDelete { + b = append(b, ')') + } + } + + if isSoftDelete { + if len(q.where) > 0 { + b = append(b, " AND "...) + } + b = append(b, q.tableModel.Table().Alias...) + b = q.appendSoftDelete(b) + } + + return b, nil +} + +func (q *Query) appendSoftDelete(b []byte) []byte { + b = append(b, '.') + b = append(b, q.tableModel.Table().SoftDeleteField.Column...) + if q.hasFlag(deletedFlag) { + b = append(b, " IS NOT NULL"...) + } else { + b = append(b, " IS NULL"...) + } + return b +} + +func (q *Query) appendUpdWhere(fmter QueryFormatter, b []byte) ([]byte, error) { + return q._appendWhere(fmter, b, q.updWhere) +} + +func (q *Query) _appendWhere( + fmter QueryFormatter, b []byte, where []queryWithSepAppender, +) (_ []byte, err error) { + for i, f := range where { + start := len(b) + + if i > 0 { + b = f.AppendSep(b) + } + + before := len(b) + + b, err = f.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + + if len(b) == before { + b = b[:start] + } + } + return b, nil +} + +func (q *Query) appendSet(fmter QueryFormatter, b []byte) (_ []byte, err error) { + b = append(b, " SET "...) + for i, f := range q.set { + if i > 0 { + b = append(b, ", "...) + } + b, err = f.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + return b, nil +} + +func (q *Query) hasReturning() bool { + if len(q.returning) == 0 { + return false + } + if len(q.returning) == 1 { + switch q.returning[0].query { + case "null", "NULL": + return false + } + } + return true +} + +func (q *Query) appendReturning(fmter QueryFormatter, b []byte) (_ []byte, err error) { + if !q.hasReturning() { + return b, nil + } + + b = append(b, " RETURNING "...) + for i, f := range q.returning { + if i > 0 { + b = append(b, ", "...) + } + b, err = f.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + return b, nil +} + +func (q *Query) appendWith(fmter QueryFormatter, b []byte) (_ []byte, err error) { + b = append(b, "WITH "...) + for i, with := range q.with { + if i > 0 { + b = append(b, ", "...) + } + b = types.AppendIdent(b, with.name, 1) + b = append(b, " AS ("...) + + b, err = with.query.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + + b = append(b, ')') + } + b = append(b, ' ') + return b, nil +} + +func (q *Query) isSliceModelWithData() bool { + if !q.hasTableModel() { + return false + } + m, ok := q.tableModel.(*sliceTableModel) + return ok && m.sliceLen > 0 +} + +//------------------------------------------------------------------------------ + +type wherePKStructQuery struct { + q *Query +} + +var _ queryWithSepAppender = (*wherePKStructQuery)(nil) + +func (wherePKStructQuery) AppendSep(b []byte) []byte { + return append(b, " AND "...) +} + +func (q wherePKStructQuery) AppendQuery(fmter QueryFormatter, b []byte) ([]byte, error) { + table := q.q.tableModel.Table() + value := q.q.tableModel.Value() + return appendColumnAndValue(fmter, b, value, table.Alias, table.PKs), nil +} + +func appendColumnAndValue( + fmter QueryFormatter, b []byte, v reflect.Value, alias types.Safe, fields []*Field, +) []byte { + isPlaceholder := isTemplateFormatter(fmter) + for i, f := range fields { + if i > 0 { + b = append(b, " AND "...) + } + b = append(b, alias...) + b = append(b, '.') + b = append(b, f.Column...) + b = append(b, " = "...) + if isPlaceholder { + b = append(b, '?') + } else { + b = f.AppendValue(b, v, 1) + } + } + return b +} + +//------------------------------------------------------------------------------ + +type wherePKSliceQuery struct { + q *Query +} + +var _ queryWithSepAppender = (*wherePKSliceQuery)(nil) + +func (wherePKSliceQuery) AppendSep(b []byte) []byte { + return append(b, " AND "...) +} + +func (q wherePKSliceQuery) AppendQuery(fmter QueryFormatter, b []byte) ([]byte, error) { + table := q.q.tableModel.Table() + + for i, f := range table.PKs { + if i > 0 { + b = append(b, " AND "...) + } + b = append(b, table.Alias...) + b = append(b, '.') + b = append(b, f.Column...) + b = append(b, " = "...) + b = append(b, `"_data".`...) + b = append(b, f.Column...) + } + + return b, nil +} + +type joinPKSliceQuery struct { + q *Query +} + +var _ QueryAppender = (*joinPKSliceQuery)(nil) + +func (q joinPKSliceQuery) AppendQuery(fmter QueryFormatter, b []byte) ([]byte, error) { + table := q.q.tableModel.Table() + slice := q.q.tableModel.Value() + + b = append(b, " JOIN (VALUES "...) + + sliceLen := slice.Len() + for i := 0; i < sliceLen; i++ { + if i > 0 { + b = append(b, ", "...) + } + + el := indirect(slice.Index(i)) + + b = append(b, '(') + for i, f := range table.PKs { + if i > 0 { + b = append(b, ", "...) + } + + b = f.AppendValue(b, el, 1) + + if f.UserSQLType != "" { + b = append(b, "::"...) + b = append(b, f.SQLType...) + } + } + + b = append(b, ", "...) + b = strconv.AppendInt(b, int64(i), 10) + + b = append(b, ')') + } + + b = append(b, `) AS "_data" (`...) + + for i, f := range table.PKs { + if i > 0 { + b = append(b, ", "...) + } + b = append(b, f.Column...) + } + + b = append(b, ", "...) + b = append(b, `"ordering"`...) + b = append(b, ") ON TRUE"...) + + return b, nil +} diff --git a/vendor/github.com/go-pg/pg/v10/orm/relation.go b/vendor/github.com/go-pg/pg/v10/orm/relation.go new file mode 100644 index 000000000..28d915bcd --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/orm/relation.go @@ -0,0 +1,33 @@ +package orm + +import ( + "fmt" + + "github.com/go-pg/pg/v10/types" +) + +const ( + InvalidRelation = iota + HasOneRelation + BelongsToRelation + HasManyRelation + Many2ManyRelation +) + +type Relation struct { + Type int + Field *Field + JoinTable *Table + BaseFKs []*Field + JoinFKs []*Field + Polymorphic *Field + + M2MTableName types.Safe + M2MTableAlias types.Safe + M2MBaseFKs []string + M2MJoinFKs []string +} + +func (r *Relation) String() string { + return fmt.Sprintf("relation=%s", r.Field.GoName) +} diff --git a/vendor/github.com/go-pg/pg/v10/orm/result.go b/vendor/github.com/go-pg/pg/v10/orm/result.go new file mode 100644 index 000000000..9d82815ef --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/orm/result.go @@ -0,0 +1,14 @@ +package orm + +// Result summarizes an executed SQL command. +type Result interface { + Model() Model + + // RowsAffected returns the number of rows affected by SELECT, INSERT, UPDATE, + // or DELETE queries. It returns -1 if query can't possibly affect any rows, + // e.g. in case of CREATE or SHOW queries. + RowsAffected() int + + // RowsReturned returns the number of rows returned by the query. + RowsReturned() int +} diff --git a/vendor/github.com/go-pg/pg/v10/orm/select.go b/vendor/github.com/go-pg/pg/v10/orm/select.go new file mode 100644 index 000000000..d3b38742d --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/orm/select.go @@ -0,0 +1,346 @@ +package orm + +import ( + "bytes" + "fmt" + "strconv" + "strings" + + "github.com/go-pg/pg/v10/types" +) + +type SelectQuery struct { + q *Query + count string +} + +var ( + _ QueryAppender = (*SelectQuery)(nil) + _ QueryCommand = (*SelectQuery)(nil) +) + +func NewSelectQuery(q *Query) *SelectQuery { + return &SelectQuery{ + q: q, + } +} + +func (q *SelectQuery) String() string { + b, err := q.AppendQuery(defaultFmter, nil) + if err != nil { + panic(err) + } + return string(b) +} + +func (q *SelectQuery) Operation() QueryOp { + return SelectOp +} + +func (q *SelectQuery) Clone() QueryCommand { + return &SelectQuery{ + q: q.q.Clone(), + count: q.count, + } +} + +func (q *SelectQuery) Query() *Query { + return q.q +} + +func (q *SelectQuery) AppendTemplate(b []byte) ([]byte, error) { + return q.AppendQuery(dummyFormatter{}, b) +} + +func (q *SelectQuery) AppendQuery(fmter QueryFormatter, b []byte) (_ []byte, err error) { //nolint:gocyclo + if q.q.stickyErr != nil { + return nil, q.q.stickyErr + } + + cteCount := q.count != "" && (len(q.q.group) > 0 || q.isDistinct()) + if cteCount { + b = append(b, `WITH "_count_wrapper" AS (`...) + } + + if len(q.q.with) > 0 { + b, err = q.q.appendWith(fmter, b) + if err != nil { + return nil, err + } + } + + if len(q.q.union) > 0 { + b = append(b, '(') + } + + b = append(b, "SELECT "...) + + if len(q.q.distinctOn) > 0 { + b = append(b, "DISTINCT ON ("...) + for i, app := range q.q.distinctOn { + if i > 0 { + b = append(b, ", "...) + } + b, err = app.AppendQuery(fmter, b) + } + b = append(b, ") "...) + } else if q.q.distinctOn != nil { + b = append(b, "DISTINCT "...) + } + + if q.count != "" && !cteCount { + b = append(b, q.count...) + } else { + b, err = q.appendColumns(fmter, b) + if err != nil { + return nil, err + } + } + + if q.q.hasTables() { + b = append(b, " FROM "...) + b, err = q.appendTables(fmter, b) + if err != nil { + return nil, err + } + } + + err = q.q.forEachHasOneJoin(func(j *join) error { + b = append(b, ' ') + b, err = j.appendHasOneJoin(fmter, b, q.q) + return err + }) + if err != nil { + return nil, err + } + + for _, j := range q.q.joins { + b, err = j.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + + if len(q.q.where) > 0 || q.q.isSoftDelete() { + b = append(b, " WHERE "...) + b, err = q.q.appendWhere(fmter, b) + if err != nil { + return nil, err + } + } + + if len(q.q.group) > 0 { + b = append(b, " GROUP BY "...) + for i, f := range q.q.group { + if i > 0 { + b = append(b, ", "...) + } + b, err = f.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + } + + if len(q.q.having) > 0 { + b = append(b, " HAVING "...) + for i, f := range q.q.having { + if i > 0 { + b = append(b, " AND "...) + } + b = append(b, '(') + b, err = f.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + b = append(b, ')') + } + } + + if q.count == "" { + if len(q.q.order) > 0 { + b = append(b, " ORDER BY "...) + for i, f := range q.q.order { + if i > 0 { + b = append(b, ", "...) + } + b, err = f.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + } + + if q.q.limit != 0 { + b = append(b, " LIMIT "...) + b = strconv.AppendInt(b, int64(q.q.limit), 10) + } + + if q.q.offset != 0 { + b = append(b, " OFFSET "...) + b = strconv.AppendInt(b, int64(q.q.offset), 10) + } + + if q.q.selFor != nil { + b = append(b, " FOR "...) + b, err = q.q.selFor.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + } else if cteCount { + b = append(b, `) SELECT `...) + b = append(b, q.count...) + b = append(b, ` FROM "_count_wrapper"`...) + } + + if len(q.q.union) > 0 { + b = append(b, ")"...) + + for _, u := range q.q.union { + b = append(b, u.expr...) + b = append(b, '(') + b, err = u.query.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + b = append(b, ")"...) + } + } + + return b, q.q.stickyErr +} + +func (q SelectQuery) appendColumns(fmter QueryFormatter, b []byte) (_ []byte, err error) { + start := len(b) + + switch { + case q.q.columns != nil: + b, err = q.q.appendColumns(fmter, b) + if err != nil { + return nil, err + } + case q.q.hasExplicitTableModel(): + table := q.q.tableModel.Table() + if len(table.Fields) > 10 && isTemplateFormatter(fmter) { + b = append(b, table.Alias...) + b = append(b, '.') + b = types.AppendString(b, fmt.Sprintf("%d columns", len(table.Fields)), 2) + } else { + b = appendColumns(b, table.Alias, table.Fields) + } + default: + b = append(b, '*') + } + + err = q.q.forEachHasOneJoin(func(j *join) error { + if len(b) != start { + b = append(b, ", "...) + start = len(b) + } + + b = j.appendHasOneColumns(b) + return nil + }) + if err != nil { + return nil, err + } + + b = bytes.TrimSuffix(b, []byte(", ")) + + return b, nil +} + +func (q *SelectQuery) isDistinct() bool { + if q.q.distinctOn != nil { + return true + } + for _, column := range q.q.columns { + column, ok := column.(*SafeQueryAppender) + if ok { + if strings.Contains(column.query, "DISTINCT") || + strings.Contains(column.query, "distinct") { + return true + } + } + } + return false +} + +func (q *SelectQuery) appendTables(fmter QueryFormatter, b []byte) (_ []byte, err error) { + tables := q.q.tables + + if q.q.modelHasTableName() { + table := q.q.tableModel.Table() + b = fmter.FormatQuery(b, string(table.SQLNameForSelects)) + if table.Alias != "" { + b = append(b, " AS "...) + b = append(b, table.Alias...) + } + + if len(tables) > 0 { + b = append(b, ", "...) + } + } else if len(tables) > 0 { + b, err = tables[0].AppendQuery(fmter, b) + if err != nil { + return nil, err + } + if q.q.modelHasTableAlias() { + b = append(b, " AS "...) + b = append(b, q.q.tableModel.Table().Alias...) + } + + tables = tables[1:] + if len(tables) > 0 { + b = append(b, ", "...) + } + } + + for i, f := range tables { + if i > 0 { + b = append(b, ", "...) + } + b, err = f.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + + return b, nil +} + +//------------------------------------------------------------------------------ + +type joinQuery struct { + join *SafeQueryAppender + on []*condAppender +} + +func (j *joinQuery) AppendOn(app *condAppender) { + j.on = append(j.on, app) +} + +func (j *joinQuery) AppendQuery(fmter QueryFormatter, b []byte) (_ []byte, err error) { + b = append(b, ' ') + + b, err = j.join.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + + if len(j.on) > 0 { + b = append(b, " ON "...) + for i, on := range j.on { + if i > 0 { + b = on.AppendSep(b) + } + b, err = on.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + } + + return b, nil +} diff --git a/vendor/github.com/go-pg/pg/v10/orm/table.go b/vendor/github.com/go-pg/pg/v10/orm/table.go new file mode 100644 index 000000000..8b57bbfc0 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/orm/table.go @@ -0,0 +1,1560 @@ +package orm + +import ( + "database/sql" + "encoding/json" + "fmt" + "net" + "reflect" + "strconv" + "strings" + "sync" + "time" + + "github.com/jinzhu/inflection" + "github.com/vmihailenco/tagparser" + + "github.com/go-pg/pg/v10/internal" + "github.com/go-pg/pg/v10/internal/pool" + "github.com/go-pg/pg/v10/pgjson" + "github.com/go-pg/pg/v10/types" + "github.com/go-pg/zerochecker" +) + +const ( + beforeScanHookFlag = uint16(1) << iota + afterScanHookFlag + afterSelectHookFlag + beforeInsertHookFlag + afterInsertHookFlag + beforeUpdateHookFlag + afterUpdateHookFlag + beforeDeleteHookFlag + afterDeleteHookFlag + discardUnknownColumnsFlag +) + +var ( + timeType = reflect.TypeOf((*time.Time)(nil)).Elem() + nullTimeType = reflect.TypeOf((*types.NullTime)(nil)).Elem() + sqlNullTimeType = reflect.TypeOf((*sql.NullTime)(nil)).Elem() + ipType = reflect.TypeOf((*net.IP)(nil)).Elem() + ipNetType = reflect.TypeOf((*net.IPNet)(nil)).Elem() + scannerType = reflect.TypeOf((*sql.Scanner)(nil)).Elem() + nullBoolType = reflect.TypeOf((*sql.NullBool)(nil)).Elem() + nullFloatType = reflect.TypeOf((*sql.NullFloat64)(nil)).Elem() + nullIntType = reflect.TypeOf((*sql.NullInt64)(nil)).Elem() + nullStringType = reflect.TypeOf((*sql.NullString)(nil)).Elem() + jsonRawMessageType = reflect.TypeOf((*json.RawMessage)(nil)).Elem() +) + +var tableNameInflector = inflection.Plural + +// SetTableNameInflector overrides the default func that pluralizes +// model name to get table name, e.g. my_article becomes my_articles. +func SetTableNameInflector(fn func(string) string) { + tableNameInflector = fn +} + +// Table represents a SQL table created from Go struct. +type Table struct { + Type reflect.Type + zeroStruct reflect.Value + + TypeName string + Alias types.Safe + ModelName string + + SQLName types.Safe + SQLNameForSelects types.Safe + + Tablespace types.Safe + + PartitionBy string + + allFields []*Field // read only + skippedFields []*Field + + Fields []*Field // PKs + DataFields + PKs []*Field + DataFields []*Field + fieldsMapMu sync.RWMutex + FieldsMap map[string]*Field + + Methods map[string]*Method + Relations map[string]*Relation + Unique map[string][]*Field + + SoftDeleteField *Field + SetSoftDeleteField func(fv reflect.Value) error + + flags uint16 +} + +func newTable(typ reflect.Type) *Table { + t := new(Table) + t.Type = typ + t.zeroStruct = reflect.New(t.Type).Elem() + t.TypeName = internal.ToExported(t.Type.Name()) + t.ModelName = internal.Underscore(t.Type.Name()) + tableName := tableNameInflector(t.ModelName) + t.setName(quoteIdent(tableName)) + t.Alias = quoteIdent(t.ModelName) + + typ = reflect.PtrTo(t.Type) + if typ.Implements(beforeScanHookType) { + t.setFlag(beforeScanHookFlag) + } + if typ.Implements(afterScanHookType) { + t.setFlag(afterScanHookFlag) + } + if typ.Implements(afterSelectHookType) { + t.setFlag(afterSelectHookFlag) + } + if typ.Implements(beforeInsertHookType) { + t.setFlag(beforeInsertHookFlag) + } + if typ.Implements(afterInsertHookType) { + t.setFlag(afterInsertHookFlag) + } + if typ.Implements(beforeUpdateHookType) { + t.setFlag(beforeUpdateHookFlag) + } + if typ.Implements(afterUpdateHookType) { + t.setFlag(afterUpdateHookFlag) + } + if typ.Implements(beforeDeleteHookType) { + t.setFlag(beforeDeleteHookFlag) + } + if typ.Implements(afterDeleteHookType) { + t.setFlag(afterDeleteHookFlag) + } + + return t +} + +func (t *Table) init1() { + t.initFields() + t.initMethods() +} + +func (t *Table) init2() { + t.initInlines() + t.initRelations() + t.skippedFields = nil +} + +func (t *Table) setName(name types.Safe) { + t.SQLName = name + t.SQLNameForSelects = name + if t.Alias == "" { + t.Alias = name + } +} + +func (t *Table) String() string { + return "model=" + t.TypeName +} + +func (t *Table) setFlag(flag uint16) { + t.flags |= flag +} + +func (t *Table) hasFlag(flag uint16) bool { + if t == nil { + return false + } + return t.flags&flag != 0 +} + +func (t *Table) checkPKs() error { + if len(t.PKs) == 0 { + return fmt.Errorf("pg: %s does not have primary keys", t) + } + return nil +} + +func (t *Table) mustSoftDelete() error { + if t.SoftDeleteField == nil { + return fmt.Errorf("pg: %s does not support soft deletes", t) + } + return nil +} + +func (t *Table) AddField(field *Field) { + t.Fields = append(t.Fields, field) + if field.hasFlag(PrimaryKeyFlag) { + t.PKs = append(t.PKs, field) + } else { + t.DataFields = append(t.DataFields, field) + } + t.FieldsMap[field.SQLName] = field +} + +func (t *Table) RemoveField(field *Field) { + t.Fields = removeField(t.Fields, field) + if field.hasFlag(PrimaryKeyFlag) { + t.PKs = removeField(t.PKs, field) + } else { + t.DataFields = removeField(t.DataFields, field) + } + delete(t.FieldsMap, field.SQLName) +} + +func removeField(fields []*Field, field *Field) []*Field { + for i, f := range fields { + if f == field { + fields = append(fields[:i], fields[i+1:]...) + } + } + return fields +} + +func (t *Table) getField(name string) *Field { + t.fieldsMapMu.RLock() + field := t.FieldsMap[name] + t.fieldsMapMu.RUnlock() + return field +} + +func (t *Table) HasField(name string) bool { + _, ok := t.FieldsMap[name] + return ok +} + +func (t *Table) GetField(name string) (*Field, error) { + field, ok := t.FieldsMap[name] + if !ok { + return nil, fmt.Errorf("pg: %s does not have column=%s", t, name) + } + return field, nil +} + +func (t *Table) AppendParam(b []byte, strct reflect.Value, name string) ([]byte, bool) { + field, ok := t.FieldsMap[name] + if ok { + b = field.AppendValue(b, strct, 1) + return b, true + } + + method, ok := t.Methods[name] + if ok { + b = method.AppendValue(b, strct.Addr(), 1) + return b, true + } + + return b, false +} + +func (t *Table) initFields() { + t.Fields = make([]*Field, 0, t.Type.NumField()) + t.FieldsMap = make(map[string]*Field, t.Type.NumField()) + t.addFields(t.Type, nil) +} + +func (t *Table) addFields(typ reflect.Type, baseIndex []int) { + for i := 0; i < typ.NumField(); i++ { + f := typ.Field(i) + + // Make a copy so slice is not shared between fields. + index := make([]int, len(baseIndex)) + copy(index, baseIndex) + + if f.Anonymous { + if f.Tag.Get("sql") == "-" || f.Tag.Get("pg") == "-" { + continue + } + + fieldType := indirectType(f.Type) + if fieldType.Kind() != reflect.Struct { + continue + } + t.addFields(fieldType, append(index, f.Index...)) + + pgTag := tagparser.Parse(f.Tag.Get("pg")) + if _, inherit := pgTag.Options["inherit"]; inherit { + embeddedTable := _tables.get(fieldType, true) + t.TypeName = embeddedTable.TypeName + t.SQLName = embeddedTable.SQLName + t.SQLNameForSelects = embeddedTable.SQLNameForSelects + t.Alias = embeddedTable.Alias + t.ModelName = embeddedTable.ModelName + } + + continue + } + + field := t.newField(f, index) + if field != nil { + t.AddField(field) + } + } +} + +//nolint +func (t *Table) newField(f reflect.StructField, index []int) *Field { + pgTag := tagparser.Parse(f.Tag.Get("pg")) + + switch f.Name { + case "tableName": + if len(index) > 0 { + return nil + } + + if isKnownTableOption(pgTag.Name) { + internal.Warn.Printf( + "%s.%s tag name %q is also an option name; is it a mistake?", + t.TypeName, f.Name, pgTag.Name, + ) + } + + for name := range pgTag.Options { + if !isKnownTableOption(name) { + internal.Warn.Printf("%s.%s has unknown tag option: %q", t.TypeName, f.Name, name) + } + } + + if tableSpace, ok := pgTag.Options["tablespace"]; ok { + s, _ := tagparser.Unquote(tableSpace) + t.Tablespace = quoteIdent(s) + } + + partitionBy, ok := pgTag.Options["partition_by"] + if !ok { + partitionBy, ok = pgTag.Options["partitionBy"] + if ok { + internal.Deprecated.Printf("partitionBy is renamed to partition_by") + } + } + if ok { + s, _ := tagparser.Unquote(partitionBy) + t.PartitionBy = s + } + + if pgTag.Name == "_" { + t.setName("") + } else if pgTag.Name != "" { + s, _ := tagparser.Unquote(pgTag.Name) + t.setName(types.Safe(quoteTableName(s))) + } + + if s, ok := pgTag.Options["select"]; ok { + s, _ = tagparser.Unquote(s) + t.SQLNameForSelects = types.Safe(quoteTableName(s)) + } + + if v, ok := pgTag.Options["alias"]; ok { + v, _ = tagparser.Unquote(v) + t.Alias = quoteIdent(v) + } + + pgTag := tagparser.Parse(f.Tag.Get("pg")) + if _, ok := pgTag.Options["discard_unknown_columns"]; ok { + t.setFlag(discardUnknownColumnsFlag) + } + + return nil + } + + if f.PkgPath != "" { + return nil + } + + sqlName := internal.Underscore(f.Name) + + if pgTag.Name != sqlName && isKnownFieldOption(pgTag.Name) { + internal.Warn.Printf( + "%s.%s tag name %q is also an option name; is it a mistake?", + t.TypeName, f.Name, pgTag.Name, + ) + } + + for name := range pgTag.Options { + if !isKnownFieldOption(name) { + internal.Warn.Printf("%s.%s has unknown tag option: %q", t.TypeName, f.Name, name) + } + } + + skip := pgTag.Name == "-" + if !skip && pgTag.Name != "" { + sqlName = pgTag.Name + } + + index = append(index, f.Index...) + if field := t.getField(sqlName); field != nil { + if indexEqual(field.Index, index) { + return field + } + t.RemoveField(field) + } + + field := &Field{ + Field: f, + Type: indirectType(f.Type), + + GoName: f.Name, + SQLName: sqlName, + Column: quoteIdent(sqlName), + + Index: index, + } + + if _, ok := pgTag.Options["notnull"]; ok { + field.setFlag(NotNullFlag) + } + if v, ok := pgTag.Options["unique"]; ok { + if v == "" { + field.setFlag(UniqueFlag) + } + // Split the value by comma, this will allow multiple names to be specified. + // We can use this to create multiple named unique constraints where a single column + // might be included in multiple constraints. + v, _ = tagparser.Unquote(v) + for _, uniqueName := range strings.Split(v, ",") { + if t.Unique == nil { + t.Unique = make(map[string][]*Field) + } + t.Unique[uniqueName] = append(t.Unique[uniqueName], field) + } + } + if v, ok := pgTag.Options["default"]; ok { + v, ok = tagparser.Unquote(v) + if ok { + field.Default = types.Safe(types.AppendString(nil, v, 1)) + } else { + field.Default = types.Safe(v) + } + } + + //nolint + if _, ok := pgTag.Options["pk"]; ok { + field.setFlag(PrimaryKeyFlag) + } else if strings.HasSuffix(field.SQLName, "_id") || + strings.HasSuffix(field.SQLName, "_uuid") { + field.setFlag(ForeignKeyFlag) + } else if strings.HasPrefix(field.SQLName, "fk_") { + field.setFlag(ForeignKeyFlag) + } else if len(t.PKs) == 0 && !pgTag.HasOption("nopk") { + switch field.SQLName { + case "id", "uuid", "pk_" + t.ModelName: + field.setFlag(PrimaryKeyFlag) + } + } + + if _, ok := pgTag.Options["use_zero"]; ok { + field.setFlag(UseZeroFlag) + } + if _, ok := pgTag.Options["array"]; ok { + field.setFlag(ArrayFlag) + } + + field.SQLType = fieldSQLType(field, pgTag) + if strings.HasSuffix(field.SQLType, "[]") { + field.setFlag(ArrayFlag) + } + + if v, ok := pgTag.Options["on_delete"]; ok { + field.OnDelete = v + } + + if v, ok := pgTag.Options["on_update"]; ok { + field.OnUpdate = v + } + + if _, ok := pgTag.Options["composite"]; ok { + field.append = compositeAppender(f.Type) + field.scan = compositeScanner(f.Type) + } else if _, ok := pgTag.Options["json_use_number"]; ok { + field.append = types.Appender(f.Type) + field.scan = scanJSONValue + } else if field.hasFlag(ArrayFlag) { + field.append = types.ArrayAppender(f.Type) + field.scan = types.ArrayScanner(f.Type) + } else if _, ok := pgTag.Options["hstore"]; ok { + field.append = types.HstoreAppender(f.Type) + field.scan = types.HstoreScanner(f.Type) + } else if field.SQLType == pgTypeBigint && field.Type.Kind() == reflect.Uint64 { + if f.Type.Kind() == reflect.Ptr { + field.append = appendUintPtrAsInt + } else { + field.append = appendUintAsInt + } + field.scan = types.Scanner(f.Type) + } else if _, ok := pgTag.Options["msgpack"]; ok { + field.append = msgpackAppender(f.Type) + field.scan = msgpackScanner(f.Type) + } else { + field.append = types.Appender(f.Type) + field.scan = types.Scanner(f.Type) + } + field.isZero = zerochecker.Checker(f.Type) + + if v, ok := pgTag.Options["alias"]; ok { + v, _ = tagparser.Unquote(v) + t.FieldsMap[v] = field + } + + t.allFields = append(t.allFields, field) + if skip { + t.skippedFields = append(t.skippedFields, field) + t.FieldsMap[field.SQLName] = field + return nil + } + + if _, ok := pgTag.Options["soft_delete"]; ok { + t.SetSoftDeleteField = setSoftDeleteFieldFunc(f.Type) + if t.SetSoftDeleteField == nil { + err := fmt.Errorf( + "pg: soft_delete is only supported for time.Time, pg.NullTime, sql.NullInt64, and int64 (or implement ValueScanner that scans time)") + panic(err) + } + t.SoftDeleteField = field + } + + return field +} + +func (t *Table) initMethods() { + t.Methods = make(map[string]*Method) + typ := reflect.PtrTo(t.Type) + for i := 0; i < typ.NumMethod(); i++ { + m := typ.Method(i) + if m.PkgPath != "" { + continue + } + if m.Type.NumIn() > 1 { + continue + } + if m.Type.NumOut() != 1 { + continue + } + + retType := m.Type.Out(0) + t.Methods[m.Name] = &Method{ + Index: m.Index, + + appender: types.Appender(retType), + } + } +} + +func (t *Table) initInlines() { + for _, f := range t.skippedFields { + if f.Type.Kind() == reflect.Struct { + t.inlineFields(f, nil) + } + } +} + +func (t *Table) initRelations() { + for i := 0; i < len(t.Fields); { + f := t.Fields[i] + if t.tryRelation(f) { + t.Fields = removeField(t.Fields, f) + t.DataFields = removeField(t.DataFields, f) + } else { + i++ + } + + if f.Type.Kind() == reflect.Struct { + t.inlineFields(f, nil) + } + } +} + +func (t *Table) tryRelation(field *Field) bool { + pgTag := tagparser.Parse(field.Field.Tag.Get("pg")) + + if rel, ok := pgTag.Options["rel"]; ok { + return t.tryRelationType(field, rel, pgTag) + } + if _, ok := pgTag.Options["many2many"]; ok { + return t.tryRelationType(field, "many2many", pgTag) + } + + if field.UserSQLType != "" || isScanner(field.Type) { + return false + } + + switch field.Type.Kind() { + case reflect.Slice: + return t.tryRelationSlice(field, pgTag) + case reflect.Struct: + return t.tryRelationStruct(field, pgTag) + } + return false +} + +func (t *Table) tryRelationType(field *Field, rel string, pgTag *tagparser.Tag) bool { + switch rel { + case "has-one": + return t.mustHasOneRelation(field, pgTag) + case "belongs-to": + return t.mustBelongsToRelation(field, pgTag) + case "has-many": + return t.mustHasManyRelation(field, pgTag) + case "many2many": + return t.mustM2MRelation(field, pgTag) + default: + panic(fmt.Errorf("pg: unknown relation=%s on field=%s", rel, field.GoName)) + } +} + +func (t *Table) mustHasOneRelation(field *Field, pgTag *tagparser.Tag) bool { + joinTable := _tables.get(field.Type, true) + if err := joinTable.checkPKs(); err != nil { + panic(err) + } + fkPrefix, fkOK := pgTag.Options["fk"] + + if fkOK && len(joinTable.PKs) == 1 { + fk := t.getField(fkPrefix) + if fk == nil { + panic(fmt.Errorf( + "pg: %s has-one %s: %s must have column %s "+ + "(use fk:custom_column tag on %s field to specify custom column)", + t.TypeName, field.GoName, t.TypeName, fkPrefix, field.GoName, + )) + } + + t.addRelation(&Relation{ + Type: HasOneRelation, + Field: field, + JoinTable: joinTable, + BaseFKs: []*Field{fk}, + JoinFKs: joinTable.PKs, + }) + return true + } + + if !fkOK { + fkPrefix = internal.Underscore(field.GoName) + "_" + } + fks := make([]*Field, 0, len(joinTable.PKs)) + + for _, joinPK := range joinTable.PKs { + fkName := fkPrefix + joinPK.SQLName + if fk := t.getField(fkName); fk != nil { + fks = append(fks, fk) + continue + } + + if fk := t.getField(joinPK.SQLName); fk != nil { + fks = append(fks, fk) + continue + } + + panic(fmt.Errorf( + "pg: %s has-one %s: %s must have column %s "+ + "(use fk:custom_column tag on %s field to specify custom column)", + t.TypeName, field.GoName, t.TypeName, fkName, field.GoName, + )) + } + + t.addRelation(&Relation{ + Type: HasOneRelation, + Field: field, + JoinTable: joinTable, + BaseFKs: fks, + JoinFKs: joinTable.PKs, + }) + return true +} + +func (t *Table) mustBelongsToRelation(field *Field, pgTag *tagparser.Tag) bool { + if err := t.checkPKs(); err != nil { + panic(err) + } + joinTable := _tables.get(field.Type, true) + fkPrefix, fkOK := pgTag.Options["join_fk"] + + if fkOK && len(t.PKs) == 1 { + fk := joinTable.getField(fkPrefix) + if fk == nil { + panic(fmt.Errorf( + "pg: %s belongs-to %s: %s must have column %s "+ + "(use join_fk:custom_column tag on %s field to specify custom column)", + field.GoName, t.TypeName, joinTable.TypeName, fkPrefix, field.GoName, + )) + } + + t.addRelation(&Relation{ + Type: BelongsToRelation, + Field: field, + JoinTable: joinTable, + BaseFKs: t.PKs, + JoinFKs: []*Field{fk}, + }) + return true + } + + if !fkOK { + fkPrefix = internal.Underscore(t.ModelName) + "_" + } + fks := make([]*Field, 0, len(t.PKs)) + + for _, pk := range t.PKs { + fkName := fkPrefix + pk.SQLName + if fk := joinTable.getField(fkName); fk != nil { + fks = append(fks, fk) + continue + } + + if fk := joinTable.getField(pk.SQLName); fk != nil { + fks = append(fks, fk) + continue + } + + panic(fmt.Errorf( + "pg: %s belongs-to %s: %s must have column %s "+ + "(use join_fk:custom_column tag on %s field to specify custom column)", + field.GoName, t.TypeName, joinTable.TypeName, fkName, field.GoName, + )) + } + + t.addRelation(&Relation{ + Type: BelongsToRelation, + Field: field, + JoinTable: joinTable, + BaseFKs: t.PKs, + JoinFKs: fks, + }) + return true +} + +func (t *Table) mustHasManyRelation(field *Field, pgTag *tagparser.Tag) bool { + if err := t.checkPKs(); err != nil { + panic(err) + } + if field.Type.Kind() != reflect.Slice { + panic(fmt.Errorf( + "pg: %s.%s has-many relation requires slice, got %q", + t.TypeName, field.GoName, field.Type.Kind(), + )) + } + + joinTable := _tables.get(indirectType(field.Type.Elem()), true) + fkPrefix, fkOK := pgTag.Options["join_fk"] + _, polymorphic := pgTag.Options["polymorphic"] + + if fkOK && !polymorphic && len(t.PKs) == 1 { + fk := joinTable.getField(fkPrefix) + if fk == nil { + panic(fmt.Errorf( + "pg: %s has-many %s: %s must have column %s "+ + "(use join_fk:custom_column tag on %s field to specify custom column)", + t.TypeName, field.GoName, joinTable.TypeName, fkPrefix, field.GoName, + )) + } + + t.addRelation(&Relation{ + Type: HasManyRelation, + Field: field, + JoinTable: joinTable, + BaseFKs: t.PKs, + JoinFKs: []*Field{fk}, + }) + return true + } + + if !fkOK { + fkPrefix = internal.Underscore(t.ModelName) + "_" + } + fks := make([]*Field, 0, len(t.PKs)) + + for _, pk := range t.PKs { + fkName := fkPrefix + pk.SQLName + if fk := joinTable.getField(fkName); fk != nil { + fks = append(fks, fk) + continue + } + + if fk := joinTable.getField(pk.SQLName); fk != nil { + fks = append(fks, fk) + continue + } + + panic(fmt.Errorf( + "pg: %s has-many %s: %s must have column %s "+ + "(use join_fk:custom_column tag on %s field to specify custom column)", + t.TypeName, field.GoName, joinTable.TypeName, fkName, field.GoName, + )) + } + + var typeField *Field + + if polymorphic { + typeFieldName := fkPrefix + "type" + typeField = joinTable.getField(typeFieldName) + if typeField == nil { + panic(fmt.Errorf( + "pg: %s has-many %s: %s must have polymorphic column %s", + t.TypeName, field.GoName, joinTable.TypeName, typeFieldName, + )) + } + } + + t.addRelation(&Relation{ + Type: HasManyRelation, + Field: field, + JoinTable: joinTable, + BaseFKs: t.PKs, + JoinFKs: fks, + Polymorphic: typeField, + }) + return true +} + +func (t *Table) mustM2MRelation(field *Field, pgTag *tagparser.Tag) bool { + if field.Type.Kind() != reflect.Slice { + panic(fmt.Errorf( + "pg: %s.%s many2many relation requires slice, got %q", + t.TypeName, field.GoName, field.Type.Kind(), + )) + } + joinTable := _tables.get(indirectType(field.Type.Elem()), true) + + if err := t.checkPKs(); err != nil { + panic(err) + } + if err := joinTable.checkPKs(); err != nil { + panic(err) + } + + m2mTableNameString, ok := pgTag.Options["many2many"] + if !ok { + panic(fmt.Errorf("pg: %s must have many2many tag option", field.GoName)) + } + m2mTableName := quoteTableName(m2mTableNameString) + + m2mTable := _tables.getByName(m2mTableName) + if m2mTable == nil { + panic(fmt.Errorf( + "pg: can't find %s table (use orm.RegisterTable to register the model)", + m2mTableName, + )) + } + + var baseFKs []string + var joinFKs []string + + { + fkPrefix, ok := pgTag.Options["fk"] + if !ok { + fkPrefix = internal.Underscore(t.ModelName) + "_" + } + + if ok && len(t.PKs) == 1 { + if m2mTable.getField(fkPrefix) == nil { + panic(fmt.Errorf( + "pg: %s many2many %s: %s must have column %s "+ + "(use fk:custom_column tag on %s field to specify custom column)", + t.TypeName, field.GoName, m2mTable.TypeName, fkPrefix, field.GoName, + )) + } + baseFKs = []string{fkPrefix} + } else { + for _, pk := range t.PKs { + fkName := fkPrefix + pk.SQLName + if m2mTable.getField(fkName) != nil { + baseFKs = append(baseFKs, fkName) + continue + } + + if m2mTable.getField(pk.SQLName) != nil { + baseFKs = append(baseFKs, pk.SQLName) + continue + } + + panic(fmt.Errorf( + "pg: %s many2many %s: %s must have column %s "+ + "(use fk:custom_column tag on %s field to specify custom column)", + t.TypeName, field.GoName, m2mTable.TypeName, fkName, field.GoName, + )) + } + } + } + + { + joinFKPrefix, ok := pgTag.Options["join_fk"] + if !ok { + joinFKPrefix = internal.Underscore(joinTable.ModelName) + "_" + } + + if ok && len(joinTable.PKs) == 1 { + if m2mTable.getField(joinFKPrefix) == nil { + panic(fmt.Errorf( + "pg: %s many2many %s: %s must have column %s "+ + "(use join_fk:custom_column tag on %s field to specify custom column)", + joinTable.TypeName, field.GoName, m2mTable.TypeName, joinFKPrefix, field.GoName, + )) + } + joinFKs = []string{joinFKPrefix} + } else { + for _, joinPK := range joinTable.PKs { + fkName := joinFKPrefix + joinPK.SQLName + if m2mTable.getField(fkName) != nil { + joinFKs = append(joinFKs, fkName) + continue + } + + if m2mTable.getField(joinPK.SQLName) != nil { + joinFKs = append(joinFKs, joinPK.SQLName) + continue + } + + panic(fmt.Errorf( + "pg: %s many2many %s: %s must have column %s "+ + "(use join_fk:custom_column tag on %s field to specify custom column)", + t.TypeName, field.GoName, m2mTable.TypeName, fkName, field.GoName, + )) + } + } + } + + t.addRelation(&Relation{ + Type: Many2ManyRelation, + Field: field, + JoinTable: joinTable, + M2MTableName: m2mTableName, + M2MTableAlias: m2mTable.Alias, + M2MBaseFKs: baseFKs, + M2MJoinFKs: joinFKs, + }) + return true +} + +//nolint +func (t *Table) tryRelationSlice(field *Field, pgTag *tagparser.Tag) bool { + if t.tryM2MRelation(field, pgTag) { + internal.Deprecated.Printf( + `add pg:"rel:many2many" to %s.%s field tag`, t.TypeName, field.GoName) + return true + } + if t.tryHasManyRelation(field, pgTag) { + internal.Deprecated.Printf( + `add pg:"rel:has-many" to %s.%s field tag`, t.TypeName, field.GoName) + return true + } + return false +} + +func (t *Table) tryM2MRelation(field *Field, pgTag *tagparser.Tag) bool { + elemType := indirectType(field.Type.Elem()) + if elemType.Kind() != reflect.Struct { + return false + } + + joinTable := _tables.get(elemType, true) + + fk, fkOK := pgTag.Options["fk"] + if fkOK { + if fk == "-" { + return false + } + fk = tryUnderscorePrefix(fk) + } + + m2mTableName := pgTag.Options["many2many"] + if m2mTableName == "" { + return false + } + + m2mTable := _tables.getByName(quoteIdent(m2mTableName)) + + var m2mTableAlias types.Safe + if m2mTable != nil { + m2mTableAlias = m2mTable.Alias + } else if ind := strings.IndexByte(m2mTableName, '.'); ind >= 0 { + m2mTableAlias = quoteIdent(m2mTableName[ind+1:]) + } else { + m2mTableAlias = quoteIdent(m2mTableName) + } + + var fks []string + if !fkOK { + fk = t.ModelName + "_" + } + if m2mTable != nil { + keys := foreignKeys(t, m2mTable, fk, fkOK) + if len(keys) == 0 { + return false + } + for _, fk := range keys { + fks = append(fks, fk.SQLName) + } + } else { + if fkOK && len(t.PKs) == 1 { + fks = append(fks, fk) + } else { + for _, pk := range t.PKs { + fks = append(fks, fk+pk.SQLName) + } + } + } + + joinFK, joinFKOk := pgTag.Options["join_fk"] + if !joinFKOk { + joinFK, joinFKOk = pgTag.Options["joinFK"] + if joinFKOk { + internal.Deprecated.Printf("joinFK is renamed to join_fk") + } + } + if joinFKOk { + joinFK = tryUnderscorePrefix(joinFK) + } else { + joinFK = joinTable.ModelName + "_" + } + + var joinFKs []string + if m2mTable != nil { + keys := foreignKeys(joinTable, m2mTable, joinFK, joinFKOk) + if len(keys) == 0 { + return false + } + for _, fk := range keys { + joinFKs = append(joinFKs, fk.SQLName) + } + } else { + if joinFKOk && len(joinTable.PKs) == 1 { + joinFKs = append(joinFKs, joinFK) + } else { + for _, pk := range joinTable.PKs { + joinFKs = append(joinFKs, joinFK+pk.SQLName) + } + } + } + + t.addRelation(&Relation{ + Type: Many2ManyRelation, + Field: field, + JoinTable: joinTable, + M2MTableName: quoteIdent(m2mTableName), + M2MTableAlias: m2mTableAlias, + M2MBaseFKs: fks, + M2MJoinFKs: joinFKs, + }) + return true +} + +func (t *Table) tryHasManyRelation(field *Field, pgTag *tagparser.Tag) bool { + elemType := indirectType(field.Type.Elem()) + if elemType.Kind() != reflect.Struct { + return false + } + + joinTable := _tables.get(elemType, true) + + fk, fkOK := pgTag.Options["fk"] + if fkOK { + if fk == "-" { + return false + } + fk = tryUnderscorePrefix(fk) + } + + s, polymorphic := pgTag.Options["polymorphic"] + var typeField *Field + if polymorphic { + fk = tryUnderscorePrefix(s) + + typeField = joinTable.getField(fk + "type") + if typeField == nil { + return false + } + } else if !fkOK { + fk = t.ModelName + "_" + } + + fks := foreignKeys(t, joinTable, fk, fkOK || polymorphic) + if len(fks) == 0 { + return false + } + + var fkValues []*Field + fkValue, ok := pgTag.Options["fk_value"] + if ok { + if len(fks) > 1 { + panic(fmt.Errorf("got fk_value, but there are %d fks", len(fks))) + } + + f := t.getField(fkValue) + if f == nil { + panic(fmt.Errorf("fk_value=%q not found in %s", fkValue, t)) + } + fkValues = append(fkValues, f) + } else { + fkValues = t.PKs + } + + if len(fks) != len(fkValues) { + panic("len(fks) != len(fkValues)") + } + + if len(fks) > 0 { + t.addRelation(&Relation{ + Type: HasManyRelation, + Field: field, + JoinTable: joinTable, + BaseFKs: fkValues, + JoinFKs: fks, + Polymorphic: typeField, + }) + return true + } + + return false +} + +func (t *Table) tryRelationStruct(field *Field, pgTag *tagparser.Tag) bool { + joinTable := _tables.get(field.Type, true) + + if len(joinTable.allFields) == 0 { + return false + } + + if t.tryHasOne(joinTable, field, pgTag) { + internal.Deprecated.Printf( + `add pg:"rel:has-one" to %s.%s field tag`, t.TypeName, field.GoName) + t.inlineFields(field, nil) + return true + } + + if t.tryBelongsToOne(joinTable, field, pgTag) { + internal.Deprecated.Printf( + `add pg:"rel:belongs-to" to %s.%s field tag`, t.TypeName, field.GoName) + t.inlineFields(field, nil) + return true + } + + t.inlineFields(field, nil) + return false +} + +func (t *Table) inlineFields(strct *Field, path map[reflect.Type]struct{}) { + if path == nil { + path = map[reflect.Type]struct{}{ + t.Type: {}, + } + } + + if _, ok := path[strct.Type]; ok { + return + } + path[strct.Type] = struct{}{} + + joinTable := _tables.get(strct.Type, true) + for _, f := range joinTable.allFields { + f = f.Clone() + f.GoName = strct.GoName + "_" + f.GoName + f.SQLName = strct.SQLName + "__" + f.SQLName + f.Column = quoteIdent(f.SQLName) + f.Index = appendNew(strct.Index, f.Index...) + + t.fieldsMapMu.Lock() + if _, ok := t.FieldsMap[f.SQLName]; !ok { + t.FieldsMap[f.SQLName] = f + } + t.fieldsMapMu.Unlock() + + if f.Type.Kind() != reflect.Struct { + continue + } + + if _, ok := path[f.Type]; !ok { + t.inlineFields(f, path) + } + } +} + +func appendNew(dst []int, src ...int) []int { + cp := make([]int, len(dst)+len(src)) + copy(cp, dst) + copy(cp[len(dst):], src) + return cp +} + +func isScanner(typ reflect.Type) bool { + return typ.Implements(scannerType) || reflect.PtrTo(typ).Implements(scannerType) +} + +func fieldSQLType(field *Field, pgTag *tagparser.Tag) string { + if typ, ok := pgTag.Options["type"]; ok { + typ, _ = tagparser.Unquote(typ) + field.UserSQLType = typ + typ = normalizeSQLType(typ) + return typ + } + + if typ, ok := pgTag.Options["composite"]; ok { + typ, _ = tagparser.Unquote(typ) + return typ + } + + if _, ok := pgTag.Options["hstore"]; ok { + return "hstore" + } else if _, ok := pgTag.Options["hstore"]; ok { + return "hstore" + } + + if field.hasFlag(ArrayFlag) { + switch field.Type.Kind() { + case reflect.Slice, reflect.Array: + sqlType := sqlType(field.Type.Elem()) + return sqlType + "[]" + } + } + + sqlType := sqlType(field.Type) + return sqlType +} + +func sqlType(typ reflect.Type) string { + switch typ { + case timeType, nullTimeType, sqlNullTimeType: + return pgTypeTimestampTz + case ipType: + return pgTypeInet + case ipNetType: + return pgTypeCidr + case nullBoolType: + return pgTypeBoolean + case nullFloatType: + return pgTypeDoublePrecision + case nullIntType: + return pgTypeBigint + case nullStringType: + return pgTypeText + case jsonRawMessageType: + return pgTypeJSONB + } + + switch typ.Kind() { + case reflect.Int8, reflect.Uint8, reflect.Int16: + return pgTypeSmallint + case reflect.Uint16, reflect.Int32: + return pgTypeInteger + case reflect.Uint32, reflect.Int64, reflect.Int: + return pgTypeBigint + case reflect.Uint, reflect.Uint64: + // Unsigned bigint is not supported - use bigint. + return pgTypeBigint + case reflect.Float32: + return pgTypeReal + case reflect.Float64: + return pgTypeDoublePrecision + case reflect.Bool: + return pgTypeBoolean + case reflect.String: + return pgTypeText + case reflect.Map, reflect.Struct: + return pgTypeJSONB + case reflect.Array, reflect.Slice: + if typ.Elem().Kind() == reflect.Uint8 { + return pgTypeBytea + } + return pgTypeJSONB + default: + return typ.Kind().String() + } +} + +func normalizeSQLType(s string) string { + switch s { + case "int2": + return pgTypeSmallint + case "int4", "int", "serial": + return pgTypeInteger + case "int8", pgTypeBigserial: + return pgTypeBigint + case "float4": + return pgTypeReal + case "float8": + return pgTypeDoublePrecision + } + return s +} + +func sqlTypeEqual(a, b string) bool { + return a == b +} + +func (t *Table) tryHasOne(joinTable *Table, field *Field, pgTag *tagparser.Tag) bool { + fk, fkOK := pgTag.Options["fk"] + if fkOK { + if fk == "-" { + return false + } + fk = tryUnderscorePrefix(fk) + } else { + fk = internal.Underscore(field.GoName) + "_" + } + + fks := foreignKeys(joinTable, t, fk, fkOK) + if len(fks) > 0 { + t.addRelation(&Relation{ + Type: HasOneRelation, + Field: field, + JoinTable: joinTable, + BaseFKs: fks, + JoinFKs: joinTable.PKs, + }) + return true + } + return false +} + +func (t *Table) tryBelongsToOne(joinTable *Table, field *Field, pgTag *tagparser.Tag) bool { + fk, fkOK := pgTag.Options["fk"] + if fkOK { + if fk == "-" { + return false + } + fk = tryUnderscorePrefix(fk) + } else { + fk = internal.Underscore(t.TypeName) + "_" + } + + fks := foreignKeys(t, joinTable, fk, fkOK) + if len(fks) > 0 { + t.addRelation(&Relation{ + Type: BelongsToRelation, + Field: field, + JoinTable: joinTable, + BaseFKs: t.PKs, + JoinFKs: fks, + }) + return true + } + return false +} + +func (t *Table) addRelation(rel *Relation) { + if t.Relations == nil { + t.Relations = make(map[string]*Relation) + } + _, ok := t.Relations[rel.Field.GoName] + if ok { + panic(fmt.Errorf("%s already has %s", t, rel)) + } + t.Relations[rel.Field.GoName] = rel +} + +func foreignKeys(base, join *Table, fk string, tryFK bool) []*Field { + var fks []*Field + + for _, pk := range base.PKs { + fkName := fk + pk.SQLName + f := join.getField(fkName) + if f != nil && sqlTypeEqual(pk.SQLType, f.SQLType) { + fks = append(fks, f) + continue + } + + if strings.IndexByte(pk.SQLName, '_') == -1 { + continue + } + + f = join.getField(pk.SQLName) + if f != nil && sqlTypeEqual(pk.SQLType, f.SQLType) { + fks = append(fks, f) + continue + } + } + if len(fks) > 0 && len(fks) == len(base.PKs) { + return fks + } + + fks = nil + for _, pk := range base.PKs { + if !strings.HasPrefix(pk.SQLName, "pk_") { + continue + } + fkName := "fk_" + pk.SQLName[3:] + f := join.getField(fkName) + if f != nil && sqlTypeEqual(pk.SQLType, f.SQLType) { + fks = append(fks, f) + } + } + if len(fks) > 0 && len(fks) == len(base.PKs) { + return fks + } + + if fk == "" || len(base.PKs) != 1 { + return nil + } + + if tryFK { + f := join.getField(fk) + if f != nil && sqlTypeEqual(base.PKs[0].SQLType, f.SQLType) { + return []*Field{f} + } + } + + for _, suffix := range []string{"id", "uuid"} { + f := join.getField(fk + suffix) + if f != nil && sqlTypeEqual(base.PKs[0].SQLType, f.SQLType) { + return []*Field{f} + } + } + + return nil +} + +func scanJSONValue(v reflect.Value, rd types.Reader, n int) error { + // Zero value so it works with SelectOrInsert. + // TODO: better handle slices + v.Set(reflect.New(v.Type()).Elem()) + + if n == -1 { + return nil + } + + dec := pgjson.NewDecoder(rd) + dec.UseNumber() + return dec.Decode(v.Addr().Interface()) +} + +func appendUintAsInt(b []byte, v reflect.Value, _ int) []byte { + return strconv.AppendInt(b, int64(v.Uint()), 10) +} + +func appendUintPtrAsInt(b []byte, v reflect.Value, _ int) []byte { + return strconv.AppendInt(b, int64(v.Elem().Uint()), 10) +} + +func tryUnderscorePrefix(s string) string { + if s == "" { + return s + } + if c := s[0]; internal.IsUpper(c) { + return internal.Underscore(s) + "_" + } + return s +} + +func quoteTableName(s string) types.Safe { + // Don't quote if table name contains placeholder (?) or parentheses. + if strings.IndexByte(s, '?') >= 0 || + strings.IndexByte(s, '(') >= 0 && strings.IndexByte(s, ')') >= 0 { + return types.Safe(s) + } + return quoteIdent(s) +} + +func quoteIdent(s string) types.Safe { + return types.Safe(types.AppendIdent(nil, s, 1)) +} + +func setSoftDeleteFieldFunc(typ reflect.Type) func(fv reflect.Value) error { + switch typ { + case timeType: + return func(fv reflect.Value) error { + ptr := fv.Addr().Interface().(*time.Time) + *ptr = time.Now() + return nil + } + case nullTimeType: + return func(fv reflect.Value) error { + ptr := fv.Addr().Interface().(*types.NullTime) + *ptr = types.NullTime{Time: time.Now()} + return nil + } + case nullIntType: + return func(fv reflect.Value) error { + ptr := fv.Addr().Interface().(*sql.NullInt64) + *ptr = sql.NullInt64{Int64: time.Now().UnixNano()} + return nil + } + } + + switch typ.Kind() { + case reflect.Int64: + return func(fv reflect.Value) error { + ptr := fv.Addr().Interface().(*int64) + *ptr = time.Now().UnixNano() + return nil + } + case reflect.Ptr: + break + default: + return setSoftDeleteFallbackFunc(typ) + } + + originalType := typ + typ = typ.Elem() + + switch typ { //nolint:gocritic + case timeType: + return func(fv reflect.Value) error { + now := time.Now() + fv.Set(reflect.ValueOf(&now)) + return nil + } + } + + switch typ.Kind() { //nolint:gocritic + case reflect.Int64: + return func(fv reflect.Value) error { + utime := time.Now().UnixNano() + fv.Set(reflect.ValueOf(&utime)) + return nil + } + } + + return setSoftDeleteFallbackFunc(originalType) +} + +func setSoftDeleteFallbackFunc(typ reflect.Type) func(fv reflect.Value) error { + scanner := types.Scanner(typ) + if scanner == nil { + return nil + } + + return func(fv reflect.Value) error { + var flags int + b := types.AppendTime(nil, time.Now(), flags) + return scanner(fv, pool.NewBytesReader(b), len(b)) + } +} + +func isKnownTableOption(name string) bool { + switch name { + case "alias", + "select", + "tablespace", + "partition_by", + "discard_unknown_columns": + return true + } + return false +} + +func isKnownFieldOption(name string) bool { + switch name { + case "alias", + "type", + "array", + "hstore", + "composite", + "json_use_number", + "msgpack", + "notnull", + "use_zero", + "default", + "unique", + "soft_delete", + "on_delete", + "on_update", + + "pk", + "nopk", + "rel", + "fk", + "join_fk", + "many2many", + "polymorphic": + return true + } + return false +} diff --git a/vendor/github.com/go-pg/pg/v10/orm/table_create.go b/vendor/github.com/go-pg/pg/v10/orm/table_create.go new file mode 100644 index 000000000..384c729de --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/orm/table_create.go @@ -0,0 +1,248 @@ +package orm + +import ( + "sort" + "strconv" + + "github.com/go-pg/pg/v10/types" +) + +type CreateTableOptions struct { + Varchar int // replaces PostgreSQL data type `text` with `varchar(n)` + Temp bool + IfNotExists bool + + // FKConstraints causes CreateTable to create foreign key constraints + // for has one relations. ON DELETE hook can be added using tag + // `pg:"on_delete:RESTRICT"` on foreign key field. ON UPDATE hook can be added using tag + // `pg:"on_update:CASCADE"` + FKConstraints bool +} + +type CreateTableQuery struct { + q *Query + opt *CreateTableOptions +} + +var ( + _ QueryAppender = (*CreateTableQuery)(nil) + _ QueryCommand = (*CreateTableQuery)(nil) +) + +func NewCreateTableQuery(q *Query, opt *CreateTableOptions) *CreateTableQuery { + return &CreateTableQuery{ + q: q, + opt: opt, + } +} + +func (q *CreateTableQuery) String() string { + b, err := q.AppendQuery(defaultFmter, nil) + if err != nil { + panic(err) + } + return string(b) +} + +func (q *CreateTableQuery) Operation() QueryOp { + return CreateTableOp +} + +func (q *CreateTableQuery) Clone() QueryCommand { + return &CreateTableQuery{ + q: q.q.Clone(), + opt: q.opt, + } +} + +func (q *CreateTableQuery) Query() *Query { + return q.q +} + +func (q *CreateTableQuery) AppendTemplate(b []byte) ([]byte, error) { + return q.AppendQuery(dummyFormatter{}, b) +} + +func (q *CreateTableQuery) AppendQuery(fmter QueryFormatter, b []byte) (_ []byte, err error) { + if q.q.stickyErr != nil { + return nil, q.q.stickyErr + } + if q.q.tableModel == nil { + return nil, errModelNil + } + + table := q.q.tableModel.Table() + + b = append(b, "CREATE "...) + if q.opt != nil && q.opt.Temp { + b = append(b, "TEMP "...) + } + b = append(b, "TABLE "...) + if q.opt != nil && q.opt.IfNotExists { + b = append(b, "IF NOT EXISTS "...) + } + b, err = q.q.appendFirstTable(fmter, b) + if err != nil { + return nil, err + } + b = append(b, " ("...) + + for i, field := range table.Fields { + if i > 0 { + b = append(b, ", "...) + } + + b = append(b, field.Column...) + b = append(b, " "...) + b = q.appendSQLType(b, field) + if field.hasFlag(NotNullFlag) { + b = append(b, " NOT NULL"...) + } + if field.hasFlag(UniqueFlag) { + b = append(b, " UNIQUE"...) + } + if field.Default != "" { + b = append(b, " DEFAULT "...) + b = append(b, field.Default...) + } + } + + b = appendPKConstraint(b, table.PKs) + b = appendUniqueConstraints(b, table) + + if q.opt != nil && q.opt.FKConstraints { + for _, rel := range table.Relations { + b = q.appendFKConstraint(fmter, b, rel) + } + } + + b = append(b, ")"...) + + if table.PartitionBy != "" { + b = append(b, " PARTITION BY "...) + b = append(b, table.PartitionBy...) + } + + if table.Tablespace != "" { + b = q.appendTablespace(b, table.Tablespace) + } + + return b, q.q.stickyErr +} + +func (q *CreateTableQuery) appendSQLType(b []byte, field *Field) []byte { + if field.UserSQLType != "" { + return append(b, field.UserSQLType...) + } + if q.opt != nil && q.opt.Varchar > 0 && + field.SQLType == "text" { + b = append(b, "varchar("...) + b = strconv.AppendInt(b, int64(q.opt.Varchar), 10) + b = append(b, ")"...) + return b + } + if field.hasFlag(PrimaryKeyFlag) { + return append(b, pkSQLType(field.SQLType)...) + } + return append(b, field.SQLType...) +} + +func pkSQLType(s string) string { + switch s { + case pgTypeSmallint: + return pgTypeSmallserial + case pgTypeInteger: + return pgTypeSerial + case pgTypeBigint: + return pgTypeBigserial + } + return s +} + +func appendPKConstraint(b []byte, pks []*Field) []byte { + if len(pks) == 0 { + return b + } + + b = append(b, ", PRIMARY KEY ("...) + b = appendColumns(b, "", pks) + b = append(b, ")"...) + return b +} + +func appendUniqueConstraints(b []byte, table *Table) []byte { + keys := make([]string, 0, len(table.Unique)) + for key := range table.Unique { + keys = append(keys, key) + } + sort.Strings(keys) + + for _, key := range keys { + b = appendUnique(b, table.Unique[key]) + } + + return b +} + +func appendUnique(b []byte, fields []*Field) []byte { + b = append(b, ", UNIQUE ("...) + b = appendColumns(b, "", fields) + b = append(b, ")"...) + return b +} + +func (q *CreateTableQuery) appendFKConstraint(fmter QueryFormatter, b []byte, rel *Relation) []byte { + if rel.Type != HasOneRelation { + return b + } + + b = append(b, ", FOREIGN KEY ("...) + b = appendColumns(b, "", rel.BaseFKs) + b = append(b, ")"...) + + b = append(b, " REFERENCES "...) + b = fmter.FormatQuery(b, string(rel.JoinTable.SQLName)) + b = append(b, " ("...) + b = appendColumns(b, "", rel.JoinFKs) + b = append(b, ")"...) + + if s := onDelete(rel.BaseFKs); s != "" { + b = append(b, " ON DELETE "...) + b = append(b, s...) + } + + if s := onUpdate(rel.BaseFKs); s != "" { + b = append(b, " ON UPDATE "...) + b = append(b, s...) + } + + return b +} + +func (q *CreateTableQuery) appendTablespace(b []byte, tableSpace types.Safe) []byte { + b = append(b, " TABLESPACE "...) + b = append(b, tableSpace...) + return b +} + +func onDelete(fks []*Field) string { + var onDelete string + for _, f := range fks { + if f.OnDelete != "" { + onDelete = f.OnDelete + break + } + } + return onDelete +} + +func onUpdate(fks []*Field) string { + var onUpdate string + for _, f := range fks { + if f.OnUpdate != "" { + onUpdate = f.OnUpdate + break + } + } + return onUpdate +} diff --git a/vendor/github.com/go-pg/pg/v10/orm/table_drop.go b/vendor/github.com/go-pg/pg/v10/orm/table_drop.go new file mode 100644 index 000000000..599ac3952 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/orm/table_drop.go @@ -0,0 +1,73 @@ +package orm + +type DropTableOptions struct { + IfExists bool + Cascade bool +} + +type DropTableQuery struct { + q *Query + opt *DropTableOptions +} + +var ( + _ QueryAppender = (*DropTableQuery)(nil) + _ QueryCommand = (*DropTableQuery)(nil) +) + +func NewDropTableQuery(q *Query, opt *DropTableOptions) *DropTableQuery { + return &DropTableQuery{ + q: q, + opt: opt, + } +} + +func (q *DropTableQuery) String() string { + b, err := q.AppendQuery(defaultFmter, nil) + if err != nil { + panic(err) + } + return string(b) +} + +func (q *DropTableQuery) Operation() QueryOp { + return DropTableOp +} + +func (q *DropTableQuery) Clone() QueryCommand { + return &DropTableQuery{ + q: q.q.Clone(), + opt: q.opt, + } +} + +func (q *DropTableQuery) Query() *Query { + return q.q +} + +func (q *DropTableQuery) AppendTemplate(b []byte) ([]byte, error) { + return q.AppendQuery(dummyFormatter{}, b) +} + +func (q *DropTableQuery) AppendQuery(fmter QueryFormatter, b []byte) (_ []byte, err error) { + if q.q.stickyErr != nil { + return nil, q.q.stickyErr + } + if q.q.tableModel == nil { + return nil, errModelNil + } + + b = append(b, "DROP TABLE "...) + if q.opt != nil && q.opt.IfExists { + b = append(b, "IF EXISTS "...) + } + b, err = q.q.appendFirstTable(fmter, b) + if err != nil { + return nil, err + } + if q.opt != nil && q.opt.Cascade { + b = append(b, " CASCADE"...) + } + + return b, q.q.stickyErr +} diff --git a/vendor/github.com/go-pg/pg/v10/orm/table_params.go b/vendor/github.com/go-pg/pg/v10/orm/table_params.go new file mode 100644 index 000000000..46d8e064a --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/orm/table_params.go @@ -0,0 +1,29 @@ +package orm + +import "reflect" + +type tableParams struct { + table *Table + strct reflect.Value +} + +func newTableParams(strct interface{}) (*tableParams, bool) { + v := reflect.ValueOf(strct) + if !v.IsValid() { + return nil, false + } + + v = reflect.Indirect(v) + if v.Kind() != reflect.Struct { + return nil, false + } + + return &tableParams{ + table: GetTable(v.Type()), + strct: v, + }, true +} + +func (m *tableParams) AppendParam(fmter QueryFormatter, b []byte, name string) ([]byte, bool) { + return m.table.AppendParam(b, m.strct, name) +} diff --git a/vendor/github.com/go-pg/pg/v10/orm/tables.go b/vendor/github.com/go-pg/pg/v10/orm/tables.go new file mode 100644 index 000000000..fa937a54e --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/orm/tables.go @@ -0,0 +1,136 @@ +package orm + +import ( + "fmt" + "reflect" + "sync" + + "github.com/go-pg/pg/v10/types" +) + +var _tables = newTables() + +type tableInProgress struct { + table *Table + + init1Once sync.Once + init2Once sync.Once +} + +func newTableInProgress(table *Table) *tableInProgress { + return &tableInProgress{ + table: table, + } +} + +func (inp *tableInProgress) init1() bool { + var inited bool + inp.init1Once.Do(func() { + inp.table.init1() + inited = true + }) + return inited +} + +func (inp *tableInProgress) init2() bool { + var inited bool + inp.init2Once.Do(func() { + inp.table.init2() + inited = true + }) + return inited +} + +// GetTable returns a Table for a struct type. +func GetTable(typ reflect.Type) *Table { + return _tables.Get(typ) +} + +// RegisterTable registers a struct as SQL table. +// It is usually used to register intermediate table +// in many to many relationship. +func RegisterTable(strct interface{}) { + _tables.Register(strct) +} + +type tables struct { + tables sync.Map + + mu sync.RWMutex + inProgress map[reflect.Type]*tableInProgress +} + +func newTables() *tables { + return &tables{ + inProgress: make(map[reflect.Type]*tableInProgress), + } +} + +func (t *tables) Register(strct interface{}) { + typ := reflect.TypeOf(strct) + if typ.Kind() == reflect.Ptr { + typ = typ.Elem() + } + _ = t.Get(typ) +} + +func (t *tables) get(typ reflect.Type, allowInProgress bool) *Table { + if typ.Kind() != reflect.Struct { + panic(fmt.Errorf("got %s, wanted %s", typ.Kind(), reflect.Struct)) + } + + if v, ok := t.tables.Load(typ); ok { + return v.(*Table) + } + + t.mu.Lock() + + if v, ok := t.tables.Load(typ); ok { + t.mu.Unlock() + return v.(*Table) + } + + var table *Table + + inProgress := t.inProgress[typ] + if inProgress == nil { + table = newTable(typ) + inProgress = newTableInProgress(table) + t.inProgress[typ] = inProgress + } else { + table = inProgress.table + } + + t.mu.Unlock() + + inProgress.init1() + if allowInProgress { + return table + } + + if inProgress.init2() { + t.mu.Lock() + delete(t.inProgress, typ) + t.tables.Store(typ, table) + t.mu.Unlock() + } + + return table +} + +func (t *tables) Get(typ reflect.Type) *Table { + return t.get(typ, false) +} + +func (t *tables) getByName(name types.Safe) *Table { + var found *Table + t.tables.Range(func(key, value interface{}) bool { + t := value.(*Table) + if t.SQLName == name { + found = t + return false + } + return true + }) + return found +} diff --git a/vendor/github.com/go-pg/pg/v10/orm/types.go b/vendor/github.com/go-pg/pg/v10/orm/types.go new file mode 100644 index 000000000..c8e9ec375 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/orm/types.go @@ -0,0 +1,48 @@ +package orm + +//nolint +const ( + // Date / Time + pgTypeTimestamp = "timestamp" // Timestamp without a time zone + pgTypeTimestampTz = "timestamptz" // Timestamp with a time zone + pgTypeDate = "date" // Date + pgTypeTime = "time" // Time without a time zone + pgTypeTimeTz = "time with time zone" // Time with a time zone + pgTypeInterval = "interval" // Time Interval + + // Network Addresses + pgTypeInet = "inet" // IPv4 or IPv6 hosts and networks + pgTypeCidr = "cidr" // IPv4 or IPv6 networks + pgTypeMacaddr = "macaddr" // MAC addresses + + // Boolean + pgTypeBoolean = "boolean" + + // Numeric Types + + // Floating Point Types + pgTypeReal = "real" // 4 byte floating point (6 digit precision) + pgTypeDoublePrecision = "double precision" // 8 byte floating point (15 digit precision) + + // Integer Types + pgTypeSmallint = "smallint" // 2 byte integer + pgTypeInteger = "integer" // 4 byte integer + pgTypeBigint = "bigint" // 8 byte integer + + // Serial Types + pgTypeSmallserial = "smallserial" // 2 byte autoincrementing integer + pgTypeSerial = "serial" // 4 byte autoincrementing integer + pgTypeBigserial = "bigserial" // 8 byte autoincrementing integer + + // Character Types + pgTypeVarchar = "varchar" // variable length string with limit + pgTypeChar = "char" // fixed length string (blank padded) + pgTypeText = "text" // variable length string without limit + + // JSON Types + pgTypeJSON = "json" // text representation of json data + pgTypeJSONB = "jsonb" // binary representation of json data + + // Binary Data Types + pgTypeBytea = "bytea" // binary string +) diff --git a/vendor/github.com/go-pg/pg/v10/orm/update.go b/vendor/github.com/go-pg/pg/v10/orm/update.go new file mode 100644 index 000000000..ce6396fd3 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/orm/update.go @@ -0,0 +1,378 @@ +package orm + +import ( + "fmt" + "reflect" + "sort" + + "github.com/go-pg/pg/v10/types" +) + +type UpdateQuery struct { + q *Query + omitZero bool + placeholder bool +} + +var ( + _ QueryAppender = (*UpdateQuery)(nil) + _ QueryCommand = (*UpdateQuery)(nil) +) + +func NewUpdateQuery(q *Query, omitZero bool) *UpdateQuery { + return &UpdateQuery{ + q: q, + omitZero: omitZero, + } +} + +func (q *UpdateQuery) String() string { + b, err := q.AppendQuery(defaultFmter, nil) + if err != nil { + panic(err) + } + return string(b) +} + +func (q *UpdateQuery) Operation() QueryOp { + return UpdateOp +} + +func (q *UpdateQuery) Clone() QueryCommand { + return &UpdateQuery{ + q: q.q.Clone(), + omitZero: q.omitZero, + placeholder: q.placeholder, + } +} + +func (q *UpdateQuery) Query() *Query { + return q.q +} + +func (q *UpdateQuery) AppendTemplate(b []byte) ([]byte, error) { + cp := q.Clone().(*UpdateQuery) + cp.placeholder = true + return cp.AppendQuery(dummyFormatter{}, b) +} + +func (q *UpdateQuery) AppendQuery(fmter QueryFormatter, b []byte) (_ []byte, err error) { + if q.q.stickyErr != nil { + return nil, q.q.stickyErr + } + + if len(q.q.with) > 0 { + b, err = q.q.appendWith(fmter, b) + if err != nil { + return nil, err + } + } + + b = append(b, "UPDATE "...) + + b, err = q.q.appendFirstTableWithAlias(fmter, b) + if err != nil { + return nil, err + } + + b, err = q.mustAppendSet(fmter, b) + if err != nil { + return nil, err + } + + isSliceModelWithData := q.q.isSliceModelWithData() + if isSliceModelWithData || q.q.hasMultiTables() { + b = append(b, " FROM "...) + b, err = q.q.appendOtherTables(fmter, b) + if err != nil { + return nil, err + } + + if isSliceModelWithData { + b, err = q.appendSliceModelData(fmter, b) + if err != nil { + return nil, err + } + } + } + + b, err = q.mustAppendWhere(fmter, b, isSliceModelWithData) + if err != nil { + return nil, err + } + + if len(q.q.returning) > 0 { + b, err = q.q.appendReturning(fmter, b) + if err != nil { + return nil, err + } + } + + return b, q.q.stickyErr +} + +func (q *UpdateQuery) mustAppendWhere( + fmter QueryFormatter, b []byte, isSliceModelWithData bool, +) (_ []byte, err error) { + b = append(b, " WHERE "...) + + if !isSliceModelWithData { + return q.q.mustAppendWhere(fmter, b) + } + + if len(q.q.where) > 0 { + return q.q.appendWhere(fmter, b) + } + + table := q.q.tableModel.Table() + err = table.checkPKs() + if err != nil { + return nil, err + } + + b = appendWhereColumnAndColumn(b, table.Alias, table.PKs) + return b, nil +} + +func (q *UpdateQuery) mustAppendSet(fmter QueryFormatter, b []byte) (_ []byte, err error) { + if len(q.q.set) > 0 { + return q.q.appendSet(fmter, b) + } + + b = append(b, " SET "...) + + if m, ok := q.q.model.(*mapModel); ok { + return q.appendMapSet(b, m.m), nil + } + + if !q.q.hasTableModel() { + return nil, errModelNil + } + + value := q.q.tableModel.Value() + if value.Kind() == reflect.Struct { + b, err = q.appendSetStruct(fmter, b, value) + } else { + if value.Len() > 0 { + b, err = q.appendSetSlice(b) + } else { + err = fmt.Errorf("pg: can't bulk-update empty slice %s", value.Type()) + } + } + if err != nil { + return nil, err + } + + return b, nil +} + +func (q *UpdateQuery) appendMapSet(b []byte, m map[string]interface{}) []byte { + keys := make([]string, 0, len(m)) + + for k := range m { + keys = append(keys, k) + } + sort.Strings(keys) + + for i, k := range keys { + if i > 0 { + b = append(b, ", "...) + } + + b = types.AppendIdent(b, k, 1) + b = append(b, " = "...) + if q.placeholder { + b = append(b, '?') + } else { + b = types.Append(b, m[k], 1) + } + } + + return b +} + +func (q *UpdateQuery) appendSetStruct(fmter QueryFormatter, b []byte, strct reflect.Value) ([]byte, error) { + fields, err := q.q.getFields() + if err != nil { + return nil, err + } + + if len(fields) == 0 { + fields = q.q.tableModel.Table().DataFields + } + + pos := len(b) + for _, f := range fields { + if q.omitZero && f.NullZero() && f.HasZeroValue(strct) { + continue + } + + if len(b) != pos { + b = append(b, ", "...) + pos = len(b) + } + + b = append(b, f.Column...) + b = append(b, " = "...) + + if q.placeholder { + b = append(b, '?') + continue + } + + app, ok := q.q.modelValues[f.SQLName] + if ok { + b, err = app.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } else { + b = f.AppendValue(b, strct, 1) + } + } + + for i, v := range q.q.extraValues { + if i > 0 || len(fields) > 0 { + b = append(b, ", "...) + } + + b = append(b, v.column...) + b = append(b, " = "...) + + b, err = v.value.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + + return b, nil +} + +func (q *UpdateQuery) appendSetSlice(b []byte) ([]byte, error) { + fields, err := q.q.getFields() + if err != nil { + return nil, err + } + + if len(fields) == 0 { + fields = q.q.tableModel.Table().DataFields + } + + var table *Table + if q.omitZero { + table = q.q.tableModel.Table() + } + + for i, f := range fields { + if i > 0 { + b = append(b, ", "...) + } + + b = append(b, f.Column...) + b = append(b, " = "...) + if q.omitZero && table != nil { + b = append(b, "COALESCE("...) + } + b = append(b, "_data."...) + b = append(b, f.Column...) + if q.omitZero && table != nil { + b = append(b, ", "...) + if table.Alias != table.SQLName { + b = append(b, table.Alias...) + b = append(b, '.') + } + b = append(b, f.Column...) + b = append(b, ")"...) + } + } + + return b, nil +} + +func (q *UpdateQuery) appendSliceModelData(fmter QueryFormatter, b []byte) ([]byte, error) { + columns, err := q.q.getDataFields() + if err != nil { + return nil, err + } + + if len(columns) > 0 { + columns = append(columns, q.q.tableModel.Table().PKs...) + } else { + columns = q.q.tableModel.Table().Fields + } + + return q.appendSliceValues(fmter, b, columns, q.q.tableModel.Value()) +} + +func (q *UpdateQuery) appendSliceValues( + fmter QueryFormatter, b []byte, fields []*Field, slice reflect.Value, +) (_ []byte, err error) { + b = append(b, "(VALUES ("...) + + if q.placeholder { + b, err = q.appendValues(fmter, b, fields, reflect.Value{}) + if err != nil { + return nil, err + } + } else { + sliceLen := slice.Len() + for i := 0; i < sliceLen; i++ { + if i > 0 { + b = append(b, "), ("...) + } + b, err = q.appendValues(fmter, b, fields, slice.Index(i)) + if err != nil { + return nil, err + } + } + } + + b = append(b, ")) AS _data("...) + b = appendColumns(b, "", fields) + b = append(b, ")"...) + + return b, nil +} + +func (q *UpdateQuery) appendValues( + fmter QueryFormatter, b []byte, fields []*Field, strct reflect.Value, +) (_ []byte, err error) { + for i, f := range fields { + if i > 0 { + b = append(b, ", "...) + } + + app, ok := q.q.modelValues[f.SQLName] + if ok { + b, err = app.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + continue + } + + if q.placeholder { + b = append(b, '?') + } else { + b = f.AppendValue(b, indirect(strct), 1) + } + + b = append(b, "::"...) + b = append(b, f.SQLType...) + } + return b, nil +} + +func appendWhereColumnAndColumn(b []byte, alias types.Safe, fields []*Field) []byte { + for i, f := range fields { + if i > 0 { + b = append(b, " AND "...) + } + b = append(b, alias...) + b = append(b, '.') + b = append(b, f.Column...) + b = append(b, " = _data."...) + b = append(b, f.Column...) + } + return b +} diff --git a/vendor/github.com/go-pg/pg/v10/orm/util.go b/vendor/github.com/go-pg/pg/v10/orm/util.go new file mode 100644 index 000000000..b7963ba0b --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/orm/util.go @@ -0,0 +1,151 @@ +package orm + +import ( + "reflect" + + "github.com/go-pg/pg/v10/types" +) + +func indirect(v reflect.Value) reflect.Value { + switch v.Kind() { + case reflect.Interface: + return indirect(v.Elem()) + case reflect.Ptr: + return v.Elem() + default: + return v + } +} + +func indirectType(t reflect.Type) reflect.Type { + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + return t +} + +func sliceElemType(v reflect.Value) reflect.Type { + elemType := v.Type().Elem() + if elemType.Kind() == reflect.Interface && v.Len() > 0 { + return indirect(v.Index(0).Elem()).Type() + } + return indirectType(elemType) +} + +func typeByIndex(t reflect.Type, index []int) reflect.Type { + for _, x := range index { + switch t.Kind() { + case reflect.Ptr: + t = t.Elem() + case reflect.Slice: + t = indirectType(t.Elem()) + } + t = t.Field(x).Type + } + return indirectType(t) +} + +func fieldByIndex(v reflect.Value, index []int) (_ reflect.Value, ok bool) { + if len(index) == 1 { + return v.Field(index[0]), true + } + + for i, idx := range index { + if i > 0 { + if v.Kind() == reflect.Ptr { + if v.IsNil() { + return v, false + } + v = v.Elem() + } + } + v = v.Field(idx) + } + return v, true +} + +func fieldByIndexAlloc(v reflect.Value, index []int) reflect.Value { + if len(index) == 1 { + return v.Field(index[0]) + } + + for i, idx := range index { + if i > 0 { + v = indirectNil(v) + } + v = v.Field(idx) + } + return v +} + +func indirectNil(v reflect.Value) reflect.Value { + if v.Kind() == reflect.Ptr { + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } + v = v.Elem() + } + return v +} + +func walk(v reflect.Value, index []int, fn func(reflect.Value)) { + v = reflect.Indirect(v) + switch v.Kind() { + case reflect.Slice: + sliceLen := v.Len() + for i := 0; i < sliceLen; i++ { + visitField(v.Index(i), index, fn) + } + default: + visitField(v, index, fn) + } +} + +func visitField(v reflect.Value, index []int, fn func(reflect.Value)) { + v = reflect.Indirect(v) + if len(index) > 0 { + v = v.Field(index[0]) + if v.Kind() == reflect.Ptr && v.IsNil() { + return + } + walk(v, index[1:], fn) + } else { + fn(v) + } +} + +func dstValues(model TableModel, fields []*Field) map[string][]reflect.Value { + fieldIndex := model.Relation().Field.Index + m := make(map[string][]reflect.Value) + var id []byte + walk(model.Root(), model.ParentIndex(), func(v reflect.Value) { + id = modelID(id[:0], v, fields) + m[string(id)] = append(m[string(id)], v.FieldByIndex(fieldIndex)) + }) + return m +} + +func modelID(b []byte, v reflect.Value, fields []*Field) []byte { + for i, f := range fields { + if i > 0 { + b = append(b, ',') + } + b = f.AppendValue(b, v, 0) + } + return b +} + +func appendColumns(b []byte, table types.Safe, fields []*Field) []byte { + for i, f := range fields { + if i > 0 { + b = append(b, ", "...) + } + + if len(table) > 0 { + b = append(b, table...) + b = append(b, '.') + } + b = append(b, f.Column...) + } + return b +} |