diff options
author | 2021-08-25 15:34:33 +0200 | |
---|---|---|
committer | 2021-08-25 15:34:33 +0200 | |
commit | 2dc9fc1626507bb54417fc4a1920b847cafb27a2 (patch) | |
tree | 4ddeac479b923db38090aac8bd9209f3646851c1 /vendor/github.com/uptrace/bun/query_base.go | |
parent | Manually approves followers (#146) (diff) | |
download | gotosocial-2dc9fc1626507bb54417fc4a1920b847cafb27a2.tar.xz |
Pg to bun (#148)
* start moving to bun
* changing more stuff
* more
* and yet more
* tests passing
* seems stable now
* more big changes
* small fix
* little fixes
Diffstat (limited to 'vendor/github.com/uptrace/bun/query_base.go')
-rw-r--r-- | vendor/github.com/uptrace/bun/query_base.go | 874 |
1 files changed, 874 insertions, 0 deletions
diff --git a/vendor/github.com/uptrace/bun/query_base.go b/vendor/github.com/uptrace/bun/query_base.go new file mode 100644 index 000000000..1a7c32720 --- /dev/null +++ b/vendor/github.com/uptrace/bun/query_base.go @@ -0,0 +1,874 @@ +package bun + +import ( + "context" + "database/sql" + "errors" + "fmt" + + "github.com/uptrace/bun/dialect/feature" + "github.com/uptrace/bun/internal" + "github.com/uptrace/bun/schema" +) + +const ( + wherePKFlag internal.Flag = 1 << iota + forceDeleteFlag + deletedFlag + allWithDeletedFlag +) + +type withQuery struct { + name string + query schema.QueryAppender +} + +// IConn is a common interface for *sql.DB, *sql.Conn, and *sql.Tx. +type IConn interface { + QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) + ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) + QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row +} + +var ( + _ IConn = (*sql.DB)(nil) + _ IConn = (*sql.Conn)(nil) + _ IConn = (*sql.Tx)(nil) + _ IConn = (*DB)(nil) + _ IConn = (*Conn)(nil) + _ IConn = (*Tx)(nil) +) + +// IDB is a common interface for *bun.DB, bun.Conn, and bun.Tx. +type IDB interface { + IConn + + NewValues(model interface{}) *ValuesQuery + NewSelect() *SelectQuery + NewInsert() *InsertQuery + NewUpdate() *UpdateQuery + NewDelete() *DeleteQuery + NewCreateTable() *CreateTableQuery + NewDropTable() *DropTableQuery + NewCreateIndex() *CreateIndexQuery + NewDropIndex() *DropIndexQuery + NewTruncateTable() *TruncateTableQuery + NewAddColumn() *AddColumnQuery + NewDropColumn() *DropColumnQuery +} + +var ( + _ IConn = (*DB)(nil) + _ IConn = (*Conn)(nil) + _ IConn = (*Tx)(nil) +) + +type baseQuery struct { + db *DB + conn IConn + + model model + err error + + tableModel tableModel + table *schema.Table + + with []withQuery + modelTable schema.QueryWithArgs + tables []schema.QueryWithArgs + columns []schema.QueryWithArgs + + flags internal.Flag +} + +func (q *baseQuery) DB() *DB { + return q.db +} + +func (q *baseQuery) GetModel() Model { + return q.model +} + +func (q *baseQuery) setConn(db IConn) { + // Unwrap Bun wrappers to not call query hooks twice. + switch db := db.(type) { + case *DB: + q.conn = db.DB + case Conn: + q.conn = db.Conn + case Tx: + q.conn = db.Tx + default: + q.conn = db + } +} + +// TODO: rename to setModel +func (q *baseQuery) setTableModel(modeli interface{}) { + model, err := newSingleModel(q.db, modeli) + if err != nil { + q.setErr(err) + return + } + + q.model = model + if tm, ok := model.(tableModel); ok { + q.tableModel = tm + q.table = tm.Table() + } +} + +func (q *baseQuery) setErr(err error) { + if q.err == nil { + q.err = err + } +} + +func (q *baseQuery) getModel(dest []interface{}) (model, error) { + if len(dest) == 0 { + if q.model != nil { + return q.model, nil + } + return nil, errNilModel + } + return newModel(q.db, dest) +} + +//------------------------------------------------------------------------------ + +func (q *baseQuery) checkSoftDelete() error { + if q.table == nil { + return errors.New("bun: can't use soft deletes without a table") + } + if q.table.SoftDeleteField == nil { + return fmt.Errorf("%s does not have a soft delete field", q.table) + } + if q.tableModel == nil { + return errors.New("bun: can't use soft deletes without a table model") + } + return nil +} + +// Deleted adds `WHERE deleted_at IS NOT NULL` clause for soft deleted models. +func (q *baseQuery) whereDeleted() { + if err := q.checkSoftDelete(); err != nil { + q.setErr(err) + return + } + q.flags = q.flags.Set(deletedFlag) + q.flags = q.flags.Remove(allWithDeletedFlag) +} + +// AllWithDeleted changes query to return all rows including soft deleted ones. +func (q *baseQuery) whereAllWithDeleted() { + if err := q.checkSoftDelete(); err != nil { + q.setErr(err) + return + } + q.flags = q.flags.Set(allWithDeletedFlag) + q.flags = q.flags.Remove(deletedFlag) +} + +func (q *baseQuery) isSoftDelete() bool { + if q.table != nil { + return q.table.SoftDeleteField != nil && !q.flags.Has(allWithDeletedFlag) + } + return false +} + +//------------------------------------------------------------------------------ + +func (q *baseQuery) addWith(name string, query schema.QueryAppender) { + q.with = append(q.with, withQuery{ + name: name, + query: query, + }) +} + +func (q *baseQuery) appendWith(fmter schema.Formatter, b []byte) (_ []byte, err error) { + if len(q.with) == 0 { + return b, nil + } + + b = append(b, "WITH "...) + for i, with := range q.with { + if i > 0 { + b = append(b, ", "...) + } + + b = fmter.AppendIdent(b, with.name) + if q, ok := with.query.(schema.ColumnsAppender); ok { + b = append(b, " ("...) + b, err = q.AppendColumns(fmter, b) + if err != nil { + return nil, err + } + b = append(b, ")"...) + } + + b = append(b, " AS ("...) + + b, err = with.query.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + + b = append(b, ')') + } + b = append(b, ' ') + return b, nil +} + +//------------------------------------------------------------------------------ + +func (q *baseQuery) addTable(table schema.QueryWithArgs) { + q.tables = append(q.tables, table) +} + +func (q *baseQuery) addColumn(column schema.QueryWithArgs) { + q.columns = append(q.columns, column) +} + +func (q *baseQuery) excludeColumn(columns []string) { + if q.columns == nil { + for _, f := range q.table.Fields { + q.columns = append(q.columns, schema.UnsafeIdent(f.Name)) + } + } + + if len(columns) == 1 && columns[0] == "*" { + q.columns = make([]schema.QueryWithArgs, 0) + return + } + + for _, column := range columns { + if !q._excludeColumn(column) { + q.setErr(fmt.Errorf("bun: can't find column=%q", column)) + return + } + } +} + +func (q *baseQuery) _excludeColumn(column string) bool { + for i, col := range q.columns { + if col.Args == nil && col.Query == column { + q.columns = append(q.columns[:i], q.columns[i+1:]...) + return true + } + } + return false +} + +//------------------------------------------------------------------------------ + +func (q *baseQuery) modelHasTableName() bool { + return !q.modelTable.IsZero() || q.table != nil +} + +func (q *baseQuery) hasTables() bool { + return q.modelHasTableName() || len(q.tables) > 0 +} + +func (q *baseQuery) appendTables( + fmter schema.Formatter, b []byte, +) (_ []byte, err error) { + return q._appendTables(fmter, b, false) +} + +func (q *baseQuery) appendTablesWithAlias( + fmter schema.Formatter, b []byte, +) (_ []byte, err error) { + return q._appendTables(fmter, b, true) +} + +func (q *baseQuery) _appendTables( + fmter schema.Formatter, b []byte, withAlias bool, +) (_ []byte, err error) { + startLen := len(b) + + if q.modelHasTableName() { + if !q.modelTable.IsZero() { + b, err = q.modelTable.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } else { + b = fmter.AppendQuery(b, string(q.table.SQLNameForSelects)) + if withAlias && q.table.SQLAlias != q.table.SQLNameForSelects { + b = append(b, " AS "...) + b = append(b, q.table.SQLAlias...) + } + } + } + + for _, table := range q.tables { + if len(b) > startLen { + b = append(b, ", "...) + } + b, err = table.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + + return b, nil +} + +func (q *baseQuery) appendFirstTable(fmter schema.Formatter, b []byte) ([]byte, error) { + return q._appendFirstTable(fmter, b, false) +} + +func (q *baseQuery) appendFirstTableWithAlias( + fmter schema.Formatter, b []byte, +) ([]byte, error) { + return q._appendFirstTable(fmter, b, true) +} + +func (q *baseQuery) _appendFirstTable( + fmter schema.Formatter, b []byte, withAlias bool, +) ([]byte, error) { + if !q.modelTable.IsZero() { + return q.modelTable.AppendQuery(fmter, b) + } + + if q.table != nil { + b = fmter.AppendQuery(b, string(q.table.SQLName)) + if withAlias { + b = append(b, " AS "...) + b = append(b, q.table.SQLAlias...) + } + return b, nil + } + + if len(q.tables) > 0 { + return q.tables[0].AppendQuery(fmter, b) + } + + return nil, errors.New("bun: query does not have a table") +} + +func (q *baseQuery) hasMultiTables() bool { + if q.modelHasTableName() { + return len(q.tables) >= 1 + } + return len(q.tables) >= 2 +} + +func (q *baseQuery) appendOtherTables(fmter schema.Formatter, b []byte) (_ []byte, err error) { + tables := q.tables + if !q.modelHasTableName() { + tables = tables[1:] + } + for i, table := range tables { + if i > 0 { + b = append(b, ", "...) + } + b, err = table.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + return b, nil +} + +//------------------------------------------------------------------------------ + +func (q *baseQuery) appendColumns(fmter schema.Formatter, b []byte) (_ []byte, err error) { + for i, f := range q.columns { + if i > 0 { + b = append(b, ", "...) + } + b, err = f.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + return b, nil +} + +func (q *baseQuery) getFields() ([]*schema.Field, error) { + table := q.tableModel.Table() + + if len(q.columns) == 0 { + return table.Fields, nil + } + + fields, err := q._getFields(false) + if err != nil { + return nil, err + } + + return fields, nil +} + +func (q *baseQuery) getDataFields() ([]*schema.Field, error) { + if len(q.columns) == 0 { + return q.table.DataFields, nil + } + return q._getFields(true) +} + +func (q *baseQuery) _getFields(omitPK bool) ([]*schema.Field, error) { + fields := make([]*schema.Field, 0, len(q.columns)) + for _, col := range q.columns { + if col.Args != nil { + continue + } + + field, err := q.table.Field(col.Query) + if err != nil { + return nil, err + } + + if omitPK && field.IsPK { + continue + } + + fields = append(fields, field) + } + return fields, nil +} + +func (q *baseQuery) scan( + ctx context.Context, + queryApp schema.QueryAppender, + query string, + model model, + hasDest bool, +) (res result, _ error) { + ctx, event := q.db.beforeQuery(ctx, queryApp, query, nil) + + rows, err := q.conn.QueryContext(ctx, query) + if err != nil { + q.db.afterQuery(ctx, event, nil, err) + return res, err + } + defer rows.Close() + + n, err := model.ScanRows(ctx, rows) + if err != nil { + q.db.afterQuery(ctx, event, nil, err) + return res, err + } + + res.n = n + if n == 0 && hasDest && isSingleRowModel(model) { + err = sql.ErrNoRows + } + + q.db.afterQuery(ctx, event, nil, err) + + return res, err +} + +func (q *baseQuery) exec( + ctx context.Context, + queryApp schema.QueryAppender, + query string, +) (res result, _ error) { + ctx, event := q.db.beforeQuery(ctx, queryApp, query, nil) + + r, err := q.conn.ExecContext(ctx, query) + if err != nil { + q.db.afterQuery(ctx, event, nil, err) + return res, err + } + + res.r = r + + q.db.afterQuery(ctx, event, nil, err) + return res, nil +} + +//------------------------------------------------------------------------------ + +func (q *baseQuery) AppendNamedArg(fmter schema.Formatter, b []byte, name string) ([]byte, bool) { + if q.table == nil { + return b, false + } + + if m, ok := q.tableModel.(*structTableModel); ok { + if b, ok := m.AppendNamedArg(fmter, b, name); ok { + return b, ok + } + } + + switch name { + case "TableName": + b = fmter.AppendQuery(b, string(q.table.SQLName)) + return b, true + case "TableAlias": + b = fmter.AppendQuery(b, string(q.table.SQLAlias)) + return b, true + case "PKs": + b = appendColumns(b, "", q.table.PKs) + return b, true + case "TablePKs": + b = appendColumns(b, q.table.SQLAlias, q.table.PKs) + return b, true + case "Columns": + b = appendColumns(b, "", q.table.Fields) + return b, true + case "TableColumns": + b = appendColumns(b, q.table.SQLAlias, q.table.Fields) + return b, true + } + + return b, false +} + +func appendColumns(b []byte, table schema.Safe, fields []*schema.Field) []byte { + for i, f := range fields { + if i > 0 { + b = append(b, ", "...) + } + + if len(table) > 0 { + b = append(b, table...) + b = append(b, '.') + } + b = append(b, f.SQLName...) + } + return b +} + +func formatterWithModel(fmter schema.Formatter, model schema.NamedArgAppender) schema.Formatter { + if fmter.IsNop() { + return fmter + } + return fmter.WithArg(model) +} + +//------------------------------------------------------------------------------ + +type whereBaseQuery struct { + baseQuery + + where []schema.QueryWithSep +} + +func (q *whereBaseQuery) addWhere(where schema.QueryWithSep) { + q.where = append(q.where, where) +} + +func (q *whereBaseQuery) addWhereGroup(sep string, where []schema.QueryWithSep) { + if len(where) == 0 { + return + } + + where[0].Sep = "" + + q.addWhere(schema.SafeQueryWithSep("", nil, sep+"(")) + q.where = append(q.where, where...) + q.addWhere(schema.SafeQueryWithSep("", nil, ")")) +} + +func (q *whereBaseQuery) mustAppendWhere( + fmter schema.Formatter, b []byte, withAlias bool, +) ([]byte, error) { + if len(q.where) == 0 && !q.flags.Has(wherePKFlag) { + err := errors.New("bun: Update and Delete queries require at least one Where") + return nil, err + } + return q.appendWhere(fmter, b, withAlias) +} + +func (q *whereBaseQuery) appendWhere( + fmter schema.Formatter, b []byte, withAlias bool, +) (_ []byte, err error) { + if len(q.where) == 0 && !q.isSoftDelete() && !q.flags.Has(wherePKFlag) { + return b, nil + } + + b = append(b, " WHERE "...) + startLen := len(b) + + if len(q.where) > 0 { + b, err = appendWhere(fmter, b, q.where) + if err != nil { + return nil, err + } + } + + if q.isSoftDelete() { + if len(b) > startLen { + b = append(b, " AND "...) + } + if withAlias { + b = append(b, q.tableModel.Table().SQLAlias...) + b = append(b, '.') + } + b = append(b, q.tableModel.Table().SoftDeleteField.SQLName...) + if q.flags.Has(deletedFlag) { + b = append(b, " IS NOT NULL"...) + } else { + b = append(b, " IS NULL"...) + } + } + + if q.flags.Has(wherePKFlag) { + if len(b) > startLen { + b = append(b, " AND "...) + } + b, err = q.appendWherePK(fmter, b, withAlias) + if err != nil { + return nil, err + } + } + + return b, nil +} + +func appendWhere( + fmter schema.Formatter, b []byte, where []schema.QueryWithSep, +) (_ []byte, err error) { + for i, where := range where { + if i > 0 || where.Sep == "(" { + b = append(b, where.Sep...) + } + + if where.Query == "" && where.Args == nil { + continue + } + + b = append(b, '(') + b, err = where.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + b = append(b, ')') + } + return b, nil +} + +func (q *whereBaseQuery) appendWherePK( + fmter schema.Formatter, b []byte, withAlias bool, +) (_ []byte, err error) { + if q.table == nil { + err := fmt.Errorf("bun: got %T, but WherePK requires a struct or slice-based model", q.model) + return nil, err + } + if err := q.table.CheckPKs(); err != nil { + return nil, err + } + + switch model := q.tableModel.(type) { + case *structTableModel: + return q.appendWherePKStruct(fmter, b, model, withAlias) + case *sliceTableModel: + return q.appendWherePKSlice(fmter, b, model, withAlias) + } + + return nil, fmt.Errorf("bun: WherePK does not support %T", q.tableModel) +} + +func (q *whereBaseQuery) appendWherePKStruct( + fmter schema.Formatter, b []byte, model *structTableModel, withAlias bool, +) (_ []byte, err error) { + if !model.strct.IsValid() { + return nil, errNilModel + } + + isTemplate := fmter.IsNop() + b = append(b, '(') + for i, f := range q.table.PKs { + if i > 0 { + b = append(b, " AND "...) + } + if withAlias { + b = append(b, q.table.SQLAlias...) + b = append(b, '.') + } + b = append(b, f.SQLName...) + b = append(b, " = "...) + if isTemplate { + b = append(b, '?') + } else { + b = f.AppendValue(fmter, b, model.strct) + } + } + b = append(b, ')') + return b, nil +} + +func (q *whereBaseQuery) appendWherePKSlice( + fmter schema.Formatter, b []byte, model *sliceTableModel, withAlias bool, +) (_ []byte, err error) { + if len(q.table.PKs) > 1 { + b = append(b, '(') + } + if withAlias { + b = appendColumns(b, q.table.SQLAlias, q.table.PKs) + } else { + b = appendColumns(b, "", q.table.PKs) + } + if len(q.table.PKs) > 1 { + b = append(b, ')') + } + + b = append(b, " IN ("...) + + isTemplate := fmter.IsNop() + slice := model.slice + sliceLen := slice.Len() + for i := 0; i < sliceLen; i++ { + if i > 0 { + if isTemplate { + break + } + b = append(b, ", "...) + } + + el := indirect(slice.Index(i)) + + if len(q.table.PKs) > 1 { + b = append(b, '(') + } + for i, f := range q.table.PKs { + if i > 0 { + b = append(b, ", "...) + } + if isTemplate { + b = append(b, '?') + } else { + b = f.AppendValue(fmter, b, el) + } + } + if len(q.table.PKs) > 1 { + b = append(b, ')') + } + } + + b = append(b, ')') + + return b, nil +} + +//------------------------------------------------------------------------------ + +type returningQuery struct { + returning []schema.QueryWithArgs + returningFields []*schema.Field +} + +func (q *returningQuery) addReturning(ret schema.QueryWithArgs) { + q.returning = append(q.returning, ret) +} + +func (q *returningQuery) addReturningField(field *schema.Field) { + if len(q.returning) > 0 { + return + } + for _, f := range q.returningFields { + if f == field { + return + } + } + q.returningFields = append(q.returningFields, field) +} + +func (q *returningQuery) hasReturning() bool { + if len(q.returning) == 1 { + switch q.returning[0].Query { + case "null", "NULL": + return false + } + } + return len(q.returning) > 0 || len(q.returningFields) > 0 +} + +func (q *returningQuery) appendReturning( + fmter schema.Formatter, b []byte, +) (_ []byte, err error) { + if !q.hasReturning() { + return b, nil + } + + b = append(b, " RETURNING "...) + + for i, f := range q.returning { + if i > 0 { + b = append(b, ", "...) + } + b, err = f.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + + if len(q.returning) > 0 { + return b, nil + } + + b = appendColumns(b, "", q.returningFields) + return b, nil +} + +//------------------------------------------------------------------------------ + +type columnValue struct { + column string + value schema.QueryWithArgs +} + +type customValueQuery struct { + modelValues map[string]schema.QueryWithArgs + extraValues []columnValue +} + +func (q *customValueQuery) addValue( + table *schema.Table, column string, value string, args []interface{}, +) { + if _, ok := table.FieldMap[column]; ok { + if q.modelValues == nil { + q.modelValues = make(map[string]schema.QueryWithArgs) + } + q.modelValues[column] = schema.SafeQuery(value, args) + } else { + q.extraValues = append(q.extraValues, columnValue{ + column: column, + value: schema.SafeQuery(value, args), + }) + } +} + +//------------------------------------------------------------------------------ + +type setQuery struct { + set []schema.QueryWithArgs +} + +func (q *setQuery) addSet(set schema.QueryWithArgs) { + q.set = append(q.set, set) +} + +func (q setQuery) appendSet(fmter schema.Formatter, b []byte) (_ []byte, err error) { + for i, f := range q.set { + if i > 0 { + b = append(b, ", "...) + } + b, err = f.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + return b, nil +} + +//------------------------------------------------------------------------------ + +type cascadeQuery struct { + restrict bool +} + +func (q cascadeQuery) appendCascade(fmter schema.Formatter, b []byte) []byte { + if !fmter.HasFeature(feature.TableCascade) { + return b + } + if q.restrict { + b = append(b, " RESTRICT"...) + } else { + b = append(b, " CASCADE"...) + } + return b +} |