diff options
Diffstat (limited to 'vendor/github.com/uptrace/bun/query_select.go')
-rw-r--r-- | vendor/github.com/uptrace/bun/query_select.go | 60 |
1 files changed, 57 insertions, 3 deletions
diff --git a/vendor/github.com/uptrace/bun/query_select.go b/vendor/github.com/uptrace/bun/query_select.go index 2b0872ae0..11761bb96 100644 --- a/vendor/github.com/uptrace/bun/query_select.go +++ b/vendor/github.com/uptrace/bun/query_select.go @@ -31,7 +31,8 @@ type SelectQuery struct { having []schema.QueryWithArgs selFor schema.QueryWithArgs - union []union + union []union + comment string } var _ Query = (*SelectQuery)(nil) @@ -381,6 +382,43 @@ func (q *SelectQuery) Relation(name string, apply ...func(*SelectQuery) *SelectQ return q } + q.applyToRelation(join, apply...) + + return q +} + +type RelationOpts struct { + // Apply applies additional options to the relation. + Apply func(*SelectQuery) *SelectQuery + // AdditionalJoinOnConditions adds additional conditions to the JOIN ON clause. + AdditionalJoinOnConditions []schema.QueryWithArgs +} + +// RelationWithOpts adds a relation to the query with additional options. +func (q *SelectQuery) RelationWithOpts(name string, opts RelationOpts) *SelectQuery { + if q.tableModel == nil { + q.setErr(errNilModel) + return q + } + + join := q.tableModel.join(name) + if join == nil { + q.setErr(fmt.Errorf("%s does not have relation=%q", q.table, name)) + return q + } + + if opts.Apply != nil { + q.applyToRelation(join, opts.Apply) + } + + if len(opts.AdditionalJoinOnConditions) > 0 { + join.additionalJoinOnConditions = opts.AdditionalJoinOnConditions + } + + return q +} + +func (q *SelectQuery) applyToRelation(join *relationJoin, apply ...func(*SelectQuery) *SelectQuery) { var apply1, apply2 func(*SelectQuery) *SelectQuery if len(join.Relation.Condition) > 0 { @@ -407,8 +445,6 @@ func (q *SelectQuery) Relation(name string, apply ...func(*SelectQuery) *SelectQ return q } - - return q } func (q *SelectQuery) forEachInlineRelJoin(fn func(*relationJoin) error) error { @@ -460,11 +496,21 @@ func (q *SelectQuery) selectJoins(ctx context.Context, joins []relationJoin) err //------------------------------------------------------------------------------ +// Comment adds a comment to the query, wrapped by /* ... */. +func (q *SelectQuery) Comment(comment string) *SelectQuery { + q.comment = comment + return q +} + +//------------------------------------------------------------------------------ + func (q *SelectQuery) Operation() string { return "SELECT" } func (q *SelectQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) { + b = appendComment(b, q.comment) + return q.appendQuery(fmter, b, false) } @@ -803,6 +849,14 @@ func (q *SelectQuery) scanResult(ctx context.Context, dest ...interface{}) (sql. if err != nil { return nil, err } + if len(dest) > 0 && q.tableModel != nil && len(q.tableModel.getJoins()) > 0 { + for _, j := range q.tableModel.getJoins() { + switch j.Relation.Type { + case schema.HasManyRelation, schema.ManyToManyRelation: + return nil, fmt.Errorf("When querying has-many or many-to-many relationships, you should use Model instead of the dest parameter in Scan.") + } + } + } if q.table != nil { if err := q.beforeSelectHook(ctx); err != nil { |