summaryrefslogtreecommitdiff
path: root/vendor/github.com/uptrace/bun/query_base.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/uptrace/bun/query_base.go')
-rw-r--r--vendor/github.com/uptrace/bun/query_base.go47
1 files changed, 29 insertions, 18 deletions
diff --git a/vendor/github.com/uptrace/bun/query_base.go b/vendor/github.com/uptrace/bun/query_base.go
index 08ff8e5d9..b17498742 100644
--- a/vendor/github.com/uptrace/bun/query_base.go
+++ b/vendor/github.com/uptrace/bun/query_base.go
@@ -24,7 +24,7 @@ const (
type withQuery struct {
name string
- query schema.QueryAppender
+ query Query
recursive bool
}
@@ -114,8 +114,16 @@ func (q *baseQuery) DB() *DB {
return q.db
}
-func (q *baseQuery) GetConn() IConn {
- return q.conn
+func (q *baseQuery) resolveConn(query Query) IConn {
+ if q.conn != nil {
+ return q.conn
+ }
+ if q.db.resolver != nil {
+ if conn := q.db.resolver.ResolveConn(query); conn != nil {
+ return conn
+ }
+ }
+ return q.db.DB
}
func (q *baseQuery) GetModel() Model {
@@ -128,10 +136,8 @@ func (q *baseQuery) GetTableName() string {
}
for _, wq := range q.with {
- if v, ok := wq.query.(Query); ok {
- if model := v.GetModel(); model != nil {
- return v.GetTableName()
- }
+ if model := wq.query.GetModel(); model != nil {
+ return wq.query.GetTableName()
}
}
@@ -249,7 +255,7 @@ func (q *baseQuery) isSoftDelete() bool {
//------------------------------------------------------------------------------
-func (q *baseQuery) addWith(name string, query schema.QueryAppender, recursive bool) {
+func (q *baseQuery) addWith(name string, query Query, recursive bool) {
q.with = append(q.with, withQuery{
name: name,
query: query,
@@ -565,28 +571,33 @@ func (q *baseQuery) scan(
hasDest bool,
) (sql.Result, error) {
ctx, event := q.db.beforeQuery(ctx, iquery, query, nil, query, q.model)
+ res, err := q._scan(ctx, iquery, query, model, hasDest)
+ q.db.afterQuery(ctx, event, res, err)
+ return res, err
+}
- rows, err := q.conn.QueryContext(ctx, query)
+func (q *baseQuery) _scan(
+ ctx context.Context,
+ iquery Query,
+ query string,
+ model Model,
+ hasDest bool,
+) (sql.Result, error) {
+ rows, err := q.resolveConn(iquery).QueryContext(ctx, query)
if err != nil {
- q.db.afterQuery(ctx, event, nil, err)
return nil, err
}
defer rows.Close()
numRow, err := model.ScanRows(ctx, rows)
if err != nil {
- q.db.afterQuery(ctx, event, nil, err)
return nil, err
}
if numRow == 0 && hasDest && isSingleRowModel(model) {
- err = sql.ErrNoRows
+ return nil, sql.ErrNoRows
}
-
- res := driver.RowsAffected(numRow)
- q.db.afterQuery(ctx, event, res, err)
-
- return res, err
+ return driver.RowsAffected(numRow), nil
}
func (q *baseQuery) exec(
@@ -595,7 +606,7 @@ func (q *baseQuery) exec(
query string,
) (sql.Result, error) {
ctx, event := q.db.beforeQuery(ctx, iquery, query, nil, query, q.model)
- res, err := q.conn.ExecContext(ctx, query)
+ res, err := q.resolveConn(iquery).ExecContext(ctx, query)
q.db.afterQuery(ctx, event, res, err)
return res, err
}