From 9a53b1a8d19da525ca7ace957b2d32f85dbe0fe9 Mon Sep 17 00:00:00 2001 From: tobi <31960611+tsmethurst@users.noreply.github.com> Date: Wed, 29 Sep 2021 15:09:45 +0200 Subject: upstep bun to v1.0.9 (#252) --- vendor/github.com/uptrace/bun/query_base.go | 47 +++++++++++++++++++++++------ 1 file changed, 38 insertions(+), 9 deletions(-) (limited to 'vendor/github.com/uptrace/bun/query_base.go') diff --git a/vendor/github.com/uptrace/bun/query_base.go b/vendor/github.com/uptrace/bun/query_base.go index 4e1151dbe..60226890f 100644 --- a/vendor/github.com/uptrace/bun/query_base.go +++ b/vendor/github.com/uptrace/bun/query_base.go @@ -68,10 +68,10 @@ type baseQuery struct { db *DB conn IConn - model model + model Model err error - tableModel tableModel + tableModel TableModel table *schema.Table with []withQuery @@ -86,10 +86,39 @@ func (q *baseQuery) DB() *DB { return q.db } +type query interface { + GetModel() Model + GetTableName() string +} + +var _ query = (*baseQuery)(nil) + func (q *baseQuery) GetModel() Model { return q.model } +func (q *baseQuery) GetTableName() string { + if q.table != nil { + return q.table.Name + } + + for _, wq := range q.with { + if v, ok := wq.query.(query); ok { + if model := v.GetModel(); model != nil { + return v.GetTableName() + } + } + } + + if q.modelTable.Query != "" { + return q.modelTable.Query + } + if len(q.tables) > 0 { + return q.tables[0].Query + } + return "" +} + func (q *baseQuery) setConn(db IConn) { // Unwrap Bun wrappers to not call query hooks twice. switch db := db.(type) { @@ -113,7 +142,7 @@ func (q *baseQuery) setTableModel(modeli interface{}) { } q.model = model - if tm, ok := model.(tableModel); ok { + if tm, ok := model.(TableModel); ok { q.tableModel = tm q.table = tm.Table() } @@ -125,7 +154,7 @@ func (q *baseQuery) setErr(err error) { } } -func (q *baseQuery) getModel(dest []interface{}) (model, error) { +func (q *baseQuery) getModel(dest []interface{}) (Model, error) { if len(dest) == 0 { if q.model != nil { return q.model, nil @@ -427,12 +456,12 @@ func (q *baseQuery) _getFields(omitPK bool) ([]*schema.Field, error) { func (q *baseQuery) scan( ctx context.Context, - queryApp schema.Query, + iquery IQuery, query string, - model model, + model Model, hasDest bool, ) (sql.Result, error) { - ctx, event := q.db.beforeQuery(ctx, queryApp, query, nil) + ctx, event := q.db.beforeQuery(ctx, iquery, query, nil, q.model) rows, err := q.conn.QueryContext(ctx, query) if err != nil { @@ -459,10 +488,10 @@ func (q *baseQuery) scan( func (q *baseQuery) exec( ctx context.Context, - queryApp schema.Query, + iquery IQuery, query string, ) (sql.Result, error) { - ctx, event := q.db.beforeQuery(ctx, queryApp, query, nil) + ctx, event := q.db.beforeQuery(ctx, iquery, query, nil, q.model) res, err := q.conn.ExecContext(ctx, query) if err != nil { -- cgit v1.2.3