summaryrefslogtreecommitdiff
path: root/vendor/github.com/jackc/pgconn/pgconn.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/jackc/pgconn/pgconn.go')
-rw-r--r--vendor/github.com/jackc/pgconn/pgconn.go1724
1 files changed, 1724 insertions, 0 deletions
diff --git a/vendor/github.com/jackc/pgconn/pgconn.go b/vendor/github.com/jackc/pgconn/pgconn.go
new file mode 100644
index 000000000..43b13e43a
--- /dev/null
+++ b/vendor/github.com/jackc/pgconn/pgconn.go
@@ -0,0 +1,1724 @@
+package pgconn
+
+import (
+ "context"
+ "crypto/md5"
+ "crypto/tls"
+ "encoding/binary"
+ "encoding/hex"
+ "errors"
+ "fmt"
+ "io"
+ "math"
+ "net"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/jackc/pgconn/internal/ctxwatch"
+ "github.com/jackc/pgio"
+ "github.com/jackc/pgproto3/v2"
+)
+
+const (
+ connStatusUninitialized = iota
+ connStatusConnecting
+ connStatusClosed
+ connStatusIdle
+ connStatusBusy
+)
+
+const wbufLen = 1024
+
+// Notice represents a notice response message reported by the PostgreSQL server. Be aware that this is distinct from
+// LISTEN/NOTIFY notification.
+type Notice PgError
+
+// Notification is a message received from the PostgreSQL LISTEN/NOTIFY system
+type Notification struct {
+ PID uint32 // backend pid that sent the notification
+ Channel string // channel from which notification was received
+ Payload string
+}
+
+// DialFunc is a function that can be used to connect to a PostgreSQL server.
+type DialFunc func(ctx context.Context, network, addr string) (net.Conn, error)
+
+// LookupFunc is a function that can be used to lookup IPs addrs from host.
+type LookupFunc func(ctx context.Context, host string) (addrs []string, err error)
+
+// BuildFrontendFunc is a function that can be used to create Frontend implementation for connection.
+type BuildFrontendFunc func(r io.Reader, w io.Writer) Frontend
+
+// NoticeHandler is a function that can handle notices received from the PostgreSQL server. Notices can be received at
+// any time, usually during handling of a query response. The *PgConn is provided so the handler is aware of the origin
+// of the notice, but it must not invoke any query method. Be aware that this is distinct from LISTEN/NOTIFY
+// notification.
+type NoticeHandler func(*PgConn, *Notice)
+
+// NotificationHandler is a function that can handle notifications received from the PostgreSQL server. Notifications
+// can be received at any time, usually during handling of a query response. The *PgConn is provided so the handler is
+// aware of the origin of the notice, but it must not invoke any query method. Be aware that this is distinct from a
+// notice event.
+type NotificationHandler func(*PgConn, *Notification)
+
+// Frontend used to receive messages from backend.
+type Frontend interface {
+ Receive() (pgproto3.BackendMessage, error)
+}
+
+// PgConn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage.
+type PgConn struct {
+ conn net.Conn // the underlying TCP or unix domain socket connection
+ pid uint32 // backend pid
+ secretKey uint32 // key to use to send a cancel query message to the server
+ parameterStatuses map[string]string // parameters that have been reported by the server
+ txStatus byte
+ frontend Frontend
+
+ config *Config
+
+ status byte // One of connStatus* constants
+
+ bufferingReceive bool
+ bufferingReceiveMux sync.Mutex
+ bufferingReceiveMsg pgproto3.BackendMessage
+ bufferingReceiveErr error
+
+ peekedMsg pgproto3.BackendMessage
+
+ // Reusable / preallocated resources
+ wbuf []byte // write buffer
+ resultReader ResultReader
+ multiResultReader MultiResultReader
+ contextWatcher *ctxwatch.ContextWatcher
+
+ cleanupDone chan struct{}
+}
+
+// Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or DSN format)
+// to provide configuration. See documention for ParseConfig for details. ctx can be used to cancel a connect attempt.
+func Connect(ctx context.Context, connString string) (*PgConn, error) {
+ config, err := ParseConfig(connString)
+ if err != nil {
+ return nil, err
+ }
+
+ return ConnectConfig(ctx, config)
+}
+
+// Connect establishes a connection to a PostgreSQL server using config. config must have been constructed with
+// ParseConfig. ctx can be used to cancel a connect attempt.
+//
+// If config.Fallbacks are present they will sequentially be tried in case of error establishing network connection. An
+// authentication error will terminate the chain of attempts (like libpq:
+// https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS) and be returned as the error. Otherwise,
+// if all attempts fail the last error is returned.
+func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err error) {
+ // Default values are set in ParseConfig. Enforce initial creation by ParseConfig rather than setting defaults from
+ // zero values.
+ if !config.createdByParseConfig {
+ panic("config must be created by ParseConfig")
+ }
+
+ // ConnectTimeout restricts the whole connection process.
+ if config.ConnectTimeout != 0 {
+ var cancel context.CancelFunc
+ ctx, cancel = context.WithTimeout(ctx, config.ConnectTimeout)
+ defer cancel()
+ }
+ // Simplify usage by treating primary config and fallbacks the same.
+ fallbackConfigs := []*FallbackConfig{
+ {
+ Host: config.Host,
+ Port: config.Port,
+ TLSConfig: config.TLSConfig,
+ },
+ }
+ fallbackConfigs = append(fallbackConfigs, config.Fallbacks...)
+
+ fallbackConfigs, err = expandWithIPs(ctx, config.LookupFunc, fallbackConfigs)
+ if err != nil {
+ return nil, &connectError{config: config, msg: "hostname resolving error", err: err}
+ }
+
+ if len(fallbackConfigs) == 0 {
+ return nil, &connectError{config: config, msg: "hostname resolving error", err: errors.New("ip addr wasn't found")}
+ }
+
+ for _, fc := range fallbackConfigs {
+ pgConn, err = connect(ctx, config, fc)
+ if err == nil {
+ break
+ } else if pgerr, ok := err.(*PgError); ok {
+ err = &connectError{config: config, msg: "server error", err: pgerr}
+ ERRCODE_INVALID_PASSWORD := "28P01" // worng password
+ ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION := "28000" // db does not exist
+ if pgerr.Code == ERRCODE_INVALID_PASSWORD || pgerr.Code == ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION {
+ break
+ }
+ }
+ }
+
+ if err != nil {
+ return nil, err // no need to wrap in connectError because it will already be wrapped in all cases except PgError
+ }
+
+ if config.AfterConnect != nil {
+ err := config.AfterConnect(ctx, pgConn)
+ if err != nil {
+ pgConn.conn.Close()
+ return nil, &connectError{config: config, msg: "AfterConnect error", err: err}
+ }
+ }
+
+ return pgConn, nil
+}
+
+func expandWithIPs(ctx context.Context, lookupFn LookupFunc, fallbacks []*FallbackConfig) ([]*FallbackConfig, error) {
+ var configs []*FallbackConfig
+
+ for _, fb := range fallbacks {
+ // skip resolve for unix sockets
+ if strings.HasPrefix(fb.Host, "/") {
+ configs = append(configs, &FallbackConfig{
+ Host: fb.Host,
+ Port: fb.Port,
+ TLSConfig: fb.TLSConfig,
+ })
+
+ continue
+ }
+
+ ips, err := lookupFn(ctx, fb.Host)
+ if err != nil {
+ return nil, err
+ }
+
+ for _, ip := range ips {
+ configs = append(configs, &FallbackConfig{
+ Host: ip,
+ Port: fb.Port,
+ TLSConfig: fb.TLSConfig,
+ })
+ }
+ }
+
+ return configs, nil
+}
+
+func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig) (*PgConn, error) {
+ pgConn := new(PgConn)
+ pgConn.config = config
+ pgConn.wbuf = make([]byte, 0, wbufLen)
+ pgConn.cleanupDone = make(chan struct{})
+
+ var err error
+ network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port)
+ pgConn.conn, err = config.DialFunc(ctx, network, address)
+ if err != nil {
+ var netErr net.Error
+ if errors.As(err, &netErr) && netErr.Timeout() {
+ err = &errTimeout{err: err}
+ }
+ return nil, &connectError{config: config, msg: "dial error", err: err}
+ }
+
+ pgConn.parameterStatuses = make(map[string]string)
+
+ if fallbackConfig.TLSConfig != nil {
+ if err := pgConn.startTLS(fallbackConfig.TLSConfig); err != nil {
+ pgConn.conn.Close()
+ return nil, &connectError{config: config, msg: "tls error", err: err}
+ }
+ }
+
+ pgConn.status = connStatusConnecting
+ pgConn.contextWatcher = ctxwatch.NewContextWatcher(
+ func() { pgConn.conn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) },
+ func() { pgConn.conn.SetDeadline(time.Time{}) },
+ )
+
+ pgConn.contextWatcher.Watch(ctx)
+ defer pgConn.contextWatcher.Unwatch()
+
+ pgConn.frontend = config.BuildFrontend(pgConn.conn, pgConn.conn)
+
+ startupMsg := pgproto3.StartupMessage{
+ ProtocolVersion: pgproto3.ProtocolVersionNumber,
+ Parameters: make(map[string]string),
+ }
+
+ // Copy default run-time params
+ for k, v := range config.RuntimeParams {
+ startupMsg.Parameters[k] = v
+ }
+
+ startupMsg.Parameters["user"] = config.User
+ if config.Database != "" {
+ startupMsg.Parameters["database"] = config.Database
+ }
+
+ if _, err := pgConn.conn.Write(startupMsg.Encode(pgConn.wbuf)); err != nil {
+ pgConn.conn.Close()
+ return nil, &connectError{config: config, msg: "failed to write startup message", err: err}
+ }
+
+ for {
+ msg, err := pgConn.receiveMessage()
+ if err != nil {
+ pgConn.conn.Close()
+ if err, ok := err.(*PgError); ok {
+ return nil, err
+ }
+ return nil, &connectError{config: config, msg: "failed to receive message", err: preferContextOverNetTimeoutError(ctx, err)}
+ }
+
+ switch msg := msg.(type) {
+ case *pgproto3.BackendKeyData:
+ pgConn.pid = msg.ProcessID
+ pgConn.secretKey = msg.SecretKey
+
+ case *pgproto3.AuthenticationOk:
+ case *pgproto3.AuthenticationCleartextPassword:
+ err = pgConn.txPasswordMessage(pgConn.config.Password)
+ if err != nil {
+ pgConn.conn.Close()
+ return nil, &connectError{config: config, msg: "failed to write password message", err: err}
+ }
+ case *pgproto3.AuthenticationMD5Password:
+ digestedPassword := "md5" + hexMD5(hexMD5(pgConn.config.Password+pgConn.config.User)+string(msg.Salt[:]))
+ err = pgConn.txPasswordMessage(digestedPassword)
+ if err != nil {
+ pgConn.conn.Close()
+ return nil, &connectError{config: config, msg: "failed to write password message", err: err}
+ }
+ case *pgproto3.AuthenticationSASL:
+ err = pgConn.scramAuth(msg.AuthMechanisms)
+ if err != nil {
+ pgConn.conn.Close()
+ return nil, &connectError{config: config, msg: "failed SASL auth", err: err}
+ }
+
+ case *pgproto3.ReadyForQuery:
+ pgConn.status = connStatusIdle
+ if config.ValidateConnect != nil {
+ // ValidateConnect may execute commands that cause the context to be watched again. Unwatch first to avoid
+ // the watch already in progress panic. This is that last thing done by this method so there is no need to
+ // restart the watch after ValidateConnect returns.
+ //
+ // See https://github.com/jackc/pgconn/issues/40.
+ pgConn.contextWatcher.Unwatch()
+
+ err := config.ValidateConnect(ctx, pgConn)
+ if err != nil {
+ pgConn.conn.Close()
+ return nil, &connectError{config: config, msg: "ValidateConnect failed", err: err}
+ }
+ }
+ return pgConn, nil
+ case *pgproto3.ParameterStatus:
+ // handled by ReceiveMessage
+ case *pgproto3.ErrorResponse:
+ pgConn.conn.Close()
+ return nil, ErrorResponseToPgError(msg)
+ default:
+ pgConn.conn.Close()
+ return nil, &connectError{config: config, msg: "received unexpected message", err: err}
+ }
+ }
+}
+
+func (pgConn *PgConn) startTLS(tlsConfig *tls.Config) (err error) {
+ err = binary.Write(pgConn.conn, binary.BigEndian, []int32{8, 80877103})
+ if err != nil {
+ return
+ }
+
+ response := make([]byte, 1)
+ if _, err = io.ReadFull(pgConn.conn, response); err != nil {
+ return
+ }
+
+ if response[0] != 'S' {
+ return errors.New("server refused TLS connection")
+ }
+
+ pgConn.conn = tls.Client(pgConn.conn, tlsConfig)
+
+ return nil
+}
+
+func (pgConn *PgConn) txPasswordMessage(password string) (err error) {
+ msg := &pgproto3.PasswordMessage{Password: password}
+ _, err = pgConn.conn.Write(msg.Encode(pgConn.wbuf))
+ return err
+}
+
+func hexMD5(s string) string {
+ hash := md5.New()
+ io.WriteString(hash, s)
+ return hex.EncodeToString(hash.Sum(nil))
+}
+
+func (pgConn *PgConn) signalMessage() chan struct{} {
+ if pgConn.bufferingReceive {
+ panic("BUG: signalMessage when already in progress")
+ }
+
+ pgConn.bufferingReceive = true
+ pgConn.bufferingReceiveMux.Lock()
+
+ ch := make(chan struct{})
+ go func() {
+ pgConn.bufferingReceiveMsg, pgConn.bufferingReceiveErr = pgConn.frontend.Receive()
+ pgConn.bufferingReceiveMux.Unlock()
+ close(ch)
+ }()
+
+ return ch
+}
+
+// SendBytes sends buf to the PostgreSQL server. It must only be used when the connection is not busy. e.g. It is as
+// error to call SendBytes while reading the result of a query.
+//
+// This is a very low level method that requires deep understanding of the PostgreSQL wire protocol to use correctly.
+// See https://www.postgresql.org/docs/current/protocol.html.
+func (pgConn *PgConn) SendBytes(ctx context.Context, buf []byte) error {
+ if err := pgConn.lock(); err != nil {
+ return err
+ }
+ defer pgConn.unlock()
+
+ if ctx != context.Background() {
+ select {
+ case <-ctx.Done():
+ return newContextAlreadyDoneError(ctx)
+ default:
+ }
+ pgConn.contextWatcher.Watch(ctx)
+ defer pgConn.contextWatcher.Unwatch()
+ }
+
+ n, err := pgConn.conn.Write(buf)
+ if err != nil {
+ pgConn.asyncClose()
+ return &writeError{err: err, safeToRetry: n == 0}
+ }
+
+ return nil
+}
+
+// ReceiveMessage receives one wire protocol message from the PostgreSQL server. It must only be used when the
+// connection is not busy. e.g. It is an error to call ReceiveMessage while reading the result of a query. The messages
+// are still handled by the core pgconn message handling system so receiving a NotificationResponse will still trigger
+// the OnNotification callback.
+//
+// This is a very low level method that requires deep understanding of the PostgreSQL wire protocol to use correctly.
+// See https://www.postgresql.org/docs/current/protocol.html.
+func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessage, error) {
+ if err := pgConn.lock(); err != nil {
+ return nil, err
+ }
+ defer pgConn.unlock()
+
+ if ctx != context.Background() {
+ select {
+ case <-ctx.Done():
+ return nil, newContextAlreadyDoneError(ctx)
+ default:
+ }
+ pgConn.contextWatcher.Watch(ctx)
+ defer pgConn.contextWatcher.Unwatch()
+ }
+
+ msg, err := pgConn.receiveMessage()
+ if err != nil {
+ err = &pgconnError{
+ msg: "receive message failed",
+ err: preferContextOverNetTimeoutError(ctx, err),
+ safeToRetry: true}
+ }
+ return msg, err
+}
+
+// peekMessage peeks at the next message without setting up context cancellation.
+func (pgConn *PgConn) peekMessage() (pgproto3.BackendMessage, error) {
+ if pgConn.peekedMsg != nil {
+ return pgConn.peekedMsg, nil
+ }
+
+ var msg pgproto3.BackendMessage
+ var err error
+ if pgConn.bufferingReceive {
+ pgConn.bufferingReceiveMux.Lock()
+ msg = pgConn.bufferingReceiveMsg
+ err = pgConn.bufferingReceiveErr
+ pgConn.bufferingReceiveMux.Unlock()
+ pgConn.bufferingReceive = false
+
+ // If a timeout error happened in the background try the read again.
+ var netErr net.Error
+ if errors.As(err, &netErr) && netErr.Timeout() {
+ msg, err = pgConn.frontend.Receive()
+ }
+ } else {
+ msg, err = pgConn.frontend.Receive()
+ }
+
+ if err != nil {
+ // Close on anything other than timeout error - everything else is fatal
+ var netErr net.Error
+ isNetErr := errors.As(err, &netErr)
+ if !(isNetErr && netErr.Timeout()) {
+ pgConn.asyncClose()
+ }
+
+ return nil, err
+ }
+
+ pgConn.peekedMsg = msg
+ return msg, nil
+}
+
+// receiveMessage receives a message without setting up context cancellation
+func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) {
+ msg, err := pgConn.peekMessage()
+ if err != nil {
+ // Close on anything other than timeout error - everything else is fatal
+ var netErr net.Error
+ isNetErr := errors.As(err, &netErr)
+ if !(isNetErr && netErr.Timeout()) {
+ pgConn.asyncClose()
+ }
+
+ return nil, err
+ }
+ pgConn.peekedMsg = nil
+
+ switch msg := msg.(type) {
+ case *pgproto3.ReadyForQuery:
+ pgConn.txStatus = msg.TxStatus
+ case *pgproto3.ParameterStatus:
+ pgConn.parameterStatuses[msg.Name] = msg.Value
+ case *pgproto3.ErrorResponse:
+ if msg.Severity == "FATAL" {
+ pgConn.status = connStatusClosed
+ pgConn.conn.Close() // Ignore error as the connection is already broken and there is already an error to return.
+ close(pgConn.cleanupDone)
+ return nil, ErrorResponseToPgError(msg)
+ }
+ case *pgproto3.NoticeResponse:
+ if pgConn.config.OnNotice != nil {
+ pgConn.config.OnNotice(pgConn, noticeResponseToNotice(msg))
+ }
+ case *pgproto3.NotificationResponse:
+ if pgConn.config.OnNotification != nil {
+ pgConn.config.OnNotification(pgConn, &Notification{PID: msg.PID, Channel: msg.Channel, Payload: msg.Payload})
+ }
+ }
+
+ return msg, nil
+}
+
+// Conn returns the underlying net.Conn.
+func (pgConn *PgConn) Conn() net.Conn {
+ return pgConn.conn
+}
+
+// PID returns the backend PID.
+func (pgConn *PgConn) PID() uint32 {
+ return pgConn.pid
+}
+
+// TxStatus returns the current TxStatus as reported by the server in the ReadyForQuery message.
+//
+// Possible return values:
+// 'I' - idle / not in transaction
+// 'T' - in a transaction
+// 'E' - in a failed transaction
+//
+// See https://www.postgresql.org/docs/current/protocol-message-formats.html.
+func (pgConn *PgConn) TxStatus() byte {
+ return pgConn.txStatus
+}
+
+// SecretKey returns the backend secret key used to send a cancel query message to the server.
+func (pgConn *PgConn) SecretKey() uint32 {
+ return pgConn.secretKey
+}
+
+// Close closes a connection. It is safe to call Close on a already closed connection. Close attempts a clean close by
+// sending the exit message to PostgreSQL. However, this could block so ctx is available to limit the time to wait. The
+// underlying net.Conn.Close() will always be called regardless of any other errors.
+func (pgConn *PgConn) Close(ctx context.Context) error {
+ if pgConn.status == connStatusClosed {
+ return nil
+ }
+ pgConn.status = connStatusClosed
+
+ defer close(pgConn.cleanupDone)
+ defer pgConn.conn.Close()
+
+ if ctx != context.Background() {
+ // Close may be called while a cancellable query is in progress. This will most often be triggered by panic when
+ // a defer closes the connection (possibly indirectly via a transaction or a connection pool). Unwatch to end any
+ // previous watch. It is safe to Unwatch regardless of whether a watch is already is progress.
+ //
+ // See https://github.com/jackc/pgconn/issues/29
+ pgConn.contextWatcher.Unwatch()
+
+ pgConn.contextWatcher.Watch(ctx)
+ defer pgConn.contextWatcher.Unwatch()
+ }
+
+ // Ignore any errors sending Terminate message and waiting for server to close connection.
+ // This mimics the behavior of libpq PQfinish. It calls closePGconn which calls sendTerminateConn which purposefully
+ // ignores errors.
+ //
+ // See https://github.com/jackc/pgx/issues/637
+ pgConn.conn.Write([]byte{'X', 0, 0, 0, 4})
+ pgConn.conn.Read(make([]byte, 1))
+
+ return pgConn.conn.Close()
+}
+
+// asyncClose marks the connection as closed and asynchronously sends a cancel query message and closes the underlying
+// connection.
+func (pgConn *PgConn) asyncClose() {
+ if pgConn.status == connStatusClosed {
+ return
+ }
+ pgConn.status = connStatusClosed
+
+ go func() {
+ defer close(pgConn.cleanupDone)
+ defer pgConn.conn.Close()
+
+ deadline := time.Now().Add(time.Second * 15)
+
+ ctx, cancel := context.WithDeadline(context.Background(), deadline)
+ defer cancel()
+
+ pgConn.CancelRequest(ctx)
+
+ pgConn.conn.SetDeadline(deadline)
+
+ pgConn.conn.Write([]byte{'X', 0, 0, 0, 4})
+ pgConn.conn.Read(make([]byte, 1))
+ }()
+}
+
+// CleanupDone returns a channel that will be closed after all underlying resources have been cleaned up. A closed
+// connection is no longer usable, but underlying resources, in particular the net.Conn, may not have finished closing
+// yet. This is because certain errors such as a context cancellation require that the interrupted function call return
+// immediately, but the error may also cause the connection to be closed. In these cases the underlying resources are
+// closed asynchronously.
+//
+// This is only likely to be useful to connection pools. It gives them a way avoid establishing a new connection while
+// an old connection is still being cleaned up and thereby exceeding the maximum pool size.
+func (pgConn *PgConn) CleanupDone() chan (struct{}) {
+ return pgConn.cleanupDone
+}
+
+// IsClosed reports if the connection has been closed.
+//
+// CleanupDone() can be used to determine if all cleanup has been completed.
+func (pgConn *PgConn) IsClosed() bool {
+ return pgConn.status < connStatusIdle
+}
+
+// IsBusy reports if the connection is busy.
+func (pgConn *PgConn) IsBusy() bool {
+ return pgConn.status == connStatusBusy
+}
+
+// lock locks the connection.
+func (pgConn *PgConn) lock() error {
+ switch pgConn.status {
+ case connStatusBusy:
+ return &connLockError{status: "conn busy"} // This only should be possible in case of an application bug.
+ case connStatusClosed:
+ return &connLockError{status: "conn closed"}
+ case connStatusUninitialized:
+ return &connLockError{status: "conn uninitialized"}
+ }
+ pgConn.status = connStatusBusy
+ return nil
+}
+
+func (pgConn *PgConn) unlock() {
+ switch pgConn.status {
+ case connStatusBusy:
+ pgConn.status = connStatusIdle
+ case connStatusClosed:
+ default:
+ panic("BUG: cannot unlock unlocked connection") // This should only be possible if there is a bug in this package.
+ }
+}
+
+// ParameterStatus returns the value of a parameter reported by the server (e.g.
+// server_version). Returns an empty string for unknown parameters.
+func (pgConn *PgConn) ParameterStatus(key string) string {
+ return pgConn.parameterStatuses[key]
+}
+
+// CommandTag is the result of an Exec function
+type CommandTag []byte
+
+// RowsAffected returns the number of rows affected. If the CommandTag was not
+// for a row affecting command (e.g. "CREATE TABLE") then it returns 0.
+func (ct CommandTag) RowsAffected() int64 {
+ // Find last non-digit
+ idx := -1
+ for i := len(ct) - 1; i >= 0; i-- {
+ if ct[i] >= '0' && ct[i] <= '9' {
+ idx = i
+ } else {
+ break
+ }
+ }
+
+ if idx == -1 {
+ return 0
+ }
+
+ var n int64
+ for _, b := range ct[idx:] {
+ n = n*10 + int64(b-'0')
+ }
+
+ return n
+}
+
+func (ct CommandTag) String() string {
+ return string(ct)
+}
+
+// Insert is true if the command tag starts with "INSERT".
+func (ct CommandTag) Insert() bool {
+ return len(ct) >= 6 &&
+ ct[0] == 'I' &&
+ ct[1] == 'N' &&
+ ct[2] == 'S' &&
+ ct[3] == 'E' &&
+ ct[4] == 'R' &&
+ ct[5] == 'T'
+}
+
+// Update is true if the command tag starts with "UPDATE".
+func (ct CommandTag) Update() bool {
+ return len(ct) >= 6 &&
+ ct[0] == 'U' &&
+ ct[1] == 'P' &&
+ ct[2] == 'D' &&
+ ct[3] == 'A' &&
+ ct[4] == 'T' &&
+ ct[5] == 'E'
+}
+
+// Delete is true if the command tag starts with "DELETE".
+func (ct CommandTag) Delete() bool {
+ return len(ct) >= 6 &&
+ ct[0] == 'D' &&
+ ct[1] == 'E' &&
+ ct[2] == 'L' &&
+ ct[3] == 'E' &&
+ ct[4] == 'T' &&
+ ct[5] == 'E'
+}
+
+// Select is true if the command tag starts with "SELECT".
+func (ct CommandTag) Select() bool {
+ return len(ct) >= 6 &&
+ ct[0] == 'S' &&
+ ct[1] == 'E' &&
+ ct[2] == 'L' &&
+ ct[3] == 'E' &&
+ ct[4] == 'C' &&
+ ct[5] == 'T'
+}
+
+type StatementDescription struct {
+ Name string
+ SQL string
+ ParamOIDs []uint32
+ Fields []pgproto3.FieldDescription
+}
+
+// Prepare creates a prepared statement. If the name is empty, the anonymous prepared statement will be used. This
+// allows Prepare to also to describe statements without creating a server-side prepared statement.
+func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) (*StatementDescription, error) {
+ if err := pgConn.lock(); err != nil {
+ return nil, err
+ }
+ defer pgConn.unlock()
+
+ if ctx != context.Background() {
+ select {
+ case <-ctx.Done():
+ return nil, newContextAlreadyDoneError(ctx)
+ default:
+ }
+ pgConn.contextWatcher.Watch(ctx)
+ defer pgConn.contextWatcher.Unwatch()
+ }
+
+ buf := pgConn.wbuf
+ buf = (&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}).Encode(buf)
+ buf = (&pgproto3.Describe{ObjectType: 'S', Name: name}).Encode(buf)
+ buf = (&pgproto3.Sync{}).Encode(buf)
+
+ n, err := pgConn.conn.Write(buf)
+ if err != nil {
+ pgConn.asyncClose()
+ return nil, &writeError{err: err, safeToRetry: n == 0}
+ }
+
+ psd := &StatementDescription{Name: name, SQL: sql}
+
+ var parseErr error
+
+readloop:
+ for {
+ msg, err := pgConn.receiveMessage()
+ if err != nil {
+ pgConn.asyncClose()
+ return nil, preferContextOverNetTimeoutError(ctx, err)
+ }
+
+ switch msg := msg.(type) {
+ case *pgproto3.ParameterDescription:
+ psd.ParamOIDs = make([]uint32, len(msg.ParameterOIDs))
+ copy(psd.ParamOIDs, msg.ParameterOIDs)
+ case *pgproto3.RowDescription:
+ psd.Fields = make([]pgproto3.FieldDescription, len(msg.Fields))
+ copy(psd.Fields, msg.Fields)
+ case *pgproto3.ErrorResponse:
+ parseErr = ErrorResponseToPgError(msg)
+ case *pgproto3.ReadyForQuery:
+ break readloop
+ }
+ }
+
+ if parseErr != nil {
+ return nil, parseErr
+ }
+ return psd, nil
+}
+
+// ErrorResponseToPgError converts a wire protocol error message to a *PgError.
+func ErrorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError {
+ return &PgError{
+ Severity: msg.Severity,
+ Code: string(msg.Code),
+ Message: string(msg.Message),
+ Detail: string(msg.Detail),
+ Hint: msg.Hint,
+ Position: msg.Position,
+ InternalPosition: msg.InternalPosition,
+ InternalQuery: string(msg.InternalQuery),
+ Where: string(msg.Where),
+ SchemaName: string(msg.SchemaName),
+ TableName: string(msg.TableName),
+ ColumnName: string(msg.ColumnName),
+ DataTypeName: string(msg.DataTypeName),
+ ConstraintName: msg.ConstraintName,
+ File: string(msg.File),
+ Line: msg.Line,
+ Routine: string(msg.Routine),
+ }
+}
+
+func noticeResponseToNotice(msg *pgproto3.NoticeResponse) *Notice {
+ pgerr := ErrorResponseToPgError((*pgproto3.ErrorResponse)(msg))
+ return (*Notice)(pgerr)
+}
+
+// CancelRequest sends a cancel request to the PostgreSQL server. It returns an error if unable to deliver the cancel
+// request, but lack of an error does not ensure that the query was canceled. As specified in the documentation, there
+// is no way to be sure a query was canceled. See https://www.postgresql.org/docs/11/protocol-flow.html#id-1.10.5.7.9
+func (pgConn *PgConn) CancelRequest(ctx context.Context) error {
+ // Open a cancellation request to the same server. The address is taken from the net.Conn directly instead of reusing
+ // the connection config. This is important in high availability configurations where fallback connections may be
+ // specified or DNS may be used to load balance.
+ serverAddr := pgConn.conn.RemoteAddr()
+ cancelConn, err := pgConn.config.DialFunc(ctx, serverAddr.Network(), serverAddr.String())
+ if err != nil {
+ return err
+ }
+ defer cancelConn.Close()
+
+ if ctx != context.Background() {
+ contextWatcher := ctxwatch.NewContextWatcher(
+ func() { cancelConn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) },
+ func() { cancelConn.SetDeadline(time.Time{}) },
+ )
+ contextWatcher.Watch(ctx)
+ defer contextWatcher.Unwatch()
+ }
+
+ buf := make([]byte, 16)
+ binary.BigEndian.PutUint32(buf[0:4], 16)
+ binary.BigEndian.PutUint32(buf[4:8], 80877102)
+ binary.BigEndian.PutUint32(buf[8:12], uint32(pgConn.pid))
+ binary.BigEndian.PutUint32(buf[12:16], uint32(pgConn.secretKey))
+ _, err = cancelConn.Write(buf)
+ if err != nil {
+ return err
+ }
+
+ _, err = cancelConn.Read(buf)
+ if err != io.EOF {
+ return err
+ }
+
+ return nil
+}
+
+// WaitForNotification waits for a LISTON/NOTIFY message to be received. It returns an error if a notification was not
+// received.
+func (pgConn *PgConn) WaitForNotification(ctx context.Context) error {
+ if err := pgConn.lock(); err != nil {
+ return err
+ }
+ defer pgConn.unlock()
+
+ if ctx != context.Background() {
+ select {
+ case <-ctx.Done():
+ return newContextAlreadyDoneError(ctx)
+ default:
+ }
+
+ pgConn.contextWatcher.Watch(ctx)
+ defer pgConn.contextWatcher.Unwatch()
+ }
+
+ for {
+ msg, err := pgConn.receiveMessage()
+ if err != nil {
+ return preferContextOverNetTimeoutError(ctx, err)
+ }
+
+ switch msg.(type) {
+ case *pgproto3.NotificationResponse:
+ return nil
+ }
+ }
+}
+
+// Exec executes SQL via the PostgreSQL simple query protocol. SQL may contain multiple queries. Execution is
+// implicitly wrapped in a transaction unless a transaction is already in progress or SQL contains transaction control
+// statements.
+//
+// Prefer ExecParams unless executing arbitrary SQL that may contain multiple queries.
+func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader {
+ if err := pgConn.lock(); err != nil {
+ return &MultiResultReader{
+ closed: true,
+ err: err,
+ }
+ }
+
+ pgConn.multiResultReader = MultiResultReader{
+ pgConn: pgConn,
+ ctx: ctx,
+ }
+ multiResult := &pgConn.multiResultReader
+ if ctx != context.Background() {
+ select {
+ case <-ctx.Done():
+ multiResult.closed = true
+ multiResult.err = newContextAlreadyDoneError(ctx)
+ pgConn.unlock()
+ return multiResult
+ default:
+ }
+ pgConn.contextWatcher.Watch(ctx)
+ }
+
+ buf := pgConn.wbuf
+ buf = (&pgproto3.Query{String: sql}).Encode(buf)
+
+ n, err := pgConn.conn.Write(buf)
+ if err != nil {
+ pgConn.asyncClose()
+ pgConn.contextWatcher.Unwatch()
+ multiResult.closed = true
+ multiResult.err = &writeError{err: err, safeToRetry: n == 0}
+ pgConn.unlock()
+ return multiResult
+ }
+
+ return multiResult
+}
+
+// ReceiveResults reads the result that might be returned by Postgres after a SendBytes
+// (e.a. after sending a CopyDone in a copy-both situation).
+//
+// This is a very low level method that requires deep understanding of the PostgreSQL wire protocol to use correctly.
+// See https://www.postgresql.org/docs/current/protocol.html.
+func (pgConn *PgConn) ReceiveResults(ctx context.Context) *MultiResultReader {
+ if err := pgConn.lock(); err != nil {
+ return &MultiResultReader{
+ closed: true,
+ err: err,
+ }
+ }
+
+ pgConn.multiResultReader = MultiResultReader{
+ pgConn: pgConn,
+ ctx: ctx,
+ }
+ multiResult := &pgConn.multiResultReader
+ if ctx != context.Background() {
+ select {
+ case <-ctx.Done():
+ multiResult.closed = true
+ multiResult.err = newContextAlreadyDoneError(ctx)
+ pgConn.unlock()
+ return multiResult
+ default:
+ }
+ pgConn.contextWatcher.Watch(ctx)
+ }
+
+ return multiResult
+}
+
+// ExecParams executes a command via the PostgreSQL extended query protocol.
+//
+// sql is a SQL command string. It may only contain one query. Parameter substitution is positional using $1, $2, $3,
+// etc.
+//
+// paramValues are the parameter values. It must be encoded in the format given by paramFormats.
+//
+// paramOIDs is a slice of data type OIDs for paramValues. If paramOIDs is nil, the server will infer the data type for
+// all parameters. Any paramOID element that is 0 that will cause the server to infer the data type for that parameter.
+// ExecParams will panic if len(paramOIDs) is not 0, 1, or len(paramValues).
+//
+// paramFormats is a slice of format codes determining for each paramValue column whether it is encoded in text or
+// binary format. If paramFormats is nil all params are text format. ExecParams will panic if
+// len(paramFormats) is not 0, 1, or len(paramValues).
+//
+// resultFormats is a slice of format codes determining for each result column whether it is encoded in text or
+// binary format. If resultFormats is nil all results will be in text format.
+//
+// ResultReader must be closed before PgConn can be used again.
+func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) *ResultReader {
+ result := pgConn.execExtendedPrefix(ctx, paramValues)
+ if result.closed {
+ return result
+ }
+
+ buf := pgConn.wbuf
+ buf = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(buf)
+ buf = (&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf)
+
+ pgConn.execExtendedSuffix(buf, result)
+
+ return result
+}
+
+// ExecPrepared enqueues the execution of a prepared statement via the PostgreSQL extended query protocol.
+//
+// paramValues are the parameter values. It must be encoded in the format given by paramFormats.
+//
+// paramFormats is a slice of format codes determining for each paramValue column whether it is encoded in text or
+// binary format. If paramFormats is nil all params are text format. ExecPrepared will panic if
+// len(paramFormats) is not 0, 1, or len(paramValues).
+//
+// resultFormats is a slice of format codes determining for each result column whether it is encoded in text or
+// binary format. If resultFormats is nil all results will be in text format.
+//
+// ResultReader must be closed before PgConn can be used again.
+func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) *ResultReader {
+ result := pgConn.execExtendedPrefix(ctx, paramValues)
+ if result.closed {
+ return result
+ }
+
+ buf := pgConn.wbuf
+ buf = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf)
+
+ pgConn.execExtendedSuffix(buf, result)
+
+ return result
+}
+
+func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]byte) *ResultReader {
+ pgConn.resultReader = ResultReader{
+ pgConn: pgConn,
+ ctx: ctx,
+ }
+ result := &pgConn.resultReader
+
+ if err := pgConn.lock(); err != nil {
+ result.concludeCommand(nil, err)
+ result.closed = true
+ return result
+ }
+
+ if len(paramValues) > math.MaxUint16 {
+ result.concludeCommand(nil, fmt.Errorf("extended protocol limited to %v parameters", math.MaxUint16))
+ result.closed = true
+ pgConn.unlock()
+ return result
+ }
+
+ if ctx != context.Background() {
+ select {
+ case <-ctx.Done():
+ result.concludeCommand(nil, newContextAlreadyDoneError(ctx))
+ result.closed = true
+ pgConn.unlock()
+ return result
+ default:
+ }
+ pgConn.contextWatcher.Watch(ctx)
+ }
+
+ return result
+}
+
+func (pgConn *PgConn) execExtendedSuffix(buf []byte, result *ResultReader) {
+ buf = (&pgproto3.Describe{ObjectType: 'P'}).Encode(buf)
+ buf = (&pgproto3.Execute{}).Encode(buf)
+ buf = (&pgproto3.Sync{}).Encode(buf)
+
+ n, err := pgConn.conn.Write(buf)
+ if err != nil {
+ pgConn.asyncClose()
+ result.concludeCommand(nil, &writeError{err: err, safeToRetry: n == 0})
+ pgConn.contextWatcher.Unwatch()
+ result.closed = true
+ pgConn.unlock()
+ return
+ }
+
+ result.readUntilRowDescription()
+}
+
+// CopyTo executes the copy command sql and copies the results to w.
+func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (CommandTag, error) {
+ if err := pgConn.lock(); err != nil {
+ return nil, err
+ }
+
+ if ctx != context.Background() {
+ select {
+ case <-ctx.Done():
+ pgConn.unlock()
+ return nil, newContextAlreadyDoneError(ctx)
+ default:
+ }
+ pgConn.contextWatcher.Watch(ctx)
+ defer pgConn.contextWatcher.Unwatch()
+ }
+
+ // Send copy to command
+ buf := pgConn.wbuf
+ buf = (&pgproto3.Query{String: sql}).Encode(buf)
+
+ n, err := pgConn.conn.Write(buf)
+ if err != nil {
+ pgConn.asyncClose()
+ pgConn.unlock()
+ return nil, &writeError{err: err, safeToRetry: n == 0}
+ }
+
+ // Read results
+ var commandTag CommandTag
+ var pgErr error
+ for {
+ msg, err := pgConn.receiveMessage()
+ if err != nil {
+ pgConn.asyncClose()
+ return nil, preferContextOverNetTimeoutError(ctx, err)
+ }
+
+ switch msg := msg.(type) {
+ case *pgproto3.CopyDone:
+ case *pgproto3.CopyData:
+ _, err := w.Write(msg.Data)
+ if err != nil {
+ pgConn.asyncClose()
+ return nil, err
+ }
+ case *pgproto3.ReadyForQuery:
+ pgConn.unlock()
+ return commandTag, pgErr
+ case *pgproto3.CommandComplete:
+ commandTag = CommandTag(msg.CommandTag)
+ case *pgproto3.ErrorResponse:
+ pgErr = ErrorResponseToPgError(msg)
+ }
+ }
+}
+
+// CopyFrom executes the copy command sql and copies all of r to the PostgreSQL server.
+//
+// Note: context cancellation will only interrupt operations on the underlying PostgreSQL network connection. Reads on r
+// could still block.
+func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (CommandTag, error) {
+ if err := pgConn.lock(); err != nil {
+ return nil, err
+ }
+ defer pgConn.unlock()
+
+ if ctx != context.Background() {
+ select {
+ case <-ctx.Done():
+ return nil, newContextAlreadyDoneError(ctx)
+ default:
+ }
+ pgConn.contextWatcher.Watch(ctx)
+ defer pgConn.contextWatcher.Unwatch()
+ }
+
+ // Send copy to command
+ buf := pgConn.wbuf
+ buf = (&pgproto3.Query{String: sql}).Encode(buf)
+
+ n, err := pgConn.conn.Write(buf)
+ if err != nil {
+ pgConn.asyncClose()
+ return nil, &writeError{err: err, safeToRetry: n == 0}
+ }
+
+ // Read until copy in response or error.
+ var commandTag CommandTag
+ var pgErr error
+ pendingCopyInResponse := true
+ for pendingCopyInResponse {
+ msg, err := pgConn.receiveMessage()
+ if err != nil {
+ pgConn.asyncClose()
+ return nil, preferContextOverNetTimeoutError(ctx, err)
+ }
+
+ switch msg := msg.(type) {
+ case *pgproto3.CopyInResponse:
+ pendingCopyInResponse = false
+ case *pgproto3.ErrorResponse:
+ pgErr = ErrorResponseToPgError(msg)
+ case *pgproto3.ReadyForQuery:
+ return commandTag, pgErr
+ }
+ }
+
+ // Send copy data
+ abortCopyChan := make(chan struct{})
+ copyErrChan := make(chan error, 1)
+ signalMessageChan := pgConn.signalMessage()
+
+ go func() {
+ buf := make([]byte, 0, 65536)
+ buf = append(buf, 'd')
+ sp := len(buf)
+
+ for {
+ n, readErr := r.Read(buf[5:cap(buf)])
+ if n > 0 {
+ buf = buf[0 : n+5]
+ pgio.SetInt32(buf[sp:], int32(n+4))
+
+ _, writeErr := pgConn.conn.Write(buf)
+ if writeErr != nil {
+ // Write errors are always fatal, but we can't use asyncClose because we are in a different goroutine.
+ pgConn.conn.Close()
+
+ copyErrChan <- writeErr
+ return
+ }
+ }
+ if readErr != nil {
+ copyErrChan <- readErr
+ return
+ }
+
+ select {
+ case <-abortCopyChan:
+ return
+ default:
+ }
+ }
+ }()
+
+ var copyErr error
+ for copyErr == nil && pgErr == nil {
+ select {
+ case copyErr = <-copyErrChan:
+ case <-signalMessageChan:
+ msg, err := pgConn.receiveMessage()
+ if err != nil {
+ pgConn.asyncClose()
+ return nil, preferContextOverNetTimeoutError(ctx, err)
+ }
+
+ switch msg := msg.(type) {
+ case *pgproto3.ErrorResponse:
+ pgErr = ErrorResponseToPgError(msg)
+ default:
+ signalMessageChan = pgConn.signalMessage()
+ }
+ }
+ }
+ close(abortCopyChan)
+
+ buf = buf[:0]
+ if copyErr == io.EOF || pgErr != nil {
+ copyDone := &pgproto3.CopyDone{}
+ buf = copyDone.Encode(buf)
+ } else {
+ copyFail := &pgproto3.CopyFail{Message: copyErr.Error()}
+ buf = copyFail.Encode(buf)
+ }
+ _, err = pgConn.conn.Write(buf)
+ if err != nil {
+ pgConn.asyncClose()
+ return nil, err
+ }
+
+ // Read results
+ for {
+ msg, err := pgConn.receiveMessage()
+ if err != nil {
+ pgConn.asyncClose()
+ return nil, preferContextOverNetTimeoutError(ctx, err)
+ }
+
+ switch msg := msg.(type) {
+ case *pgproto3.ReadyForQuery:
+ return commandTag, pgErr
+ case *pgproto3.CommandComplete:
+ commandTag = CommandTag(msg.CommandTag)
+ case *pgproto3.ErrorResponse:
+ pgErr = ErrorResponseToPgError(msg)
+ }
+ }
+}
+
+// MultiResultReader is a reader for a command that could return multiple results such as Exec or ExecBatch.
+type MultiResultReader struct {
+ pgConn *PgConn
+ ctx context.Context
+
+ rr *ResultReader
+
+ closed bool
+ err error
+}
+
+// ReadAll reads all available results. Calling ReadAll is mutually exclusive with all other MultiResultReader methods.
+func (mrr *MultiResultReader) ReadAll() ([]*Result, error) {
+ var results []*Result
+
+ for mrr.NextResult() {
+ results = append(results, mrr.ResultReader().Read())
+ }
+ err := mrr.Close()
+
+ return results, err
+}
+
+func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) {
+ msg, err := mrr.pgConn.receiveMessage()
+
+ if err != nil {
+ mrr.pgConn.contextWatcher.Unwatch()
+ mrr.err = preferContextOverNetTimeoutError(mrr.ctx, err)
+ mrr.closed = true
+ mrr.pgConn.asyncClose()
+ return nil, mrr.err
+ }
+
+ switch msg := msg.(type) {
+ case *pgproto3.ReadyForQuery:
+ mrr.pgConn.contextWatcher.Unwatch()
+ mrr.closed = true
+ mrr.pgConn.unlock()
+ case *pgproto3.ErrorResponse:
+ mrr.err = ErrorResponseToPgError(msg)
+ }
+
+ return msg, nil
+}
+
+// NextResult returns advances the MultiResultReader to the next result and returns true if a result is available.
+func (mrr *MultiResultReader) NextResult() bool {
+ for !mrr.closed && mrr.err == nil {
+ msg, err := mrr.receiveMessage()
+ if err != nil {
+ return false
+ }
+
+ switch msg := msg.(type) {
+ case *pgproto3.RowDescription:
+ mrr.pgConn.resultReader = ResultReader{
+ pgConn: mrr.pgConn,
+ multiResultReader: mrr,
+ ctx: mrr.ctx,
+ fieldDescriptions: msg.Fields,
+ }
+ mrr.rr = &mrr.pgConn.resultReader
+ return true
+ case *pgproto3.CommandComplete:
+ mrr.pgConn.resultReader = ResultReader{
+ commandTag: CommandTag(msg.CommandTag),
+ commandConcluded: true,
+ closed: true,
+ }
+ mrr.rr = &mrr.pgConn.resultReader
+ return true
+ case *pgproto3.EmptyQueryResponse:
+ return false
+ }
+ }
+
+ return false
+}
+
+// ResultReader returns the current ResultReader.
+func (mrr *MultiResultReader) ResultReader() *ResultReader {
+ return mrr.rr
+}
+
+// Close closes the MultiResultReader and returns the first error that occurred during the MultiResultReader's use.
+func (mrr *MultiResultReader) Close() error {
+ for !mrr.closed {
+ _, err := mrr.receiveMessage()
+ if err != nil {
+ return mrr.err
+ }
+ }
+
+ return mrr.err
+}
+
+// ResultReader is a reader for the result of a single query.
+type ResultReader struct {
+ pgConn *PgConn
+ multiResultReader *MultiResultReader
+ ctx context.Context
+
+ fieldDescriptions []pgproto3.FieldDescription
+ rowValues [][]byte
+ commandTag CommandTag
+ commandConcluded bool
+ closed bool
+ err error
+}
+
+// Result is the saved query response that is returned by calling Read on a ResultReader.
+type Result struct {
+ FieldDescriptions []pgproto3.FieldDescription
+ Rows [][][]byte
+ CommandTag CommandTag
+ Err error
+}
+
+// Read saves the query response to a Result.
+func (rr *ResultReader) Read() *Result {
+ br := &Result{}
+
+ for rr.NextRow() {
+ if br.FieldDescriptions == nil {
+ br.FieldDescriptions = make([]pgproto3.FieldDescription, len(rr.FieldDescriptions()))
+ copy(br.FieldDescriptions, rr.FieldDescriptions())
+ }
+
+ row := make([][]byte, len(rr.Values()))
+ copy(row, rr.Values())
+ br.Rows = append(br.Rows, row)
+ }
+
+ br.CommandTag, br.Err = rr.Close()
+
+ return br
+}
+
+// NextRow advances the ResultReader to the next row and returns true if a row is available.
+func (rr *ResultReader) NextRow() bool {
+ for !rr.commandConcluded {
+ msg, err := rr.receiveMessage()
+ if err != nil {
+ return false
+ }
+
+ switch msg := msg.(type) {
+ case *pgproto3.DataRow:
+ rr.rowValues = msg.Values
+ return true
+ }
+ }
+
+ return false
+}
+
+// FieldDescriptions returns the field descriptions for the current result set. The returned slice is only valid until
+// the ResultReader is closed.
+func (rr *ResultReader) FieldDescriptions() []pgproto3.FieldDescription {
+ return rr.fieldDescriptions
+}
+
+// Values returns the current row data. NextRow must have been previously been called. The returned [][]byte is only
+// valid until the next NextRow call or the ResultReader is closed. However, the underlying byte data is safe to
+// retain a reference to and mutate.
+func (rr *ResultReader) Values() [][]byte {
+ return rr.rowValues
+}
+
+// Close consumes any remaining result data and returns the command tag or
+// error.
+func (rr *ResultReader) Close() (CommandTag, error) {
+ if rr.closed {
+ return rr.commandTag, rr.err
+ }
+ rr.closed = true
+
+ for !rr.commandConcluded {
+ _, err := rr.receiveMessage()
+ if err != nil {
+ return nil, rr.err
+ }
+ }
+
+ if rr.multiResultReader == nil {
+ for {
+ msg, err := rr.receiveMessage()
+ if err != nil {
+ return nil, rr.err
+ }
+
+ switch msg := msg.(type) {
+ // Detect a deferred constraint violation where the ErrorResponse is sent after CommandComplete.
+ case *pgproto3.ErrorResponse:
+ rr.err = ErrorResponseToPgError(msg)
+ case *pgproto3.ReadyForQuery:
+ rr.pgConn.contextWatcher.Unwatch()
+ rr.pgConn.unlock()
+ return rr.commandTag, rr.err
+ }
+ }
+ }
+
+ return rr.commandTag, rr.err
+}
+
+// readUntilRowDescription ensures the ResultReader's fieldDescriptions are loaded. It does not return an error as any
+// error will be stored in the ResultReader.
+func (rr *ResultReader) readUntilRowDescription() {
+ for !rr.commandConcluded {
+ // Peek before receive to avoid consuming a DataRow if the result set does not include a RowDescription method.
+ // This should never happen under normal pgconn usage, but it is possible if SendBytes and ReceiveResults are
+ // manually used to construct a query that does not issue a describe statement.
+ msg, _ := rr.pgConn.peekMessage()
+ if _, ok := msg.(*pgproto3.DataRow); ok {
+ return
+ }
+
+ // Consume the message
+ msg, _ = rr.receiveMessage()
+ if _, ok := msg.(*pgproto3.RowDescription); ok {
+ return
+ }
+ }
+}
+
+func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error) {
+ if rr.multiResultReader == nil {
+ msg, err = rr.pgConn.receiveMessage()
+ } else {
+ msg, err = rr.multiResultReader.receiveMessage()
+ }
+
+ if err != nil {
+ err = preferContextOverNetTimeoutError(rr.ctx, err)
+ rr.concludeCommand(nil, err)
+ rr.pgConn.contextWatcher.Unwatch()
+ rr.closed = true
+ if rr.multiResultReader == nil {
+ rr.pgConn.asyncClose()
+ }
+
+ return nil, rr.err
+ }
+
+ switch msg := msg.(type) {
+ case *pgproto3.RowDescription:
+ rr.fieldDescriptions = msg.Fields
+ case *pgproto3.CommandComplete:
+ rr.concludeCommand(CommandTag(msg.CommandTag), nil)
+ case *pgproto3.EmptyQueryResponse:
+ rr.concludeCommand(nil, nil)
+ case *pgproto3.ErrorResponse:
+ rr.concludeCommand(nil, ErrorResponseToPgError(msg))
+ }
+
+ return msg, nil
+}
+
+func (rr *ResultReader) concludeCommand(commandTag CommandTag, err error) {
+ // Keep the first error that is recorded. Store the error before checking if the command is already concluded to
+ // allow for receiving an error after CommandComplete but before ReadyForQuery.
+ if err != nil && rr.err == nil {
+ rr.err = err
+ }
+
+ if rr.commandConcluded {
+ return
+ }
+
+ rr.commandTag = commandTag
+ rr.rowValues = nil
+ rr.commandConcluded = true
+}
+
+// Batch is a collection of queries that can be sent to the PostgreSQL server in a single round-trip.
+type Batch struct {
+ buf []byte
+}
+
+// ExecParams appends an ExecParams command to the batch. See PgConn.ExecParams for parameter descriptions.
+func (batch *Batch) ExecParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) {
+ batch.buf = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(batch.buf)
+ batch.ExecPrepared("", paramValues, paramFormats, resultFormats)
+}
+
+// ExecPrepared appends an ExecPrepared e command to the batch. See PgConn.ExecPrepared for parameter descriptions.
+func (batch *Batch) ExecPrepared(stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) {
+ batch.buf = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(batch.buf)
+ batch.buf = (&pgproto3.Describe{ObjectType: 'P'}).Encode(batch.buf)
+ batch.buf = (&pgproto3.Execute{}).Encode(batch.buf)
+}
+
+// ExecBatch executes all the queries in batch in a single round-trip. Execution is implicitly transactional unless a
+// transaction is already in progress or SQL contains transaction control statements.
+func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultReader {
+ if err := pgConn.lock(); err != nil {
+ return &MultiResultReader{
+ closed: true,
+ err: err,
+ }
+ }
+
+ pgConn.multiResultReader = MultiResultReader{
+ pgConn: pgConn,
+ ctx: ctx,
+ }
+ multiResult := &pgConn.multiResultReader
+
+ if ctx != context.Background() {
+ select {
+ case <-ctx.Done():
+ multiResult.closed = true
+ multiResult.err = newContextAlreadyDoneError(ctx)
+ pgConn.unlock()
+ return multiResult
+ default:
+ }
+ pgConn.contextWatcher.Watch(ctx)
+ }
+
+ batch.buf = (&pgproto3.Sync{}).Encode(batch.buf)
+
+ // A large batch can deadlock without concurrent reading and writing. If the Write fails the underlying net.Conn is
+ // closed. This is all that can be done without introducing a race condition or adding a concurrent safe communication
+ // channel to relay the error back. The practical effect of this is that the underlying Write error is not reported.
+ // The error the code reading the batch results receives will be a closed connection error.
+ //
+ // See https://github.com/jackc/pgx/issues/374.
+ go func() {
+ _, err := pgConn.conn.Write(batch.buf)
+ if err != nil {
+ pgConn.conn.Close()
+ }
+ }()
+
+ return multiResult
+}
+
+// EscapeString escapes a string such that it can safely be interpolated into a SQL command string. It does not include
+// the surrounding single quotes.
+//
+// The current implementation requires that standard_conforming_strings=on and client_encoding="UTF8". If these
+// conditions are not met an error will be returned. It is possible these restrictions will be lifted in the future.
+func (pgConn *PgConn) EscapeString(s string) (string, error) {
+ if pgConn.ParameterStatus("standard_conforming_strings") != "on" {
+ return "", errors.New("EscapeString must be run with standard_conforming_strings=on")
+ }
+
+ if pgConn.ParameterStatus("client_encoding") != "UTF8" {
+ return "", errors.New("EscapeString must be run with client_encoding=UTF8")
+ }
+
+ return strings.Replace(s, "'", "''", -1), nil
+}
+
+// HijackedConn is the result of hijacking a connection.
+//
+// Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning
+// compatibility.
+type HijackedConn struct {
+ Conn net.Conn // the underlying TCP or unix domain socket connection
+ PID uint32 // backend pid
+ SecretKey uint32 // key to use to send a cancel query message to the server
+ ParameterStatuses map[string]string // parameters that have been reported by the server
+ TxStatus byte
+ Frontend Frontend
+ Config *Config
+}
+
+// Hijack extracts the internal connection data. pgConn must be in an idle state. pgConn is unusable after hijacking.
+// Hijacking is typically only useful when using pgconn to establish a connection, but taking complete control of the
+// raw connection after that (e.g. a load balancer or proxy).
+//
+// Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning
+// compatibility.
+func (pgConn *PgConn) Hijack() (*HijackedConn, error) {
+ if err := pgConn.lock(); err != nil {
+ return nil, err
+ }
+ pgConn.status = connStatusClosed
+
+ return &HijackedConn{
+ Conn: pgConn.conn,
+ PID: pgConn.pid,
+ SecretKey: pgConn.secretKey,
+ ParameterStatuses: pgConn.parameterStatuses,
+ TxStatus: pgConn.txStatus,
+ Frontend: pgConn.frontend,
+ Config: pgConn.config,
+ }, nil
+}
+
+// Construct created a PgConn from an already established connection to a PostgreSQL server. This is the inverse of
+// PgConn.Hijack. The connection must be in an idle state.
+//
+// Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning
+// compatibility.
+func Construct(hc *HijackedConn) (*PgConn, error) {
+ pgConn := &PgConn{
+ conn: hc.Conn,
+ pid: hc.PID,
+ secretKey: hc.SecretKey,
+ parameterStatuses: hc.ParameterStatuses,
+ txStatus: hc.TxStatus,
+ frontend: hc.Frontend,
+ config: hc.Config,
+
+ status: connStatusIdle,
+
+ wbuf: make([]byte, 0, wbufLen),
+ cleanupDone: make(chan struct{}),
+ }
+
+ pgConn.contextWatcher = ctxwatch.NewContextWatcher(
+ func() { pgConn.conn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) },
+ func() { pgConn.conn.SetDeadline(time.Time{}) },
+ )
+
+ return pgConn, nil
+}