summaryrefslogtreecommitdiff
path: root/vendor/github.com/uptrace/bun/query_base.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/uptrace/bun/query_base.go')
-rw-r--r--vendor/github.com/uptrace/bun/query_base.go50
1 files changed, 23 insertions, 27 deletions
diff --git a/vendor/github.com/uptrace/bun/query_base.go b/vendor/github.com/uptrace/bun/query_base.go
index 1a7c32720..83cbd2605 100644
--- a/vendor/github.com/uptrace/bun/query_base.go
+++ b/vendor/github.com/uptrace/bun/query_base.go
@@ -3,6 +3,7 @@ package bun
import (
"context"
"database/sql"
+ "database/sql/driver"
"errors"
"fmt"
@@ -262,7 +263,10 @@ func (q *baseQuery) _excludeColumn(column string) bool {
//------------------------------------------------------------------------------
func (q *baseQuery) modelHasTableName() bool {
- return !q.modelTable.IsZero() || q.table != nil
+ if !q.modelTable.IsZero() {
+ return q.modelTable.Query != ""
+ }
+ return q.table != nil
}
func (q *baseQuery) hasTables() bool {
@@ -387,18 +391,10 @@ func (q *baseQuery) appendColumns(fmter schema.Formatter, b []byte) (_ []byte, e
}
func (q *baseQuery) getFields() ([]*schema.Field, error) {
- table := q.tableModel.Table()
-
if len(q.columns) == 0 {
- return table.Fields, nil
+ return q.table.Fields, nil
}
-
- fields, err := q._getFields(false)
- if err != nil {
- return nil, err
- }
-
- return fields, nil
+ return q._getFields(false)
}
func (q *baseQuery) getDataFields() ([]*schema.Field, error) {
@@ -435,28 +431,28 @@ func (q *baseQuery) scan(
query string,
model model,
hasDest bool,
-) (res result, _ error) {
+) (sql.Result, error) {
ctx, event := q.db.beforeQuery(ctx, queryApp, query, nil)
rows, err := q.conn.QueryContext(ctx, query)
if err != nil {
q.db.afterQuery(ctx, event, nil, err)
- return res, err
+ return nil, err
}
defer rows.Close()
- n, err := model.ScanRows(ctx, rows)
+ numRow, err := model.ScanRows(ctx, rows)
if err != nil {
q.db.afterQuery(ctx, event, nil, err)
- return res, err
+ return nil, err
}
- res.n = n
- if n == 0 && hasDest && isSingleRowModel(model) {
+ if numRow == 0 && hasDest && isSingleRowModel(model) {
err = sql.ErrNoRows
}
- q.db.afterQuery(ctx, event, nil, err)
+ res := driver.RowsAffected(numRow)
+ q.db.afterQuery(ctx, event, res, err)
return res, err
}
@@ -465,18 +461,16 @@ func (q *baseQuery) exec(
ctx context.Context,
queryApp schema.QueryAppender,
query string,
-) (res result, _ error) {
+) (sql.Result, error) {
ctx, event := q.db.beforeQuery(ctx, queryApp, query, nil)
- r, err := q.conn.ExecContext(ctx, query)
+ res, err := q.conn.ExecContext(ctx, query)
if err != nil {
q.db.afterQuery(ctx, event, nil, err)
return res, err
}
- res.r = r
-
- q.db.afterQuery(ctx, event, nil, err)
+ q.db.afterQuery(ctx, event, res, err)
return res, nil
}
@@ -556,10 +550,12 @@ func (q *whereBaseQuery) addWhereGroup(sep string, where []schema.QueryWithSep)
return
}
- where[0].Sep = ""
+ q.addWhere(schema.SafeQueryWithSep("", nil, sep))
+ q.addWhere(schema.SafeQueryWithSep("", nil, "("))
- q.addWhere(schema.SafeQueryWithSep("", nil, sep+"("))
+ where[0].Sep = ""
q.where = append(q.where, where...)
+
q.addWhere(schema.SafeQueryWithSep("", nil, ")"))
}
@@ -623,11 +619,11 @@ func appendWhere(
fmter schema.Formatter, b []byte, where []schema.QueryWithSep,
) (_ []byte, err error) {
for i, where := range where {
- if i > 0 || where.Sep == "(" {
+ if i > 0 {
b = append(b, where.Sep...)
}
- if where.Query == "" && where.Args == nil {
+ if where.Query == "" {
continue
}