summaryrefslogtreecommitdiff
path: root/vendor/github.com/uptrace/bun/relation_join.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/uptrace/bun/relation_join.go')
-rw-r--r--vendor/github.com/uptrace/bun/relation_join.go36
1 files changed, 35 insertions, 1 deletions
diff --git a/vendor/github.com/uptrace/bun/relation_join.go b/vendor/github.com/uptrace/bun/relation_join.go
index 487f776ed..47f27afd5 100644
--- a/vendor/github.com/uptrace/bun/relation_join.go
+++ b/vendor/github.com/uptrace/bun/relation_join.go
@@ -16,6 +16,8 @@ type relationJoin struct {
JoinModel TableModel
Relation *schema.Relation
+ additionalJoinOnConditions []schema.QueryWithArgs
+
apply func(*SelectQuery) *SelectQuery
columns []schema.QueryWithArgs
}
@@ -63,7 +65,7 @@ func (j *relationJoin) manyQuery(q *SelectQuery) *SelectQuery {
var where []byte
- if q.db.dialect.Features().Has(feature.CompositeIn) {
+ if q.db.HasFeature(feature.CompositeIn) {
return j.manyQueryCompositeIn(where, q)
}
return j.manyQueryMulti(where, q)
@@ -86,6 +88,11 @@ func (j *relationJoin) manyQueryCompositeIn(where []byte, q *SelectQuery) *Selec
j.Relation.BasePKs,
)
where = append(where, ")"...)
+ if len(j.additionalJoinOnConditions) > 0 {
+ where = append(where, " AND "...)
+ where = appendAdditionalJoinOnConditions(q.db.Formatter(), where, j.additionalJoinOnConditions)
+ }
+
q = q.Where(internal.String(where))
if j.Relation.PolymorphicField != nil {
@@ -111,6 +118,10 @@ func (j *relationJoin) manyQueryMulti(where []byte, q *SelectQuery) *SelectQuery
q = q.Where(internal.String(where))
+ if len(j.additionalJoinOnConditions) > 0 {
+ q = q.Where(internal.String(appendAdditionalJoinOnConditions(q.db.Formatter(), []byte{}, j.additionalJoinOnConditions)))
+ }
+
if j.Relation.PolymorphicField != nil {
q = q.Where("? = ?", j.Relation.PolymorphicField.SQLName, j.Relation.PolymorphicValue)
}
@@ -204,6 +215,12 @@ func (j *relationJoin) m2mQuery(q *SelectQuery) *SelectQuery {
join = append(join, ") IN ("...)
join = appendChildValues(fmter, join, j.BaseModel.rootValue(), index, j.Relation.BasePKs)
join = append(join, ")"...)
+
+ if len(j.additionalJoinOnConditions) > 0 {
+ join = append(join, " AND "...)
+ join = appendAdditionalJoinOnConditions(fmter, join, j.additionalJoinOnConditions)
+ }
+
q = q.Join(internal.String(join))
joinTable := j.JoinModel.Table()
@@ -330,6 +347,11 @@ func (j *relationJoin) appendHasOneJoin(
b = j.appendSoftDelete(fmter, b, q.flags)
}
+ if len(j.additionalJoinOnConditions) > 0 {
+ b = append(b, " AND "...)
+ b = appendAdditionalJoinOnConditions(fmter, b, j.additionalJoinOnConditions)
+ }
+
return b, nil
}
@@ -417,3 +439,15 @@ func appendMultiValues(
b = append(b, ')')
return b
}
+
+func appendAdditionalJoinOnConditions(
+ fmter schema.Formatter, b []byte, conditions []schema.QueryWithArgs,
+) []byte {
+ for i, cond := range conditions {
+ if i > 0 {
+ b = append(b, " AND "...)
+ }
+ b = fmter.AppendQuery(b, cond.Query, cond.Args...)
+ }
+ return b
+}