diff options
Diffstat (limited to 'vendor/github.com/uptrace/bun/query_select.go')
| -rw-r--r-- | vendor/github.com/uptrace/bun/query_select.go | 162 |
1 files changed, 147 insertions, 15 deletions
diff --git a/vendor/github.com/uptrace/bun/query_select.go b/vendor/github.com/uptrace/bun/query_select.go index 1ef7e3bb1..db7f42df1 100644 --- a/vendor/github.com/uptrace/bun/query_select.go +++ b/vendor/github.com/uptrace/bun/query_select.go @@ -354,7 +354,7 @@ func (q *SelectQuery) JoinOnOr(cond string, args ...interface{}) *SelectQuery { func (q *SelectQuery) joinOn(cond string, args []interface{}, sep string) *SelectQuery { if len(q.joins) == 0 { - q.err = errors.New("bun: query has no joins") + q.setErr(errors.New("bun: query has no joins")) return q } j := &q.joins[len(q.joins)-1] @@ -791,6 +791,9 @@ func (q *SelectQuery) Rows(ctx context.Context) (*sql.Rows, error) { return nil, err } + // if a comment is propagated via the context, use it + setCommentFromContext(ctx, q) + queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes()) if err != nil { return nil, err @@ -812,6 +815,9 @@ func (q *SelectQuery) Exec(ctx context.Context, dest ...interface{}) (res sql.Re return nil, err } + // if a comment is propagated via the context, use it + setCommentFromContext(ctx, q) + queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes()) if err != nil { return nil, err @@ -872,6 +878,9 @@ func (q *SelectQuery) scanResult(ctx context.Context, dest ...interface{}) (sql. return nil, err } + // if a comment is propagated via the context, use it + setCommentFromContext(ctx, q) + queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes()) if err != nil { return nil, err @@ -924,6 +933,9 @@ func (q *SelectQuery) Count(ctx context.Context) (int, error) { return 0, q.err } + // if a comment is propagated via the context, use it + setCommentFromContext(ctx, q) + qq := countQuery{q} queryBytes, err := qq.AppendQuery(q.db.fmter, nil) @@ -967,27 +979,27 @@ func (q *SelectQuery) scanAndCountConcurrently( var mu sync.Mutex var firstErr error - if q.limit >= 0 { - wg.Add(1) - go func() { - defer wg.Done() - - if err := q.Scan(ctx, dest...); err != nil { - mu.Lock() - if firstErr == nil { - firstErr = err - } - mu.Unlock() + countQuery := q.Clone() + + wg.Add(1) + go func() { + defer wg.Done() + + if err := q.Scan(ctx, dest...); err != nil { + mu.Lock() + if firstErr == nil { + firstErr = err } - }() - } + mu.Unlock() + } + }() wg.Add(1) go func() { defer wg.Done() var err error - count, err = q.Count(ctx) + count, err = countQuery.Count(ctx) if err != nil { mu.Lock() if firstErr == nil { @@ -1028,6 +1040,9 @@ func (q *SelectQuery) Exists(ctx context.Context) (bool, error) { } func (q *SelectQuery) selectExists(ctx context.Context) (bool, error) { + // if a comment is propagated via the context, use it + setCommentFromContext(ctx, q) + qq := selectExistsQuery{q} queryBytes, err := qq.AppendQuery(q.db.fmter, nil) @@ -1047,6 +1062,9 @@ func (q *SelectQuery) selectExists(ctx context.Context) (bool, error) { } func (q *SelectQuery) whereExists(ctx context.Context) (bool, error) { + // if a comment is propagated via the context, use it + setCommentFromContext(ctx, q) + qq := whereExistsQuery{q} queryBytes, err := qq.AppendQuery(q.db.fmter, nil) @@ -1077,6 +1095,120 @@ func (q *SelectQuery) String() string { return string(buf) } +func (q *SelectQuery) Clone() *SelectQuery { + if q == nil { + return nil + } + + cloneArgs := func(args []schema.QueryWithArgs) []schema.QueryWithArgs { + if len(args) == 0 { + return nil + } + clone := make([]schema.QueryWithArgs, len(args)) + copy(clone, args) + return clone + } + cloneHints := func(hints *indexHints) *indexHints { + if hints == nil { + return nil + } + return &indexHints{ + names: cloneArgs(hints.names), + forJoin: cloneArgs(hints.forJoin), + forOrderBy: cloneArgs(hints.forOrderBy), + forGroupBy: cloneArgs(hints.forGroupBy), + } + } + + clone := &SelectQuery{ + whereBaseQuery: whereBaseQuery{ + baseQuery: baseQuery{ + db: q.db, + table: q.table, + model: q.model, + tableModel: q.tableModel, + with: make([]withQuery, len(q.with)), + tables: cloneArgs(q.tables), + columns: cloneArgs(q.columns), + modelTableName: q.modelTableName, + }, + where: make([]schema.QueryWithSep, len(q.where)), + }, + + idxHintsQuery: idxHintsQuery{ + use: cloneHints(q.idxHintsQuery.use), + ignore: cloneHints(q.idxHintsQuery.ignore), + force: cloneHints(q.idxHintsQuery.force), + }, + + orderLimitOffsetQuery: orderLimitOffsetQuery{ + order: cloneArgs(q.order), + limit: q.limit, + offset: q.offset, + }, + + distinctOn: cloneArgs(q.distinctOn), + joins: make([]joinQuery, len(q.joins)), + group: cloneArgs(q.group), + having: cloneArgs(q.having), + union: make([]union, len(q.union)), + comment: q.comment, + } + + for i, w := range q.with { + clone.with[i] = withQuery{ + name: w.name, + recursive: w.recursive, + query: w.query, // TODO: maybe clone is need + } + } + + if !q.modelTableName.IsZero() { + clone.modelTableName = schema.SafeQuery( + q.modelTableName.Query, + append([]any(nil), q.modelTableName.Args...), + ) + } + + for i, w := range q.where { + clone.where[i] = schema.SafeQueryWithSep( + w.Query, + append([]any(nil), w.Args...), + w.Sep, + ) + } + + for i, j := range q.joins { + clone.joins[i] = joinQuery{ + join: schema.SafeQuery(j.join.Query, append([]any(nil), j.join.Args...)), + on: make([]schema.QueryWithSep, len(j.on)), + } + for k, on := range j.on { + clone.joins[i].on[k] = schema.SafeQueryWithSep( + on.Query, + append([]any(nil), on.Args...), + on.Sep, + ) + } + } + + for i, u := range q.union { + clone.union[i] = union{ + expr: u.expr, + query: u.query.Clone(), + } + } + + if !q.selFor.IsZero() { + clone.selFor = schema.SafeQuery( + q.selFor.Query, + append([]any(nil), q.selFor.Args...), + ) + } + + return clone +} + //------------------------------------------------------------------------------ func (q *SelectQuery) QueryBuilder() QueryBuilder { |
