summaryrefslogtreecommitdiff
path: root/vendor/github.com/uptrace/bun/query_select.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/uptrace/bun/query_select.go')
-rw-r--r--vendor/github.com/uptrace/bun/query_select.go38
1 files changed, 33 insertions, 5 deletions
diff --git a/vendor/github.com/uptrace/bun/query_select.go b/vendor/github.com/uptrace/bun/query_select.go
index ad4db6670..401bf1acc 100644
--- a/vendor/github.com/uptrace/bun/query_select.go
+++ b/vendor/github.com/uptrace/bun/query_select.go
@@ -34,6 +34,8 @@ type SelectQuery struct {
union []union
}
+var _ Query = (*SelectQuery)(nil)
+
func NewSelectQuery(db *DB) *SelectQuery {
return &SelectQuery{
whereBaseQuery: whereBaseQuery{
@@ -90,7 +92,7 @@ func (q *SelectQuery) TableExpr(query string, args ...interface{}) *SelectQuery
}
func (q *SelectQuery) ModelTableExpr(query string, args ...interface{}) *SelectQuery {
- q.modelTable = schema.SafeQuery(query, args)
+ q.modelTableName = schema.SafeQuery(query, args)
return q
}
@@ -342,9 +344,9 @@ func (q *SelectQuery) selectJoins(ctx context.Context, joins []relationJoin) err
case schema.HasOneRelation, schema.BelongsToRelation:
err = q.selectJoins(ctx, j.JoinModel.getJoins())
case schema.HasManyRelation:
- err = j.selectMany(ctx, q.db.NewSelect())
+ err = j.selectMany(ctx, q.db.NewSelect().Conn(q.conn))
case schema.ManyToManyRelation:
- err = j.selectM2M(ctx, q.db.NewSelect())
+ err = j.selectM2M(ctx, q.db.NewSelect().Conn(q.conn))
default:
panic("not reached")
}
@@ -369,6 +371,10 @@ func (q *SelectQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, e
func (q *SelectQuery) appendQuery(
fmter schema.Formatter, b []byte, count bool,
) (_ []byte, err error) {
+ if q.err != nil {
+ return nil, q.err
+ }
+
fmter = formatterWithModel(fmter, q)
cteCount := count && (len(q.group) > 0 || q.distinctOn != nil)
@@ -767,7 +773,7 @@ func (q *SelectQuery) Count(ctx context.Context) (int, error) {
}
query := internal.String(queryBytes)
- ctx, event := q.db.beforeQuery(ctx, qq, query, nil, q.model)
+ ctx, event := q.db.beforeQuery(ctx, qq, query, nil, query, q.model)
var num int
err = q.conn.QueryRowContext(ctx, query).Scan(&num)
@@ -778,6 +784,13 @@ func (q *SelectQuery) Count(ctx context.Context) (int, error) {
}
func (q *SelectQuery) ScanAndCount(ctx context.Context, dest ...interface{}) (int, error) {
+ if _, ok := q.conn.(*DB); ok {
+ return q.scanAndCountConc(ctx, dest...)
+ }
+ return q.scanAndCountSeq(ctx, dest...)
+}
+
+func (q *SelectQuery) scanAndCountConc(ctx context.Context, dest ...interface{}) (int, error) {
var count int
var wg sync.WaitGroup
var mu sync.Mutex
@@ -817,6 +830,21 @@ func (q *SelectQuery) ScanAndCount(ctx context.Context, dest ...interface{}) (in
return count, firstErr
}
+func (q *SelectQuery) scanAndCountSeq(ctx context.Context, dest ...interface{}) (int, error) {
+ var firstErr error
+
+ if q.limit >= 0 {
+ firstErr = q.Scan(ctx, dest...)
+ }
+
+ count, err := q.Count(ctx)
+ if err != nil && firstErr == nil {
+ firstErr = err
+ }
+
+ return count, firstErr
+}
+
func (q *SelectQuery) Exists(ctx context.Context) (bool, error) {
if q.err != nil {
return false, q.err
@@ -830,7 +858,7 @@ func (q *SelectQuery) Exists(ctx context.Context) (bool, error) {
}
query := internal.String(queryBytes)
- ctx, event := q.db.beforeQuery(ctx, qq, query, nil, q.model)
+ ctx, event := q.db.beforeQuery(ctx, qq, query, nil, query, q.model)
var exists bool
err = q.conn.QueryRowContext(ctx, query).Scan(&exists)