diff options
Diffstat (limited to 'vendor/github.com/uptrace/bun/query_select.go')
-rw-r--r-- | vendor/github.com/uptrace/bun/query_select.go | 38 |
1 files changed, 33 insertions, 5 deletions
diff --git a/vendor/github.com/uptrace/bun/query_select.go b/vendor/github.com/uptrace/bun/query_select.go index ad4db6670..401bf1acc 100644 --- a/vendor/github.com/uptrace/bun/query_select.go +++ b/vendor/github.com/uptrace/bun/query_select.go @@ -34,6 +34,8 @@ type SelectQuery struct { union []union } +var _ Query = (*SelectQuery)(nil) + func NewSelectQuery(db *DB) *SelectQuery { return &SelectQuery{ whereBaseQuery: whereBaseQuery{ @@ -90,7 +92,7 @@ func (q *SelectQuery) TableExpr(query string, args ...interface{}) *SelectQuery } func (q *SelectQuery) ModelTableExpr(query string, args ...interface{}) *SelectQuery { - q.modelTable = schema.SafeQuery(query, args) + q.modelTableName = schema.SafeQuery(query, args) return q } @@ -342,9 +344,9 @@ func (q *SelectQuery) selectJoins(ctx context.Context, joins []relationJoin) err case schema.HasOneRelation, schema.BelongsToRelation: err = q.selectJoins(ctx, j.JoinModel.getJoins()) case schema.HasManyRelation: - err = j.selectMany(ctx, q.db.NewSelect()) + err = j.selectMany(ctx, q.db.NewSelect().Conn(q.conn)) case schema.ManyToManyRelation: - err = j.selectM2M(ctx, q.db.NewSelect()) + err = j.selectM2M(ctx, q.db.NewSelect().Conn(q.conn)) default: panic("not reached") } @@ -369,6 +371,10 @@ func (q *SelectQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, e func (q *SelectQuery) appendQuery( fmter schema.Formatter, b []byte, count bool, ) (_ []byte, err error) { + if q.err != nil { + return nil, q.err + } + fmter = formatterWithModel(fmter, q) cteCount := count && (len(q.group) > 0 || q.distinctOn != nil) @@ -767,7 +773,7 @@ func (q *SelectQuery) Count(ctx context.Context) (int, error) { } query := internal.String(queryBytes) - ctx, event := q.db.beforeQuery(ctx, qq, query, nil, q.model) + ctx, event := q.db.beforeQuery(ctx, qq, query, nil, query, q.model) var num int err = q.conn.QueryRowContext(ctx, query).Scan(&num) @@ -778,6 +784,13 @@ func (q *SelectQuery) Count(ctx context.Context) (int, error) { } func (q *SelectQuery) ScanAndCount(ctx context.Context, dest ...interface{}) (int, error) { + if _, ok := q.conn.(*DB); ok { + return q.scanAndCountConc(ctx, dest...) + } + return q.scanAndCountSeq(ctx, dest...) +} + +func (q *SelectQuery) scanAndCountConc(ctx context.Context, dest ...interface{}) (int, error) { var count int var wg sync.WaitGroup var mu sync.Mutex @@ -817,6 +830,21 @@ func (q *SelectQuery) ScanAndCount(ctx context.Context, dest ...interface{}) (in return count, firstErr } +func (q *SelectQuery) scanAndCountSeq(ctx context.Context, dest ...interface{}) (int, error) { + var firstErr error + + if q.limit >= 0 { + firstErr = q.Scan(ctx, dest...) + } + + count, err := q.Count(ctx) + if err != nil && firstErr == nil { + firstErr = err + } + + return count, firstErr +} + func (q *SelectQuery) Exists(ctx context.Context) (bool, error) { if q.err != nil { return false, q.err @@ -830,7 +858,7 @@ func (q *SelectQuery) Exists(ctx context.Context) (bool, error) { } query := internal.String(queryBytes) - ctx, event := q.db.beforeQuery(ctx, qq, query, nil, q.model) + ctx, event := q.db.beforeQuery(ctx, qq, query, nil, query, q.model) var exists bool err = q.conn.QueryRowContext(ctx, query).Scan(&exists) |