summaryrefslogtreecommitdiff
path: root/vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go')
-rw-r--r--vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go301
1 files changed, 201 insertions, 100 deletions
diff --git a/vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go b/vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go
index 8656ea518..9f84605fe 100644
--- a/vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go
+++ b/vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go
@@ -13,11 +13,12 @@ import (
"net"
"strconv"
"strings"
+ "sync"
"time"
"github.com/jackc/pgx/v5/internal/iobufpool"
- "github.com/jackc/pgx/v5/internal/nbconn"
"github.com/jackc/pgx/v5/internal/pgio"
+ "github.com/jackc/pgx/v5/pgconn/internal/bgreader"
"github.com/jackc/pgx/v5/pgconn/internal/ctxwatch"
"github.com/jackc/pgx/v5/pgproto3"
)
@@ -65,17 +66,24 @@ type NotificationHandler func(*PgConn, *Notification)
// PgConn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage.
type PgConn struct {
- conn nbconn.Conn // the non-blocking wrapper for the underlying TCP or unix domain socket connection
+ conn net.Conn
pid uint32 // backend pid
secretKey uint32 // key to use to send a cancel query message to the server
parameterStatuses map[string]string // parameters that have been reported by the server
txStatus byte
frontend *pgproto3.Frontend
+ bgReader *bgreader.BGReader
+ slowWriteTimer *time.Timer
config *Config
status byte // One of connStatus* constants
+ bufferingReceive bool
+ bufferingReceiveMux sync.Mutex
+ bufferingReceiveMsg pgproto3.BackendMessage
+ bufferingReceiveErr error
+
peekedMsg pgproto3.BackendMessage
// Reusable / preallocated resources
@@ -266,14 +274,13 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
if err != nil {
return nil, &connectError{config: config, msg: "dial error", err: normalizeTimeoutError(ctx, err)}
}
- nbNetConn := nbconn.NewNetConn(netConn, false)
- pgConn.conn = nbNetConn
- pgConn.contextWatcher = newContextWatcher(nbNetConn)
+ pgConn.conn = netConn
+ pgConn.contextWatcher = newContextWatcher(netConn)
pgConn.contextWatcher.Watch(ctx)
if fallbackConfig.TLSConfig != nil {
- nbTLSConn, err := startTLS(nbNetConn, fallbackConfig.TLSConfig)
+ nbTLSConn, err := startTLS(netConn, fallbackConfig.TLSConfig)
pgConn.contextWatcher.Unwatch() // Always unwatch `netConn` after TLS.
if err != nil {
netConn.Close()
@@ -289,7 +296,9 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
pgConn.parameterStatuses = make(map[string]string)
pgConn.status = connStatusConnecting
- pgConn.frontend = config.BuildFrontend(pgConn.conn, pgConn.conn)
+ pgConn.bgReader = bgreader.New(pgConn.conn)
+ pgConn.slowWriteTimer = time.AfterFunc(time.Duration(math.MaxInt64), pgConn.bgReader.Start)
+ pgConn.frontend = config.BuildFrontend(pgConn.bgReader, pgConn.conn)
startupMsg := pgproto3.StartupMessage{
ProtocolVersion: pgproto3.ProtocolVersionNumber,
@@ -307,9 +316,9 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
}
pgConn.frontend.Send(&startupMsg)
- if err := pgConn.frontend.Flush(); err != nil {
+ if err := pgConn.flushWithPotentialWriteReadDeadlock(); err != nil {
pgConn.conn.Close()
- return nil, &connectError{config: config, msg: "failed to write startup message", err: err}
+ return nil, &connectError{config: config, msg: "failed to write startup message", err: normalizeTimeoutError(ctx, err)}
}
for {
@@ -392,7 +401,7 @@ func newContextWatcher(conn net.Conn) *ctxwatch.ContextWatcher {
)
}
-func startTLS(conn *nbconn.NetConn, tlsConfig *tls.Config) (*nbconn.TLSConn, error) {
+func startTLS(conn net.Conn, tlsConfig *tls.Config) (net.Conn, error) {
err := binary.Write(conn, binary.BigEndian, []int32{8, 80877103})
if err != nil {
return nil, err
@@ -407,17 +416,12 @@ func startTLS(conn *nbconn.NetConn, tlsConfig *tls.Config) (*nbconn.TLSConn, err
return nil, errors.New("server refused TLS connection")
}
- tlsConn, err := nbconn.TLSClient(conn, tlsConfig)
- if err != nil {
- return nil, err
- }
-
- return tlsConn, nil
+ return tls.Client(conn, tlsConfig), nil
}
func (pgConn *PgConn) txPasswordMessage(password string) (err error) {
pgConn.frontend.Send(&pgproto3.PasswordMessage{Password: password})
- return pgConn.frontend.Flush()
+ return pgConn.flushWithPotentialWriteReadDeadlock()
}
func hexMD5(s string) string {
@@ -426,6 +430,24 @@ func hexMD5(s string) string {
return hex.EncodeToString(hash.Sum(nil))
}
+func (pgConn *PgConn) signalMessage() chan struct{} {
+ if pgConn.bufferingReceive {
+ panic("BUG: signalMessage when already in progress")
+ }
+
+ pgConn.bufferingReceive = true
+ pgConn.bufferingReceiveMux.Lock()
+
+ ch := make(chan struct{})
+ go func() {
+ pgConn.bufferingReceiveMsg, pgConn.bufferingReceiveErr = pgConn.frontend.Receive()
+ pgConn.bufferingReceiveMux.Unlock()
+ close(ch)
+ }()
+
+ return ch
+}
+
// ReceiveMessage receives one wire protocol message from the PostgreSQL server. It must only be used when the
// connection is not busy. e.g. It is an error to call ReceiveMessage while reading the result of a query. The messages
// are still handled by the core pgconn message handling system so receiving a NotificationResponse will still trigger
@@ -465,13 +487,25 @@ func (pgConn *PgConn) peekMessage() (pgproto3.BackendMessage, error) {
return pgConn.peekedMsg, nil
}
- msg, err := pgConn.frontend.Receive()
+ var msg pgproto3.BackendMessage
+ var err error
+ if pgConn.bufferingReceive {
+ pgConn.bufferingReceiveMux.Lock()
+ msg = pgConn.bufferingReceiveMsg
+ err = pgConn.bufferingReceiveErr
+ pgConn.bufferingReceiveMux.Unlock()
+ pgConn.bufferingReceive = false
- if err != nil {
- if errors.Is(err, nbconn.ErrWouldBlock) {
- return nil, err
+ // If a timeout error happened in the background try the read again.
+ var netErr net.Error
+ if errors.As(err, &netErr) && netErr.Timeout() {
+ msg, err = pgConn.frontend.Receive()
}
+ } else {
+ msg, err = pgConn.frontend.Receive()
+ }
+ if err != nil {
// Close on anything other than timeout error - everything else is fatal
var netErr net.Error
isNetErr := errors.As(err, &netErr)
@@ -582,7 +616,7 @@ func (pgConn *PgConn) Close(ctx context.Context) error {
//
// See https://github.com/jackc/pgx/issues/637
pgConn.frontend.Send(&pgproto3.Terminate{})
- pgConn.frontend.Flush()
+ pgConn.flushWithPotentialWriteReadDeadlock()
return pgConn.conn.Close()
}
@@ -609,7 +643,7 @@ func (pgConn *PgConn) asyncClose() {
pgConn.conn.SetDeadline(deadline)
pgConn.frontend.Send(&pgproto3.Terminate{})
- pgConn.frontend.Flush()
+ pgConn.flushWithPotentialWriteReadDeadlock()
}()
}
@@ -784,7 +818,7 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [
pgConn.frontend.SendParse(&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs})
pgConn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'S', Name: name})
pgConn.frontend.SendSync(&pgproto3.Sync{})
- err := pgConn.frontend.Flush()
+ err := pgConn.flushWithPotentialWriteReadDeadlock()
if err != nil {
pgConn.asyncClose()
return nil, err
@@ -857,9 +891,28 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error {
// the connection config. This is important in high availability configurations where fallback connections may be
// specified or DNS may be used to load balance.
serverAddr := pgConn.conn.RemoteAddr()
- cancelConn, err := pgConn.config.DialFunc(ctx, serverAddr.Network(), serverAddr.String())
+ var serverNetwork string
+ var serverAddress string
+ if serverAddr.Network() == "unix" {
+ // for unix sockets, RemoteAddr() calls getpeername() which returns the name the
+ // server passed to bind(). For Postgres, this is always a relative path "./.s.PGSQL.5432"
+ // so connecting to it will fail. Fall back to the config's value
+ serverNetwork, serverAddress = NetworkAddress(pgConn.config.Host, pgConn.config.Port)
+ } else {
+ serverNetwork, serverAddress = serverAddr.Network(), serverAddr.String()
+ }
+ cancelConn, err := pgConn.config.DialFunc(ctx, serverNetwork, serverAddress)
if err != nil {
- return err
+ // In case of unix sockets, RemoteAddr() returns only the file part of the path. If the
+ // first connect failed, try the config.
+ if serverAddr.Network() != "unix" {
+ return err
+ }
+ serverNetwork, serverAddr := NetworkAddress(pgConn.config.Host, pgConn.config.Port)
+ cancelConn, err = pgConn.config.DialFunc(ctx, serverNetwork, serverAddr)
+ if err != nil {
+ return err
+ }
}
defer cancelConn.Close()
@@ -877,17 +930,11 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error {
binary.BigEndian.PutUint32(buf[4:8], 80877102)
binary.BigEndian.PutUint32(buf[8:12], uint32(pgConn.pid))
binary.BigEndian.PutUint32(buf[12:16], uint32(pgConn.secretKey))
+ // Postgres will process the request and close the connection
+ // so when don't need to read the reply
+ // https://www.postgresql.org/docs/current/protocol-flow.html#id-1.10.6.7.10
_, err = cancelConn.Write(buf)
- if err != nil {
- return err
- }
-
- _, err = cancelConn.Read(buf)
- if err != io.EOF {
- return err
- }
-
- return nil
+ return err
}
// WaitForNotification waits for a LISTON/NOTIFY message to be received. It returns an error if a notification was not
@@ -953,7 +1000,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader {
}
pgConn.frontend.SendQuery(&pgproto3.Query{String: sql})
- err := pgConn.frontend.Flush()
+ err := pgConn.flushWithPotentialWriteReadDeadlock()
if err != nil {
pgConn.asyncClose()
pgConn.contextWatcher.Unwatch()
@@ -1064,7 +1111,7 @@ func (pgConn *PgConn) execExtendedSuffix(result *ResultReader) {
pgConn.frontend.SendExecute(&pgproto3.Execute{})
pgConn.frontend.SendSync(&pgproto3.Sync{})
- err := pgConn.frontend.Flush()
+ err := pgConn.flushWithPotentialWriteReadDeadlock()
if err != nil {
pgConn.asyncClose()
result.concludeCommand(CommandTag{}, err)
@@ -1097,7 +1144,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
// Send copy to command
pgConn.frontend.SendQuery(&pgproto3.Query{String: sql})
- err := pgConn.frontend.Flush()
+ err := pgConn.flushWithPotentialWriteReadDeadlock()
if err != nil {
pgConn.asyncClose()
pgConn.unlock()
@@ -1153,85 +1200,91 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
defer pgConn.contextWatcher.Unwatch()
}
- // Send copy to command
+ // Send copy from query
pgConn.frontend.SendQuery(&pgproto3.Query{String: sql})
- err := pgConn.frontend.Flush()
+ err := pgConn.flushWithPotentialWriteReadDeadlock()
if err != nil {
pgConn.asyncClose()
return CommandTag{}, err
}
- err = pgConn.conn.SetReadDeadline(nbconn.NonBlockingDeadline)
- if err != nil {
- pgConn.asyncClose()
- return CommandTag{}, err
- }
- nonblocking := true
- defer func() {
- if nonblocking {
- pgConn.conn.SetReadDeadline(time.Time{})
- }
- }()
+ // Send copy data
+ abortCopyChan := make(chan struct{})
+ copyErrChan := make(chan error, 1)
+ signalMessageChan := pgConn.signalMessage()
+ var wg sync.WaitGroup
+ wg.Add(1)
- buf := iobufpool.Get(65536)
- defer iobufpool.Put(buf)
- (*buf)[0] = 'd'
+ go func() {
+ defer wg.Done()
+ buf := iobufpool.Get(65536)
+ defer iobufpool.Put(buf)
+ (*buf)[0] = 'd'
- var readErr, pgErr error
- for pgErr == nil {
- // Read chunk from r.
- var n int
- n, readErr = r.Read((*buf)[5:cap(*buf)])
+ for {
+ n, readErr := r.Read((*buf)[5:cap(*buf)])
+ if n > 0 {
+ *buf = (*buf)[0 : n+5]
+ pgio.SetInt32((*buf)[1:], int32(n+4))
- // Send chunk to PostgreSQL.
- if n > 0 {
- *buf = (*buf)[0 : n+5]
- pgio.SetInt32((*buf)[1:], int32(n+4))
+ writeErr := pgConn.frontend.SendUnbufferedEncodedCopyData(*buf)
+ if writeErr != nil {
+ // Write errors are always fatal, but we can't use asyncClose because we are in a different goroutine. Not
+ // setting pgConn.status or closing pgConn.cleanupDone for the same reason.
+ pgConn.conn.Close()
- writeErr := pgConn.frontend.SendUnbufferedEncodedCopyData(*buf)
- if writeErr != nil {
- pgConn.asyncClose()
- return CommandTag{}, err
+ copyErrChan <- writeErr
+ return
+ }
+ }
+ if readErr != nil {
+ copyErrChan <- readErr
+ return
}
- }
- // Abort loop if there was a read error.
- if readErr != nil {
- break
+ select {
+ case <-abortCopyChan:
+ return
+ default:
+ }
}
+ }()
- // Read messages until error or none available.
- for pgErr == nil {
- msg, err := pgConn.receiveMessage()
- if err != nil {
- if errors.Is(err, nbconn.ErrWouldBlock) {
- break
- }
- pgConn.asyncClose()
+ var pgErr error
+ var copyErr error
+ for copyErr == nil && pgErr == nil {
+ select {
+ case copyErr = <-copyErrChan:
+ case <-signalMessageChan:
+ // If pgConn.receiveMessage encounters an error it will call pgConn.asyncClose. But that is a race condition with
+ // the goroutine. So instead check pgConn.bufferingReceiveErr which will have been set by the signalMessage. If an
+ // error is found then forcibly close the connection without sending the Terminate message.
+ if err := pgConn.bufferingReceiveErr; err != nil {
+ pgConn.status = connStatusClosed
+ pgConn.conn.Close()
+ close(pgConn.cleanupDone)
return CommandTag{}, normalizeTimeoutError(ctx, err)
}
+ msg, _ := pgConn.receiveMessage()
switch msg := msg.(type) {
case *pgproto3.ErrorResponse:
pgErr = ErrorResponseToPgError(msg)
- break
+ default:
+ signalMessageChan = pgConn.signalMessage()
}
}
}
+ close(abortCopyChan)
+ // Make sure io goroutine finishes before writing.
+ wg.Wait()
- err = pgConn.conn.SetReadDeadline(time.Time{})
- if err != nil {
- pgConn.asyncClose()
- return CommandTag{}, err
- }
- nonblocking = false
-
- if readErr == io.EOF || pgErr != nil {
+ if copyErr == io.EOF || pgErr != nil {
pgConn.frontend.Send(&pgproto3.CopyDone{})
} else {
- pgConn.frontend.Send(&pgproto3.CopyFail{Message: readErr.Error()})
+ pgConn.frontend.Send(&pgproto3.CopyFail{Message: copyErr.Error()})
}
- err = pgConn.frontend.Flush()
+ err = pgConn.flushWithPotentialWriteReadDeadlock()
if err != nil {
pgConn.asyncClose()
return CommandTag{}, err
@@ -1426,7 +1479,8 @@ func (rr *ResultReader) NextRow() bool {
}
// FieldDescriptions returns the field descriptions for the current result set. The returned slice is only valid until
-// the ResultReader is closed.
+// the ResultReader is closed. It may return nil (for example, if the query did not return a result set or an error was
+// encountered.)
func (rr *ResultReader) FieldDescriptions() []FieldDescription {
return rr.fieldDescriptions
}
@@ -1592,7 +1646,9 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR
batch.buf = (&pgproto3.Sync{}).Encode(batch.buf)
+ pgConn.enterPotentialWriteReadDeadlock()
_, err := pgConn.conn.Write(batch.buf)
+ pgConn.exitPotentialWriteReadDeadlock()
if err != nil {
multiResult.closed = true
multiResult.err = err
@@ -1620,29 +1676,72 @@ func (pgConn *PgConn) EscapeString(s string) (string, error) {
return strings.Replace(s, "'", "''", -1), nil
}
-// CheckConn checks the underlying connection without writing any bytes. This is currently implemented by reading and
-// buffering until the read would block or an error occurs. This can be used to check if the server has closed the
-// connection. If this is done immediately before sending a query it reduces the chances a query will be sent that fails
+// CheckConn checks the underlying connection without writing any bytes. This is currently implemented by doing a read
+// with a very short deadline. This can be useful because a TCP connection can be broken such that a write will appear
+// to succeed even though it will never actually reach the server. Reading immediately before a write will detect this
+// condition. If this is done immediately before sending a query it reduces the chances a query will be sent that fails
// without the client knowing whether the server received it or not.
+//
+// Deprecated: CheckConn is deprecated in favor of Ping. CheckConn cannot detect all types of broken connections where
+// the write would still appear to succeed. Prefer Ping unless on a high latency connection.
func (pgConn *PgConn) CheckConn() error {
- err := pgConn.conn.BufferReadUntilBlock()
- if err != nil && !errors.Is(err, nbconn.ErrWouldBlock) {
- return err
+ ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
+ defer cancel()
+
+ _, err := pgConn.ReceiveMessage(ctx)
+ if err != nil {
+ if !Timeout(err) {
+ return err
+ }
}
+
return nil
}
+// Ping pings the server. This can be useful because a TCP connection can be broken such that a write will appear to
+// succeed even though it will never actually reach the server. Pinging immediately before sending a query reduces the
+// chances a query will be sent that fails without the client knowing whether the server received it or not.
+func (pgConn *PgConn) Ping(ctx context.Context) error {
+ return pgConn.Exec(ctx, "-- ping").Close()
+}
+
// makeCommandTag makes a CommandTag. It does not retain a reference to buf or buf's underlying memory.
func (pgConn *PgConn) makeCommandTag(buf []byte) CommandTag {
return CommandTag{s: string(buf)}
}
+// enterPotentialWriteReadDeadlock must be called before a write that could deadlock if the server is simultaneously
+// blocked writing to us.
+func (pgConn *PgConn) enterPotentialWriteReadDeadlock() {
+ // The time to wait is somewhat arbitrary. A Write should only take as long as the syscall and memcpy to the OS
+ // outbound network buffer unless the buffer is full (which potentially is a block). It needs to be long enough for
+ // the normal case, but short enough not to kill performance if a block occurs.
+ //
+ // In addition, on Windows the default timer resolution is 15.6ms. So setting the timer to less than that is
+ // ineffective.
+ pgConn.slowWriteTimer.Reset(15 * time.Millisecond)
+}
+
+// exitPotentialWriteReadDeadlock must be called after a call to enterPotentialWriteReadDeadlock.
+func (pgConn *PgConn) exitPotentialWriteReadDeadlock() {
+ if !pgConn.slowWriteTimer.Reset(time.Duration(math.MaxInt64)) {
+ pgConn.slowWriteTimer.Stop()
+ }
+}
+
+func (pgConn *PgConn) flushWithPotentialWriteReadDeadlock() error {
+ pgConn.enterPotentialWriteReadDeadlock()
+ err := pgConn.frontend.Flush()
+ pgConn.exitPotentialWriteReadDeadlock()
+ return err
+}
+
// HijackedConn is the result of hijacking a connection.
//
// Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning
// compatibility.
type HijackedConn struct {
- Conn nbconn.Conn // the non-blocking wrapper of the underlying TCP or unix domain socket connection
+ Conn net.Conn
PID uint32 // backend pid
SecretKey uint32 // key to use to send a cancel query message to the server
ParameterStatuses map[string]string // parameters that have been reported by the server
@@ -1695,6 +1794,8 @@ func Construct(hc *HijackedConn) (*PgConn, error) {
}
pgConn.contextWatcher = newContextWatcher(pgConn.conn)
+ pgConn.bgReader = bgreader.New(pgConn.conn)
+ pgConn.slowWriteTimer = time.AfterFunc(time.Duration(math.MaxInt64), pgConn.bgReader.Start)
return pgConn, nil
}
@@ -1817,7 +1918,7 @@ func (p *Pipeline) Flush() error {
return errors.New("pipeline closed")
}
- err := p.conn.frontend.Flush()
+ err := p.conn.flushWithPotentialWriteReadDeadlock()
if err != nil {
err = normalizeTimeoutError(p.ctx, err)