diff options
Diffstat (limited to 'vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go')
| -rw-r--r-- | vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go | 301 |
1 files changed, 201 insertions, 100 deletions
diff --git a/vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go b/vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go index 8656ea518..9f84605fe 100644 --- a/vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go +++ b/vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go @@ -13,11 +13,12 @@ import ( "net" "strconv" "strings" + "sync" "time" "github.com/jackc/pgx/v5/internal/iobufpool" - "github.com/jackc/pgx/v5/internal/nbconn" "github.com/jackc/pgx/v5/internal/pgio" + "github.com/jackc/pgx/v5/pgconn/internal/bgreader" "github.com/jackc/pgx/v5/pgconn/internal/ctxwatch" "github.com/jackc/pgx/v5/pgproto3" ) @@ -65,17 +66,24 @@ type NotificationHandler func(*PgConn, *Notification) // PgConn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage. type PgConn struct { - conn nbconn.Conn // the non-blocking wrapper for the underlying TCP or unix domain socket connection + conn net.Conn pid uint32 // backend pid secretKey uint32 // key to use to send a cancel query message to the server parameterStatuses map[string]string // parameters that have been reported by the server txStatus byte frontend *pgproto3.Frontend + bgReader *bgreader.BGReader + slowWriteTimer *time.Timer config *Config status byte // One of connStatus* constants + bufferingReceive bool + bufferingReceiveMux sync.Mutex + bufferingReceiveMsg pgproto3.BackendMessage + bufferingReceiveErr error + peekedMsg pgproto3.BackendMessage // Reusable / preallocated resources @@ -266,14 +274,13 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig if err != nil { return nil, &connectError{config: config, msg: "dial error", err: normalizeTimeoutError(ctx, err)} } - nbNetConn := nbconn.NewNetConn(netConn, false) - pgConn.conn = nbNetConn - pgConn.contextWatcher = newContextWatcher(nbNetConn) + pgConn.conn = netConn + pgConn.contextWatcher = newContextWatcher(netConn) pgConn.contextWatcher.Watch(ctx) if fallbackConfig.TLSConfig != nil { - nbTLSConn, err := startTLS(nbNetConn, fallbackConfig.TLSConfig) + nbTLSConn, err := startTLS(netConn, fallbackConfig.TLSConfig) pgConn.contextWatcher.Unwatch() // Always unwatch `netConn` after TLS. if err != nil { netConn.Close() @@ -289,7 +296,9 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig pgConn.parameterStatuses = make(map[string]string) pgConn.status = connStatusConnecting - pgConn.frontend = config.BuildFrontend(pgConn.conn, pgConn.conn) + pgConn.bgReader = bgreader.New(pgConn.conn) + pgConn.slowWriteTimer = time.AfterFunc(time.Duration(math.MaxInt64), pgConn.bgReader.Start) + pgConn.frontend = config.BuildFrontend(pgConn.bgReader, pgConn.conn) startupMsg := pgproto3.StartupMessage{ ProtocolVersion: pgproto3.ProtocolVersionNumber, @@ -307,9 +316,9 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig } pgConn.frontend.Send(&startupMsg) - if err := pgConn.frontend.Flush(); err != nil { + if err := pgConn.flushWithPotentialWriteReadDeadlock(); err != nil { pgConn.conn.Close() - return nil, &connectError{config: config, msg: "failed to write startup message", err: err} + return nil, &connectError{config: config, msg: "failed to write startup message", err: normalizeTimeoutError(ctx, err)} } for { @@ -392,7 +401,7 @@ func newContextWatcher(conn net.Conn) *ctxwatch.ContextWatcher { ) } -func startTLS(conn *nbconn.NetConn, tlsConfig *tls.Config) (*nbconn.TLSConn, error) { +func startTLS(conn net.Conn, tlsConfig *tls.Config) (net.Conn, error) { err := binary.Write(conn, binary.BigEndian, []int32{8, 80877103}) if err != nil { return nil, err @@ -407,17 +416,12 @@ func startTLS(conn *nbconn.NetConn, tlsConfig *tls.Config) (*nbconn.TLSConn, err return nil, errors.New("server refused TLS connection") } - tlsConn, err := nbconn.TLSClient(conn, tlsConfig) - if err != nil { - return nil, err - } - - return tlsConn, nil + return tls.Client(conn, tlsConfig), nil } func (pgConn *PgConn) txPasswordMessage(password string) (err error) { pgConn.frontend.Send(&pgproto3.PasswordMessage{Password: password}) - return pgConn.frontend.Flush() + return pgConn.flushWithPotentialWriteReadDeadlock() } func hexMD5(s string) string { @@ -426,6 +430,24 @@ func hexMD5(s string) string { return hex.EncodeToString(hash.Sum(nil)) } +func (pgConn *PgConn) signalMessage() chan struct{} { + if pgConn.bufferingReceive { + panic("BUG: signalMessage when already in progress") + } + + pgConn.bufferingReceive = true + pgConn.bufferingReceiveMux.Lock() + + ch := make(chan struct{}) + go func() { + pgConn.bufferingReceiveMsg, pgConn.bufferingReceiveErr = pgConn.frontend.Receive() + pgConn.bufferingReceiveMux.Unlock() + close(ch) + }() + + return ch +} + // ReceiveMessage receives one wire protocol message from the PostgreSQL server. It must only be used when the // connection is not busy. e.g. It is an error to call ReceiveMessage while reading the result of a query. The messages // are still handled by the core pgconn message handling system so receiving a NotificationResponse will still trigger @@ -465,13 +487,25 @@ func (pgConn *PgConn) peekMessage() (pgproto3.BackendMessage, error) { return pgConn.peekedMsg, nil } - msg, err := pgConn.frontend.Receive() + var msg pgproto3.BackendMessage + var err error + if pgConn.bufferingReceive { + pgConn.bufferingReceiveMux.Lock() + msg = pgConn.bufferingReceiveMsg + err = pgConn.bufferingReceiveErr + pgConn.bufferingReceiveMux.Unlock() + pgConn.bufferingReceive = false - if err != nil { - if errors.Is(err, nbconn.ErrWouldBlock) { - return nil, err + // If a timeout error happened in the background try the read again. + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { + msg, err = pgConn.frontend.Receive() } + } else { + msg, err = pgConn.frontend.Receive() + } + if err != nil { // Close on anything other than timeout error - everything else is fatal var netErr net.Error isNetErr := errors.As(err, &netErr) @@ -582,7 +616,7 @@ func (pgConn *PgConn) Close(ctx context.Context) error { // // See https://github.com/jackc/pgx/issues/637 pgConn.frontend.Send(&pgproto3.Terminate{}) - pgConn.frontend.Flush() + pgConn.flushWithPotentialWriteReadDeadlock() return pgConn.conn.Close() } @@ -609,7 +643,7 @@ func (pgConn *PgConn) asyncClose() { pgConn.conn.SetDeadline(deadline) pgConn.frontend.Send(&pgproto3.Terminate{}) - pgConn.frontend.Flush() + pgConn.flushWithPotentialWriteReadDeadlock() }() } @@ -784,7 +818,7 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ pgConn.frontend.SendParse(&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}) pgConn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'S', Name: name}) pgConn.frontend.SendSync(&pgproto3.Sync{}) - err := pgConn.frontend.Flush() + err := pgConn.flushWithPotentialWriteReadDeadlock() if err != nil { pgConn.asyncClose() return nil, err @@ -857,9 +891,28 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error { // the connection config. This is important in high availability configurations where fallback connections may be // specified or DNS may be used to load balance. serverAddr := pgConn.conn.RemoteAddr() - cancelConn, err := pgConn.config.DialFunc(ctx, serverAddr.Network(), serverAddr.String()) + var serverNetwork string + var serverAddress string + if serverAddr.Network() == "unix" { + // for unix sockets, RemoteAddr() calls getpeername() which returns the name the + // server passed to bind(). For Postgres, this is always a relative path "./.s.PGSQL.5432" + // so connecting to it will fail. Fall back to the config's value + serverNetwork, serverAddress = NetworkAddress(pgConn.config.Host, pgConn.config.Port) + } else { + serverNetwork, serverAddress = serverAddr.Network(), serverAddr.String() + } + cancelConn, err := pgConn.config.DialFunc(ctx, serverNetwork, serverAddress) if err != nil { - return err + // In case of unix sockets, RemoteAddr() returns only the file part of the path. If the + // first connect failed, try the config. + if serverAddr.Network() != "unix" { + return err + } + serverNetwork, serverAddr := NetworkAddress(pgConn.config.Host, pgConn.config.Port) + cancelConn, err = pgConn.config.DialFunc(ctx, serverNetwork, serverAddr) + if err != nil { + return err + } } defer cancelConn.Close() @@ -877,17 +930,11 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error { binary.BigEndian.PutUint32(buf[4:8], 80877102) binary.BigEndian.PutUint32(buf[8:12], uint32(pgConn.pid)) binary.BigEndian.PutUint32(buf[12:16], uint32(pgConn.secretKey)) + // Postgres will process the request and close the connection + // so when don't need to read the reply + // https://www.postgresql.org/docs/current/protocol-flow.html#id-1.10.6.7.10 _, err = cancelConn.Write(buf) - if err != nil { - return err - } - - _, err = cancelConn.Read(buf) - if err != io.EOF { - return err - } - - return nil + return err } // WaitForNotification waits for a LISTON/NOTIFY message to be received. It returns an error if a notification was not @@ -953,7 +1000,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { } pgConn.frontend.SendQuery(&pgproto3.Query{String: sql}) - err := pgConn.frontend.Flush() + err := pgConn.flushWithPotentialWriteReadDeadlock() if err != nil { pgConn.asyncClose() pgConn.contextWatcher.Unwatch() @@ -1064,7 +1111,7 @@ func (pgConn *PgConn) execExtendedSuffix(result *ResultReader) { pgConn.frontend.SendExecute(&pgproto3.Execute{}) pgConn.frontend.SendSync(&pgproto3.Sync{}) - err := pgConn.frontend.Flush() + err := pgConn.flushWithPotentialWriteReadDeadlock() if err != nil { pgConn.asyncClose() result.concludeCommand(CommandTag{}, err) @@ -1097,7 +1144,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm // Send copy to command pgConn.frontend.SendQuery(&pgproto3.Query{String: sql}) - err := pgConn.frontend.Flush() + err := pgConn.flushWithPotentialWriteReadDeadlock() if err != nil { pgConn.asyncClose() pgConn.unlock() @@ -1153,85 +1200,91 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co defer pgConn.contextWatcher.Unwatch() } - // Send copy to command + // Send copy from query pgConn.frontend.SendQuery(&pgproto3.Query{String: sql}) - err := pgConn.frontend.Flush() + err := pgConn.flushWithPotentialWriteReadDeadlock() if err != nil { pgConn.asyncClose() return CommandTag{}, err } - err = pgConn.conn.SetReadDeadline(nbconn.NonBlockingDeadline) - if err != nil { - pgConn.asyncClose() - return CommandTag{}, err - } - nonblocking := true - defer func() { - if nonblocking { - pgConn.conn.SetReadDeadline(time.Time{}) - } - }() + // Send copy data + abortCopyChan := make(chan struct{}) + copyErrChan := make(chan error, 1) + signalMessageChan := pgConn.signalMessage() + var wg sync.WaitGroup + wg.Add(1) - buf := iobufpool.Get(65536) - defer iobufpool.Put(buf) - (*buf)[0] = 'd' + go func() { + defer wg.Done() + buf := iobufpool.Get(65536) + defer iobufpool.Put(buf) + (*buf)[0] = 'd' - var readErr, pgErr error - for pgErr == nil { - // Read chunk from r. - var n int - n, readErr = r.Read((*buf)[5:cap(*buf)]) + for { + n, readErr := r.Read((*buf)[5:cap(*buf)]) + if n > 0 { + *buf = (*buf)[0 : n+5] + pgio.SetInt32((*buf)[1:], int32(n+4)) - // Send chunk to PostgreSQL. - if n > 0 { - *buf = (*buf)[0 : n+5] - pgio.SetInt32((*buf)[1:], int32(n+4)) + writeErr := pgConn.frontend.SendUnbufferedEncodedCopyData(*buf) + if writeErr != nil { + // Write errors are always fatal, but we can't use asyncClose because we are in a different goroutine. Not + // setting pgConn.status or closing pgConn.cleanupDone for the same reason. + pgConn.conn.Close() - writeErr := pgConn.frontend.SendUnbufferedEncodedCopyData(*buf) - if writeErr != nil { - pgConn.asyncClose() - return CommandTag{}, err + copyErrChan <- writeErr + return + } + } + if readErr != nil { + copyErrChan <- readErr + return } - } - // Abort loop if there was a read error. - if readErr != nil { - break + select { + case <-abortCopyChan: + return + default: + } } + }() - // Read messages until error or none available. - for pgErr == nil { - msg, err := pgConn.receiveMessage() - if err != nil { - if errors.Is(err, nbconn.ErrWouldBlock) { - break - } - pgConn.asyncClose() + var pgErr error + var copyErr error + for copyErr == nil && pgErr == nil { + select { + case copyErr = <-copyErrChan: + case <-signalMessageChan: + // If pgConn.receiveMessage encounters an error it will call pgConn.asyncClose. But that is a race condition with + // the goroutine. So instead check pgConn.bufferingReceiveErr which will have been set by the signalMessage. If an + // error is found then forcibly close the connection without sending the Terminate message. + if err := pgConn.bufferingReceiveErr; err != nil { + pgConn.status = connStatusClosed + pgConn.conn.Close() + close(pgConn.cleanupDone) return CommandTag{}, normalizeTimeoutError(ctx, err) } + msg, _ := pgConn.receiveMessage() switch msg := msg.(type) { case *pgproto3.ErrorResponse: pgErr = ErrorResponseToPgError(msg) - break + default: + signalMessageChan = pgConn.signalMessage() } } } + close(abortCopyChan) + // Make sure io goroutine finishes before writing. + wg.Wait() - err = pgConn.conn.SetReadDeadline(time.Time{}) - if err != nil { - pgConn.asyncClose() - return CommandTag{}, err - } - nonblocking = false - - if readErr == io.EOF || pgErr != nil { + if copyErr == io.EOF || pgErr != nil { pgConn.frontend.Send(&pgproto3.CopyDone{}) } else { - pgConn.frontend.Send(&pgproto3.CopyFail{Message: readErr.Error()}) + pgConn.frontend.Send(&pgproto3.CopyFail{Message: copyErr.Error()}) } - err = pgConn.frontend.Flush() + err = pgConn.flushWithPotentialWriteReadDeadlock() if err != nil { pgConn.asyncClose() return CommandTag{}, err @@ -1426,7 +1479,8 @@ func (rr *ResultReader) NextRow() bool { } // FieldDescriptions returns the field descriptions for the current result set. The returned slice is only valid until -// the ResultReader is closed. +// the ResultReader is closed. It may return nil (for example, if the query did not return a result set or an error was +// encountered.) func (rr *ResultReader) FieldDescriptions() []FieldDescription { return rr.fieldDescriptions } @@ -1592,7 +1646,9 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR batch.buf = (&pgproto3.Sync{}).Encode(batch.buf) + pgConn.enterPotentialWriteReadDeadlock() _, err := pgConn.conn.Write(batch.buf) + pgConn.exitPotentialWriteReadDeadlock() if err != nil { multiResult.closed = true multiResult.err = err @@ -1620,29 +1676,72 @@ func (pgConn *PgConn) EscapeString(s string) (string, error) { return strings.Replace(s, "'", "''", -1), nil } -// CheckConn checks the underlying connection without writing any bytes. This is currently implemented by reading and -// buffering until the read would block or an error occurs. This can be used to check if the server has closed the -// connection. If this is done immediately before sending a query it reduces the chances a query will be sent that fails +// CheckConn checks the underlying connection without writing any bytes. This is currently implemented by doing a read +// with a very short deadline. This can be useful because a TCP connection can be broken such that a write will appear +// to succeed even though it will never actually reach the server. Reading immediately before a write will detect this +// condition. If this is done immediately before sending a query it reduces the chances a query will be sent that fails // without the client knowing whether the server received it or not. +// +// Deprecated: CheckConn is deprecated in favor of Ping. CheckConn cannot detect all types of broken connections where +// the write would still appear to succeed. Prefer Ping unless on a high latency connection. func (pgConn *PgConn) CheckConn() error { - err := pgConn.conn.BufferReadUntilBlock() - if err != nil && !errors.Is(err, nbconn.ErrWouldBlock) { - return err + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond) + defer cancel() + + _, err := pgConn.ReceiveMessage(ctx) + if err != nil { + if !Timeout(err) { + return err + } } + return nil } +// Ping pings the server. This can be useful because a TCP connection can be broken such that a write will appear to +// succeed even though it will never actually reach the server. Pinging immediately before sending a query reduces the +// chances a query will be sent that fails without the client knowing whether the server received it or not. +func (pgConn *PgConn) Ping(ctx context.Context) error { + return pgConn.Exec(ctx, "-- ping").Close() +} + // makeCommandTag makes a CommandTag. It does not retain a reference to buf or buf's underlying memory. func (pgConn *PgConn) makeCommandTag(buf []byte) CommandTag { return CommandTag{s: string(buf)} } +// enterPotentialWriteReadDeadlock must be called before a write that could deadlock if the server is simultaneously +// blocked writing to us. +func (pgConn *PgConn) enterPotentialWriteReadDeadlock() { + // The time to wait is somewhat arbitrary. A Write should only take as long as the syscall and memcpy to the OS + // outbound network buffer unless the buffer is full (which potentially is a block). It needs to be long enough for + // the normal case, but short enough not to kill performance if a block occurs. + // + // In addition, on Windows the default timer resolution is 15.6ms. So setting the timer to less than that is + // ineffective. + pgConn.slowWriteTimer.Reset(15 * time.Millisecond) +} + +// exitPotentialWriteReadDeadlock must be called after a call to enterPotentialWriteReadDeadlock. +func (pgConn *PgConn) exitPotentialWriteReadDeadlock() { + if !pgConn.slowWriteTimer.Reset(time.Duration(math.MaxInt64)) { + pgConn.slowWriteTimer.Stop() + } +} + +func (pgConn *PgConn) flushWithPotentialWriteReadDeadlock() error { + pgConn.enterPotentialWriteReadDeadlock() + err := pgConn.frontend.Flush() + pgConn.exitPotentialWriteReadDeadlock() + return err +} + // HijackedConn is the result of hijacking a connection. // // Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning // compatibility. type HijackedConn struct { - Conn nbconn.Conn // the non-blocking wrapper of the underlying TCP or unix domain socket connection + Conn net.Conn PID uint32 // backend pid SecretKey uint32 // key to use to send a cancel query message to the server ParameterStatuses map[string]string // parameters that have been reported by the server @@ -1695,6 +1794,8 @@ func Construct(hc *HijackedConn) (*PgConn, error) { } pgConn.contextWatcher = newContextWatcher(pgConn.conn) + pgConn.bgReader = bgreader.New(pgConn.conn) + pgConn.slowWriteTimer = time.AfterFunc(time.Duration(math.MaxInt64), pgConn.bgReader.Start) return pgConn, nil } @@ -1817,7 +1918,7 @@ func (p *Pipeline) Flush() error { return errors.New("pipeline closed") } - err := p.conn.frontend.Flush() + err := p.conn.flushWithPotentialWriteReadDeadlock() if err != nil { err = normalizeTimeoutError(p.ctx, err) |
