diff options
Diffstat (limited to 'vendor/github.com/uptrace/bun/query_select.go')
-rw-r--r-- | vendor/github.com/uptrace/bun/query_select.go | 161 |
1 files changed, 152 insertions, 9 deletions
diff --git a/vendor/github.com/uptrace/bun/query_select.go b/vendor/github.com/uptrace/bun/query_select.go index 56ba310e5..b61bcfaf0 100644 --- a/vendor/github.com/uptrace/bun/query_select.go +++ b/vendor/github.com/uptrace/bun/query_select.go @@ -10,6 +10,8 @@ import ( "strings" "sync" + "github.com/uptrace/bun/dialect" + "github.com/uptrace/bun/dialect/feature" "github.com/uptrace/bun/internal" "github.com/uptrace/bun/schema" @@ -22,6 +24,7 @@ type union struct { type SelectQuery struct { whereBaseQuery + idxHintsQuery distinctOn []schema.QueryWithArgs joins []joinQuery @@ -159,6 +162,92 @@ func (q *SelectQuery) WhereAllWithDeleted() *SelectQuery { //------------------------------------------------------------------------------ +func (q *SelectQuery) UseIndex(indexes ...string) *SelectQuery { + if q.db.dialect.Name() == dialect.MySQL { + q.addUseIndex(indexes...) + } + return q +} + +func (q *SelectQuery) UseIndexForJoin(indexes ...string) *SelectQuery { + if q.db.dialect.Name() == dialect.MySQL { + q.addUseIndexForJoin(indexes...) + } + return q +} + +func (q *SelectQuery) UseIndexForOrderBy(indexes ...string) *SelectQuery { + if q.db.dialect.Name() == dialect.MySQL { + q.addUseIndexForOrderBy(indexes...) + } + return q +} + +func (q *SelectQuery) UseIndexForGroupBy(indexes ...string) *SelectQuery { + if q.db.dialect.Name() == dialect.MySQL { + q.addUseIndexForGroupBy(indexes...) + } + return q +} + +func (q *SelectQuery) IgnoreIndex(indexes ...string) *SelectQuery { + if q.db.dialect.Name() == dialect.MySQL { + q.addIgnoreIndex(indexes...) + } + return q +} + +func (q *SelectQuery) IgnoreIndexForJoin(indexes ...string) *SelectQuery { + if q.db.dialect.Name() == dialect.MySQL { + q.addIgnoreIndexForJoin(indexes...) + } + return q +} + +func (q *SelectQuery) IgnoreIndexForOrderBy(indexes ...string) *SelectQuery { + if q.db.dialect.Name() == dialect.MySQL { + q.addIgnoreIndexForOrderBy(indexes...) + } + return q +} + +func (q *SelectQuery) IgnoreIndexForGroupBy(indexes ...string) *SelectQuery { + if q.db.dialect.Name() == dialect.MySQL { + q.addIgnoreIndexForGroupBy(indexes...) + } + return q +} + +func (q *SelectQuery) ForceIndex(indexes ...string) *SelectQuery { + if q.db.dialect.Name() == dialect.MySQL { + q.addForceIndex(indexes...) + } + return q +} + +func (q *SelectQuery) ForceIndexForJoin(indexes ...string) *SelectQuery { + if q.db.dialect.Name() == dialect.MySQL { + q.addForceIndexForJoin(indexes...) + } + return q +} + +func (q *SelectQuery) ForceIndexForOrderBy(indexes ...string) *SelectQuery { + if q.db.dialect.Name() == dialect.MySQL { + q.addForceIndexForOrderBy(indexes...) + } + return q +} + +func (q *SelectQuery) ForceIndexForGroupBy(indexes ...string) *SelectQuery { + if q.db.dialect.Name() == dialect.MySQL { + q.addForceIndexForGroupBy(indexes...) + } + return q +} + +//------------------------------------------------------------------------------ + func (q *SelectQuery) Group(columns ...string) *SelectQuery { for _, column := range columns { q.group = append(q.group, schema.UnsafeIdent(column)) @@ -305,8 +394,31 @@ func (q *SelectQuery) Relation(name string, apply ...func(*SelectQuery) *SelectQ return q } + var apply1, apply2 func(*SelectQuery) *SelectQuery + + if len(join.Relation.Condition) > 0 { + apply1 = func(q *SelectQuery) *SelectQuery { + for _, opt := range join.Relation.Condition { + q.addWhere(schema.SafeQueryWithSep(opt, nil, " AND ")) + } + + return q + } + } + if len(apply) == 1 { - join.apply = apply[0] + apply2 = apply[0] + } + + join.apply = func(q *SelectQuery) *SelectQuery { + if apply1 != nil { + q = apply1(q) + } + if apply2 != nil { + q = apply2(q) + } + + return q } return q @@ -441,6 +553,11 @@ func (q *SelectQuery) appendQuery( } } + b, err = q.appendIndexHints(fmter, b) + if err != nil { + return nil, err + } + b, err = q.appendWhere(fmter, b, true) if err != nil { return nil, err @@ -481,7 +598,7 @@ func (q *SelectQuery) appendQuery( } if fmter.Dialect().Features().Has(feature.OffsetFetch) { - if q.offset != 0 { + if q.limit > 0 && q.offset > 0 { b = append(b, " OFFSET "...) b = strconv.AppendInt(b, int64(q.offset), 10) b = append(b, " ROWS"...) @@ -489,13 +606,23 @@ func (q *SelectQuery) appendQuery( b = append(b, " FETCH NEXT "...) b = strconv.AppendInt(b, int64(q.limit), 10) b = append(b, " ROWS ONLY"...) + } else if q.limit > 0 { + b = append(b, " OFFSET 0 ROWS"...) + + b = append(b, " FETCH NEXT "...) + b = strconv.AppendInt(b, int64(q.limit), 10) + b = append(b, " ROWS ONLY"...) + } else if q.offset > 0 { + b = append(b, " OFFSET "...) + b = strconv.AppendInt(b, int64(q.offset), 10) + b = append(b, " ROWS"...) } } else { - if q.limit != 0 { + if q.limit > 0 { b = append(b, " LIMIT "...) b = strconv.AppendInt(b, int64(q.limit), 10) } - if q.offset != 0 { + if q.offset > 0 { b = append(b, " OFFSET "...) b = strconv.AppendInt(b, int64(q.offset), 10) } @@ -920,12 +1047,32 @@ func (q *SelectQuery) whereExists(ctx context.Context) (bool, error) { return n == 1, nil } +func (q *SelectQuery) String() string { + buf, err := q.AppendQuery(q.db.Formatter(), nil) + if err != nil { + panic(err) + } + + return string(buf) +} + //------------------------------------------------------------------------------ + +func (q *SelectQuery) QueryBuilder() QueryBuilder { + return &selectQueryBuilder{q} +} + +func (q *SelectQuery) ApplyQueryBuilder(fn func(QueryBuilder) QueryBuilder) *SelectQuery { + return fn(q.QueryBuilder()).Unwrap().(*SelectQuery) +} + type selectQueryBuilder struct { *SelectQuery } -func (q *selectQueryBuilder) WhereGroup(sep string, fn func(QueryBuilder) QueryBuilder) QueryBuilder { +func (q *selectQueryBuilder) WhereGroup( + sep string, fn func(QueryBuilder) QueryBuilder, +) QueryBuilder { q.SelectQuery = q.SelectQuery.WhereGroup(sep, func(qs *SelectQuery) *SelectQuery { return fn(q).(*selectQueryBuilder).SelectQuery }) @@ -961,10 +1108,6 @@ func (q *selectQueryBuilder) Unwrap() interface{} { return q.SelectQuery } -func (q *SelectQuery) Query() QueryBuilder { - return &selectQueryBuilder{q} -} - //------------------------------------------------------------------------------ type joinQuery struct { |