summaryrefslogtreecommitdiff
path: root/vendor/github.com/uptrace/bun/query_base.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/uptrace/bun/query_base.go')
-rw-r--r--vendor/github.com/uptrace/bun/query_base.go47
1 files changed, 38 insertions, 9 deletions
diff --git a/vendor/github.com/uptrace/bun/query_base.go b/vendor/github.com/uptrace/bun/query_base.go
index 4e1151dbe..60226890f 100644
--- a/vendor/github.com/uptrace/bun/query_base.go
+++ b/vendor/github.com/uptrace/bun/query_base.go
@@ -68,10 +68,10 @@ type baseQuery struct {
db *DB
conn IConn
- model model
+ model Model
err error
- tableModel tableModel
+ tableModel TableModel
table *schema.Table
with []withQuery
@@ -86,10 +86,39 @@ func (q *baseQuery) DB() *DB {
return q.db
}
+type query interface {
+ GetModel() Model
+ GetTableName() string
+}
+
+var _ query = (*baseQuery)(nil)
+
func (q *baseQuery) GetModel() Model {
return q.model
}
+func (q *baseQuery) GetTableName() string {
+ if q.table != nil {
+ return q.table.Name
+ }
+
+ for _, wq := range q.with {
+ if v, ok := wq.query.(query); ok {
+ if model := v.GetModel(); model != nil {
+ return v.GetTableName()
+ }
+ }
+ }
+
+ if q.modelTable.Query != "" {
+ return q.modelTable.Query
+ }
+ if len(q.tables) > 0 {
+ return q.tables[0].Query
+ }
+ return ""
+}
+
func (q *baseQuery) setConn(db IConn) {
// Unwrap Bun wrappers to not call query hooks twice.
switch db := db.(type) {
@@ -113,7 +142,7 @@ func (q *baseQuery) setTableModel(modeli interface{}) {
}
q.model = model
- if tm, ok := model.(tableModel); ok {
+ if tm, ok := model.(TableModel); ok {
q.tableModel = tm
q.table = tm.Table()
}
@@ -125,7 +154,7 @@ func (q *baseQuery) setErr(err error) {
}
}
-func (q *baseQuery) getModel(dest []interface{}) (model, error) {
+func (q *baseQuery) getModel(dest []interface{}) (Model, error) {
if len(dest) == 0 {
if q.model != nil {
return q.model, nil
@@ -427,12 +456,12 @@ func (q *baseQuery) _getFields(omitPK bool) ([]*schema.Field, error) {
func (q *baseQuery) scan(
ctx context.Context,
- queryApp schema.Query,
+ iquery IQuery,
query string,
- model model,
+ model Model,
hasDest bool,
) (sql.Result, error) {
- ctx, event := q.db.beforeQuery(ctx, queryApp, query, nil)
+ ctx, event := q.db.beforeQuery(ctx, iquery, query, nil, q.model)
rows, err := q.conn.QueryContext(ctx, query)
if err != nil {
@@ -459,10 +488,10 @@ func (q *baseQuery) scan(
func (q *baseQuery) exec(
ctx context.Context,
- queryApp schema.Query,
+ iquery IQuery,
query string,
) (sql.Result, error) {
- ctx, event := q.db.beforeQuery(ctx, queryApp, query, nil)
+ ctx, event := q.db.beforeQuery(ctx, iquery, query, nil, q.model)
res, err := q.conn.ExecContext(ctx, query)
if err != nil {