diff options
Diffstat (limited to 'vendor/github.com/uptrace/bun/query_select.go')
-rw-r--r-- | vendor/github.com/uptrace/bun/query_select.go | 55 |
1 files changed, 30 insertions, 25 deletions
diff --git a/vendor/github.com/uptrace/bun/query_select.go b/vendor/github.com/uptrace/bun/query_select.go index 1f63686ad..7ff93366f 100644 --- a/vendor/github.com/uptrace/bun/query_select.go +++ b/vendor/github.com/uptrace/bun/query_select.go @@ -286,41 +286,38 @@ func (q *SelectQuery) joinOn(cond string, args []interface{}, sep string) *Selec //------------------------------------------------------------------------------ -// Relation adds a relation to the query. Relation name can be: -// - RelationName to select all columns, -// - RelationName.column_name, -// - RelationName._ to join relation without selecting relation columns. +// Relation adds a relation to the query. func (q *SelectQuery) Relation(name string, apply ...func(*SelectQuery) *SelectQuery) *SelectQuery { + if len(apply) > 1 { + panic("only one apply function is supported") + } + if q.tableModel == nil { q.setErr(errNilModel) return q } - var fn func(*SelectQuery) *SelectQuery - - if len(apply) == 1 { - fn = apply[0] - } else if len(apply) > 1 { - panic("only one apply function is supported") - } - - join := q.tableModel.Join(name, fn) + join := q.tableModel.Join(name) if join == nil { q.setErr(fmt.Errorf("%s does not have relation=%q", q.table, name)) return q } + if len(apply) == 1 { + join.apply = apply[0] + } + return q } -func (q *SelectQuery) forEachHasOneJoin(fn func(*join) error) error { +func (q *SelectQuery) forEachHasOneJoin(fn func(*relationJoin) error) error { if q.tableModel == nil { return nil } return q._forEachHasOneJoin(fn, q.tableModel.GetJoins()) } -func (q *SelectQuery) _forEachHasOneJoin(fn func(*join) error, joins []join) error { +func (q *SelectQuery) _forEachHasOneJoin(fn func(*relationJoin) error, joins []relationJoin) error { for i := range joins { j := &joins[i] switch j.Relation.Type { @@ -336,16 +333,23 @@ func (q *SelectQuery) _forEachHasOneJoin(fn func(*join) error, joins []join) err return nil } -func (q *SelectQuery) selectJoins(ctx context.Context, joins []join) error { - var err error +func (q *SelectQuery) selectJoins(ctx context.Context, joins []relationJoin) error { for i := range joins { j := &joins[i] + + var err error + switch j.Relation.Type { case schema.HasOneRelation, schema.BelongsToRelation: err = q.selectJoins(ctx, j.JoinModel.GetJoins()) + case schema.HasManyRelation: + err = j.selectMany(ctx, q.db.NewSelect()) + case schema.ManyToManyRelation: + err = j.selectM2M(ctx, q.db.NewSelect()) default: - err = j.Select(ctx, q.db.NewSelect()) + panic("not reached") } + if err != nil { return err } @@ -415,7 +419,7 @@ func (q *SelectQuery) appendQuery( } } - if err := q.forEachHasOneJoin(func(j *join) error { + if err := q.forEachHasOneJoin(func(j *relationJoin) error { b = append(b, ' ') b, err = j.appendHasOneJoin(fmter, b, q) return err @@ -545,13 +549,13 @@ func (q *SelectQuery) appendColumns(fmter schema.Formatter, b []byte) (_ []byte, b = append(b, '*') } - if err := q.forEachHasOneJoin(func(j *join) error { + if err := q.forEachHasOneJoin(func(join *relationJoin) error { if len(b) != start { b = append(b, ", "...) start = len(b) } - b, err = q.appendHasOneColumns(fmter, b, j) + b, err = q.appendHasOneColumns(fmter, b, join) if err != nil { return err } @@ -567,18 +571,19 @@ func (q *SelectQuery) appendColumns(fmter schema.Formatter, b []byte) (_ []byte, } func (q *SelectQuery) appendHasOneColumns( - fmter schema.Formatter, b []byte, join *join, + fmter schema.Formatter, b []byte, join *relationJoin, ) (_ []byte, err error) { - join.applyQuery(q) + join.applyTo(q) if join.columns != nil { + table := join.JoinModel.Table() for i, col := range join.columns { if i > 0 { b = append(b, ", "...) } if col.Args == nil { - if field, ok := q.table.FieldMap[col.Query]; ok { + if field, ok := table.FieldMap[col.Query]; ok { b = join.appendAlias(fmter, b) b = append(b, '.') b = append(b, field.SQLName...) @@ -691,7 +696,7 @@ func (q *SelectQuery) Scan(ctx context.Context, dest ...interface{}) error { return err } - if res.n > 0 { + if n, _ := res.RowsAffected(); n > 0 { if tableModel, ok := model.(tableModel); ok { if err := q.selectJoins(ctx, tableModel.GetJoins()); err != nil { return err |