diff options
Diffstat (limited to 'vendor/github.com/uptrace/bun/relation_join.go')
-rw-r--r-- | vendor/github.com/uptrace/bun/relation_join.go | 314 |
1 files changed, 314 insertions, 0 deletions
diff --git a/vendor/github.com/uptrace/bun/relation_join.go b/vendor/github.com/uptrace/bun/relation_join.go new file mode 100644 index 000000000..e8074e0c6 --- /dev/null +++ b/vendor/github.com/uptrace/bun/relation_join.go @@ -0,0 +1,314 @@ +package bun + +import ( + "context" + "reflect" + + "github.com/uptrace/bun/internal" + "github.com/uptrace/bun/schema" +) + +type relationJoin struct { + Parent *relationJoin + BaseModel TableModel + JoinModel TableModel + Relation *schema.Relation + + apply func(*SelectQuery) *SelectQuery + columns []schema.QueryWithArgs +} + +func (j *relationJoin) applyTo(q *SelectQuery) { + if j.apply == nil { + return + } + + var table *schema.Table + var columns []schema.QueryWithArgs + + // Save state. + table, q.table = q.table, j.JoinModel.Table() + columns, q.columns = q.columns, nil + + q = j.apply(q) + + // Restore state. + q.table = table + j.columns, q.columns = q.columns, columns +} + +func (j *relationJoin) Select(ctx context.Context, q *SelectQuery) error { + switch j.Relation.Type { + } + panic("not reached") +} + +func (j *relationJoin) selectMany(ctx context.Context, q *SelectQuery) error { + q = j.manyQuery(q) + if q == nil { + return nil + } + return q.Scan(ctx) +} + +func (j *relationJoin) manyQuery(q *SelectQuery) *SelectQuery { + hasManyModel := newHasManyModel(j) + if hasManyModel == nil { + return nil + } + + q = q.Model(hasManyModel) + + var where []byte + if len(j.Relation.JoinFields) > 1 { + where = append(where, '(') + } + where = appendColumns(where, j.JoinModel.Table().SQLAlias, j.Relation.JoinFields) + if len(j.Relation.JoinFields) > 1 { + where = append(where, ')') + } + where = append(where, " IN ("...) + where = appendChildValues( + q.db.Formatter(), + where, + j.JoinModel.rootValue(), + j.JoinModel.parentIndex(), + j.Relation.BaseFields, + ) + where = append(where, ")"...) + q = q.Where(internal.String(where)) + + if j.Relation.PolymorphicField != nil { + q = q.Where("? = ?", j.Relation.PolymorphicField.SQLName, j.Relation.PolymorphicValue) + } + + j.applyTo(q) + q = q.Apply(j.hasManyColumns) + + return q +} + +func (j *relationJoin) hasManyColumns(q *SelectQuery) *SelectQuery { + b := make([]byte, 0, 32) + + joinTable := j.JoinModel.Table() + if len(j.columns) > 0 { + for i, col := range j.columns { + if i > 0 { + b = append(b, ", "...) + } + + if col.Args == nil { + if field, ok := joinTable.FieldMap[col.Query]; ok { + b = append(b, joinTable.SQLAlias...) + b = append(b, '.') + b = append(b, field.SQLName...) + continue + } + } + + var err error + b, err = col.AppendQuery(q.db.fmter, b) + if err != nil { + q.setErr(err) + return q + } + + } + } else { + b = appendColumns(b, joinTable.SQLAlias, joinTable.Fields) + } + + q = q.ColumnExpr(internal.String(b)) + + return q +} + +func (j *relationJoin) selectM2M(ctx context.Context, q *SelectQuery) error { + q = j.m2mQuery(q) + if q == nil { + return nil + } + return q.Scan(ctx) +} + +func (j *relationJoin) m2mQuery(q *SelectQuery) *SelectQuery { + fmter := q.db.fmter + + m2mModel := newM2MModel(j) + if m2mModel == nil { + return nil + } + q = q.Model(m2mModel) + + index := j.JoinModel.parentIndex() + baseTable := j.BaseModel.Table() + + if j.Relation.M2MTable != nil { + q = q.ColumnExpr(string(j.Relation.M2MTable.SQLAlias) + ".*") + } + + //nolint + var join []byte + join = append(join, "JOIN "...) + join = fmter.AppendQuery(join, string(j.Relation.M2MTable.Name)) + join = append(join, " AS "...) + join = append(join, j.Relation.M2MTable.SQLAlias...) + join = append(join, " ON ("...) + for i, col := range j.Relation.M2MBaseFields { + if i > 0 { + join = append(join, ", "...) + } + join = append(join, j.Relation.M2MTable.SQLAlias...) + join = append(join, '.') + join = append(join, col.SQLName...) + } + join = append(join, ") IN ("...) + join = appendChildValues(fmter, join, j.BaseModel.rootValue(), index, baseTable.PKs) + join = append(join, ")"...) + q = q.Join(internal.String(join)) + + joinTable := j.JoinModel.Table() + for i, m2mJoinField := range j.Relation.M2MJoinFields { + joinField := j.Relation.JoinFields[i] + q = q.Where("?.? = ?.?", + joinTable.SQLAlias, joinField.SQLName, + j.Relation.M2MTable.SQLAlias, m2mJoinField.SQLName) + } + + j.applyTo(q) + q = q.Apply(j.hasManyColumns) + + return q +} + +func (j *relationJoin) hasParent() bool { + if j.Parent != nil { + switch j.Parent.Relation.Type { + case schema.HasOneRelation, schema.BelongsToRelation: + return true + } + } + return false +} + +func (j *relationJoin) appendAlias(fmter schema.Formatter, b []byte) []byte { + quote := fmter.IdentQuote() + + b = append(b, quote) + b = appendAlias(b, j) + b = append(b, quote) + return b +} + +func (j *relationJoin) appendAliasColumn(fmter schema.Formatter, b []byte, column string) []byte { + quote := fmter.IdentQuote() + + b = append(b, quote) + b = appendAlias(b, j) + b = append(b, "__"...) + b = append(b, column...) + b = append(b, quote) + return b +} + +func (j *relationJoin) appendBaseAlias(fmter schema.Formatter, b []byte) []byte { + quote := fmter.IdentQuote() + + if j.hasParent() { + b = append(b, quote) + b = appendAlias(b, j.Parent) + b = append(b, quote) + return b + } + return append(b, j.BaseModel.Table().SQLAlias...) +} + +func (j *relationJoin) appendSoftDelete(b []byte, flags internal.Flag) []byte { + b = append(b, '.') + b = append(b, j.JoinModel.Table().SoftDeleteField.SQLName...) + if flags.Has(deletedFlag) { + b = append(b, " IS NOT NULL"...) + } else { + b = append(b, " IS NULL"...) + } + return b +} + +func appendAlias(b []byte, j *relationJoin) []byte { + if j.hasParent() { + b = appendAlias(b, j.Parent) + b = append(b, "__"...) + } + b = append(b, j.Relation.Field.Name...) + return b +} + +func (j *relationJoin) appendHasOneJoin( + fmter schema.Formatter, b []byte, q *SelectQuery, +) (_ []byte, err error) { + isSoftDelete := j.JoinModel.Table().SoftDeleteField != nil && !q.flags.Has(allWithDeletedFlag) + + b = append(b, "LEFT JOIN "...) + b = fmter.AppendQuery(b, string(j.JoinModel.Table().SQLNameForSelects)) + b = append(b, " AS "...) + b = j.appendAlias(fmter, b) + + b = append(b, " ON "...) + + b = append(b, '(') + for i, baseField := range j.Relation.BaseFields { + if i > 0 { + b = append(b, " AND "...) + } + b = j.appendAlias(fmter, b) + b = append(b, '.') + b = append(b, j.Relation.JoinFields[i].SQLName...) + b = append(b, " = "...) + b = j.appendBaseAlias(fmter, b) + b = append(b, '.') + b = append(b, baseField.SQLName...) + } + b = append(b, ')') + + if isSoftDelete { + b = append(b, " AND "...) + b = j.appendAlias(fmter, b) + b = j.appendSoftDelete(b, q.flags) + } + + return b, nil +} + +func appendChildValues( + fmter schema.Formatter, b []byte, v reflect.Value, index []int, fields []*schema.Field, +) []byte { + seen := make(map[string]struct{}) + walk(v, index, func(v reflect.Value) { + start := len(b) + + if len(fields) > 1 { + b = append(b, '(') + } + for i, f := range fields { + if i > 0 { + b = append(b, ", "...) + } + b = f.AppendValue(fmter, b, v) + } + if len(fields) > 1 { + b = append(b, ')') + } + b = append(b, ", "...) + + if _, ok := seen[string(b[start:])]; ok { + b = b[:start] + } else { + seen[string(b[start:])] = struct{}{} + } + }) + if len(seen) > 0 { + b = b[:len(b)-2] // trim ", " + } + return b +} |