summaryrefslogtreecommitdiff
path: root/vendor/github.com/uptrace/bun/query_select.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/uptrace/bun/query_select.go')
-rw-r--r--vendor/github.com/uptrace/bun/query_select.go161
1 files changed, 152 insertions, 9 deletions
diff --git a/vendor/github.com/uptrace/bun/query_select.go b/vendor/github.com/uptrace/bun/query_select.go
index 56ba310e5..b61bcfaf0 100644
--- a/vendor/github.com/uptrace/bun/query_select.go
+++ b/vendor/github.com/uptrace/bun/query_select.go
@@ -10,6 +10,8 @@ import (
"strings"
"sync"
+ "github.com/uptrace/bun/dialect"
+
"github.com/uptrace/bun/dialect/feature"
"github.com/uptrace/bun/internal"
"github.com/uptrace/bun/schema"
@@ -22,6 +24,7 @@ type union struct {
type SelectQuery struct {
whereBaseQuery
+ idxHintsQuery
distinctOn []schema.QueryWithArgs
joins []joinQuery
@@ -159,6 +162,92 @@ func (q *SelectQuery) WhereAllWithDeleted() *SelectQuery {
//------------------------------------------------------------------------------
+func (q *SelectQuery) UseIndex(indexes ...string) *SelectQuery {
+ if q.db.dialect.Name() == dialect.MySQL {
+ q.addUseIndex(indexes...)
+ }
+ return q
+}
+
+func (q *SelectQuery) UseIndexForJoin(indexes ...string) *SelectQuery {
+ if q.db.dialect.Name() == dialect.MySQL {
+ q.addUseIndexForJoin(indexes...)
+ }
+ return q
+}
+
+func (q *SelectQuery) UseIndexForOrderBy(indexes ...string) *SelectQuery {
+ if q.db.dialect.Name() == dialect.MySQL {
+ q.addUseIndexForOrderBy(indexes...)
+ }
+ return q
+}
+
+func (q *SelectQuery) UseIndexForGroupBy(indexes ...string) *SelectQuery {
+ if q.db.dialect.Name() == dialect.MySQL {
+ q.addUseIndexForGroupBy(indexes...)
+ }
+ return q
+}
+
+func (q *SelectQuery) IgnoreIndex(indexes ...string) *SelectQuery {
+ if q.db.dialect.Name() == dialect.MySQL {
+ q.addIgnoreIndex(indexes...)
+ }
+ return q
+}
+
+func (q *SelectQuery) IgnoreIndexForJoin(indexes ...string) *SelectQuery {
+ if q.db.dialect.Name() == dialect.MySQL {
+ q.addIgnoreIndexForJoin(indexes...)
+ }
+ return q
+}
+
+func (q *SelectQuery) IgnoreIndexForOrderBy(indexes ...string) *SelectQuery {
+ if q.db.dialect.Name() == dialect.MySQL {
+ q.addIgnoreIndexForOrderBy(indexes...)
+ }
+ return q
+}
+
+func (q *SelectQuery) IgnoreIndexForGroupBy(indexes ...string) *SelectQuery {
+ if q.db.dialect.Name() == dialect.MySQL {
+ q.addIgnoreIndexForGroupBy(indexes...)
+ }
+ return q
+}
+
+func (q *SelectQuery) ForceIndex(indexes ...string) *SelectQuery {
+ if q.db.dialect.Name() == dialect.MySQL {
+ q.addForceIndex(indexes...)
+ }
+ return q
+}
+
+func (q *SelectQuery) ForceIndexForJoin(indexes ...string) *SelectQuery {
+ if q.db.dialect.Name() == dialect.MySQL {
+ q.addForceIndexForJoin(indexes...)
+ }
+ return q
+}
+
+func (q *SelectQuery) ForceIndexForOrderBy(indexes ...string) *SelectQuery {
+ if q.db.dialect.Name() == dialect.MySQL {
+ q.addForceIndexForOrderBy(indexes...)
+ }
+ return q
+}
+
+func (q *SelectQuery) ForceIndexForGroupBy(indexes ...string) *SelectQuery {
+ if q.db.dialect.Name() == dialect.MySQL {
+ q.addForceIndexForGroupBy(indexes...)
+ }
+ return q
+}
+
+//------------------------------------------------------------------------------
+
func (q *SelectQuery) Group(columns ...string) *SelectQuery {
for _, column := range columns {
q.group = append(q.group, schema.UnsafeIdent(column))
@@ -305,8 +394,31 @@ func (q *SelectQuery) Relation(name string, apply ...func(*SelectQuery) *SelectQ
return q
}
+ var apply1, apply2 func(*SelectQuery) *SelectQuery
+
+ if len(join.Relation.Condition) > 0 {
+ apply1 = func(q *SelectQuery) *SelectQuery {
+ for _, opt := range join.Relation.Condition {
+ q.addWhere(schema.SafeQueryWithSep(opt, nil, " AND "))
+ }
+
+ return q
+ }
+ }
+
if len(apply) == 1 {
- join.apply = apply[0]
+ apply2 = apply[0]
+ }
+
+ join.apply = func(q *SelectQuery) *SelectQuery {
+ if apply1 != nil {
+ q = apply1(q)
+ }
+ if apply2 != nil {
+ q = apply2(q)
+ }
+
+ return q
}
return q
@@ -441,6 +553,11 @@ func (q *SelectQuery) appendQuery(
}
}
+ b, err = q.appendIndexHints(fmter, b)
+ if err != nil {
+ return nil, err
+ }
+
b, err = q.appendWhere(fmter, b, true)
if err != nil {
return nil, err
@@ -481,7 +598,7 @@ func (q *SelectQuery) appendQuery(
}
if fmter.Dialect().Features().Has(feature.OffsetFetch) {
- if q.offset != 0 {
+ if q.limit > 0 && q.offset > 0 {
b = append(b, " OFFSET "...)
b = strconv.AppendInt(b, int64(q.offset), 10)
b = append(b, " ROWS"...)
@@ -489,13 +606,23 @@ func (q *SelectQuery) appendQuery(
b = append(b, " FETCH NEXT "...)
b = strconv.AppendInt(b, int64(q.limit), 10)
b = append(b, " ROWS ONLY"...)
+ } else if q.limit > 0 {
+ b = append(b, " OFFSET 0 ROWS"...)
+
+ b = append(b, " FETCH NEXT "...)
+ b = strconv.AppendInt(b, int64(q.limit), 10)
+ b = append(b, " ROWS ONLY"...)
+ } else if q.offset > 0 {
+ b = append(b, " OFFSET "...)
+ b = strconv.AppendInt(b, int64(q.offset), 10)
+ b = append(b, " ROWS"...)
}
} else {
- if q.limit != 0 {
+ if q.limit > 0 {
b = append(b, " LIMIT "...)
b = strconv.AppendInt(b, int64(q.limit), 10)
}
- if q.offset != 0 {
+ if q.offset > 0 {
b = append(b, " OFFSET "...)
b = strconv.AppendInt(b, int64(q.offset), 10)
}
@@ -920,12 +1047,32 @@ func (q *SelectQuery) whereExists(ctx context.Context) (bool, error) {
return n == 1, nil
}
+func (q *SelectQuery) String() string {
+ buf, err := q.AppendQuery(q.db.Formatter(), nil)
+ if err != nil {
+ panic(err)
+ }
+
+ return string(buf)
+}
+
//------------------------------------------------------------------------------
+
+func (q *SelectQuery) QueryBuilder() QueryBuilder {
+ return &selectQueryBuilder{q}
+}
+
+func (q *SelectQuery) ApplyQueryBuilder(fn func(QueryBuilder) QueryBuilder) *SelectQuery {
+ return fn(q.QueryBuilder()).Unwrap().(*SelectQuery)
+}
+
type selectQueryBuilder struct {
*SelectQuery
}
-func (q *selectQueryBuilder) WhereGroup(sep string, fn func(QueryBuilder) QueryBuilder) QueryBuilder {
+func (q *selectQueryBuilder) WhereGroup(
+ sep string, fn func(QueryBuilder) QueryBuilder,
+) QueryBuilder {
q.SelectQuery = q.SelectQuery.WhereGroup(sep, func(qs *SelectQuery) *SelectQuery {
return fn(q).(*selectQueryBuilder).SelectQuery
})
@@ -961,10 +1108,6 @@ func (q *selectQueryBuilder) Unwrap() interface{} {
return q.SelectQuery
}
-func (q *SelectQuery) Query() QueryBuilder {
- return &selectQueryBuilder{q}
-}
-
//------------------------------------------------------------------------------
type joinQuery struct {