summaryrefslogtreecommitdiff
path: root/vendor/github.com/jackc/pgx/v5/pgconn
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/jackc/pgx/v5/pgconn')
-rw-r--r--vendor/github.com/jackc/pgx/v5/pgconn/auth_scram.go4
-rw-r--r--vendor/github.com/jackc/pgx/v5/pgconn/internal/bgreader/bgreader.go132
-rw-r--r--vendor/github.com/jackc/pgx/v5/pgconn/krb5.go2
-rw-r--r--vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go305
4 files changed, 338 insertions, 105 deletions
diff --git a/vendor/github.com/jackc/pgx/v5/pgconn/auth_scram.go b/vendor/github.com/jackc/pgx/v5/pgconn/auth_scram.go
index 6ca9e3379..8c4b2de3c 100644
--- a/vendor/github.com/jackc/pgx/v5/pgconn/auth_scram.go
+++ b/vendor/github.com/jackc/pgx/v5/pgconn/auth_scram.go
@@ -42,7 +42,7 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error {
Data: sc.clientFirstMessage(),
}
c.frontend.Send(saslInitialResponse)
- err = c.frontend.Flush()
+ err = c.flushWithPotentialWriteReadDeadlock()
if err != nil {
return err
}
@@ -62,7 +62,7 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error {
Data: []byte(sc.clientFinalMessage()),
}
c.frontend.Send(saslResponse)
- err = c.frontend.Flush()
+ err = c.flushWithPotentialWriteReadDeadlock()
if err != nil {
return err
}
diff --git a/vendor/github.com/jackc/pgx/v5/pgconn/internal/bgreader/bgreader.go b/vendor/github.com/jackc/pgx/v5/pgconn/internal/bgreader/bgreader.go
new file mode 100644
index 000000000..aa1a3d39c
--- /dev/null
+++ b/vendor/github.com/jackc/pgx/v5/pgconn/internal/bgreader/bgreader.go
@@ -0,0 +1,132 @@
+// Package bgreader provides a io.Reader that can optionally buffer reads in the background.
+package bgreader
+
+import (
+ "io"
+ "sync"
+
+ "github.com/jackc/pgx/v5/internal/iobufpool"
+)
+
+const (
+ bgReaderStatusStopped = iota
+ bgReaderStatusRunning
+ bgReaderStatusStopping
+)
+
+// BGReader is an io.Reader that can optionally buffer reads in the background. It is safe for concurrent use.
+type BGReader struct {
+ r io.Reader
+
+ cond *sync.Cond
+ bgReaderStatus int32
+ readResults []readResult
+}
+
+type readResult struct {
+ buf *[]byte
+ err error
+}
+
+// Start starts the backgrounder reader. If the background reader is already running this is a no-op. The background
+// reader will stop automatically when the underlying reader returns an error.
+func (r *BGReader) Start() {
+ r.cond.L.Lock()
+ defer r.cond.L.Unlock()
+
+ switch r.bgReaderStatus {
+ case bgReaderStatusStopped:
+ r.bgReaderStatus = bgReaderStatusRunning
+ go r.bgRead()
+ case bgReaderStatusRunning:
+ // no-op
+ case bgReaderStatusStopping:
+ r.bgReaderStatus = bgReaderStatusRunning
+ }
+}
+
+// Stop tells the background reader to stop after the in progress Read returns. It is safe to call Stop when the
+// background reader is not running.
+func (r *BGReader) Stop() {
+ r.cond.L.Lock()
+ defer r.cond.L.Unlock()
+
+ switch r.bgReaderStatus {
+ case bgReaderStatusStopped:
+ // no-op
+ case bgReaderStatusRunning:
+ r.bgReaderStatus = bgReaderStatusStopping
+ case bgReaderStatusStopping:
+ // no-op
+ }
+}
+
+func (r *BGReader) bgRead() {
+ keepReading := true
+ for keepReading {
+ buf := iobufpool.Get(8192)
+ n, err := r.r.Read(*buf)
+ *buf = (*buf)[:n]
+
+ r.cond.L.Lock()
+ r.readResults = append(r.readResults, readResult{buf: buf, err: err})
+ if r.bgReaderStatus == bgReaderStatusStopping || err != nil {
+ r.bgReaderStatus = bgReaderStatusStopped
+ keepReading = false
+ }
+ r.cond.L.Unlock()
+ r.cond.Broadcast()
+ }
+}
+
+// Read implements the io.Reader interface.
+func (r *BGReader) Read(p []byte) (int, error) {
+ r.cond.L.Lock()
+ defer r.cond.L.Unlock()
+
+ if len(r.readResults) > 0 {
+ return r.readFromReadResults(p)
+ }
+
+ // There are no unread background read results and the background reader is stopped.
+ if r.bgReaderStatus == bgReaderStatusStopped {
+ return r.r.Read(p)
+ }
+
+ // Wait for results from the background reader
+ for len(r.readResults) == 0 {
+ r.cond.Wait()
+ }
+ return r.readFromReadResults(p)
+}
+
+// readBackgroundResults reads a result previously read by the background reader. r.cond.L must be held.
+func (r *BGReader) readFromReadResults(p []byte) (int, error) {
+ buf := r.readResults[0].buf
+ var err error
+
+ n := copy(p, *buf)
+ if n == len(*buf) {
+ err = r.readResults[0].err
+ iobufpool.Put(buf)
+ if len(r.readResults) == 1 {
+ r.readResults = nil
+ } else {
+ r.readResults = r.readResults[1:]
+ }
+ } else {
+ *buf = (*buf)[n:]
+ r.readResults[0].buf = buf
+ }
+
+ return n, err
+}
+
+func New(r io.Reader) *BGReader {
+ return &BGReader{
+ r: r,
+ cond: &sync.Cond{
+ L: &sync.Mutex{},
+ },
+ }
+}
diff --git a/vendor/github.com/jackc/pgx/v5/pgconn/krb5.go b/vendor/github.com/jackc/pgx/v5/pgconn/krb5.go
index 969675fd2..3c1af3477 100644
--- a/vendor/github.com/jackc/pgx/v5/pgconn/krb5.go
+++ b/vendor/github.com/jackc/pgx/v5/pgconn/krb5.go
@@ -63,7 +63,7 @@ func (c *PgConn) gssAuth() error {
Data: nextData,
}
c.frontend.Send(gssResponse)
- err = c.frontend.Flush()
+ err = c.flushWithPotentialWriteReadDeadlock()
if err != nil {
return err
}
diff --git a/vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go b/vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go
index 8656ea518..9f84605fe 100644
--- a/vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go
+++ b/vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go
@@ -13,11 +13,12 @@ import (
"net"
"strconv"
"strings"
+ "sync"
"time"
"github.com/jackc/pgx/v5/internal/iobufpool"
- "github.com/jackc/pgx/v5/internal/nbconn"
"github.com/jackc/pgx/v5/internal/pgio"
+ "github.com/jackc/pgx/v5/pgconn/internal/bgreader"
"github.com/jackc/pgx/v5/pgconn/internal/ctxwatch"
"github.com/jackc/pgx/v5/pgproto3"
)
@@ -65,17 +66,24 @@ type NotificationHandler func(*PgConn, *Notification)
// PgConn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage.
type PgConn struct {
- conn nbconn.Conn // the non-blocking wrapper for the underlying TCP or unix domain socket connection
+ conn net.Conn
pid uint32 // backend pid
secretKey uint32 // key to use to send a cancel query message to the server
parameterStatuses map[string]string // parameters that have been reported by the server
txStatus byte
frontend *pgproto3.Frontend
+ bgReader *bgreader.BGReader
+ slowWriteTimer *time.Timer
config *Config
status byte // One of connStatus* constants
+ bufferingReceive bool
+ bufferingReceiveMux sync.Mutex
+ bufferingReceiveMsg pgproto3.BackendMessage
+ bufferingReceiveErr error
+
peekedMsg pgproto3.BackendMessage
// Reusable / preallocated resources
@@ -266,14 +274,13 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
if err != nil {
return nil, &connectError{config: config, msg: "dial error", err: normalizeTimeoutError(ctx, err)}
}
- nbNetConn := nbconn.NewNetConn(netConn, false)
- pgConn.conn = nbNetConn
- pgConn.contextWatcher = newContextWatcher(nbNetConn)
+ pgConn.conn = netConn
+ pgConn.contextWatcher = newContextWatcher(netConn)
pgConn.contextWatcher.Watch(ctx)
if fallbackConfig.TLSConfig != nil {
- nbTLSConn, err := startTLS(nbNetConn, fallbackConfig.TLSConfig)
+ nbTLSConn, err := startTLS(netConn, fallbackConfig.TLSConfig)
pgConn.contextWatcher.Unwatch() // Always unwatch `netConn` after TLS.
if err != nil {
netConn.Close()
@@ -289,7 +296,9 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
pgConn.parameterStatuses = make(map[string]string)
pgConn.status = connStatusConnecting
- pgConn.frontend = config.BuildFrontend(pgConn.conn, pgConn.conn)
+ pgConn.bgReader = bgreader.New(pgConn.conn)
+ pgConn.slowWriteTimer = time.AfterFunc(time.Duration(math.MaxInt64), pgConn.bgReader.Start)
+ pgConn.frontend = config.BuildFrontend(pgConn.bgReader, pgConn.conn)
startupMsg := pgproto3.StartupMessage{
ProtocolVersion: pgproto3.ProtocolVersionNumber,
@@ -307,9 +316,9 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
}
pgConn.frontend.Send(&startupMsg)
- if err := pgConn.frontend.Flush(); err != nil {
+ if err := pgConn.flushWithPotentialWriteReadDeadlock(); err != nil {
pgConn.conn.Close()
- return nil, &connectError{config: config, msg: "failed to write startup message", err: err}
+ return nil, &connectError{config: config, msg: "failed to write startup message", err: normalizeTimeoutError(ctx, err)}
}
for {
@@ -392,7 +401,7 @@ func newContextWatcher(conn net.Conn) *ctxwatch.ContextWatcher {
)
}
-func startTLS(conn *nbconn.NetConn, tlsConfig *tls.Config) (*nbconn.TLSConn, error) {
+func startTLS(conn net.Conn, tlsConfig *tls.Config) (net.Conn, error) {
err := binary.Write(conn, binary.BigEndian, []int32{8, 80877103})
if err != nil {
return nil, err
@@ -407,17 +416,12 @@ func startTLS(conn *nbconn.NetConn, tlsConfig *tls.Config) (*nbconn.TLSConn, err
return nil, errors.New("server refused TLS connection")
}
- tlsConn, err := nbconn.TLSClient(conn, tlsConfig)
- if err != nil {
- return nil, err
- }
-
- return tlsConn, nil
+ return tls.Client(conn, tlsConfig), nil
}
func (pgConn *PgConn) txPasswordMessage(password string) (err error) {
pgConn.frontend.Send(&pgproto3.PasswordMessage{Password: password})
- return pgConn.frontend.Flush()
+ return pgConn.flushWithPotentialWriteReadDeadlock()
}
func hexMD5(s string) string {
@@ -426,6 +430,24 @@ func hexMD5(s string) string {
return hex.EncodeToString(hash.Sum(nil))
}
+func (pgConn *PgConn) signalMessage() chan struct{} {
+ if pgConn.bufferingReceive {
+ panic("BUG: signalMessage when already in progress")
+ }
+
+ pgConn.bufferingReceive = true
+ pgConn.bufferingReceiveMux.Lock()
+
+ ch := make(chan struct{})
+ go func() {
+ pgConn.bufferingReceiveMsg, pgConn.bufferingReceiveErr = pgConn.frontend.Receive()
+ pgConn.bufferingReceiveMux.Unlock()
+ close(ch)
+ }()
+
+ return ch
+}
+
// ReceiveMessage receives one wire protocol message from the PostgreSQL server. It must only be used when the
// connection is not busy. e.g. It is an error to call ReceiveMessage while reading the result of a query. The messages
// are still handled by the core pgconn message handling system so receiving a NotificationResponse will still trigger
@@ -465,13 +487,25 @@ func (pgConn *PgConn) peekMessage() (pgproto3.BackendMessage, error) {
return pgConn.peekedMsg, nil
}
- msg, err := pgConn.frontend.Receive()
-
- if err != nil {
- if errors.Is(err, nbconn.ErrWouldBlock) {
- return nil, err
+ var msg pgproto3.BackendMessage
+ var err error
+ if pgConn.bufferingReceive {
+ pgConn.bufferingReceiveMux.Lock()
+ msg = pgConn.bufferingReceiveMsg
+ err = pgConn.bufferingReceiveErr
+ pgConn.bufferingReceiveMux.Unlock()
+ pgConn.bufferingReceive = false
+
+ // If a timeout error happened in the background try the read again.
+ var netErr net.Error
+ if errors.As(err, &netErr) && netErr.Timeout() {
+ msg, err = pgConn.frontend.Receive()
}
+ } else {
+ msg, err = pgConn.frontend.Receive()
+ }
+ if err != nil {
// Close on anything other than timeout error - everything else is fatal
var netErr net.Error
isNetErr := errors.As(err, &netErr)
@@ -582,7 +616,7 @@ func (pgConn *PgConn) Close(ctx context.Context) error {
//
// See https://github.com/jackc/pgx/issues/637
pgConn.frontend.Send(&pgproto3.Terminate{})
- pgConn.frontend.Flush()
+ pgConn.flushWithPotentialWriteReadDeadlock()
return pgConn.conn.Close()
}
@@ -609,7 +643,7 @@ func (pgConn *PgConn) asyncClose() {
pgConn.conn.SetDeadline(deadline)
pgConn.frontend.Send(&pgproto3.Terminate{})
- pgConn.frontend.Flush()
+ pgConn.flushWithPotentialWriteReadDeadlock()
}()
}
@@ -784,7 +818,7 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [
pgConn.frontend.SendParse(&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs})
pgConn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'S', Name: name})
pgConn.frontend.SendSync(&pgproto3.Sync{})
- err := pgConn.frontend.Flush()
+ err := pgConn.flushWithPotentialWriteReadDeadlock()
if err != nil {
pgConn.asyncClose()
return nil, err
@@ -857,9 +891,28 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error {
// the connection config. This is important in high availability configurations where fallback connections may be
// specified or DNS may be used to load balance.
serverAddr := pgConn.conn.RemoteAddr()
- cancelConn, err := pgConn.config.DialFunc(ctx, serverAddr.Network(), serverAddr.String())
+ var serverNetwork string
+ var serverAddress string
+ if serverAddr.Network() == "unix" {
+ // for unix sockets, RemoteAddr() calls getpeername() which returns the name the
+ // server passed to bind(). For Postgres, this is always a relative path "./.s.PGSQL.5432"
+ // so connecting to it will fail. Fall back to the config's value
+ serverNetwork, serverAddress = NetworkAddress(pgConn.config.Host, pgConn.config.Port)
+ } else {
+ serverNetwork, serverAddress = serverAddr.Network(), serverAddr.String()
+ }
+ cancelConn, err := pgConn.config.DialFunc(ctx, serverNetwork, serverAddress)
if err != nil {
- return err
+ // In case of unix sockets, RemoteAddr() returns only the file part of the path. If the
+ // first connect failed, try the config.
+ if serverAddr.Network() != "unix" {
+ return err
+ }
+ serverNetwork, serverAddr := NetworkAddress(pgConn.config.Host, pgConn.config.Port)
+ cancelConn, err = pgConn.config.DialFunc(ctx, serverNetwork, serverAddr)
+ if err != nil {
+ return err
+ }
}
defer cancelConn.Close()
@@ -877,17 +930,11 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error {
binary.BigEndian.PutUint32(buf[4:8], 80877102)
binary.BigEndian.PutUint32(buf[8:12], uint32(pgConn.pid))
binary.BigEndian.PutUint32(buf[12:16], uint32(pgConn.secretKey))
+ // Postgres will process the request and close the connection
+ // so when don't need to read the reply
+ // https://www.postgresql.org/docs/current/protocol-flow.html#id-1.10.6.7.10
_, err = cancelConn.Write(buf)
- if err != nil {
- return err
- }
-
- _, err = cancelConn.Read(buf)
- if err != io.EOF {
- return err
- }
-
- return nil
+ return err
}
// WaitForNotification waits for a LISTON/NOTIFY message to be received. It returns an error if a notification was not
@@ -953,7 +1000,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader {
}
pgConn.frontend.SendQuery(&pgproto3.Query{String: sql})
- err := pgConn.frontend.Flush()
+ err := pgConn.flushWithPotentialWriteReadDeadlock()
if err != nil {
pgConn.asyncClose()
pgConn.contextWatcher.Unwatch()
@@ -1064,7 +1111,7 @@ func (pgConn *PgConn) execExtendedSuffix(result *ResultReader) {
pgConn.frontend.SendExecute(&pgproto3.Execute{})
pgConn.frontend.SendSync(&pgproto3.Sync{})
- err := pgConn.frontend.Flush()
+ err := pgConn.flushWithPotentialWriteReadDeadlock()
if err != nil {
pgConn.asyncClose()
result.concludeCommand(CommandTag{}, err)
@@ -1097,7 +1144,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
// Send copy to command
pgConn.frontend.SendQuery(&pgproto3.Query{String: sql})
- err := pgConn.frontend.Flush()
+ err := pgConn.flushWithPotentialWriteReadDeadlock()
if err != nil {
pgConn.asyncClose()
pgConn.unlock()
@@ -1153,85 +1200,91 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
defer pgConn.contextWatcher.Unwatch()
}
- // Send copy to command
+ // Send copy from query
pgConn.frontend.SendQuery(&pgproto3.Query{String: sql})
- err := pgConn.frontend.Flush()
- if err != nil {
- pgConn.asyncClose()
- return CommandTag{}, err
- }
-
- err = pgConn.conn.SetReadDeadline(nbconn.NonBlockingDeadline)
+ err := pgConn.flushWithPotentialWriteReadDeadlock()
if err != nil {
pgConn.asyncClose()
return CommandTag{}, err
}
- nonblocking := true
- defer func() {
- if nonblocking {
- pgConn.conn.SetReadDeadline(time.Time{})
- }
- }()
- buf := iobufpool.Get(65536)
- defer iobufpool.Put(buf)
- (*buf)[0] = 'd'
+ // Send copy data
+ abortCopyChan := make(chan struct{})
+ copyErrChan := make(chan error, 1)
+ signalMessageChan := pgConn.signalMessage()
+ var wg sync.WaitGroup
+ wg.Add(1)
- var readErr, pgErr error
- for pgErr == nil {
- // Read chunk from r.
- var n int
- n, readErr = r.Read((*buf)[5:cap(*buf)])
+ go func() {
+ defer wg.Done()
+ buf := iobufpool.Get(65536)
+ defer iobufpool.Put(buf)
+ (*buf)[0] = 'd'
- // Send chunk to PostgreSQL.
- if n > 0 {
- *buf = (*buf)[0 : n+5]
- pgio.SetInt32((*buf)[1:], int32(n+4))
+ for {
+ n, readErr := r.Read((*buf)[5:cap(*buf)])
+ if n > 0 {
+ *buf = (*buf)[0 : n+5]
+ pgio.SetInt32((*buf)[1:], int32(n+4))
+
+ writeErr := pgConn.frontend.SendUnbufferedEncodedCopyData(*buf)
+ if writeErr != nil {
+ // Write errors are always fatal, but we can't use asyncClose because we are in a different goroutine. Not
+ // setting pgConn.status or closing pgConn.cleanupDone for the same reason.
+ pgConn.conn.Close()
- writeErr := pgConn.frontend.SendUnbufferedEncodedCopyData(*buf)
- if writeErr != nil {
- pgConn.asyncClose()
- return CommandTag{}, err
+ copyErrChan <- writeErr
+ return
+ }
+ }
+ if readErr != nil {
+ copyErrChan <- readErr
+ return
}
- }
- // Abort loop if there was a read error.
- if readErr != nil {
- break
+ select {
+ case <-abortCopyChan:
+ return
+ default:
+ }
}
+ }()
- // Read messages until error or none available.
- for pgErr == nil {
- msg, err := pgConn.receiveMessage()
- if err != nil {
- if errors.Is(err, nbconn.ErrWouldBlock) {
- break
- }
- pgConn.asyncClose()
+ var pgErr error
+ var copyErr error
+ for copyErr == nil && pgErr == nil {
+ select {
+ case copyErr = <-copyErrChan:
+ case <-signalMessageChan:
+ // If pgConn.receiveMessage encounters an error it will call pgConn.asyncClose. But that is a race condition with
+ // the goroutine. So instead check pgConn.bufferingReceiveErr which will have been set by the signalMessage. If an
+ // error is found then forcibly close the connection without sending the Terminate message.
+ if err := pgConn.bufferingReceiveErr; err != nil {
+ pgConn.status = connStatusClosed
+ pgConn.conn.Close()
+ close(pgConn.cleanupDone)
return CommandTag{}, normalizeTimeoutError(ctx, err)
}
+ msg, _ := pgConn.receiveMessage()
switch msg := msg.(type) {
case *pgproto3.ErrorResponse:
pgErr = ErrorResponseToPgError(msg)
- break
+ default:
+ signalMessageChan = pgConn.signalMessage()
}
}
}
+ close(abortCopyChan)
+ // Make sure io goroutine finishes before writing.
+ wg.Wait()
- err = pgConn.conn.SetReadDeadline(time.Time{})
- if err != nil {
- pgConn.asyncClose()
- return CommandTag{}, err
- }
- nonblocking = false
-
- if readErr == io.EOF || pgErr != nil {
+ if copyErr == io.EOF || pgErr != nil {
pgConn.frontend.Send(&pgproto3.CopyDone{})
} else {
- pgConn.frontend.Send(&pgproto3.CopyFail{Message: readErr.Error()})
+ pgConn.frontend.Send(&pgproto3.CopyFail{Message: copyErr.Error()})
}
- err = pgConn.frontend.Flush()
+ err = pgConn.flushWithPotentialWriteReadDeadlock()
if err != nil {
pgConn.asyncClose()
return CommandTag{}, err
@@ -1426,7 +1479,8 @@ func (rr *ResultReader) NextRow() bool {
}
// FieldDescriptions returns the field descriptions for the current result set. The returned slice is only valid until
-// the ResultReader is closed.
+// the ResultReader is closed. It may return nil (for example, if the query did not return a result set or an error was
+// encountered.)
func (rr *ResultReader) FieldDescriptions() []FieldDescription {
return rr.fieldDescriptions
}
@@ -1592,7 +1646,9 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR
batch.buf = (&pgproto3.Sync{}).Encode(batch.buf)
+ pgConn.enterPotentialWriteReadDeadlock()
_, err := pgConn.conn.Write(batch.buf)
+ pgConn.exitPotentialWriteReadDeadlock()
if err != nil {
multiResult.closed = true
multiResult.err = err
@@ -1620,29 +1676,72 @@ func (pgConn *PgConn) EscapeString(s string) (string, error) {
return strings.Replace(s, "'", "''", -1), nil
}
-// CheckConn checks the underlying connection without writing any bytes. This is currently implemented by reading and
-// buffering until the read would block or an error occurs. This can be used to check if the server has closed the
-// connection. If this is done immediately before sending a query it reduces the chances a query will be sent that fails
+// CheckConn checks the underlying connection without writing any bytes. This is currently implemented by doing a read
+// with a very short deadline. This can be useful because a TCP connection can be broken such that a write will appear
+// to succeed even though it will never actually reach the server. Reading immediately before a write will detect this
+// condition. If this is done immediately before sending a query it reduces the chances a query will be sent that fails
// without the client knowing whether the server received it or not.
+//
+// Deprecated: CheckConn is deprecated in favor of Ping. CheckConn cannot detect all types of broken connections where
+// the write would still appear to succeed. Prefer Ping unless on a high latency connection.
func (pgConn *PgConn) CheckConn() error {
- err := pgConn.conn.BufferReadUntilBlock()
- if err != nil && !errors.Is(err, nbconn.ErrWouldBlock) {
- return err
+ ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
+ defer cancel()
+
+ _, err := pgConn.ReceiveMessage(ctx)
+ if err != nil {
+ if !Timeout(err) {
+ return err
+ }
}
+
return nil
}
+// Ping pings the server. This can be useful because a TCP connection can be broken such that a write will appear to
+// succeed even though it will never actually reach the server. Pinging immediately before sending a query reduces the
+// chances a query will be sent that fails without the client knowing whether the server received it or not.
+func (pgConn *PgConn) Ping(ctx context.Context) error {
+ return pgConn.Exec(ctx, "-- ping").Close()
+}
+
// makeCommandTag makes a CommandTag. It does not retain a reference to buf or buf's underlying memory.
func (pgConn *PgConn) makeCommandTag(buf []byte) CommandTag {
return CommandTag{s: string(buf)}
}
+// enterPotentialWriteReadDeadlock must be called before a write that could deadlock if the server is simultaneously
+// blocked writing to us.
+func (pgConn *PgConn) enterPotentialWriteReadDeadlock() {
+ // The time to wait is somewhat arbitrary. A Write should only take as long as the syscall and memcpy to the OS
+ // outbound network buffer unless the buffer is full (which potentially is a block). It needs to be long enough for
+ // the normal case, but short enough not to kill performance if a block occurs.
+ //
+ // In addition, on Windows the default timer resolution is 15.6ms. So setting the timer to less than that is
+ // ineffective.
+ pgConn.slowWriteTimer.Reset(15 * time.Millisecond)
+}
+
+// exitPotentialWriteReadDeadlock must be called after a call to enterPotentialWriteReadDeadlock.
+func (pgConn *PgConn) exitPotentialWriteReadDeadlock() {
+ if !pgConn.slowWriteTimer.Reset(time.Duration(math.MaxInt64)) {
+ pgConn.slowWriteTimer.Stop()
+ }
+}
+
+func (pgConn *PgConn) flushWithPotentialWriteReadDeadlock() error {
+ pgConn.enterPotentialWriteReadDeadlock()
+ err := pgConn.frontend.Flush()
+ pgConn.exitPotentialWriteReadDeadlock()
+ return err
+}
+
// HijackedConn is the result of hijacking a connection.
//
// Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning
// compatibility.
type HijackedConn struct {
- Conn nbconn.Conn // the non-blocking wrapper of the underlying TCP or unix domain socket connection
+ Conn net.Conn
PID uint32 // backend pid
SecretKey uint32 // key to use to send a cancel query message to the server
ParameterStatuses map[string]string // parameters that have been reported by the server
@@ -1695,6 +1794,8 @@ func Construct(hc *HijackedConn) (*PgConn, error) {
}
pgConn.contextWatcher = newContextWatcher(pgConn.conn)
+ pgConn.bgReader = bgreader.New(pgConn.conn)
+ pgConn.slowWriteTimer = time.AfterFunc(time.Duration(math.MaxInt64), pgConn.bgReader.Start)
return pgConn, nil
}
@@ -1817,7 +1918,7 @@ func (p *Pipeline) Flush() error {
return errors.New("pipeline closed")
}
- err := p.conn.frontend.Flush()
+ err := p.conn.flushWithPotentialWriteReadDeadlock()
if err != nil {
err = normalizeTimeoutError(p.ctx, err)