summaryrefslogtreecommitdiff
path: root/vendor/github.com/uptrace/bun/query_select.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/uptrace/bun/query_select.go')
-rw-r--r--vendor/github.com/uptrace/bun/query_select.go48
1 files changed, 37 insertions, 11 deletions
diff --git a/vendor/github.com/uptrace/bun/query_select.go b/vendor/github.com/uptrace/bun/query_select.go
index c0e145110..5bb329143 100644
--- a/vendor/github.com/uptrace/bun/query_select.go
+++ b/vendor/github.com/uptrace/bun/query_select.go
@@ -538,6 +538,11 @@ func (q *SelectQuery) appendQuery(
if count && !cteCount {
b = append(b, "count(*)"...)
} else {
+ // MSSQL: allows Limit() without Order() as per https://stackoverflow.com/a/36156953
+ if q.limit > 0 && len(q.order) == 0 && fmter.Dialect().Name() == dialect.MSSQL {
+ b = append(b, "0 AS _temp_sort, "...)
+ }
+
b, err = q.appendColumns(fmter, b)
if err != nil {
return nil, err
@@ -564,8 +569,8 @@ func (q *SelectQuery) appendQuery(
return nil, err
}
- for _, j := range q.joins {
- b, err = j.AppendQuery(fmter, b)
+ for _, join := range q.joins {
+ b, err = join.AppendQuery(fmter, b)
if err != nil {
return nil, err
}
@@ -793,6 +798,12 @@ func (q *SelectQuery) appendOrder(fmter schema.Formatter, b []byte) (_ []byte, e
return b, nil
}
+
+ // MSSQL: allows Limit() without Order() as per https://stackoverflow.com/a/36156953
+ if q.limit > 0 && fmter.Dialect().Name() == dialect.MSSQL {
+ return append(b, " ORDER BY _temp_sort"...), nil
+ }
+
return b, nil
}
@@ -856,52 +867,57 @@ func (q *SelectQuery) Exec(ctx context.Context, dest ...interface{}) (res sql.Re
}
func (q *SelectQuery) Scan(ctx context.Context, dest ...interface{}) error {
+ _, err := q.scanResult(ctx, dest...)
+ return err
+}
+
+func (q *SelectQuery) scanResult(ctx context.Context, dest ...interface{}) (sql.Result, error) {
if q.err != nil {
- return q.err
+ return nil, q.err
}
model, err := q.getModel(dest)
if err != nil {
- return err
+ return nil, err
}
if q.table != nil {
if err := q.beforeSelectHook(ctx); err != nil {
- return err
+ return nil, err
}
}
if err := q.beforeAppendModel(ctx, q); err != nil {
- return err
+ return nil, err
}
queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes())
if err != nil {
- return err
+ return nil, err
}
query := internal.String(queryBytes)
res, err := q.scan(ctx, q, query, model, true)
if err != nil {
- return err
+ return nil, err
}
if n, _ := res.RowsAffected(); n > 0 {
if tableModel, ok := model.(TableModel); ok {
if err := q.selectJoins(ctx, tableModel.getJoins()); err != nil {
- return err
+ return nil, err
}
}
}
if q.table != nil {
if err := q.afterSelectHook(ctx); err != nil {
- return err
+ return nil, err
}
}
- return nil
+ return res, nil
}
func (q *SelectQuery) beforeSelectHook(ctx context.Context) error {
@@ -946,6 +962,16 @@ func (q *SelectQuery) Count(ctx context.Context) (int, error) {
}
func (q *SelectQuery) ScanAndCount(ctx context.Context, dest ...interface{}) (int, error) {
+ if q.offset == 0 && q.limit == 0 {
+ // If there is no limit and offset, we can use a single query to get the count and scan
+ if res, err := q.scanResult(ctx, dest...); err != nil {
+ return 0, err
+ } else if n, err := res.RowsAffected(); err != nil {
+ return 0, err
+ } else {
+ return int(n), nil
+ }
+ }
if _, ok := q.conn.(*DB); ok {
return q.scanAndCountConc(ctx, dest...)
}