diff options
Diffstat (limited to 'vendor/github.com/uptrace/bun/relation_join.go')
-rw-r--r-- | vendor/github.com/uptrace/bun/relation_join.go | 36 |
1 files changed, 35 insertions, 1 deletions
diff --git a/vendor/github.com/uptrace/bun/relation_join.go b/vendor/github.com/uptrace/bun/relation_join.go index 487f776ed..47f27afd5 100644 --- a/vendor/github.com/uptrace/bun/relation_join.go +++ b/vendor/github.com/uptrace/bun/relation_join.go @@ -16,6 +16,8 @@ type relationJoin struct { JoinModel TableModel Relation *schema.Relation + additionalJoinOnConditions []schema.QueryWithArgs + apply func(*SelectQuery) *SelectQuery columns []schema.QueryWithArgs } @@ -63,7 +65,7 @@ func (j *relationJoin) manyQuery(q *SelectQuery) *SelectQuery { var where []byte - if q.db.dialect.Features().Has(feature.CompositeIn) { + if q.db.HasFeature(feature.CompositeIn) { return j.manyQueryCompositeIn(where, q) } return j.manyQueryMulti(where, q) @@ -86,6 +88,11 @@ func (j *relationJoin) manyQueryCompositeIn(where []byte, q *SelectQuery) *Selec j.Relation.BasePKs, ) where = append(where, ")"...) + if len(j.additionalJoinOnConditions) > 0 { + where = append(where, " AND "...) + where = appendAdditionalJoinOnConditions(q.db.Formatter(), where, j.additionalJoinOnConditions) + } + q = q.Where(internal.String(where)) if j.Relation.PolymorphicField != nil { @@ -111,6 +118,10 @@ func (j *relationJoin) manyQueryMulti(where []byte, q *SelectQuery) *SelectQuery q = q.Where(internal.String(where)) + if len(j.additionalJoinOnConditions) > 0 { + q = q.Where(internal.String(appendAdditionalJoinOnConditions(q.db.Formatter(), []byte{}, j.additionalJoinOnConditions))) + } + if j.Relation.PolymorphicField != nil { q = q.Where("? = ?", j.Relation.PolymorphicField.SQLName, j.Relation.PolymorphicValue) } @@ -204,6 +215,12 @@ func (j *relationJoin) m2mQuery(q *SelectQuery) *SelectQuery { join = append(join, ") IN ("...) join = appendChildValues(fmter, join, j.BaseModel.rootValue(), index, j.Relation.BasePKs) join = append(join, ")"...) + + if len(j.additionalJoinOnConditions) > 0 { + join = append(join, " AND "...) + join = appendAdditionalJoinOnConditions(fmter, join, j.additionalJoinOnConditions) + } + q = q.Join(internal.String(join)) joinTable := j.JoinModel.Table() @@ -330,6 +347,11 @@ func (j *relationJoin) appendHasOneJoin( b = j.appendSoftDelete(fmter, b, q.flags) } + if len(j.additionalJoinOnConditions) > 0 { + b = append(b, " AND "...) + b = appendAdditionalJoinOnConditions(fmter, b, j.additionalJoinOnConditions) + } + return b, nil } @@ -417,3 +439,15 @@ func appendMultiValues( b = append(b, ')') return b } + +func appendAdditionalJoinOnConditions( + fmter schema.Formatter, b []byte, conditions []schema.QueryWithArgs, +) []byte { + for i, cond := range conditions { + if i > 0 { + b = append(b, " AND "...) + } + b = fmter.AppendQuery(b, cond.Query, cond.Args...) + } + return b +} |