diff options
Diffstat (limited to 'vendor/github.com/uptrace/bun/query_base.go')
-rw-r--r-- | vendor/github.com/uptrace/bun/query_base.go | 50 |
1 files changed, 23 insertions, 27 deletions
diff --git a/vendor/github.com/uptrace/bun/query_base.go b/vendor/github.com/uptrace/bun/query_base.go index 1a7c32720..83cbd2605 100644 --- a/vendor/github.com/uptrace/bun/query_base.go +++ b/vendor/github.com/uptrace/bun/query_base.go @@ -3,6 +3,7 @@ package bun import ( "context" "database/sql" + "database/sql/driver" "errors" "fmt" @@ -262,7 +263,10 @@ func (q *baseQuery) _excludeColumn(column string) bool { //------------------------------------------------------------------------------ func (q *baseQuery) modelHasTableName() bool { - return !q.modelTable.IsZero() || q.table != nil + if !q.modelTable.IsZero() { + return q.modelTable.Query != "" + } + return q.table != nil } func (q *baseQuery) hasTables() bool { @@ -387,18 +391,10 @@ func (q *baseQuery) appendColumns(fmter schema.Formatter, b []byte) (_ []byte, e } func (q *baseQuery) getFields() ([]*schema.Field, error) { - table := q.tableModel.Table() - if len(q.columns) == 0 { - return table.Fields, nil + return q.table.Fields, nil } - - fields, err := q._getFields(false) - if err != nil { - return nil, err - } - - return fields, nil + return q._getFields(false) } func (q *baseQuery) getDataFields() ([]*schema.Field, error) { @@ -435,28 +431,28 @@ func (q *baseQuery) scan( query string, model model, hasDest bool, -) (res result, _ error) { +) (sql.Result, error) { ctx, event := q.db.beforeQuery(ctx, queryApp, query, nil) rows, err := q.conn.QueryContext(ctx, query) if err != nil { q.db.afterQuery(ctx, event, nil, err) - return res, err + return nil, err } defer rows.Close() - n, err := model.ScanRows(ctx, rows) + numRow, err := model.ScanRows(ctx, rows) if err != nil { q.db.afterQuery(ctx, event, nil, err) - return res, err + return nil, err } - res.n = n - if n == 0 && hasDest && isSingleRowModel(model) { + if numRow == 0 && hasDest && isSingleRowModel(model) { err = sql.ErrNoRows } - q.db.afterQuery(ctx, event, nil, err) + res := driver.RowsAffected(numRow) + q.db.afterQuery(ctx, event, res, err) return res, err } @@ -465,18 +461,16 @@ func (q *baseQuery) exec( ctx context.Context, queryApp schema.QueryAppender, query string, -) (res result, _ error) { +) (sql.Result, error) { ctx, event := q.db.beforeQuery(ctx, queryApp, query, nil) - r, err := q.conn.ExecContext(ctx, query) + res, err := q.conn.ExecContext(ctx, query) if err != nil { q.db.afterQuery(ctx, event, nil, err) return res, err } - res.r = r - - q.db.afterQuery(ctx, event, nil, err) + q.db.afterQuery(ctx, event, res, err) return res, nil } @@ -556,10 +550,12 @@ func (q *whereBaseQuery) addWhereGroup(sep string, where []schema.QueryWithSep) return } - where[0].Sep = "" + q.addWhere(schema.SafeQueryWithSep("", nil, sep)) + q.addWhere(schema.SafeQueryWithSep("", nil, "(")) - q.addWhere(schema.SafeQueryWithSep("", nil, sep+"(")) + where[0].Sep = "" q.where = append(q.where, where...) + q.addWhere(schema.SafeQueryWithSep("", nil, ")")) } @@ -623,11 +619,11 @@ func appendWhere( fmter schema.Formatter, b []byte, where []schema.QueryWithSep, ) (_ []byte, err error) { for i, where := range where { - if i > 0 || where.Sep == "(" { + if i > 0 { b = append(b, where.Sep...) } - if where.Query == "" && where.Args == nil { + if where.Query == "" { continue } |