summaryrefslogtreecommitdiff
path: root/vendor/github.com/uptrace/bun/query_select.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/uptrace/bun/query_select.go')
-rw-r--r--vendor/github.com/uptrace/bun/query_select.go30
1 files changed, 18 insertions, 12 deletions
diff --git a/vendor/github.com/uptrace/bun/query_select.go b/vendor/github.com/uptrace/bun/query_select.go
index 11761bb96..1ef7e3bb1 100644
--- a/vendor/github.com/uptrace/bun/query_select.go
+++ b/vendor/github.com/uptrace/bun/query_select.go
@@ -41,8 +41,7 @@ func NewSelectQuery(db *DB) *SelectQuery {
return &SelectQuery{
whereBaseQuery: whereBaseQuery{
baseQuery: baseQuery{
- db: db,
- conn: db.DB,
+ db: db,
},
},
}
@@ -73,12 +72,12 @@ func (q *SelectQuery) Apply(fns ...func(*SelectQuery) *SelectQuery) *SelectQuery
return q
}
-func (q *SelectQuery) With(name string, query schema.QueryAppender) *SelectQuery {
+func (q *SelectQuery) With(name string, query Query) *SelectQuery {
q.addWith(name, query, false)
return q
}
-func (q *SelectQuery) WithRecursive(name string, query schema.QueryAppender) *SelectQuery {
+func (q *SelectQuery) WithRecursive(name string, query Query) *SelectQuery {
q.addWith(name, query, true)
return q
}
@@ -537,6 +536,13 @@ func (q *SelectQuery) appendQuery(
return nil, err
}
+ if err := q.forEachInlineRelJoin(func(j *relationJoin) error {
+ j.applyTo(q)
+ return nil
+ }); err != nil {
+ return nil, err
+ }
+
b = append(b, "SELECT "...)
if len(q.distinctOn) > 0 {
@@ -730,8 +736,6 @@ func (q *SelectQuery) appendColumns(fmter schema.Formatter, b []byte) (_ []byte,
func (q *SelectQuery) appendInlineRelColumns(
fmter schema.Formatter, b []byte, join *relationJoin,
) (_ []byte, err error) {
- join.applyTo(q)
-
if join.columns != nil {
table := join.JoinModel.Table()
for i, col := range join.columns {
@@ -795,7 +799,7 @@ func (q *SelectQuery) Rows(ctx context.Context) (*sql.Rows, error) {
query := internal.String(queryBytes)
ctx, event := q.db.beforeQuery(ctx, q, query, nil, query, q.model)
- rows, err := q.conn.QueryContext(ctx, query)
+ rows, err := q.resolveConn(q).QueryContext(ctx, query)
q.db.afterQuery(ctx, event, nil, err)
return rows, err
}
@@ -931,7 +935,7 @@ func (q *SelectQuery) Count(ctx context.Context) (int, error) {
ctx, event := q.db.beforeQuery(ctx, qq, query, nil, query, q.model)
var num int
- err = q.conn.QueryRowContext(ctx, query).Scan(&num)
+ err = q.resolveConn(q).QueryRowContext(ctx, query).Scan(&num)
q.db.afterQuery(ctx, event, nil, err)
@@ -949,13 +953,15 @@ func (q *SelectQuery) ScanAndCount(ctx context.Context, dest ...interface{}) (in
return int(n), nil
}
}
- if _, ok := q.conn.(*DB); ok {
- return q.scanAndCountConc(ctx, dest...)
+ if q.conn == nil {
+ return q.scanAndCountConcurrently(ctx, dest...)
}
return q.scanAndCountSeq(ctx, dest...)
}
-func (q *SelectQuery) scanAndCountConc(ctx context.Context, dest ...interface{}) (int, error) {
+func (q *SelectQuery) scanAndCountConcurrently(
+ ctx context.Context, dest ...interface{},
+) (int, error) {
var count int
var wg sync.WaitGroup
var mu sync.Mutex
@@ -1033,7 +1039,7 @@ func (q *SelectQuery) selectExists(ctx context.Context) (bool, error) {
ctx, event := q.db.beforeQuery(ctx, qq, query, nil, query, q.model)
var exists bool
- err = q.conn.QueryRowContext(ctx, query).Scan(&exists)
+ err = q.resolveConn(q).QueryRowContext(ctx, query).Scan(&exists)
q.db.afterQuery(ctx, event, nil, err)