diff options
Diffstat (limited to 'vendor/github.com/uptrace/bun/relation_join.go')
-rw-r--r-- | vendor/github.com/uptrace/bun/relation_join.go | 86 |
1 files changed, 85 insertions, 1 deletions
diff --git a/vendor/github.com/uptrace/bun/relation_join.go b/vendor/github.com/uptrace/bun/relation_join.go index e8074e0c6..19dda774e 100644 --- a/vendor/github.com/uptrace/bun/relation_join.go +++ b/vendor/github.com/uptrace/bun/relation_join.go @@ -4,6 +4,7 @@ import ( "context" "reflect" + "github.com/uptrace/bun/dialect/feature" "github.com/uptrace/bun/internal" "github.com/uptrace/bun/schema" ) @@ -60,6 +61,14 @@ func (j *relationJoin) manyQuery(q *SelectQuery) *SelectQuery { q = q.Model(hasManyModel) var where []byte + + if q.db.dialect.Features().Has(feature.CompositeIn) { + return j.manyQueryCompositeIn(where, q) + } + return j.manyQueryMulti(where, q) +} + +func (j *relationJoin) manyQueryCompositeIn(where []byte, q *SelectQuery) *SelectQuery { if len(j.Relation.JoinFields) > 1 { where = append(where, '(') } @@ -88,6 +97,29 @@ func (j *relationJoin) manyQuery(q *SelectQuery) *SelectQuery { return q } +func (j *relationJoin) manyQueryMulti(where []byte, q *SelectQuery) *SelectQuery { + where = appendMultiValues( + q.db.Formatter(), + where, + j.JoinModel.rootValue(), + j.JoinModel.parentIndex(), + j.Relation.BaseFields, + j.Relation.JoinFields, + j.JoinModel.Table().SQLAlias, + ) + + q = q.Where(internal.String(where)) + + if j.Relation.PolymorphicField != nil { + q = q.Where("? = ?", j.Relation.PolymorphicField.SQLName, j.Relation.PolymorphicValue) + } + + j.applyTo(q) + q = q.Apply(j.hasManyColumns) + + return q +} + func (j *relationJoin) hasManyColumns(q *SelectQuery) *SelectQuery { b := make([]byte, 0, 32) @@ -151,7 +183,7 @@ func (j *relationJoin) m2mQuery(q *SelectQuery) *SelectQuery { //nolint var join []byte join = append(join, "JOIN "...) - join = fmter.AppendQuery(join, string(j.Relation.M2MTable.Name)) + join = fmter.AppendQuery(join, string(j.Relation.M2MTable.SQLName)) join = append(join, " AS "...) join = append(join, j.Relation.M2MTable.SQLAlias...) join = append(join, " ON ("...) @@ -312,3 +344,55 @@ func appendChildValues( } return b } + +// appendMultiValues is an alternative to appendChildValues that doesn't use the sql keyword ID +// but instead use a old style ((k1=v1) AND (k2=v2)) OR (...) of conditions. +func appendMultiValues( + fmter schema.Formatter, b []byte, v reflect.Value, index []int, baseFields, joinFields []*schema.Field, joinTable schema.Safe, +) []byte { + // This is based on a mix of appendChildValues and query_base.appendColumns + + // These should never missmatch in length but nice to know if it does + if len(joinFields) != len(baseFields) { + panic("not reached") + } + + // walk the relations + b = append(b, '(') + seen := make(map[string]struct{}) + walk(v, index, func(v reflect.Value) { + start := len(b) + for i, f := range baseFields { + if i > 0 { + b = append(b, " AND "...) + } + if len(baseFields) > 1 { + b = append(b, '(') + } + // Field name + b = append(b, joinTable...) + b = append(b, '.') + b = append(b, []byte(joinFields[i].SQLName)...) + + // Equals value + b = append(b, '=') + b = f.AppendValue(fmter, b, v) + if len(baseFields) > 1 { + b = append(b, ')') + } + } + + b = append(b, ") OR ("...) + + if _, ok := seen[string(b[start:])]; ok { + b = b[:start] + } else { + seen[string(b[start:])] = struct{}{} + } + }) + if len(seen) > 0 { + b = b[:len(b)-6] // trim ") OR (" + } + b = append(b, ')') + return b +} |