diff options
Diffstat (limited to 'vendor/github.com/go-pg/pg/v10/orm/query.go')
-rw-r--r-- | vendor/github.com/go-pg/pg/v10/orm/query.go | 1680 |
1 files changed, 1680 insertions, 0 deletions
diff --git a/vendor/github.com/go-pg/pg/v10/orm/query.go b/vendor/github.com/go-pg/pg/v10/orm/query.go new file mode 100644 index 000000000..8a9231f65 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/orm/query.go @@ -0,0 +1,1680 @@ +package orm + +import ( + "context" + "errors" + "fmt" + "io" + "reflect" + "strconv" + "strings" + "sync" + "time" + + "github.com/go-pg/pg/v10/internal" + "github.com/go-pg/pg/v10/types" +) + +type QueryOp string + +const ( + SelectOp QueryOp = "SELECT" + InsertOp QueryOp = "INSERT" + UpdateOp QueryOp = "UPDATE" + DeleteOp QueryOp = "DELETE" + CreateTableOp QueryOp = "CREATE TABLE" + DropTableOp QueryOp = "DROP TABLE" + CreateCompositeOp QueryOp = "CREATE COMPOSITE" + DropCompositeOp QueryOp = "DROP COMPOSITE" +) + +type queryFlag uint8 + +const ( + implicitModelFlag queryFlag = 1 << iota + deletedFlag + allWithDeletedFlag +) + +type withQuery struct { + name string + query QueryAppender +} + +type columnValue struct { + column string + value *SafeQueryAppender +} + +type union struct { + expr string + query *Query +} + +type Query struct { + ctx context.Context + db DB + stickyErr error + + model Model + tableModel TableModel + flags queryFlag + + with []withQuery + tables []QueryAppender + distinctOn []*SafeQueryAppender + columns []QueryAppender + set []QueryAppender + modelValues map[string]*SafeQueryAppender + extraValues []*columnValue + where []queryWithSepAppender + updWhere []queryWithSepAppender + group []QueryAppender + having []*SafeQueryAppender + union []*union + joins []QueryAppender + joinAppendOn func(app *condAppender) + order []QueryAppender + limit int + offset int + selFor *SafeQueryAppender + + onConflict *SafeQueryAppender + returning []*SafeQueryAppender +} + +func NewQuery(db DB, model ...interface{}) *Query { + ctx := context.Background() + if db != nil { + ctx = db.Context() + } + q := &Query{ctx: ctx} + return q.DB(db).Model(model...) +} + +func NewQueryContext(ctx context.Context, db DB, model ...interface{}) *Query { + return NewQuery(db, model...).Context(ctx) +} + +// New returns new zero Query bound to the current db. +func (q *Query) New() *Query { + clone := &Query{ + ctx: q.ctx, + db: q.db, + + model: q.model, + tableModel: cloneTableModelJoins(q.tableModel), + flags: q.flags, + } + return clone.withFlag(implicitModelFlag) +} + +// Clone clones the Query. +func (q *Query) Clone() *Query { + var modelValues map[string]*SafeQueryAppender + if len(q.modelValues) > 0 { + modelValues = make(map[string]*SafeQueryAppender, len(q.modelValues)) + for k, v := range q.modelValues { + modelValues[k] = v + } + } + + clone := &Query{ + ctx: q.ctx, + db: q.db, + stickyErr: q.stickyErr, + + model: q.model, + tableModel: cloneTableModelJoins(q.tableModel), + flags: q.flags, + + with: q.with[:len(q.with):len(q.with)], + tables: q.tables[:len(q.tables):len(q.tables)], + distinctOn: q.distinctOn[:len(q.distinctOn):len(q.distinctOn)], + columns: q.columns[:len(q.columns):len(q.columns)], + set: q.set[:len(q.set):len(q.set)], + modelValues: modelValues, + extraValues: q.extraValues[:len(q.extraValues):len(q.extraValues)], + where: q.where[:len(q.where):len(q.where)], + updWhere: q.updWhere[:len(q.updWhere):len(q.updWhere)], + joins: q.joins[:len(q.joins):len(q.joins)], + group: q.group[:len(q.group):len(q.group)], + having: q.having[:len(q.having):len(q.having)], + union: q.union[:len(q.union):len(q.union)], + order: q.order[:len(q.order):len(q.order)], + limit: q.limit, + offset: q.offset, + selFor: q.selFor, + + onConflict: q.onConflict, + returning: q.returning[:len(q.returning):len(q.returning)], + } + + return clone +} + +func cloneTableModelJoins(tm TableModel) TableModel { + switch tm := tm.(type) { + case *structTableModel: + if len(tm.joins) == 0 { + return tm + } + clone := *tm + clone.joins = clone.joins[:len(clone.joins):len(clone.joins)] + return &clone + case *sliceTableModel: + if len(tm.joins) == 0 { + return tm + } + clone := *tm + clone.joins = clone.joins[:len(clone.joins):len(clone.joins)] + return &clone + } + return tm +} + +func (q *Query) err(err error) *Query { + if q.stickyErr == nil { + q.stickyErr = err + } + return q +} + +func (q *Query) hasFlag(flag queryFlag) bool { + return hasFlag(q.flags, flag) +} + +func hasFlag(flags, flag queryFlag) bool { + return flags&flag != 0 +} + +func (q *Query) withFlag(flag queryFlag) *Query { + q.flags |= flag + return q +} + +func (q *Query) withoutFlag(flag queryFlag) *Query { + q.flags &= ^flag + return q +} + +func (q *Query) Context(c context.Context) *Query { + q.ctx = c + return q +} + +func (q *Query) DB(db DB) *Query { + q.db = db + return q +} + +func (q *Query) Model(model ...interface{}) *Query { + var err error + switch l := len(model); { + case l == 0: + q.model = nil + case l == 1: + q.model, err = NewModel(model[0]) + case l > 1: + q.model, err = NewModel(&model) + default: + panic("not reached") + } + if err != nil { + q = q.err(err) + } + + q.tableModel, _ = q.model.(TableModel) + + return q.withoutFlag(implicitModelFlag) +} + +func (q *Query) TableModel() TableModel { + return q.tableModel +} + +func (q *Query) isSoftDelete() bool { + if q.tableModel != nil { + return q.tableModel.Table().SoftDeleteField != nil && !q.hasFlag(allWithDeletedFlag) + } + return false +} + +// Deleted adds `WHERE deleted_at IS NOT NULL` clause for soft deleted models. +func (q *Query) Deleted() *Query { + if q.tableModel != nil { + if err := q.tableModel.Table().mustSoftDelete(); err != nil { + return q.err(err) + } + } + return q.withFlag(deletedFlag).withoutFlag(allWithDeletedFlag) +} + +// AllWithDeleted changes query to return all rows including soft deleted ones. +func (q *Query) AllWithDeleted() *Query { + if q.tableModel != nil { + if err := q.tableModel.Table().mustSoftDelete(); err != nil { + return q.err(err) + } + } + return q.withFlag(allWithDeletedFlag).withoutFlag(deletedFlag) +} + +// With adds subq as common table expression with the given name. +func (q *Query) With(name string, subq *Query) *Query { + return q._with(name, NewSelectQuery(subq)) +} + +func (q *Query) WithInsert(name string, subq *Query) *Query { + return q._with(name, NewInsertQuery(subq)) +} + +func (q *Query) WithUpdate(name string, subq *Query) *Query { + return q._with(name, NewUpdateQuery(subq, false)) +} + +func (q *Query) WithDelete(name string, subq *Query) *Query { + return q._with(name, NewDeleteQuery(subq)) +} + +func (q *Query) _with(name string, subq QueryAppender) *Query { + q.with = append(q.with, withQuery{ + name: name, + query: subq, + }) + return q +} + +// WrapWith creates new Query and adds to it current query as +// common table expression with the given name. +func (q *Query) WrapWith(name string) *Query { + wrapper := q.New() + wrapper.with = q.with + q.with = nil + wrapper = wrapper.With(name, q) + return wrapper +} + +func (q *Query) Table(tables ...string) *Query { + for _, table := range tables { + q.tables = append(q.tables, fieldAppender{table}) + } + return q +} + +func (q *Query) TableExpr(expr string, params ...interface{}) *Query { + q.tables = append(q.tables, SafeQuery(expr, params...)) + return q +} + +func (q *Query) Distinct() *Query { + q.distinctOn = make([]*SafeQueryAppender, 0) + return q +} + +func (q *Query) DistinctOn(expr string, params ...interface{}) *Query { + q.distinctOn = append(q.distinctOn, SafeQuery(expr, params...)) + return q +} + +// Column adds a column to the Query quoting it according to PostgreSQL rules. +// Does not expand params like ?TableAlias etc. +// ColumnExpr can be used to bypass quoting restriction or for params expansion. +// Column name can be: +// - column_name, +// - table_alias.column_name, +// - table_alias.*. +func (q *Query) Column(columns ...string) *Query { + for _, column := range columns { + if column == "_" { + if q.columns == nil { + q.columns = make([]QueryAppender, 0) + } + continue + } + + q.columns = append(q.columns, fieldAppender{column}) + } + return q +} + +// ColumnExpr adds column expression to the Query. +func (q *Query) ColumnExpr(expr string, params ...interface{}) *Query { + q.columns = append(q.columns, SafeQuery(expr, params...)) + return q +} + +// ExcludeColumn excludes a column from the list of to be selected columns. +func (q *Query) ExcludeColumn(columns ...string) *Query { + if q.columns == nil { + for _, f := range q.tableModel.Table().Fields { + q.columns = append(q.columns, fieldAppender{f.SQLName}) + } + } + + for _, col := range columns { + if !q.excludeColumn(col) { + return q.err(fmt.Errorf("pg: can't find column=%q", col)) + } + } + return q +} + +func (q *Query) excludeColumn(column string) bool { + for i := 0; i < len(q.columns); i++ { + app, ok := q.columns[i].(fieldAppender) + if ok && app.field == column { + q.columns = append(q.columns[:i], q.columns[i+1:]...) + return true + } + } + return false +} + +func (q *Query) getFields() ([]*Field, error) { + return q._getFields(false) +} + +func (q *Query) getDataFields() ([]*Field, error) { + return q._getFields(true) +} + +func (q *Query) _getFields(omitPKs bool) ([]*Field, error) { + table := q.tableModel.Table() + columns := make([]*Field, 0, len(q.columns)) + for _, col := range q.columns { + f, ok := col.(fieldAppender) + if !ok { + continue + } + + field, err := table.GetField(f.field) + if err != nil { + return nil, err + } + + if omitPKs && field.hasFlag(PrimaryKeyFlag) { + continue + } + + columns = append(columns, field) + } + return columns, nil +} + +// Relation adds a relation to the query. Relation name can be: +// - RelationName to select all columns, +// - RelationName.column_name, +// - RelationName._ to join relation without selecting relation columns. +func (q *Query) Relation(name string, apply ...func(*Query) (*Query, error)) *Query { + var fn func(*Query) (*Query, error) + if len(apply) == 1 { + fn = apply[0] + } else if len(apply) > 1 { + panic("only one apply function is supported") + } + + join := q.tableModel.Join(name, fn) + if join == nil { + return q.err(fmt.Errorf("%s does not have relation=%q", + q.tableModel.Table(), name)) + } + + if fn == nil { + return q + } + + switch join.Rel.Type { + case HasOneRelation, BelongsToRelation: + q.joinAppendOn = join.AppendOn + return q.Apply(fn) + default: + q.joinAppendOn = nil + return q + } +} + +func (q *Query) Set(set string, params ...interface{}) *Query { + q.set = append(q.set, SafeQuery(set, params...)) + return q +} + +// Value overwrites model value for the column in INSERT and UPDATE queries. +func (q *Query) Value(column string, value string, params ...interface{}) *Query { + if !q.hasTableModel() { + q.err(errModelNil) + return q + } + + table := q.tableModel.Table() + if _, ok := table.FieldsMap[column]; ok { + if q.modelValues == nil { + q.modelValues = make(map[string]*SafeQueryAppender) + } + q.modelValues[column] = SafeQuery(value, params...) + } else { + q.extraValues = append(q.extraValues, &columnValue{ + column: column, + value: SafeQuery(value, params...), + }) + } + + return q +} + +func (q *Query) Where(condition string, params ...interface{}) *Query { + q.addWhere(&condAppender{ + sep: " AND ", + cond: condition, + params: params, + }) + return q +} + +func (q *Query) WhereOr(condition string, params ...interface{}) *Query { + q.addWhere(&condAppender{ + sep: " OR ", + cond: condition, + params: params, + }) + return q +} + +// WhereGroup encloses conditions added in the function in parentheses. +// +// q.Where("TRUE"). +// WhereGroup(func(q *pg.Query) (*pg.Query, error) { +// q = q.WhereOr("FALSE").WhereOr("TRUE"). +// return q, nil +// }) +// +// generates +// +// WHERE TRUE AND (FALSE OR TRUE) +func (q *Query) WhereGroup(fn func(*Query) (*Query, error)) *Query { + return q.whereGroup(" AND ", fn) +} + +// WhereGroup encloses conditions added in the function in parentheses. +// +// q.Where("TRUE"). +// WhereNotGroup(func(q *pg.Query) (*pg.Query, error) { +// q = q.WhereOr("FALSE").WhereOr("TRUE"). +// return q, nil +// }) +// +// generates +// +// WHERE TRUE AND NOT (FALSE OR TRUE) +func (q *Query) WhereNotGroup(fn func(*Query) (*Query, error)) *Query { + return q.whereGroup(" AND NOT ", fn) +} + +// WhereOrGroup encloses conditions added in the function in parentheses. +// +// q.Where("TRUE"). +// WhereOrGroup(func(q *pg.Query) (*pg.Query, error) { +// q = q.Where("FALSE").Where("TRUE"). +// return q, nil +// }) +// +// generates +// +// WHERE TRUE OR (FALSE AND TRUE) +func (q *Query) WhereOrGroup(fn func(*Query) (*Query, error)) *Query { + return q.whereGroup(" OR ", fn) +} + +// WhereOrGroup encloses conditions added in the function in parentheses. +// +// q.Where("TRUE"). +// WhereOrGroup(func(q *pg.Query) (*pg.Query, error) { +// q = q.Where("FALSE").Where("TRUE"). +// return q, nil +// }) +// +// generates +// +// WHERE TRUE OR NOT (FALSE AND TRUE) +func (q *Query) WhereOrNotGroup(fn func(*Query) (*Query, error)) *Query { + return q.whereGroup(" OR NOT ", fn) +} + +func (q *Query) whereGroup(conj string, fn func(*Query) (*Query, error)) *Query { + saved := q.where + q.where = nil + + newq, err := fn(q) + if err != nil { + q.err(err) + return q + } + + if len(newq.where) == 0 { + newq.where = saved + return newq + } + + f := &condGroupAppender{ + sep: conj, + cond: newq.where, + } + newq.where = saved + newq.addWhere(f) + + return newq +} + +// WhereIn is a shortcut for Where and pg.In. +func (q *Query) WhereIn(where string, slice interface{}) *Query { + return q.Where(where, types.In(slice)) +} + +// WhereInMulti is a shortcut for Where and pg.InMulti. +func (q *Query) WhereInMulti(where string, values ...interface{}) *Query { + return q.Where(where, types.InMulti(values...)) +} + +func (q *Query) addWhere(f queryWithSepAppender) { + if q.onConflictDoUpdate() { + q.updWhere = append(q.updWhere, f) + } else { + q.where = append(q.where, f) + } +} + +// WherePK adds condition based on the model primary keys. +// Usually it is the same as: +// +// Where("id = ?id") +func (q *Query) WherePK() *Query { + if !q.hasTableModel() { + q.err(errModelNil) + return q + } + + if err := q.tableModel.Table().checkPKs(); err != nil { + q.err(err) + return q + } + + switch q.tableModel.Kind() { + case reflect.Struct: + q.where = append(q.where, wherePKStructQuery{q}) + return q + case reflect.Slice: + q.joins = append(q.joins, joinPKSliceQuery{q: q}) + q.where = append(q.where, wherePKSliceQuery{q: q}) + q = q.OrderExpr(`"_data"."ordering" ASC`) + return q + } + + panic("not reached") +} + +func (q *Query) Join(join string, params ...interface{}) *Query { + j := &joinQuery{ + join: SafeQuery(join, params...), + } + q.joins = append(q.joins, j) + q.joinAppendOn = j.AppendOn + return q +} + +// JoinOn appends join condition to the last join. +func (q *Query) JoinOn(condition string, params ...interface{}) *Query { + if q.joinAppendOn == nil { + q.err(errors.New("pg: no joins to apply JoinOn")) + return q + } + q.joinAppendOn(&condAppender{ + sep: " AND ", + cond: condition, + params: params, + }) + return q +} + +func (q *Query) JoinOnOr(condition string, params ...interface{}) *Query { + if q.joinAppendOn == nil { + q.err(errors.New("pg: no joins to apply JoinOn")) + return q + } + q.joinAppendOn(&condAppender{ + sep: " OR ", + cond: condition, + params: params, + }) + return q +} + +func (q *Query) Group(columns ...string) *Query { + for _, column := range columns { + q.group = append(q.group, fieldAppender{column}) + } + return q +} + +func (q *Query) GroupExpr(group string, params ...interface{}) *Query { + q.group = append(q.group, SafeQuery(group, params...)) + return q +} + +func (q *Query) Having(having string, params ...interface{}) *Query { + q.having = append(q.having, SafeQuery(having, params...)) + return q +} + +func (q *Query) Union(other *Query) *Query { + return q.addUnion(" UNION ", other) +} + +func (q *Query) UnionAll(other *Query) *Query { + return q.addUnion(" UNION ALL ", other) +} + +func (q *Query) Intersect(other *Query) *Query { + return q.addUnion(" INTERSECT ", other) +} + +func (q *Query) IntersectAll(other *Query) *Query { + return q.addUnion(" INTERSECT ALL ", other) +} + +func (q *Query) Except(other *Query) *Query { + return q.addUnion(" EXCEPT ", other) +} + +func (q *Query) ExceptAll(other *Query) *Query { + return q.addUnion(" EXCEPT ALL ", other) +} + +func (q *Query) addUnion(expr string, other *Query) *Query { + q.union = append(q.union, &union{ + expr: expr, + query: other, + }) + return q +} + +// Order adds sort order to the Query quoting column name. Does not expand params like ?TableAlias etc. +// OrderExpr can be used to bypass quoting restriction or for params expansion. +func (q *Query) Order(orders ...string) *Query { +loop: + for _, order := range orders { + if order == "" { + continue + } + ind := strings.Index(order, " ") + if ind != -1 { + field := order[:ind] + sort := order[ind+1:] + switch internal.UpperString(sort) { + case "ASC", "DESC", "ASC NULLS FIRST", "DESC NULLS FIRST", + "ASC NULLS LAST", "DESC NULLS LAST": + q = q.OrderExpr("? ?", types.Ident(field), types.Safe(sort)) + continue loop + } + } + + q.order = append(q.order, fieldAppender{order}) + } + return q +} + +// Order adds sort order to the Query. +func (q *Query) OrderExpr(order string, params ...interface{}) *Query { + if order != "" { + q.order = append(q.order, SafeQuery(order, params...)) + } + return q +} + +func (q *Query) Limit(n int) *Query { + q.limit = n + return q +} + +func (q *Query) Offset(n int) *Query { + q.offset = n + return q +} + +func (q *Query) OnConflict(s string, params ...interface{}) *Query { + q.onConflict = SafeQuery(s, params...) + return q +} + +func (q *Query) onConflictDoUpdate() bool { + return q.onConflict != nil && + strings.HasSuffix(internal.UpperString(q.onConflict.query), "DO UPDATE") +} + +// Returning adds a RETURNING clause to the query. +// +// `Returning("NULL")` can be used to suppress default returning clause +// generated by go-pg for INSERT queries to get values for null columns. +func (q *Query) Returning(s string, params ...interface{}) *Query { + q.returning = append(q.returning, SafeQuery(s, params...)) + return q +} + +func (q *Query) For(s string, params ...interface{}) *Query { + q.selFor = SafeQuery(s, params...) + return q +} + +// Apply calls the fn passing the Query as an argument. +func (q *Query) Apply(fn func(*Query) (*Query, error)) *Query { + qq, err := fn(q) + if err != nil { + q.err(err) + return q + } + return qq +} + +// Count returns number of rows matching the query using count aggregate function. +func (q *Query) Count() (int, error) { + if q.stickyErr != nil { + return 0, q.stickyErr + } + + var count int + _, err := q.db.QueryOneContext( + q.ctx, Scan(&count), q.countSelectQuery("count(*)"), q.tableModel) + return count, err +} + +func (q *Query) countSelectQuery(column string) *SelectQuery { + return &SelectQuery{ + q: q, + count: column, + } +} + +// First sorts rows by primary key and selects the first row. +// It is a shortcut for: +// +// q.OrderExpr("id ASC").Limit(1) +func (q *Query) First() error { + table := q.tableModel.Table() + + if err := table.checkPKs(); err != nil { + return err + } + + b := appendColumns(nil, table.Alias, table.PKs) + return q.OrderExpr(internal.BytesToString(b)).Limit(1).Select() +} + +// Last sorts rows by primary key and selects the last row. +// It is a shortcut for: +// +// q.OrderExpr("id DESC").Limit(1) +func (q *Query) Last() error { + table := q.tableModel.Table() + + if err := table.checkPKs(); err != nil { + return err + } + + // TODO: fix for multi columns + b := appendColumns(nil, table.Alias, table.PKs) + b = append(b, " DESC"...) + return q.OrderExpr(internal.BytesToString(b)).Limit(1).Select() +} + +// Select selects the model. +func (q *Query) Select(values ...interface{}) error { + if q.stickyErr != nil { + return q.stickyErr + } + + model, err := q.newModel(values) + if err != nil { + return err + } + + res, err := q.query(q.ctx, model, NewSelectQuery(q)) + if err != nil { + return err + } + + if res.RowsReturned() > 0 { + if q.tableModel != nil { + if err := q.selectJoins(q.tableModel.GetJoins()); err != nil { + return err + } + } + } + + if err := model.AfterSelect(q.ctx); err != nil { + return err + } + + return nil +} + +func (q *Query) newModel(values []interface{}) (Model, error) { + if len(values) > 0 { + return newScanModel(values) + } + return q.tableModel, nil +} + +func (q *Query) query(ctx context.Context, model Model, query interface{}) (Result, error) { + if _, ok := model.(useQueryOne); ok { + return q.db.QueryOneContext(ctx, model, query, q.tableModel) + } + return q.db.QueryContext(ctx, model, query, q.tableModel) +} + +// SelectAndCount runs Select and Count in two goroutines, +// waits for them to finish and returns the result. If query limit is -1 +// it does not select any data and only counts the results. +func (q *Query) SelectAndCount(values ...interface{}) (count int, firstErr error) { + if q.stickyErr != nil { + return 0, q.stickyErr + } + + var wg sync.WaitGroup + var mu sync.Mutex + + if q.limit >= 0 { + wg.Add(1) + go func() { + defer wg.Done() + err := q.Select(values...) + if err != nil { + mu.Lock() + if firstErr == nil { + firstErr = err + } + mu.Unlock() + } + }() + } + + wg.Add(1) + go func() { + defer wg.Done() + var err error + count, err = q.Count() + if err != nil { + mu.Lock() + if firstErr == nil { + firstErr = err + } + mu.Unlock() + } + }() + + wg.Wait() + return count, firstErr +} + +// SelectAndCountEstimate runs Select and CountEstimate in two goroutines, +// waits for them to finish and returns the result. If query limit is -1 +// it does not select any data and only counts the results. +func (q *Query) SelectAndCountEstimate(threshold int, values ...interface{}) (count int, firstErr error) { + if q.stickyErr != nil { + return 0, q.stickyErr + } + + var wg sync.WaitGroup + var mu sync.Mutex + + if q.limit >= 0 { + wg.Add(1) + go func() { + defer wg.Done() + err := q.Select(values...) + if err != nil { + mu.Lock() + if firstErr == nil { + firstErr = err + } + mu.Unlock() + } + }() + } + + wg.Add(1) + go func() { + defer wg.Done() + var err error + count, err = q.CountEstimate(threshold) + if err != nil { + mu.Lock() + if firstErr == nil { + firstErr = err + } + mu.Unlock() + } + }() + + wg.Wait() + return count, firstErr +} + +// ForEach calls the function for each row returned by the query +// without loading all rows into the memory. +// +// Function can accept a struct, a pointer to a struct, an orm.Model, +// or values for the columns in a row. Function must return an error. +func (q *Query) ForEach(fn interface{}) error { + m := newFuncModel(fn) + return q.Select(m) +} + +func (q *Query) forEachHasOneJoin(fn func(*join) error) error { + if q.tableModel == nil { + return nil + } + return q._forEachHasOneJoin(fn, q.tableModel.GetJoins()) +} + +func (q *Query) _forEachHasOneJoin(fn func(*join) error, joins []join) error { + for i := range joins { + j := &joins[i] + switch j.Rel.Type { + case HasOneRelation, BelongsToRelation: + err := fn(j) + if err != nil { + return err + } + + err = q._forEachHasOneJoin(fn, j.JoinModel.GetJoins()) + if err != nil { + return err + } + } + } + return nil +} + +func (q *Query) selectJoins(joins []join) error { + var err error + for i := range joins { + j := &joins[i] + if j.Rel.Type == HasOneRelation || j.Rel.Type == BelongsToRelation { + err = q.selectJoins(j.JoinModel.GetJoins()) + } else { + err = j.Select(q.db.Formatter(), q.New()) + } + if err != nil { + return err + } + } + return nil +} + +// Insert inserts the model. +func (q *Query) Insert(values ...interface{}) (Result, error) { + if q.stickyErr != nil { + return nil, q.stickyErr + } + + model, err := q.newModel(values) + if err != nil { + return nil, err + } + + ctx := q.ctx + + if q.tableModel != nil && q.tableModel.Table().hasFlag(beforeInsertHookFlag) { + ctx, err = q.tableModel.BeforeInsert(ctx) + if err != nil { + return nil, err + } + } + + query := NewInsertQuery(q) + res, err := q.returningQuery(ctx, model, query) + if err != nil { + return nil, err + } + + if q.tableModel != nil { + if err := q.tableModel.AfterInsert(ctx); err != nil { + return nil, err + } + } + + return res, nil +} + +// SelectOrInsert selects the model inserting one if it does not exist. +// It returns true when model was inserted. +func (q *Query) SelectOrInsert(values ...interface{}) (inserted bool, _ error) { + if q.stickyErr != nil { + return false, q.stickyErr + } + + var insertq *Query + var insertErr error + for i := 0; i < 5; i++ { + if i >= 2 { + dur := internal.RetryBackoff(i-2, 250*time.Millisecond, 5*time.Second) + if err := internal.Sleep(q.ctx, dur); err != nil { + return false, err + } + } + + err := q.Select(values...) + if err == nil { + return false, nil + } + if err != internal.ErrNoRows { + return false, err + } + + if insertq == nil { + insertq = q + if len(insertq.columns) > 0 { + insertq = insertq.Clone() + insertq.columns = nil + } + } + + res, err := insertq.Insert(values...) + if err != nil { + insertErr = err + if err == internal.ErrNoRows { + continue + } + if pgErr, ok := err.(internal.PGError); ok { + if pgErr.IntegrityViolation() { + continue + } + if pgErr.Field('C') == "55000" { + // Retry on "#55000 attempted to delete invisible tuple". + continue + } + } + return false, err + } + if res.RowsAffected() == 1 { + return true, nil + } + } + + err := fmt.Errorf( + "pg: SelectOrInsert: select returns no rows (insert fails with err=%q)", + insertErr) + return false, err +} + +// Update updates the model. +func (q *Query) Update(scan ...interface{}) (Result, error) { + return q.update(scan, false) +} + +// Update updates the model omitting fields with zero values such as: +// - empty string, +// - 0, +// - zero time, +// - empty map or slice, +// - byte array with all zeroes, +// - nil ptr, +// - types with method `IsZero() == true`. +func (q *Query) UpdateNotZero(scan ...interface{}) (Result, error) { + return q.update(scan, true) +} + +func (q *Query) update(values []interface{}, omitZero bool) (Result, error) { + if q.stickyErr != nil { + return nil, q.stickyErr + } + + model, err := q.newModel(values) + if err != nil { + return nil, err + } + + c := q.ctx + + if q.tableModel != nil { + c, err = q.tableModel.BeforeUpdate(c) + if err != nil { + return nil, err + } + } + + query := NewUpdateQuery(q, omitZero) + res, err := q.returningQuery(c, model, query) + if err != nil { + return nil, err + } + + if q.tableModel != nil { + err = q.tableModel.AfterUpdate(c) + if err != nil { + return nil, err + } + } + + return res, nil +} + +func (q *Query) returningQuery(c context.Context, model Model, query interface{}) (Result, error) { + if !q.hasReturning() { + return q.db.QueryContext(c, model, query, q.tableModel) + } + if _, ok := model.(useQueryOne); ok { + return q.db.QueryOneContext(c, model, query, q.tableModel) + } + return q.db.QueryContext(c, model, query, q.tableModel) +} + +// Delete deletes the model. When model has deleted_at column the row +// is soft deleted instead. +func (q *Query) Delete(values ...interface{}) (Result, error) { + if q.tableModel == nil { + return q.ForceDelete(values...) + } + + table := q.tableModel.Table() + if table.SoftDeleteField == nil { + return q.ForceDelete(values...) + } + + clone := q.Clone() + if q.tableModel.IsNil() { + if table.SoftDeleteField.SQLType == pgTypeBigint { + clone = clone.Set("? = ?", table.SoftDeleteField.Column, time.Now().UnixNano()) + } else { + clone = clone.Set("? = ?", table.SoftDeleteField.Column, time.Now()) + } + } else { + if err := clone.tableModel.setSoftDeleteField(); err != nil { + return nil, err + } + clone = clone.Column(table.SoftDeleteField.SQLName) + } + return clone.Update(values...) +} + +// Delete forces delete of the model with deleted_at column. +func (q *Query) ForceDelete(values ...interface{}) (Result, error) { + if q.stickyErr != nil { + return nil, q.stickyErr + } + if q.tableModel == nil { + return nil, errModelNil + } + q = q.withFlag(deletedFlag) + + model, err := q.newModel(values) + if err != nil { + return nil, err + } + + ctx := q.ctx + + if q.tableModel != nil { + ctx, err = q.tableModel.BeforeDelete(ctx) + if err != nil { + return nil, err + } + } + + res, err := q.returningQuery(ctx, model, NewDeleteQuery(q)) + if err != nil { + return nil, err + } + + if q.tableModel != nil { + if err := q.tableModel.AfterDelete(ctx); err != nil { + return nil, err + } + } + + return res, nil +} + +func (q *Query) CreateTable(opt *CreateTableOptions) error { + _, err := q.db.ExecContext(q.ctx, NewCreateTableQuery(q, opt)) + return err +} + +func (q *Query) DropTable(opt *DropTableOptions) error { + _, err := q.db.ExecContext(q.ctx, NewDropTableQuery(q, opt)) + return err +} + +func (q *Query) CreateComposite(opt *CreateCompositeOptions) error { + _, err := q.db.ExecContext(q.ctx, NewCreateCompositeQuery(q, opt)) + return err +} + +func (q *Query) DropComposite(opt *DropCompositeOptions) error { + _, err := q.db.ExecContext(q.ctx, NewDropCompositeQuery(q, opt)) + return err +} + +// Exec is an alias for DB.Exec. +func (q *Query) Exec(query interface{}, params ...interface{}) (Result, error) { + params = append(params, q.tableModel) + return q.db.ExecContext(q.ctx, query, params...) +} + +// ExecOne is an alias for DB.ExecOne. +func (q *Query) ExecOne(query interface{}, params ...interface{}) (Result, error) { + params = append(params, q.tableModel) + return q.db.ExecOneContext(q.ctx, query, params...) +} + +// Query is an alias for DB.Query. +func (q *Query) Query(model, query interface{}, params ...interface{}) (Result, error) { + params = append(params, q.tableModel) + return q.db.QueryContext(q.ctx, model, query, params...) +} + +// QueryOne is an alias for DB.QueryOne. +func (q *Query) QueryOne(model, query interface{}, params ...interface{}) (Result, error) { + params = append(params, q.tableModel) + return q.db.QueryOneContext(q.ctx, model, query, params...) +} + +// CopyFrom is an alias from DB.CopyFrom. +func (q *Query) CopyFrom(r io.Reader, query interface{}, params ...interface{}) (Result, error) { + params = append(params, q.tableModel) + return q.db.CopyFrom(r, query, params...) +} + +// CopyTo is an alias from DB.CopyTo. +func (q *Query) CopyTo(w io.Writer, query interface{}, params ...interface{}) (Result, error) { + params = append(params, q.tableModel) + return q.db.CopyTo(w, query, params...) +} + +var _ QueryAppender = (*Query)(nil) + +func (q *Query) AppendQuery(fmter QueryFormatter, b []byte) ([]byte, error) { + return NewSelectQuery(q).AppendQuery(fmter, b) +} + +// Exists returns true or false depending if there are any rows matching the query. +func (q *Query) Exists() (bool, error) { + q = q.Clone() // copy to not change original query + q.columns = []QueryAppender{SafeQuery("1")} + q.order = nil + q.limit = 1 + res, err := q.db.ExecContext(q.ctx, NewSelectQuery(q)) + if err != nil { + return false, err + } + return res.RowsAffected() > 0, nil +} + +func (q *Query) hasTableModel() bool { + return q.tableModel != nil && !q.tableModel.IsNil() +} + +func (q *Query) hasExplicitTableModel() bool { + return q.tableModel != nil && !q.hasFlag(implicitModelFlag) +} + +func (q *Query) modelHasTableName() bool { + return q.hasExplicitTableModel() && q.tableModel.Table().SQLName != "" +} + +func (q *Query) modelHasTableAlias() bool { + return q.hasExplicitTableModel() && q.tableModel.Table().Alias != "" +} + +func (q *Query) hasTables() bool { + return q.modelHasTableName() || len(q.tables) > 0 +} + +func (q *Query) appendFirstTable(fmter QueryFormatter, b []byte) ([]byte, error) { + if q.modelHasTableName() { + return fmter.FormatQuery(b, string(q.tableModel.Table().SQLName)), nil + } + if len(q.tables) > 0 { + return q.tables[0].AppendQuery(fmter, b) + } + return b, nil +} + +func (q *Query) appendFirstTableWithAlias(fmter QueryFormatter, b []byte) (_ []byte, err error) { + if q.modelHasTableName() { + table := q.tableModel.Table() + b = fmter.FormatQuery(b, string(table.SQLName)) + if table.Alias != table.SQLName { + b = append(b, " AS "...) + b = append(b, table.Alias...) + } + return b, nil + } + + if len(q.tables) > 0 { + b, err = q.tables[0].AppendQuery(fmter, b) + if err != nil { + return nil, err + } + if q.modelHasTableAlias() { + table := q.tableModel.Table() + if table.Alias != table.SQLName { + b = append(b, " AS "...) + b = append(b, table.Alias...) + } + } + } + + return b, nil +} + +func (q *Query) hasMultiTables() bool { + if q.modelHasTableName() { + return len(q.tables) >= 1 + } + return len(q.tables) >= 2 +} + +func (q *Query) appendOtherTables(fmter QueryFormatter, b []byte) (_ []byte, err error) { + tables := q.tables + if !q.modelHasTableName() { + tables = tables[1:] + } + for i, f := range tables { + if i > 0 { + b = append(b, ", "...) + } + b, err = f.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + return b, nil +} + +func (q *Query) appendColumns(fmter QueryFormatter, b []byte) (_ []byte, err error) { + for i, f := range q.columns { + if i > 0 { + b = append(b, ", "...) + } + b, err = f.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + return b, nil +} + +func (q *Query) mustAppendWhere(fmter QueryFormatter, b []byte) ([]byte, error) { + if len(q.where) == 0 { + err := errors.New( + "pg: Update and Delete queries require Where clause (try WherePK)") + return nil, err + } + return q.appendWhere(fmter, b) +} + +func (q *Query) appendWhere(fmter QueryFormatter, b []byte) (_ []byte, err error) { + isSoftDelete := q.isSoftDelete() + + if len(q.where) > 0 { + if isSoftDelete { + b = append(b, '(') + } + + b, err = q._appendWhere(fmter, b, q.where) + if err != nil { + return nil, err + } + + if isSoftDelete { + b = append(b, ')') + } + } + + if isSoftDelete { + if len(q.where) > 0 { + b = append(b, " AND "...) + } + b = append(b, q.tableModel.Table().Alias...) + b = q.appendSoftDelete(b) + } + + return b, nil +} + +func (q *Query) appendSoftDelete(b []byte) []byte { + b = append(b, '.') + b = append(b, q.tableModel.Table().SoftDeleteField.Column...) + if q.hasFlag(deletedFlag) { + b = append(b, " IS NOT NULL"...) + } else { + b = append(b, " IS NULL"...) + } + return b +} + +func (q *Query) appendUpdWhere(fmter QueryFormatter, b []byte) ([]byte, error) { + return q._appendWhere(fmter, b, q.updWhere) +} + +func (q *Query) _appendWhere( + fmter QueryFormatter, b []byte, where []queryWithSepAppender, +) (_ []byte, err error) { + for i, f := range where { + start := len(b) + + if i > 0 { + b = f.AppendSep(b) + } + + before := len(b) + + b, err = f.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + + if len(b) == before { + b = b[:start] + } + } + return b, nil +} + +func (q *Query) appendSet(fmter QueryFormatter, b []byte) (_ []byte, err error) { + b = append(b, " SET "...) + for i, f := range q.set { + if i > 0 { + b = append(b, ", "...) + } + b, err = f.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + return b, nil +} + +func (q *Query) hasReturning() bool { + if len(q.returning) == 0 { + return false + } + if len(q.returning) == 1 { + switch q.returning[0].query { + case "null", "NULL": + return false + } + } + return true +} + +func (q *Query) appendReturning(fmter QueryFormatter, b []byte) (_ []byte, err error) { + if !q.hasReturning() { + return b, nil + } + + b = append(b, " RETURNING "...) + for i, f := range q.returning { + if i > 0 { + b = append(b, ", "...) + } + b, err = f.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + return b, nil +} + +func (q *Query) appendWith(fmter QueryFormatter, b []byte) (_ []byte, err error) { + b = append(b, "WITH "...) + for i, with := range q.with { + if i > 0 { + b = append(b, ", "...) + } + b = types.AppendIdent(b, with.name, 1) + b = append(b, " AS ("...) + + b, err = with.query.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + + b = append(b, ')') + } + b = append(b, ' ') + return b, nil +} + +func (q *Query) isSliceModelWithData() bool { + if !q.hasTableModel() { + return false + } + m, ok := q.tableModel.(*sliceTableModel) + return ok && m.sliceLen > 0 +} + +//------------------------------------------------------------------------------ + +type wherePKStructQuery struct { + q *Query +} + +var _ queryWithSepAppender = (*wherePKStructQuery)(nil) + +func (wherePKStructQuery) AppendSep(b []byte) []byte { + return append(b, " AND "...) +} + +func (q wherePKStructQuery) AppendQuery(fmter QueryFormatter, b []byte) ([]byte, error) { + table := q.q.tableModel.Table() + value := q.q.tableModel.Value() + return appendColumnAndValue(fmter, b, value, table.Alias, table.PKs), nil +} + +func appendColumnAndValue( + fmter QueryFormatter, b []byte, v reflect.Value, alias types.Safe, fields []*Field, +) []byte { + isPlaceholder := isTemplateFormatter(fmter) + for i, f := range fields { + if i > 0 { + b = append(b, " AND "...) + } + b = append(b, alias...) + b = append(b, '.') + b = append(b, f.Column...) + b = append(b, " = "...) + if isPlaceholder { + b = append(b, '?') + } else { + b = f.AppendValue(b, v, 1) + } + } + return b +} + +//------------------------------------------------------------------------------ + +type wherePKSliceQuery struct { + q *Query +} + +var _ queryWithSepAppender = (*wherePKSliceQuery)(nil) + +func (wherePKSliceQuery) AppendSep(b []byte) []byte { + return append(b, " AND "...) +} + +func (q wherePKSliceQuery) AppendQuery(fmter QueryFormatter, b []byte) ([]byte, error) { + table := q.q.tableModel.Table() + + for i, f := range table.PKs { + if i > 0 { + b = append(b, " AND "...) + } + b = append(b, table.Alias...) + b = append(b, '.') + b = append(b, f.Column...) + b = append(b, " = "...) + b = append(b, `"_data".`...) + b = append(b, f.Column...) + } + + return b, nil +} + +type joinPKSliceQuery struct { + q *Query +} + +var _ QueryAppender = (*joinPKSliceQuery)(nil) + +func (q joinPKSliceQuery) AppendQuery(fmter QueryFormatter, b []byte) ([]byte, error) { + table := q.q.tableModel.Table() + slice := q.q.tableModel.Value() + + b = append(b, " JOIN (VALUES "...) + + sliceLen := slice.Len() + for i := 0; i < sliceLen; i++ { + if i > 0 { + b = append(b, ", "...) + } + + el := indirect(slice.Index(i)) + + b = append(b, '(') + for i, f := range table.PKs { + if i > 0 { + b = append(b, ", "...) + } + + b = f.AppendValue(b, el, 1) + + if f.UserSQLType != "" { + b = append(b, "::"...) + b = append(b, f.SQLType...) + } + } + + b = append(b, ", "...) + b = strconv.AppendInt(b, int64(i), 10) + + b = append(b, ')') + } + + b = append(b, `) AS "_data" (`...) + + for i, f := range table.PKs { + if i > 0 { + b = append(b, ", "...) + } + b = append(b, f.Column...) + } + + b = append(b, ", "...) + b = append(b, `"ordering"`...) + b = append(b, ") ON TRUE"...) + + return b, nil +} |