diff options
Diffstat (limited to 'vendor/github.com/jackc/pgx/v4/conn.go')
-rw-r--r-- | vendor/github.com/jackc/pgx/v4/conn.go | 850 |
1 files changed, 850 insertions, 0 deletions
diff --git a/vendor/github.com/jackc/pgx/v4/conn.go b/vendor/github.com/jackc/pgx/v4/conn.go new file mode 100644 index 000000000..9636f2fd6 --- /dev/null +++ b/vendor/github.com/jackc/pgx/v4/conn.go @@ -0,0 +1,850 @@ +package pgx + +import ( + "context" + "errors" + "fmt" + "strconv" + "strings" + "time" + + "github.com/jackc/pgconn" + "github.com/jackc/pgconn/stmtcache" + "github.com/jackc/pgproto3/v2" + "github.com/jackc/pgtype" + "github.com/jackc/pgx/v4/internal/sanitize" +) + +// ConnConfig contains all the options used to establish a connection. It must be created by ParseConfig and +// then it can be modified. A manually initialized ConnConfig will cause ConnectConfig to panic. +type ConnConfig struct { + pgconn.Config + Logger Logger + LogLevel LogLevel + + // Original connection string that was parsed into config. + connString string + + // BuildStatementCache creates the stmtcache.Cache implementation for connections created with this config. Set + // to nil to disable automatic prepared statements. + BuildStatementCache BuildStatementCacheFunc + + // PreferSimpleProtocol disables implicit prepared statement usage. By default pgx automatically uses the extended + // protocol. This can improve performance due to being able to use the binary format. It also does not rely on client + // side parameter sanitization. However, it does incur two round-trips per query (unless using a prepared statement) + // and may be incompatible proxies such as PGBouncer. Setting PreferSimpleProtocol causes the simple protocol to be + // used by default. The same functionality can be controlled on a per query basis by setting + // QueryExOptions.SimpleProtocol. + PreferSimpleProtocol bool + + createdByParseConfig bool // Used to enforce created by ParseConfig rule. +} + +// Copy returns a deep copy of the config that is safe to use and modify. +// The only exception is the tls.Config: +// according to the tls.Config docs it must not be modified after creation. +func (cc *ConnConfig) Copy() *ConnConfig { + newConfig := new(ConnConfig) + *newConfig = *cc + newConfig.Config = *newConfig.Config.Copy() + return newConfig +} + +func (cc *ConnConfig) ConnString() string { return cc.connString } + +// BuildStatementCacheFunc is a function that can be used to create a stmtcache.Cache implementation for connection. +type BuildStatementCacheFunc func(conn *pgconn.PgConn) stmtcache.Cache + +// Conn is a PostgreSQL connection handle. It is not safe for concurrent usage. Use a connection pool to manage access +// to multiple database connections from multiple goroutines. +type Conn struct { + pgConn *pgconn.PgConn + config *ConnConfig // config used when establishing this connection + preparedStatements map[string]*pgconn.StatementDescription + stmtcache stmtcache.Cache + logger Logger + logLevel LogLevel + + notifications []*pgconn.Notification + + doneChan chan struct{} + closedChan chan error + + connInfo *pgtype.ConnInfo + + wbuf []byte + preallocatedRows []connRows + eqb extendedQueryBuilder +} + +// Identifier a PostgreSQL identifier or name. Identifiers can be composed of +// multiple parts such as ["schema", "table"] or ["table", "column"]. +type Identifier []string + +// Sanitize returns a sanitized string safe for SQL interpolation. +func (ident Identifier) Sanitize() string { + parts := make([]string, len(ident)) + for i := range ident { + s := strings.ReplaceAll(ident[i], string([]byte{0}), "") + parts[i] = `"` + strings.ReplaceAll(s, `"`, `""`) + `"` + } + return strings.Join(parts, ".") +} + +// ErrNoRows occurs when rows are expected but none are returned. +var ErrNoRows = errors.New("no rows in result set") + +// ErrInvalidLogLevel occurs on attempt to set an invalid log level. +var ErrInvalidLogLevel = errors.New("invalid log level") + +// Connect establishes a connection with a PostgreSQL server with a connection string. See +// pgconn.Connect for details. +func Connect(ctx context.Context, connString string) (*Conn, error) { + connConfig, err := ParseConfig(connString) + if err != nil { + return nil, err + } + return connect(ctx, connConfig) +} + +// Connect establishes a connection with a PostgreSQL server with a configuration struct. connConfig must have been +// created by ParseConfig. +func ConnectConfig(ctx context.Context, connConfig *ConnConfig) (*Conn, error) { + return connect(ctx, connConfig) +} + +// ParseConfig creates a ConnConfig from a connection string. ParseConfig handles all options that pgconn.ParseConfig +// does. In addition, it accepts the following options: +// +// statement_cache_capacity +// The maximum size of the automatic statement cache. Set to 0 to disable automatic statement caching. Default: 512. +// +// statement_cache_mode +// Possible values: "prepare" and "describe". "prepare" will create prepared statements on the PostgreSQL server. +// "describe" will use the anonymous prepared statement to describe a statement without creating a statement on the +// server. "describe" is primarily useful when the environment does not allow prepared statements such as when +// running a connection pooler like PgBouncer. Default: "prepare" +// +// prefer_simple_protocol +// Possible values: "true" and "false". Use the simple protocol instead of extended protocol. Default: false +func ParseConfig(connString string) (*ConnConfig, error) { + config, err := pgconn.ParseConfig(connString) + if err != nil { + return nil, err + } + + var buildStatementCache BuildStatementCacheFunc + statementCacheCapacity := 512 + statementCacheMode := stmtcache.ModePrepare + if s, ok := config.RuntimeParams["statement_cache_capacity"]; ok { + delete(config.RuntimeParams, "statement_cache_capacity") + n, err := strconv.ParseInt(s, 10, 32) + if err != nil { + return nil, fmt.Errorf("cannot parse statement_cache_capacity: %w", err) + } + statementCacheCapacity = int(n) + } + + if s, ok := config.RuntimeParams["statement_cache_mode"]; ok { + delete(config.RuntimeParams, "statement_cache_mode") + switch s { + case "prepare": + statementCacheMode = stmtcache.ModePrepare + case "describe": + statementCacheMode = stmtcache.ModeDescribe + default: + return nil, fmt.Errorf("invalid statement_cache_mod: %s", s) + } + } + + if statementCacheCapacity > 0 { + buildStatementCache = func(conn *pgconn.PgConn) stmtcache.Cache { + return stmtcache.New(conn, statementCacheMode, statementCacheCapacity) + } + } + + preferSimpleProtocol := false + if s, ok := config.RuntimeParams["prefer_simple_protocol"]; ok { + delete(config.RuntimeParams, "prefer_simple_protocol") + if b, err := strconv.ParseBool(s); err == nil { + preferSimpleProtocol = b + } else { + return nil, fmt.Errorf("invalid prefer_simple_protocol: %v", err) + } + } + + connConfig := &ConnConfig{ + Config: *config, + createdByParseConfig: true, + LogLevel: LogLevelInfo, + BuildStatementCache: buildStatementCache, + PreferSimpleProtocol: preferSimpleProtocol, + connString: connString, + } + + return connConfig, nil +} + +func connect(ctx context.Context, config *ConnConfig) (c *Conn, err error) { + // Default values are set in ParseConfig. Enforce initial creation by ParseConfig rather than setting defaults from + // zero values. + if !config.createdByParseConfig { + panic("config must be created by ParseConfig") + } + originalConfig := config + + // This isn't really a deep copy. But it is enough to avoid the config.Config.OnNotification mutation from affecting + // other connections with the same config. See https://github.com/jackc/pgx/issues/618. + { + configCopy := *config + config = &configCopy + } + + c = &Conn{ + config: originalConfig, + connInfo: pgtype.NewConnInfo(), + logLevel: config.LogLevel, + logger: config.Logger, + } + + // Only install pgx notification system if no other callback handler is present. + if config.Config.OnNotification == nil { + config.Config.OnNotification = c.bufferNotifications + } else { + if c.shouldLog(LogLevelDebug) { + c.log(ctx, LogLevelDebug, "pgx notification handler disabled by application supplied OnNotification", map[string]interface{}{"host": config.Config.Host}) + } + } + + if c.shouldLog(LogLevelInfo) { + c.log(ctx, LogLevelInfo, "Dialing PostgreSQL server", map[string]interface{}{"host": config.Config.Host}) + } + c.pgConn, err = pgconn.ConnectConfig(ctx, &config.Config) + if err != nil { + if c.shouldLog(LogLevelError) { + c.log(ctx, LogLevelError, "connect failed", map[string]interface{}{"err": err}) + } + return nil, err + } + + c.preparedStatements = make(map[string]*pgconn.StatementDescription) + c.doneChan = make(chan struct{}) + c.closedChan = make(chan error) + c.wbuf = make([]byte, 0, 1024) + + if c.config.BuildStatementCache != nil { + c.stmtcache = c.config.BuildStatementCache(c.pgConn) + } + + // Replication connections can't execute the queries to + // populate the c.PgTypes and c.pgsqlAfInet + if _, ok := config.Config.RuntimeParams["replication"]; ok { + return c, nil + } + + return c, nil +} + +// Close closes a connection. It is safe to call Close on a already closed +// connection. +func (c *Conn) Close(ctx context.Context) error { + if c.IsClosed() { + return nil + } + + err := c.pgConn.Close(ctx) + if c.shouldLog(LogLevelInfo) { + c.log(ctx, LogLevelInfo, "closed connection", nil) + } + return err +} + +// Prepare creates a prepared statement with name and sql. sql can contain placeholders +// for bound parameters. These placeholders are referenced positional as $1, $2, etc. +// +// Prepare is idempotent; i.e. it is safe to call Prepare multiple times with the same +// name and sql arguments. This allows a code path to Prepare and Query/Exec without +// concern for if the statement has already been prepared. +func (c *Conn) Prepare(ctx context.Context, name, sql string) (sd *pgconn.StatementDescription, err error) { + if name != "" { + var ok bool + if sd, ok = c.preparedStatements[name]; ok && sd.SQL == sql { + return sd, nil + } + } + + if c.shouldLog(LogLevelError) { + defer func() { + if err != nil { + c.log(ctx, LogLevelError, "Prepare failed", map[string]interface{}{"err": err, "name": name, "sql": sql}) + } + }() + } + + sd, err = c.pgConn.Prepare(ctx, name, sql, nil) + if err != nil { + return nil, err + } + + if name != "" { + c.preparedStatements[name] = sd + } + + return sd, nil +} + +// Deallocate released a prepared statement +func (c *Conn) Deallocate(ctx context.Context, name string) error { + delete(c.preparedStatements, name) + _, err := c.pgConn.Exec(ctx, "deallocate "+quoteIdentifier(name)).ReadAll() + return err +} + +func (c *Conn) bufferNotifications(_ *pgconn.PgConn, n *pgconn.Notification) { + c.notifications = append(c.notifications, n) +} + +// WaitForNotification waits for a PostgreSQL notification. It wraps the underlying pgconn notification system in a +// slightly more convenient form. +func (c *Conn) WaitForNotification(ctx context.Context) (*pgconn.Notification, error) { + var n *pgconn.Notification + + // Return already received notification immediately + if len(c.notifications) > 0 { + n = c.notifications[0] + c.notifications = c.notifications[1:] + return n, nil + } + + err := c.pgConn.WaitForNotification(ctx) + if len(c.notifications) > 0 { + n = c.notifications[0] + c.notifications = c.notifications[1:] + } + return n, err +} + +func (c *Conn) IsClosed() bool { + return c.pgConn.IsClosed() +} + +func (c *Conn) die(err error) { + if c.IsClosed() { + return + } + + ctx, cancel := context.WithCancel(context.Background()) + cancel() // force immediate hard cancel + c.pgConn.Close(ctx) +} + +func (c *Conn) shouldLog(lvl LogLevel) bool { + return c.logger != nil && c.logLevel >= lvl +} + +func (c *Conn) log(ctx context.Context, lvl LogLevel, msg string, data map[string]interface{}) { + if data == nil { + data = map[string]interface{}{} + } + if c.pgConn != nil && c.pgConn.PID() != 0 { + data["pid"] = c.pgConn.PID() + } + + c.logger.Log(ctx, lvl, msg, data) +} + +func quoteIdentifier(s string) string { + return `"` + strings.ReplaceAll(s, `"`, `""`) + `"` +} + +func (c *Conn) Ping(ctx context.Context) error { + _, err := c.Exec(ctx, ";") + return err +} + +func connInfoFromRows(rows Rows, err error) (map[string]uint32, error) { + if err != nil { + return nil, err + } + defer rows.Close() + + nameOIDs := make(map[string]uint32, 256) + for rows.Next() { + var oid uint32 + var name pgtype.Text + if err = rows.Scan(&oid, &name); err != nil { + return nil, err + } + + nameOIDs[name.String] = oid + } + + if err = rows.Err(); err != nil { + return nil, err + } + + return nameOIDs, err +} + +// PgConn returns the underlying *pgconn.PgConn. This is an escape hatch method that allows lower level access to the +// PostgreSQL connection than pgx exposes. +// +// It is strongly recommended that the connection be idle (no in-progress queries) before the underlying *pgconn.PgConn +// is used and the connection must be returned to the same state before any *pgx.Conn methods are again used. +func (c *Conn) PgConn() *pgconn.PgConn { return c.pgConn } + +// StatementCache returns the statement cache used for this connection. +func (c *Conn) StatementCache() stmtcache.Cache { return c.stmtcache } + +// ConnInfo returns the connection info used for this connection. +func (c *Conn) ConnInfo() *pgtype.ConnInfo { return c.connInfo } + +// Config returns a copy of config that was used to establish this connection. +func (c *Conn) Config() *ConnConfig { return c.config.Copy() } + +// Exec executes sql. sql can be either a prepared statement name or an SQL string. arguments should be referenced +// positionally from the sql string as $1, $2, etc. +func (c *Conn) Exec(ctx context.Context, sql string, arguments ...interface{}) (pgconn.CommandTag, error) { + startTime := time.Now() + + commandTag, err := c.exec(ctx, sql, arguments...) + if err != nil { + if c.shouldLog(LogLevelError) { + c.log(ctx, LogLevelError, "Exec", map[string]interface{}{"sql": sql, "args": logQueryArgs(arguments), "err": err}) + } + return commandTag, err + } + + if c.shouldLog(LogLevelInfo) { + endTime := time.Now() + c.log(ctx, LogLevelInfo, "Exec", map[string]interface{}{"sql": sql, "args": logQueryArgs(arguments), "time": endTime.Sub(startTime), "commandTag": commandTag}) + } + + return commandTag, err +} + +func (c *Conn) exec(ctx context.Context, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) { + simpleProtocol := c.config.PreferSimpleProtocol + +optionLoop: + for len(arguments) > 0 { + switch arg := arguments[0].(type) { + case QuerySimpleProtocol: + simpleProtocol = bool(arg) + arguments = arguments[1:] + default: + break optionLoop + } + } + + if sd, ok := c.preparedStatements[sql]; ok { + return c.execPrepared(ctx, sd, arguments) + } + + if simpleProtocol { + return c.execSimpleProtocol(ctx, sql, arguments) + } + + if len(arguments) == 0 { + return c.execSimpleProtocol(ctx, sql, arguments) + } + + if c.stmtcache != nil { + sd, err := c.stmtcache.Get(ctx, sql) + if err != nil { + return nil, err + } + + if c.stmtcache.Mode() == stmtcache.ModeDescribe { + return c.execParams(ctx, sd, arguments) + } + return c.execPrepared(ctx, sd, arguments) + } + + sd, err := c.Prepare(ctx, "", sql) + if err != nil { + return nil, err + } + return c.execPrepared(ctx, sd, arguments) +} + +func (c *Conn) execSimpleProtocol(ctx context.Context, sql string, arguments []interface{}) (commandTag pgconn.CommandTag, err error) { + if len(arguments) > 0 { + sql, err = c.sanitizeForSimpleQuery(sql, arguments...) + if err != nil { + return nil, err + } + } + + mrr := c.pgConn.Exec(ctx, sql) + for mrr.NextResult() { + commandTag, err = mrr.ResultReader().Close() + } + err = mrr.Close() + return commandTag, err +} + +func (c *Conn) execParamsAndPreparedPrefix(sd *pgconn.StatementDescription, arguments []interface{}) error { + if len(sd.ParamOIDs) != len(arguments) { + return fmt.Errorf("expected %d arguments, got %d", len(sd.ParamOIDs), len(arguments)) + } + + c.eqb.Reset() + + args, err := convertDriverValuers(arguments) + if err != nil { + return err + } + + for i := range args { + err = c.eqb.AppendParam(c.connInfo, sd.ParamOIDs[i], args[i]) + if err != nil { + return err + } + } + + for i := range sd.Fields { + c.eqb.AppendResultFormat(c.ConnInfo().ResultFormatCodeForOID(sd.Fields[i].DataTypeOID)) + } + + return nil +} + +func (c *Conn) execParams(ctx context.Context, sd *pgconn.StatementDescription, arguments []interface{}) (pgconn.CommandTag, error) { + err := c.execParamsAndPreparedPrefix(sd, arguments) + if err != nil { + return nil, err + } + + result := c.pgConn.ExecParams(ctx, sd.SQL, c.eqb.paramValues, sd.ParamOIDs, c.eqb.paramFormats, c.eqb.resultFormats).Read() + return result.CommandTag, result.Err +} + +func (c *Conn) execPrepared(ctx context.Context, sd *pgconn.StatementDescription, arguments []interface{}) (pgconn.CommandTag, error) { + err := c.execParamsAndPreparedPrefix(sd, arguments) + if err != nil { + return nil, err + } + + result := c.pgConn.ExecPrepared(ctx, sd.Name, c.eqb.paramValues, c.eqb.paramFormats, c.eqb.resultFormats).Read() + return result.CommandTag, result.Err +} + +func (c *Conn) getRows(ctx context.Context, sql string, args []interface{}) *connRows { + if len(c.preallocatedRows) == 0 { + c.preallocatedRows = make([]connRows, 64) + } + + r := &c.preallocatedRows[len(c.preallocatedRows)-1] + c.preallocatedRows = c.preallocatedRows[0 : len(c.preallocatedRows)-1] + + r.ctx = ctx + r.logger = c + r.connInfo = c.connInfo + r.startTime = time.Now() + r.sql = sql + r.args = args + r.conn = c + + return r +} + +// QuerySimpleProtocol controls whether the simple or extended protocol is used to send the query. +type QuerySimpleProtocol bool + +// QueryResultFormats controls the result format (text=0, binary=1) of a query by result column position. +type QueryResultFormats []int16 + +// QueryResultFormatsByOID controls the result format (text=0, binary=1) of a query by the result column OID. +type QueryResultFormatsByOID map[uint32]int16 + +// Query executes sql with args. If there is an error the returned Rows will be returned in an error state. So it is +// allowed to ignore the error returned from Query and handle it in Rows. +// +// For extra control over how the query is executed, the types QuerySimpleProtocol, QueryResultFormats, and +// QueryResultFormatsByOID may be used as the first args to control exactly how the query is executed. This is rarely +// needed. See the documentation for those types for details. +func (c *Conn) Query(ctx context.Context, sql string, args ...interface{}) (Rows, error) { + var resultFormats QueryResultFormats + var resultFormatsByOID QueryResultFormatsByOID + simpleProtocol := c.config.PreferSimpleProtocol + +optionLoop: + for len(args) > 0 { + switch arg := args[0].(type) { + case QueryResultFormats: + resultFormats = arg + args = args[1:] + case QueryResultFormatsByOID: + resultFormatsByOID = arg + args = args[1:] + case QuerySimpleProtocol: + simpleProtocol = bool(arg) + args = args[1:] + default: + break optionLoop + } + } + + rows := c.getRows(ctx, sql, args) + + var err error + sd, ok := c.preparedStatements[sql] + + if simpleProtocol && !ok { + sql, err = c.sanitizeForSimpleQuery(sql, args...) + if err != nil { + rows.fatal(err) + return rows, err + } + + mrr := c.pgConn.Exec(ctx, sql) + if mrr.NextResult() { + rows.resultReader = mrr.ResultReader() + rows.multiResultReader = mrr + } else { + err = mrr.Close() + rows.fatal(err) + return rows, err + } + + return rows, nil + } + + c.eqb.Reset() + + if !ok { + if c.stmtcache != nil { + sd, err = c.stmtcache.Get(ctx, sql) + if err != nil { + rows.fatal(err) + return rows, rows.err + } + } else { + sd, err = c.pgConn.Prepare(ctx, "", sql, nil) + if err != nil { + rows.fatal(err) + return rows, rows.err + } + } + } + if len(sd.ParamOIDs) != len(args) { + rows.fatal(fmt.Errorf("expected %d arguments, got %d", len(sd.ParamOIDs), len(args))) + return rows, rows.err + } + + rows.sql = sd.SQL + + args, err = convertDriverValuers(args) + if err != nil { + rows.fatal(err) + return rows, rows.err + } + + for i := range args { + err = c.eqb.AppendParam(c.connInfo, sd.ParamOIDs[i], args[i]) + if err != nil { + rows.fatal(err) + return rows, rows.err + } + } + + if resultFormatsByOID != nil { + resultFormats = make([]int16, len(sd.Fields)) + for i := range resultFormats { + resultFormats[i] = resultFormatsByOID[uint32(sd.Fields[i].DataTypeOID)] + } + } + + if resultFormats == nil { + for i := range sd.Fields { + c.eqb.AppendResultFormat(c.ConnInfo().ResultFormatCodeForOID(sd.Fields[i].DataTypeOID)) + } + + resultFormats = c.eqb.resultFormats + } + + if c.stmtcache != nil && c.stmtcache.Mode() == stmtcache.ModeDescribe { + rows.resultReader = c.pgConn.ExecParams(ctx, sql, c.eqb.paramValues, sd.ParamOIDs, c.eqb.paramFormats, resultFormats) + } else { + rows.resultReader = c.pgConn.ExecPrepared(ctx, sd.Name, c.eqb.paramValues, c.eqb.paramFormats, resultFormats) + } + + return rows, rows.err +} + +// QueryRow is a convenience wrapper over Query. Any error that occurs while +// querying is deferred until calling Scan on the returned Row. That Row will +// error with ErrNoRows if no rows are returned. +func (c *Conn) QueryRow(ctx context.Context, sql string, args ...interface{}) Row { + rows, _ := c.Query(ctx, sql, args...) + return (*connRow)(rows.(*connRows)) +} + +// QueryFuncRow is the argument to the QueryFunc callback function. +// +// QueryFuncRow is an interface instead of a struct to allow tests to mock QueryFunc. However, adding a method to an +// interface is technically a breaking change. Because of this the QueryFuncRow interface is partially excluded from +// semantic version requirements. Methods will not be removed or changed, but new methods may be added. +type QueryFuncRow interface { + FieldDescriptions() []pgproto3.FieldDescription + + // RawValues returns the unparsed bytes of the row values. The returned [][]byte is only valid during the current + // function call. However, the underlying byte data is safe to retain a reference to and mutate. + RawValues() [][]byte +} + +// QueryFunc executes sql with args. For each row returned by the query the values will scanned into the elements of +// scans and f will be called. If any row fails to scan or f returns an error the query will be aborted and the error +// will be returned. +func (c *Conn) QueryFunc(ctx context.Context, sql string, args []interface{}, scans []interface{}, f func(QueryFuncRow) error) (pgconn.CommandTag, error) { + rows, err := c.Query(ctx, sql, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + for rows.Next() { + err = rows.Scan(scans...) + if err != nil { + return nil, err + } + + err = f(rows) + if err != nil { + return nil, err + } + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return rows.CommandTag(), nil +} + +// SendBatch sends all queued queries to the server at once. All queries are run in an implicit transaction unless +// explicit transaction control statements are executed. The returned BatchResults must be closed before the connection +// is used again. +func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults { + simpleProtocol := c.config.PreferSimpleProtocol + var sb strings.Builder + if simpleProtocol { + for i, bi := range b.items { + if i > 0 { + sb.WriteByte(';') + } + sql, err := c.sanitizeForSimpleQuery(bi.query, bi.arguments...) + if err != nil { + return &batchResults{ctx: ctx, conn: c, err: err} + } + sb.WriteString(sql) + } + mrr := c.pgConn.Exec(ctx, sb.String()) + return &batchResults{ + ctx: ctx, + conn: c, + mrr: mrr, + b: b, + ix: 0, + } + } + + distinctUnpreparedQueries := map[string]struct{}{} + + for _, bi := range b.items { + if _, ok := c.preparedStatements[bi.query]; ok { + continue + } + distinctUnpreparedQueries[bi.query] = struct{}{} + } + + var stmtCache stmtcache.Cache + if len(distinctUnpreparedQueries) > 0 { + if c.stmtcache != nil && c.stmtcache.Cap() >= len(distinctUnpreparedQueries) { + stmtCache = c.stmtcache + } else { + stmtCache = stmtcache.New(c.pgConn, stmtcache.ModeDescribe, len(distinctUnpreparedQueries)) + } + + for sql, _ := range distinctUnpreparedQueries { + _, err := stmtCache.Get(ctx, sql) + if err != nil { + return &batchResults{ctx: ctx, conn: c, err: err} + } + } + } + + batch := &pgconn.Batch{} + + for _, bi := range b.items { + c.eqb.Reset() + + sd := c.preparedStatements[bi.query] + if sd == nil { + var err error + sd, err = stmtCache.Get(ctx, bi.query) + if err != nil { + // the stmtCache was prefilled from distinctUnpreparedQueries above so we are guaranteed no errors + panic("BUG: unexpected error from stmtCache") + } + } + + if len(sd.ParamOIDs) != len(bi.arguments) { + return &batchResults{ctx: ctx, conn: c, err: fmt.Errorf("mismatched param and argument count")} + } + + args, err := convertDriverValuers(bi.arguments) + if err != nil { + return &batchResults{ctx: ctx, conn: c, err: err} + } + + for i := range args { + err = c.eqb.AppendParam(c.connInfo, sd.ParamOIDs[i], args[i]) + if err != nil { + return &batchResults{ctx: ctx, conn: c, err: err} + } + } + + for i := range sd.Fields { + c.eqb.AppendResultFormat(c.ConnInfo().ResultFormatCodeForOID(sd.Fields[i].DataTypeOID)) + } + + if sd.Name == "" { + batch.ExecParams(bi.query, c.eqb.paramValues, sd.ParamOIDs, c.eqb.paramFormats, c.eqb.resultFormats) + } else { + batch.ExecPrepared(sd.Name, c.eqb.paramValues, c.eqb.paramFormats, c.eqb.resultFormats) + } + } + + mrr := c.pgConn.ExecBatch(ctx, batch) + + return &batchResults{ + ctx: ctx, + conn: c, + mrr: mrr, + b: b, + ix: 0, + } +} + +func (c *Conn) sanitizeForSimpleQuery(sql string, args ...interface{}) (string, error) { + if c.pgConn.ParameterStatus("standard_conforming_strings") != "on" { + return "", errors.New("simple protocol queries must be run with standard_conforming_strings=on") + } + + if c.pgConn.ParameterStatus("client_encoding") != "UTF8" { + return "", errors.New("simple protocol queries must be run with client_encoding=UTF8") + } + + var err error + valueArgs := make([]interface{}, len(args)) + for i, a := range args { + valueArgs[i], err = convertSimpleArgument(c.connInfo, a) + if err != nil { + return "", err + } + } + + return sanitize.SanitizeSQL(sql, valueArgs...) +} |