summaryrefslogtreecommitdiff
path: root/vendor/github.com/uptrace/bun/query_select.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/uptrace/bun/query_select.go')
-rw-r--r--vendor/github.com/uptrace/bun/query_select.go162
1 files changed, 147 insertions, 15 deletions
diff --git a/vendor/github.com/uptrace/bun/query_select.go b/vendor/github.com/uptrace/bun/query_select.go
index 1ef7e3bb1..db7f42df1 100644
--- a/vendor/github.com/uptrace/bun/query_select.go
+++ b/vendor/github.com/uptrace/bun/query_select.go
@@ -354,7 +354,7 @@ func (q *SelectQuery) JoinOnOr(cond string, args ...interface{}) *SelectQuery {
func (q *SelectQuery) joinOn(cond string, args []interface{}, sep string) *SelectQuery {
if len(q.joins) == 0 {
- q.err = errors.New("bun: query has no joins")
+ q.setErr(errors.New("bun: query has no joins"))
return q
}
j := &q.joins[len(q.joins)-1]
@@ -791,6 +791,9 @@ func (q *SelectQuery) Rows(ctx context.Context) (*sql.Rows, error) {
return nil, err
}
+ // if a comment is propagated via the context, use it
+ setCommentFromContext(ctx, q)
+
queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes())
if err != nil {
return nil, err
@@ -812,6 +815,9 @@ func (q *SelectQuery) Exec(ctx context.Context, dest ...interface{}) (res sql.Re
return nil, err
}
+ // if a comment is propagated via the context, use it
+ setCommentFromContext(ctx, q)
+
queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes())
if err != nil {
return nil, err
@@ -872,6 +878,9 @@ func (q *SelectQuery) scanResult(ctx context.Context, dest ...interface{}) (sql.
return nil, err
}
+ // if a comment is propagated via the context, use it
+ setCommentFromContext(ctx, q)
+
queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes())
if err != nil {
return nil, err
@@ -924,6 +933,9 @@ func (q *SelectQuery) Count(ctx context.Context) (int, error) {
return 0, q.err
}
+ // if a comment is propagated via the context, use it
+ setCommentFromContext(ctx, q)
+
qq := countQuery{q}
queryBytes, err := qq.AppendQuery(q.db.fmter, nil)
@@ -967,27 +979,27 @@ func (q *SelectQuery) scanAndCountConcurrently(
var mu sync.Mutex
var firstErr error
- if q.limit >= 0 {
- wg.Add(1)
- go func() {
- defer wg.Done()
-
- if err := q.Scan(ctx, dest...); err != nil {
- mu.Lock()
- if firstErr == nil {
- firstErr = err
- }
- mu.Unlock()
+ countQuery := q.Clone()
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+
+ if err := q.Scan(ctx, dest...); err != nil {
+ mu.Lock()
+ if firstErr == nil {
+ firstErr = err
}
- }()
- }
+ mu.Unlock()
+ }
+ }()
wg.Add(1)
go func() {
defer wg.Done()
var err error
- count, err = q.Count(ctx)
+ count, err = countQuery.Count(ctx)
if err != nil {
mu.Lock()
if firstErr == nil {
@@ -1028,6 +1040,9 @@ func (q *SelectQuery) Exists(ctx context.Context) (bool, error) {
}
func (q *SelectQuery) selectExists(ctx context.Context) (bool, error) {
+ // if a comment is propagated via the context, use it
+ setCommentFromContext(ctx, q)
+
qq := selectExistsQuery{q}
queryBytes, err := qq.AppendQuery(q.db.fmter, nil)
@@ -1047,6 +1062,9 @@ func (q *SelectQuery) selectExists(ctx context.Context) (bool, error) {
}
func (q *SelectQuery) whereExists(ctx context.Context) (bool, error) {
+ // if a comment is propagated via the context, use it
+ setCommentFromContext(ctx, q)
+
qq := whereExistsQuery{q}
queryBytes, err := qq.AppendQuery(q.db.fmter, nil)
@@ -1077,6 +1095,120 @@ func (q *SelectQuery) String() string {
return string(buf)
}
+func (q *SelectQuery) Clone() *SelectQuery {
+ if q == nil {
+ return nil
+ }
+
+ cloneArgs := func(args []schema.QueryWithArgs) []schema.QueryWithArgs {
+ if len(args) == 0 {
+ return nil
+ }
+ clone := make([]schema.QueryWithArgs, len(args))
+ copy(clone, args)
+ return clone
+ }
+ cloneHints := func(hints *indexHints) *indexHints {
+ if hints == nil {
+ return nil
+ }
+ return &indexHints{
+ names: cloneArgs(hints.names),
+ forJoin: cloneArgs(hints.forJoin),
+ forOrderBy: cloneArgs(hints.forOrderBy),
+ forGroupBy: cloneArgs(hints.forGroupBy),
+ }
+ }
+
+ clone := &SelectQuery{
+ whereBaseQuery: whereBaseQuery{
+ baseQuery: baseQuery{
+ db: q.db,
+ table: q.table,
+ model: q.model,
+ tableModel: q.tableModel,
+ with: make([]withQuery, len(q.with)),
+ tables: cloneArgs(q.tables),
+ columns: cloneArgs(q.columns),
+ modelTableName: q.modelTableName,
+ },
+ where: make([]schema.QueryWithSep, len(q.where)),
+ },
+
+ idxHintsQuery: idxHintsQuery{
+ use: cloneHints(q.idxHintsQuery.use),
+ ignore: cloneHints(q.idxHintsQuery.ignore),
+ force: cloneHints(q.idxHintsQuery.force),
+ },
+
+ orderLimitOffsetQuery: orderLimitOffsetQuery{
+ order: cloneArgs(q.order),
+ limit: q.limit,
+ offset: q.offset,
+ },
+
+ distinctOn: cloneArgs(q.distinctOn),
+ joins: make([]joinQuery, len(q.joins)),
+ group: cloneArgs(q.group),
+ having: cloneArgs(q.having),
+ union: make([]union, len(q.union)),
+ comment: q.comment,
+ }
+
+ for i, w := range q.with {
+ clone.with[i] = withQuery{
+ name: w.name,
+ recursive: w.recursive,
+ query: w.query, // TODO: maybe clone is need
+ }
+ }
+
+ if !q.modelTableName.IsZero() {
+ clone.modelTableName = schema.SafeQuery(
+ q.modelTableName.Query,
+ append([]any(nil), q.modelTableName.Args...),
+ )
+ }
+
+ for i, w := range q.where {
+ clone.where[i] = schema.SafeQueryWithSep(
+ w.Query,
+ append([]any(nil), w.Args...),
+ w.Sep,
+ )
+ }
+
+ for i, j := range q.joins {
+ clone.joins[i] = joinQuery{
+ join: schema.SafeQuery(j.join.Query, append([]any(nil), j.join.Args...)),
+ on: make([]schema.QueryWithSep, len(j.on)),
+ }
+ for k, on := range j.on {
+ clone.joins[i].on[k] = schema.SafeQueryWithSep(
+ on.Query,
+ append([]any(nil), on.Args...),
+ on.Sep,
+ )
+ }
+ }
+
+ for i, u := range q.union {
+ clone.union[i] = union{
+ expr: u.expr,
+ query: u.query.Clone(),
+ }
+ }
+
+ if !q.selFor.IsZero() {
+ clone.selFor = schema.SafeQuery(
+ q.selFor.Query,
+ append([]any(nil), q.selFor.Args...),
+ )
+ }
+
+ return clone
+}
+
//------------------------------------------------------------------------------
func (q *SelectQuery) QueryBuilder() QueryBuilder {