diff options
Diffstat (limited to 'vendor/github.com/uptrace/bun/query_select.go')
-rw-r--r-- | vendor/github.com/uptrace/bun/query_select.go | 30 |
1 files changed, 18 insertions, 12 deletions
diff --git a/vendor/github.com/uptrace/bun/query_select.go b/vendor/github.com/uptrace/bun/query_select.go index 11761bb96..1ef7e3bb1 100644 --- a/vendor/github.com/uptrace/bun/query_select.go +++ b/vendor/github.com/uptrace/bun/query_select.go @@ -41,8 +41,7 @@ func NewSelectQuery(db *DB) *SelectQuery { return &SelectQuery{ whereBaseQuery: whereBaseQuery{ baseQuery: baseQuery{ - db: db, - conn: db.DB, + db: db, }, }, } @@ -73,12 +72,12 @@ func (q *SelectQuery) Apply(fns ...func(*SelectQuery) *SelectQuery) *SelectQuery return q } -func (q *SelectQuery) With(name string, query schema.QueryAppender) *SelectQuery { +func (q *SelectQuery) With(name string, query Query) *SelectQuery { q.addWith(name, query, false) return q } -func (q *SelectQuery) WithRecursive(name string, query schema.QueryAppender) *SelectQuery { +func (q *SelectQuery) WithRecursive(name string, query Query) *SelectQuery { q.addWith(name, query, true) return q } @@ -537,6 +536,13 @@ func (q *SelectQuery) appendQuery( return nil, err } + if err := q.forEachInlineRelJoin(func(j *relationJoin) error { + j.applyTo(q) + return nil + }); err != nil { + return nil, err + } + b = append(b, "SELECT "...) if len(q.distinctOn) > 0 { @@ -730,8 +736,6 @@ func (q *SelectQuery) appendColumns(fmter schema.Formatter, b []byte) (_ []byte, func (q *SelectQuery) appendInlineRelColumns( fmter schema.Formatter, b []byte, join *relationJoin, ) (_ []byte, err error) { - join.applyTo(q) - if join.columns != nil { table := join.JoinModel.Table() for i, col := range join.columns { @@ -795,7 +799,7 @@ func (q *SelectQuery) Rows(ctx context.Context) (*sql.Rows, error) { query := internal.String(queryBytes) ctx, event := q.db.beforeQuery(ctx, q, query, nil, query, q.model) - rows, err := q.conn.QueryContext(ctx, query) + rows, err := q.resolveConn(q).QueryContext(ctx, query) q.db.afterQuery(ctx, event, nil, err) return rows, err } @@ -931,7 +935,7 @@ func (q *SelectQuery) Count(ctx context.Context) (int, error) { ctx, event := q.db.beforeQuery(ctx, qq, query, nil, query, q.model) var num int - err = q.conn.QueryRowContext(ctx, query).Scan(&num) + err = q.resolveConn(q).QueryRowContext(ctx, query).Scan(&num) q.db.afterQuery(ctx, event, nil, err) @@ -949,13 +953,15 @@ func (q *SelectQuery) ScanAndCount(ctx context.Context, dest ...interface{}) (in return int(n), nil } } - if _, ok := q.conn.(*DB); ok { - return q.scanAndCountConc(ctx, dest...) + if q.conn == nil { + return q.scanAndCountConcurrently(ctx, dest...) } return q.scanAndCountSeq(ctx, dest...) } -func (q *SelectQuery) scanAndCountConc(ctx context.Context, dest ...interface{}) (int, error) { +func (q *SelectQuery) scanAndCountConcurrently( + ctx context.Context, dest ...interface{}, +) (int, error) { var count int var wg sync.WaitGroup var mu sync.Mutex @@ -1033,7 +1039,7 @@ func (q *SelectQuery) selectExists(ctx context.Context) (bool, error) { ctx, event := q.db.beforeQuery(ctx, qq, query, nil, query, q.model) var exists bool - err = q.conn.QueryRowContext(ctx, query).Scan(&exists) + err = q.resolveConn(q).QueryRowContext(ctx, query).Scan(&exists) q.db.afterQuery(ctx, event, nil, err) |