diff options
Diffstat (limited to 'vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go')
-rw-r--r-- | vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go | 67 |
1 files changed, 47 insertions, 20 deletions
diff --git a/vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go b/vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go index 1ccdc4db9..b287e0205 100644 --- a/vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go +++ b/vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go @@ -52,6 +52,12 @@ type LookupFunc func(ctx context.Context, host string) (addrs []string, err erro // BuildFrontendFunc is a function that can be used to create Frontend implementation for connection. type BuildFrontendFunc func(r io.Reader, w io.Writer) *pgproto3.Frontend +// PgErrorHandler is a function that handles errors returned from Postgres. This function must return true to keep +// the connection open. Returning false will cause the connection to be closed immediately. You should return +// false on any FATAL-severity errors. This will not receive network errors. The *PgConn is provided so the handler is +// aware of the origin of the error, but it must not invoke any query method. +type PgErrorHandler func(*PgConn, *PgError) bool + // NoticeHandler is a function that can handle notices received from the PostgreSQL server. Notices can be received at // any time, usually during handling of a query response. The *PgConn is provided so the handler is aware of the origin // of the notice, but it must not invoke any query method. Be aware that this is distinct from LISTEN/NOTIFY @@ -146,11 +152,11 @@ func ConnectConfig(octx context.Context, config *Config) (pgConn *PgConn, err er ctx := octx fallbackConfigs, err = expandWithIPs(ctx, config.LookupFunc, fallbackConfigs) if err != nil { - return nil, &connectError{config: config, msg: "hostname resolving error", err: err} + return nil, &ConnectError{Config: config, msg: "hostname resolving error", err: err} } if len(fallbackConfigs) == 0 { - return nil, &connectError{config: config, msg: "hostname resolving error", err: errors.New("ip addr wasn't found")} + return nil, &ConnectError{Config: config, msg: "hostname resolving error", err: errors.New("ip addr wasn't found")} } foundBestServer := false @@ -172,7 +178,7 @@ func ConnectConfig(octx context.Context, config *Config) (pgConn *PgConn, err er foundBestServer = true break } else if pgerr, ok := err.(*PgError); ok { - err = &connectError{config: config, msg: "server error", err: pgerr} + err = &ConnectError{Config: config, msg: "server error", err: pgerr} const ERRCODE_INVALID_PASSWORD = "28P01" // wrong password const ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION = "28000" // wrong password or bad pg_hba.conf settings const ERRCODE_INVALID_CATALOG_NAME = "3D000" // db does not exist @@ -183,7 +189,7 @@ func ConnectConfig(octx context.Context, config *Config) (pgConn *PgConn, err er pgerr.Code == ERRCODE_INSUFFICIENT_PRIVILEGE { break } - } else if cerr, ok := err.(*connectError); ok { + } else if cerr, ok := err.(*ConnectError); ok { if _, ok := cerr.err.(*NotPreferredError); ok { fallbackConfig = fc } @@ -193,7 +199,7 @@ func ConnectConfig(octx context.Context, config *Config) (pgConn *PgConn, err er if !foundBestServer && fallbackConfig != nil { pgConn, err = connect(ctx, config, fallbackConfig, true) if pgerr, ok := err.(*PgError); ok { - err = &connectError{config: config, msg: "server error", err: pgerr} + err = &ConnectError{Config: config, msg: "server error", err: pgerr} } } @@ -205,7 +211,7 @@ func ConnectConfig(octx context.Context, config *Config) (pgConn *PgConn, err er err := config.AfterConnect(ctx, pgConn) if err != nil { pgConn.conn.Close() - return nil, &connectError{config: config, msg: "AfterConnect error", err: err} + return nil, &ConnectError{Config: config, msg: "AfterConnect error", err: err} } } @@ -277,7 +283,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port) netConn, err := config.DialFunc(ctx, network, address) if err != nil { - return nil, &connectError{config: config, msg: "dial error", err: normalizeTimeoutError(ctx, err)} + return nil, &ConnectError{Config: config, msg: "dial error", err: normalizeTimeoutError(ctx, err)} } pgConn.conn = netConn @@ -289,7 +295,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig pgConn.contextWatcher.Unwatch() // Always unwatch `netConn` after TLS. if err != nil { netConn.Close() - return nil, &connectError{config: config, msg: "tls error", err: err} + return nil, &ConnectError{Config: config, msg: "tls error", err: normalizeTimeoutError(ctx, err)} } pgConn.conn = nbTLSConn @@ -330,7 +336,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig pgConn.frontend.Send(&startupMsg) if err := pgConn.flushWithPotentialWriteReadDeadlock(); err != nil { pgConn.conn.Close() - return nil, &connectError{config: config, msg: "failed to write startup message", err: normalizeTimeoutError(ctx, err)} + return nil, &ConnectError{Config: config, msg: "failed to write startup message", err: normalizeTimeoutError(ctx, err)} } for { @@ -340,7 +346,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig if err, ok := err.(*PgError); ok { return nil, err } - return nil, &connectError{config: config, msg: "failed to receive message", err: normalizeTimeoutError(ctx, err)} + return nil, &ConnectError{Config: config, msg: "failed to receive message", err: normalizeTimeoutError(ctx, err)} } switch msg := msg.(type) { @@ -353,26 +359,26 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig err = pgConn.txPasswordMessage(pgConn.config.Password) if err != nil { pgConn.conn.Close() - return nil, &connectError{config: config, msg: "failed to write password message", err: err} + return nil, &ConnectError{Config: config, msg: "failed to write password message", err: err} } case *pgproto3.AuthenticationMD5Password: digestedPassword := "md5" + hexMD5(hexMD5(pgConn.config.Password+pgConn.config.User)+string(msg.Salt[:])) err = pgConn.txPasswordMessage(digestedPassword) if err != nil { pgConn.conn.Close() - return nil, &connectError{config: config, msg: "failed to write password message", err: err} + return nil, &ConnectError{Config: config, msg: "failed to write password message", err: err} } case *pgproto3.AuthenticationSASL: err = pgConn.scramAuth(msg.AuthMechanisms) if err != nil { pgConn.conn.Close() - return nil, &connectError{config: config, msg: "failed SASL auth", err: err} + return nil, &ConnectError{Config: config, msg: "failed SASL auth", err: err} } case *pgproto3.AuthenticationGSS: err = pgConn.gssAuth() if err != nil { pgConn.conn.Close() - return nil, &connectError{config: config, msg: "failed GSS auth", err: err} + return nil, &ConnectError{Config: config, msg: "failed GSS auth", err: err} } case *pgproto3.ReadyForQuery: pgConn.status = connStatusIdle @@ -390,7 +396,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig return pgConn, nil } pgConn.conn.Close() - return nil, &connectError{config: config, msg: "ValidateConnect failed", err: err} + return nil, &ConnectError{Config: config, msg: "ValidateConnect failed", err: err} } } return pgConn, nil @@ -401,7 +407,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig return nil, ErrorResponseToPgError(msg) default: pgConn.conn.Close() - return nil, &connectError{config: config, msg: "received unexpected message", err: err} + return nil, &ConnectError{Config: config, msg: "received unexpected message", err: err} } } } @@ -547,11 +553,12 @@ func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) { case *pgproto3.ParameterStatus: pgConn.parameterStatuses[msg.Name] = msg.Value case *pgproto3.ErrorResponse: - if msg.Severity == "FATAL" { + err := ErrorResponseToPgError(msg) + if pgConn.config.OnPgError != nil && !pgConn.config.OnPgError(pgConn, err) { pgConn.status = connStatusClosed pgConn.conn.Close() // Ignore error as the connection is already broken and there is already an error to return. close(pgConn.cleanupDone) - return nil, ErrorResponseToPgError(msg) + return nil, err } case *pgproto3.NoticeResponse: if pgConn.config.OnNotice != nil { @@ -2046,6 +2053,13 @@ func (p *Pipeline) Flush() error { // Sync establishes a synchronization point and flushes the queued requests. func (p *Pipeline) Sync() error { + if p.closed { + if p.err != nil { + return p.err + } + return errors.New("pipeline closed") + } + p.conn.frontend.SendSync(&pgproto3.Sync{}) err := p.Flush() if err != nil { @@ -2062,10 +2076,21 @@ func (p *Pipeline) Sync() error { // *PipelineSync. If an ErrorResponse is received from the server, results will be nil and err will be a *PgError. If no // results are available, results and err will both be nil. func (p *Pipeline) GetResults() (results any, err error) { + if p.closed { + if p.err != nil { + return nil, p.err + } + return nil, errors.New("pipeline closed") + } + if p.expectedReadyForQueryCount == 0 { return nil, nil } + return p.getResults() +} + +func (p *Pipeline) getResults() (results any, err error) { for { msg, err := p.conn.receiveMessage() if err != nil { @@ -2092,7 +2117,8 @@ func (p *Pipeline) GetResults() (results any, err error) { case *pgproto3.ParseComplete: peekedMsg, err := p.conn.peekMessage() if err != nil { - return nil, err + p.conn.asyncClose() + return nil, normalizeTimeoutError(p.ctx, err) } if _, ok := peekedMsg.(*pgproto3.ParameterDescription); ok { return p.getResultsPrepare() @@ -2152,6 +2178,7 @@ func (p *Pipeline) Close() error { if p.closed { return p.err } + p.closed = true if p.pendingSync { @@ -2164,7 +2191,7 @@ func (p *Pipeline) Close() error { } for p.expectedReadyForQueryCount > 0 { - _, err := p.GetResults() + _, err := p.getResults() if err != nil { p.err = err var pgErr *PgError |