summaryrefslogtreecommitdiff
path: root/vendor/github.com/uptrace/bun/relation_join.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/uptrace/bun/relation_join.go')
-rw-r--r--vendor/github.com/uptrace/bun/relation_join.go86
1 files changed, 85 insertions, 1 deletions
diff --git a/vendor/github.com/uptrace/bun/relation_join.go b/vendor/github.com/uptrace/bun/relation_join.go
index e8074e0c6..19dda774e 100644
--- a/vendor/github.com/uptrace/bun/relation_join.go
+++ b/vendor/github.com/uptrace/bun/relation_join.go
@@ -4,6 +4,7 @@ import (
"context"
"reflect"
+ "github.com/uptrace/bun/dialect/feature"
"github.com/uptrace/bun/internal"
"github.com/uptrace/bun/schema"
)
@@ -60,6 +61,14 @@ func (j *relationJoin) manyQuery(q *SelectQuery) *SelectQuery {
q = q.Model(hasManyModel)
var where []byte
+
+ if q.db.dialect.Features().Has(feature.CompositeIn) {
+ return j.manyQueryCompositeIn(where, q)
+ }
+ return j.manyQueryMulti(where, q)
+}
+
+func (j *relationJoin) manyQueryCompositeIn(where []byte, q *SelectQuery) *SelectQuery {
if len(j.Relation.JoinFields) > 1 {
where = append(where, '(')
}
@@ -88,6 +97,29 @@ func (j *relationJoin) manyQuery(q *SelectQuery) *SelectQuery {
return q
}
+func (j *relationJoin) manyQueryMulti(where []byte, q *SelectQuery) *SelectQuery {
+ where = appendMultiValues(
+ q.db.Formatter(),
+ where,
+ j.JoinModel.rootValue(),
+ j.JoinModel.parentIndex(),
+ j.Relation.BaseFields,
+ j.Relation.JoinFields,
+ j.JoinModel.Table().SQLAlias,
+ )
+
+ q = q.Where(internal.String(where))
+
+ if j.Relation.PolymorphicField != nil {
+ q = q.Where("? = ?", j.Relation.PolymorphicField.SQLName, j.Relation.PolymorphicValue)
+ }
+
+ j.applyTo(q)
+ q = q.Apply(j.hasManyColumns)
+
+ return q
+}
+
func (j *relationJoin) hasManyColumns(q *SelectQuery) *SelectQuery {
b := make([]byte, 0, 32)
@@ -151,7 +183,7 @@ func (j *relationJoin) m2mQuery(q *SelectQuery) *SelectQuery {
//nolint
var join []byte
join = append(join, "JOIN "...)
- join = fmter.AppendQuery(join, string(j.Relation.M2MTable.Name))
+ join = fmter.AppendQuery(join, string(j.Relation.M2MTable.SQLName))
join = append(join, " AS "...)
join = append(join, j.Relation.M2MTable.SQLAlias...)
join = append(join, " ON ("...)
@@ -312,3 +344,55 @@ func appendChildValues(
}
return b
}
+
+// appendMultiValues is an alternative to appendChildValues that doesn't use the sql keyword ID
+// but instead use a old style ((k1=v1) AND (k2=v2)) OR (...) of conditions.
+func appendMultiValues(
+ fmter schema.Formatter, b []byte, v reflect.Value, index []int, baseFields, joinFields []*schema.Field, joinTable schema.Safe,
+) []byte {
+ // This is based on a mix of appendChildValues and query_base.appendColumns
+
+ // These should never missmatch in length but nice to know if it does
+ if len(joinFields) != len(baseFields) {
+ panic("not reached")
+ }
+
+ // walk the relations
+ b = append(b, '(')
+ seen := make(map[string]struct{})
+ walk(v, index, func(v reflect.Value) {
+ start := len(b)
+ for i, f := range baseFields {
+ if i > 0 {
+ b = append(b, " AND "...)
+ }
+ if len(baseFields) > 1 {
+ b = append(b, '(')
+ }
+ // Field name
+ b = append(b, joinTable...)
+ b = append(b, '.')
+ b = append(b, []byte(joinFields[i].SQLName)...)
+
+ // Equals value
+ b = append(b, '=')
+ b = f.AppendValue(fmter, b, v)
+ if len(baseFields) > 1 {
+ b = append(b, ')')
+ }
+ }
+
+ b = append(b, ") OR ("...)
+
+ if _, ok := seen[string(b[start:])]; ok {
+ b = b[:start]
+ } else {
+ seen[string(b[start:])] = struct{}{}
+ }
+ })
+ if len(seen) > 0 {
+ b = b[:len(b)-6] // trim ") OR ("
+ }
+ b = append(b, ')')
+ return b
+}