summaryrefslogtreecommitdiff
path: root/vendor/github.com/uptrace/bun/query_select.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/uptrace/bun/query_select.go')
-rw-r--r--vendor/github.com/uptrace/bun/query_select.go55
1 files changed, 30 insertions, 25 deletions
diff --git a/vendor/github.com/uptrace/bun/query_select.go b/vendor/github.com/uptrace/bun/query_select.go
index 1f63686ad..7ff93366f 100644
--- a/vendor/github.com/uptrace/bun/query_select.go
+++ b/vendor/github.com/uptrace/bun/query_select.go
@@ -286,41 +286,38 @@ func (q *SelectQuery) joinOn(cond string, args []interface{}, sep string) *Selec
//------------------------------------------------------------------------------
-// Relation adds a relation to the query. Relation name can be:
-// - RelationName to select all columns,
-// - RelationName.column_name,
-// - RelationName._ to join relation without selecting relation columns.
+// Relation adds a relation to the query.
func (q *SelectQuery) Relation(name string, apply ...func(*SelectQuery) *SelectQuery) *SelectQuery {
+ if len(apply) > 1 {
+ panic("only one apply function is supported")
+ }
+
if q.tableModel == nil {
q.setErr(errNilModel)
return q
}
- var fn func(*SelectQuery) *SelectQuery
-
- if len(apply) == 1 {
- fn = apply[0]
- } else if len(apply) > 1 {
- panic("only one apply function is supported")
- }
-
- join := q.tableModel.Join(name, fn)
+ join := q.tableModel.Join(name)
if join == nil {
q.setErr(fmt.Errorf("%s does not have relation=%q", q.table, name))
return q
}
+ if len(apply) == 1 {
+ join.apply = apply[0]
+ }
+
return q
}
-func (q *SelectQuery) forEachHasOneJoin(fn func(*join) error) error {
+func (q *SelectQuery) forEachHasOneJoin(fn func(*relationJoin) error) error {
if q.tableModel == nil {
return nil
}
return q._forEachHasOneJoin(fn, q.tableModel.GetJoins())
}
-func (q *SelectQuery) _forEachHasOneJoin(fn func(*join) error, joins []join) error {
+func (q *SelectQuery) _forEachHasOneJoin(fn func(*relationJoin) error, joins []relationJoin) error {
for i := range joins {
j := &joins[i]
switch j.Relation.Type {
@@ -336,16 +333,23 @@ func (q *SelectQuery) _forEachHasOneJoin(fn func(*join) error, joins []join) err
return nil
}
-func (q *SelectQuery) selectJoins(ctx context.Context, joins []join) error {
- var err error
+func (q *SelectQuery) selectJoins(ctx context.Context, joins []relationJoin) error {
for i := range joins {
j := &joins[i]
+
+ var err error
+
switch j.Relation.Type {
case schema.HasOneRelation, schema.BelongsToRelation:
err = q.selectJoins(ctx, j.JoinModel.GetJoins())
+ case schema.HasManyRelation:
+ err = j.selectMany(ctx, q.db.NewSelect())
+ case schema.ManyToManyRelation:
+ err = j.selectM2M(ctx, q.db.NewSelect())
default:
- err = j.Select(ctx, q.db.NewSelect())
+ panic("not reached")
}
+
if err != nil {
return err
}
@@ -415,7 +419,7 @@ func (q *SelectQuery) appendQuery(
}
}
- if err := q.forEachHasOneJoin(func(j *join) error {
+ if err := q.forEachHasOneJoin(func(j *relationJoin) error {
b = append(b, ' ')
b, err = j.appendHasOneJoin(fmter, b, q)
return err
@@ -545,13 +549,13 @@ func (q *SelectQuery) appendColumns(fmter schema.Formatter, b []byte) (_ []byte,
b = append(b, '*')
}
- if err := q.forEachHasOneJoin(func(j *join) error {
+ if err := q.forEachHasOneJoin(func(join *relationJoin) error {
if len(b) != start {
b = append(b, ", "...)
start = len(b)
}
- b, err = q.appendHasOneColumns(fmter, b, j)
+ b, err = q.appendHasOneColumns(fmter, b, join)
if err != nil {
return err
}
@@ -567,18 +571,19 @@ func (q *SelectQuery) appendColumns(fmter schema.Formatter, b []byte) (_ []byte,
}
func (q *SelectQuery) appendHasOneColumns(
- fmter schema.Formatter, b []byte, join *join,
+ fmter schema.Formatter, b []byte, join *relationJoin,
) (_ []byte, err error) {
- join.applyQuery(q)
+ join.applyTo(q)
if join.columns != nil {
+ table := join.JoinModel.Table()
for i, col := range join.columns {
if i > 0 {
b = append(b, ", "...)
}
if col.Args == nil {
- if field, ok := q.table.FieldMap[col.Query]; ok {
+ if field, ok := table.FieldMap[col.Query]; ok {
b = join.appendAlias(fmter, b)
b = append(b, '.')
b = append(b, field.SQLName...)
@@ -691,7 +696,7 @@ func (q *SelectQuery) Scan(ctx context.Context, dest ...interface{}) error {
return err
}
- if res.n > 0 {
+ if n, _ := res.RowsAffected(); n > 0 {
if tableModel, ok := model.(tableModel); ok {
if err := q.selectJoins(ctx, tableModel.GetJoins()); err != nil {
return err