diff options
Diffstat (limited to 'vendor/github.com/uptrace/bun/query_select.go')
-rw-r--r-- | vendor/github.com/uptrace/bun/query_select.go | 48 |
1 files changed, 37 insertions, 11 deletions
diff --git a/vendor/github.com/uptrace/bun/query_select.go b/vendor/github.com/uptrace/bun/query_select.go index c0e145110..5bb329143 100644 --- a/vendor/github.com/uptrace/bun/query_select.go +++ b/vendor/github.com/uptrace/bun/query_select.go @@ -538,6 +538,11 @@ func (q *SelectQuery) appendQuery( if count && !cteCount { b = append(b, "count(*)"...) } else { + // MSSQL: allows Limit() without Order() as per https://stackoverflow.com/a/36156953 + if q.limit > 0 && len(q.order) == 0 && fmter.Dialect().Name() == dialect.MSSQL { + b = append(b, "0 AS _temp_sort, "...) + } + b, err = q.appendColumns(fmter, b) if err != nil { return nil, err @@ -564,8 +569,8 @@ func (q *SelectQuery) appendQuery( return nil, err } - for _, j := range q.joins { - b, err = j.AppendQuery(fmter, b) + for _, join := range q.joins { + b, err = join.AppendQuery(fmter, b) if err != nil { return nil, err } @@ -793,6 +798,12 @@ func (q *SelectQuery) appendOrder(fmter schema.Formatter, b []byte) (_ []byte, e return b, nil } + + // MSSQL: allows Limit() without Order() as per https://stackoverflow.com/a/36156953 + if q.limit > 0 && fmter.Dialect().Name() == dialect.MSSQL { + return append(b, " ORDER BY _temp_sort"...), nil + } + return b, nil } @@ -856,52 +867,57 @@ func (q *SelectQuery) Exec(ctx context.Context, dest ...interface{}) (res sql.Re } func (q *SelectQuery) Scan(ctx context.Context, dest ...interface{}) error { + _, err := q.scanResult(ctx, dest...) + return err +} + +func (q *SelectQuery) scanResult(ctx context.Context, dest ...interface{}) (sql.Result, error) { if q.err != nil { - return q.err + return nil, q.err } model, err := q.getModel(dest) if err != nil { - return err + return nil, err } if q.table != nil { if err := q.beforeSelectHook(ctx); err != nil { - return err + return nil, err } } if err := q.beforeAppendModel(ctx, q); err != nil { - return err + return nil, err } queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes()) if err != nil { - return err + return nil, err } query := internal.String(queryBytes) res, err := q.scan(ctx, q, query, model, true) if err != nil { - return err + return nil, err } if n, _ := res.RowsAffected(); n > 0 { if tableModel, ok := model.(TableModel); ok { if err := q.selectJoins(ctx, tableModel.getJoins()); err != nil { - return err + return nil, err } } } if q.table != nil { if err := q.afterSelectHook(ctx); err != nil { - return err + return nil, err } } - return nil + return res, nil } func (q *SelectQuery) beforeSelectHook(ctx context.Context) error { @@ -946,6 +962,16 @@ func (q *SelectQuery) Count(ctx context.Context) (int, error) { } func (q *SelectQuery) ScanAndCount(ctx context.Context, dest ...interface{}) (int, error) { + if q.offset == 0 && q.limit == 0 { + // If there is no limit and offset, we can use a single query to get the count and scan + if res, err := q.scanResult(ctx, dest...); err != nil { + return 0, err + } else if n, err := res.RowsAffected(); err != nil { + return 0, err + } else { + return int(n), nil + } + } if _, ok := q.conn.(*DB); ok { return q.scanAndCountConc(ctx, dest...) } |