diff options
Diffstat (limited to 'vendor/github.com/go-pg/pg/v10/orm/join.go')
-rw-r--r-- | vendor/github.com/go-pg/pg/v10/orm/join.go | 351 |
1 files changed, 351 insertions, 0 deletions
diff --git a/vendor/github.com/go-pg/pg/v10/orm/join.go b/vendor/github.com/go-pg/pg/v10/orm/join.go new file mode 100644 index 000000000..2b64ba1b8 --- /dev/null +++ b/vendor/github.com/go-pg/pg/v10/orm/join.go @@ -0,0 +1,351 @@ +package orm + +import ( + "reflect" + + "github.com/go-pg/pg/v10/internal" + "github.com/go-pg/pg/v10/types" +) + +type join struct { + Parent *join + BaseModel TableModel + JoinModel TableModel + Rel *Relation + + ApplyQuery func(*Query) (*Query, error) + Columns []string + on []*condAppender +} + +func (j *join) AppendOn(app *condAppender) { + j.on = append(j.on, app) +} + +func (j *join) Select(fmter QueryFormatter, q *Query) error { + switch j.Rel.Type { + case HasManyRelation: + return j.selectMany(fmter, q) + case Many2ManyRelation: + return j.selectM2M(fmter, q) + } + panic("not reached") +} + +func (j *join) selectMany(_ QueryFormatter, q *Query) error { + q, err := j.manyQuery(q) + if err != nil { + return err + } + if q == nil { + return nil + } + return q.Select() +} + +func (j *join) manyQuery(q *Query) (*Query, error) { + manyModel := newManyModel(j) + if manyModel == nil { + return nil, nil + } + + q = q.Model(manyModel) + if j.ApplyQuery != nil { + var err error + q, err = j.ApplyQuery(q) + if err != nil { + return nil, err + } + } + + if len(q.columns) == 0 { + q.columns = append(q.columns, &hasManyColumnsAppender{j}) + } + + baseTable := j.BaseModel.Table() + var where []byte + if len(j.Rel.JoinFKs) > 1 { + where = append(where, '(') + } + where = appendColumns(where, j.JoinModel.Table().Alias, j.Rel.JoinFKs) + if len(j.Rel.JoinFKs) > 1 { + where = append(where, ')') + } + where = append(where, " IN ("...) + where = appendChildValues( + where, j.JoinModel.Root(), j.JoinModel.ParentIndex(), j.Rel.BaseFKs) + where = append(where, ")"...) + q = q.Where(internal.BytesToString(where)) + + if j.Rel.Polymorphic != nil { + q = q.Where(`? IN (?, ?)`, + j.Rel.Polymorphic.Column, + baseTable.ModelName, baseTable.TypeName) + } + + return q, nil +} + +func (j *join) selectM2M(fmter QueryFormatter, q *Query) error { + q, err := j.m2mQuery(fmter, q) + if err != nil { + return err + } + if q == nil { + return nil + } + return q.Select() +} + +func (j *join) m2mQuery(fmter QueryFormatter, q *Query) (*Query, error) { + m2mModel := newM2MModel(j) + if m2mModel == nil { + return nil, nil + } + + q = q.Model(m2mModel) + if j.ApplyQuery != nil { + var err error + q, err = j.ApplyQuery(q) + if err != nil { + return nil, err + } + } + + if len(q.columns) == 0 { + q.columns = append(q.columns, &hasManyColumnsAppender{j}) + } + + index := j.JoinModel.ParentIndex() + baseTable := j.BaseModel.Table() + + //nolint + var join []byte + join = append(join, "JOIN "...) + join = fmter.FormatQuery(join, string(j.Rel.M2MTableName)) + join = append(join, " AS "...) + join = append(join, j.Rel.M2MTableAlias...) + join = append(join, " ON ("...) + for i, col := range j.Rel.M2MBaseFKs { + if i > 0 { + join = append(join, ", "...) + } + join = append(join, j.Rel.M2MTableAlias...) + join = append(join, '.') + join = types.AppendIdent(join, col, 1) + } + join = append(join, ") IN ("...) + join = appendChildValues(join, j.BaseModel.Root(), index, baseTable.PKs) + join = append(join, ")"...) + q = q.Join(internal.BytesToString(join)) + + joinTable := j.JoinModel.Table() + for i, col := range j.Rel.M2MJoinFKs { + pk := joinTable.PKs[i] + q = q.Where("?.? = ?.?", + joinTable.Alias, pk.Column, + j.Rel.M2MTableAlias, types.Ident(col)) + } + + return q, nil +} + +func (j *join) hasParent() bool { + if j.Parent != nil { + switch j.Parent.Rel.Type { + case HasOneRelation, BelongsToRelation: + return true + } + } + return false +} + +func (j *join) appendAlias(b []byte) []byte { + b = append(b, '"') + b = appendAlias(b, j) + b = append(b, '"') + return b +} + +func (j *join) appendAliasColumn(b []byte, column string) []byte { + b = append(b, '"') + b = appendAlias(b, j) + b = append(b, "__"...) + b = append(b, column...) + b = append(b, '"') + return b +} + +func (j *join) appendBaseAlias(b []byte) []byte { + if j.hasParent() { + b = append(b, '"') + b = appendAlias(b, j.Parent) + b = append(b, '"') + return b + } + return append(b, j.BaseModel.Table().Alias...) +} + +func (j *join) appendSoftDelete(b []byte, flags queryFlag) []byte { + b = append(b, '.') + b = append(b, j.JoinModel.Table().SoftDeleteField.Column...) + if hasFlag(flags, deletedFlag) { + b = append(b, " IS NOT NULL"...) + } else { + b = append(b, " IS NULL"...) + } + return b +} + +func appendAlias(b []byte, j *join) []byte { + if j.hasParent() { + b = appendAlias(b, j.Parent) + b = append(b, "__"...) + } + b = append(b, j.Rel.Field.SQLName...) + return b +} + +func (j *join) appendHasOneColumns(b []byte) []byte { + if j.Columns == nil { + for i, f := range j.JoinModel.Table().Fields { + if i > 0 { + b = append(b, ", "...) + } + b = j.appendAlias(b) + b = append(b, '.') + b = append(b, f.Column...) + b = append(b, " AS "...) + b = j.appendAliasColumn(b, f.SQLName) + } + return b + } + + for i, column := range j.Columns { + if i > 0 { + b = append(b, ", "...) + } + b = j.appendAlias(b) + b = append(b, '.') + b = types.AppendIdent(b, column, 1) + b = append(b, " AS "...) + b = j.appendAliasColumn(b, column) + } + + return b +} + +func (j *join) appendHasOneJoin(fmter QueryFormatter, b []byte, q *Query) (_ []byte, err error) { + isSoftDelete := j.JoinModel.Table().SoftDeleteField != nil && !q.hasFlag(allWithDeletedFlag) + + b = append(b, "LEFT JOIN "...) + b = fmter.FormatQuery(b, string(j.JoinModel.Table().SQLNameForSelects)) + b = append(b, " AS "...) + b = j.appendAlias(b) + + b = append(b, " ON "...) + + if isSoftDelete { + b = append(b, '(') + } + + if len(j.Rel.BaseFKs) > 1 { + b = append(b, '(') + } + for i, baseFK := range j.Rel.BaseFKs { + if i > 0 { + b = append(b, " AND "...) + } + b = j.appendAlias(b) + b = append(b, '.') + b = append(b, j.Rel.JoinFKs[i].Column...) + b = append(b, " = "...) + b = j.appendBaseAlias(b) + b = append(b, '.') + b = append(b, baseFK.Column...) + } + if len(j.Rel.BaseFKs) > 1 { + b = append(b, ')') + } + + for _, on := range j.on { + b = on.AppendSep(b) + b, err = on.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + + if isSoftDelete { + b = append(b, ')') + } + + if isSoftDelete { + b = append(b, " AND "...) + b = j.appendAlias(b) + b = j.appendSoftDelete(b, q.flags) + } + + return b, nil +} + +type hasManyColumnsAppender struct { + *join +} + +var _ QueryAppender = (*hasManyColumnsAppender)(nil) + +func (q *hasManyColumnsAppender) AppendQuery(fmter QueryFormatter, b []byte) ([]byte, error) { + if q.Rel.M2MTableAlias != "" { + b = append(b, q.Rel.M2MTableAlias...) + b = append(b, ".*, "...) + } + + joinTable := q.JoinModel.Table() + + if q.Columns != nil { + for i, column := range q.Columns { + if i > 0 { + b = append(b, ", "...) + } + b = append(b, joinTable.Alias...) + b = append(b, '.') + b = types.AppendIdent(b, column, 1) + } + return b, nil + } + + b = appendColumns(b, joinTable.Alias, joinTable.Fields) + return b, nil +} + +func appendChildValues(b []byte, v reflect.Value, index []int, fields []*Field) []byte { + seen := make(map[string]struct{}) + walk(v, index, func(v reflect.Value) { + start := len(b) + + if len(fields) > 1 { + b = append(b, '(') + } + for i, f := range fields { + if i > 0 { + b = append(b, ", "...) + } + b = f.AppendValue(b, v, 1) + } + if len(fields) > 1 { + b = append(b, ')') + } + b = append(b, ", "...) + + if _, ok := seen[string(b[start:])]; ok { + b = b[:start] + } else { + seen[string(b[start:])] = struct{}{} + } + }) + if len(seen) > 0 { + b = b[:len(b)-2] // trim ", " + } + return b +} |