summaryrefslogtreecommitdiff
path: root/vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go')
-rw-r--r--vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go67
1 files changed, 47 insertions, 20 deletions
diff --git a/vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go b/vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go
index 1ccdc4db9..b287e0205 100644
--- a/vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go
+++ b/vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go
@@ -52,6 +52,12 @@ type LookupFunc func(ctx context.Context, host string) (addrs []string, err erro
// BuildFrontendFunc is a function that can be used to create Frontend implementation for connection.
type BuildFrontendFunc func(r io.Reader, w io.Writer) *pgproto3.Frontend
+// PgErrorHandler is a function that handles errors returned from Postgres. This function must return true to keep
+// the connection open. Returning false will cause the connection to be closed immediately. You should return
+// false on any FATAL-severity errors. This will not receive network errors. The *PgConn is provided so the handler is
+// aware of the origin of the error, but it must not invoke any query method.
+type PgErrorHandler func(*PgConn, *PgError) bool
+
// NoticeHandler is a function that can handle notices received from the PostgreSQL server. Notices can be received at
// any time, usually during handling of a query response. The *PgConn is provided so the handler is aware of the origin
// of the notice, but it must not invoke any query method. Be aware that this is distinct from LISTEN/NOTIFY
@@ -146,11 +152,11 @@ func ConnectConfig(octx context.Context, config *Config) (pgConn *PgConn, err er
ctx := octx
fallbackConfigs, err = expandWithIPs(ctx, config.LookupFunc, fallbackConfigs)
if err != nil {
- return nil, &connectError{config: config, msg: "hostname resolving error", err: err}
+ return nil, &ConnectError{Config: config, msg: "hostname resolving error", err: err}
}
if len(fallbackConfigs) == 0 {
- return nil, &connectError{config: config, msg: "hostname resolving error", err: errors.New("ip addr wasn't found")}
+ return nil, &ConnectError{Config: config, msg: "hostname resolving error", err: errors.New("ip addr wasn't found")}
}
foundBestServer := false
@@ -172,7 +178,7 @@ func ConnectConfig(octx context.Context, config *Config) (pgConn *PgConn, err er
foundBestServer = true
break
} else if pgerr, ok := err.(*PgError); ok {
- err = &connectError{config: config, msg: "server error", err: pgerr}
+ err = &ConnectError{Config: config, msg: "server error", err: pgerr}
const ERRCODE_INVALID_PASSWORD = "28P01" // wrong password
const ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION = "28000" // wrong password or bad pg_hba.conf settings
const ERRCODE_INVALID_CATALOG_NAME = "3D000" // db does not exist
@@ -183,7 +189,7 @@ func ConnectConfig(octx context.Context, config *Config) (pgConn *PgConn, err er
pgerr.Code == ERRCODE_INSUFFICIENT_PRIVILEGE {
break
}
- } else if cerr, ok := err.(*connectError); ok {
+ } else if cerr, ok := err.(*ConnectError); ok {
if _, ok := cerr.err.(*NotPreferredError); ok {
fallbackConfig = fc
}
@@ -193,7 +199,7 @@ func ConnectConfig(octx context.Context, config *Config) (pgConn *PgConn, err er
if !foundBestServer && fallbackConfig != nil {
pgConn, err = connect(ctx, config, fallbackConfig, true)
if pgerr, ok := err.(*PgError); ok {
- err = &connectError{config: config, msg: "server error", err: pgerr}
+ err = &ConnectError{Config: config, msg: "server error", err: pgerr}
}
}
@@ -205,7 +211,7 @@ func ConnectConfig(octx context.Context, config *Config) (pgConn *PgConn, err er
err := config.AfterConnect(ctx, pgConn)
if err != nil {
pgConn.conn.Close()
- return nil, &connectError{config: config, msg: "AfterConnect error", err: err}
+ return nil, &ConnectError{Config: config, msg: "AfterConnect error", err: err}
}
}
@@ -277,7 +283,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port)
netConn, err := config.DialFunc(ctx, network, address)
if err != nil {
- return nil, &connectError{config: config, msg: "dial error", err: normalizeTimeoutError(ctx, err)}
+ return nil, &ConnectError{Config: config, msg: "dial error", err: normalizeTimeoutError(ctx, err)}
}
pgConn.conn = netConn
@@ -289,7 +295,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
pgConn.contextWatcher.Unwatch() // Always unwatch `netConn` after TLS.
if err != nil {
netConn.Close()
- return nil, &connectError{config: config, msg: "tls error", err: err}
+ return nil, &ConnectError{Config: config, msg: "tls error", err: normalizeTimeoutError(ctx, err)}
}
pgConn.conn = nbTLSConn
@@ -330,7 +336,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
pgConn.frontend.Send(&startupMsg)
if err := pgConn.flushWithPotentialWriteReadDeadlock(); err != nil {
pgConn.conn.Close()
- return nil, &connectError{config: config, msg: "failed to write startup message", err: normalizeTimeoutError(ctx, err)}
+ return nil, &ConnectError{Config: config, msg: "failed to write startup message", err: normalizeTimeoutError(ctx, err)}
}
for {
@@ -340,7 +346,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
if err, ok := err.(*PgError); ok {
return nil, err
}
- return nil, &connectError{config: config, msg: "failed to receive message", err: normalizeTimeoutError(ctx, err)}
+ return nil, &ConnectError{Config: config, msg: "failed to receive message", err: normalizeTimeoutError(ctx, err)}
}
switch msg := msg.(type) {
@@ -353,26 +359,26 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
err = pgConn.txPasswordMessage(pgConn.config.Password)
if err != nil {
pgConn.conn.Close()
- return nil, &connectError{config: config, msg: "failed to write password message", err: err}
+ return nil, &ConnectError{Config: config, msg: "failed to write password message", err: err}
}
case *pgproto3.AuthenticationMD5Password:
digestedPassword := "md5" + hexMD5(hexMD5(pgConn.config.Password+pgConn.config.User)+string(msg.Salt[:]))
err = pgConn.txPasswordMessage(digestedPassword)
if err != nil {
pgConn.conn.Close()
- return nil, &connectError{config: config, msg: "failed to write password message", err: err}
+ return nil, &ConnectError{Config: config, msg: "failed to write password message", err: err}
}
case *pgproto3.AuthenticationSASL:
err = pgConn.scramAuth(msg.AuthMechanisms)
if err != nil {
pgConn.conn.Close()
- return nil, &connectError{config: config, msg: "failed SASL auth", err: err}
+ return nil, &ConnectError{Config: config, msg: "failed SASL auth", err: err}
}
case *pgproto3.AuthenticationGSS:
err = pgConn.gssAuth()
if err != nil {
pgConn.conn.Close()
- return nil, &connectError{config: config, msg: "failed GSS auth", err: err}
+ return nil, &ConnectError{Config: config, msg: "failed GSS auth", err: err}
}
case *pgproto3.ReadyForQuery:
pgConn.status = connStatusIdle
@@ -390,7 +396,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
return pgConn, nil
}
pgConn.conn.Close()
- return nil, &connectError{config: config, msg: "ValidateConnect failed", err: err}
+ return nil, &ConnectError{Config: config, msg: "ValidateConnect failed", err: err}
}
}
return pgConn, nil
@@ -401,7 +407,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
return nil, ErrorResponseToPgError(msg)
default:
pgConn.conn.Close()
- return nil, &connectError{config: config, msg: "received unexpected message", err: err}
+ return nil, &ConnectError{Config: config, msg: "received unexpected message", err: err}
}
}
}
@@ -547,11 +553,12 @@ func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) {
case *pgproto3.ParameterStatus:
pgConn.parameterStatuses[msg.Name] = msg.Value
case *pgproto3.ErrorResponse:
- if msg.Severity == "FATAL" {
+ err := ErrorResponseToPgError(msg)
+ if pgConn.config.OnPgError != nil && !pgConn.config.OnPgError(pgConn, err) {
pgConn.status = connStatusClosed
pgConn.conn.Close() // Ignore error as the connection is already broken and there is already an error to return.
close(pgConn.cleanupDone)
- return nil, ErrorResponseToPgError(msg)
+ return nil, err
}
case *pgproto3.NoticeResponse:
if pgConn.config.OnNotice != nil {
@@ -2046,6 +2053,13 @@ func (p *Pipeline) Flush() error {
// Sync establishes a synchronization point and flushes the queued requests.
func (p *Pipeline) Sync() error {
+ if p.closed {
+ if p.err != nil {
+ return p.err
+ }
+ return errors.New("pipeline closed")
+ }
+
p.conn.frontend.SendSync(&pgproto3.Sync{})
err := p.Flush()
if err != nil {
@@ -2062,10 +2076,21 @@ func (p *Pipeline) Sync() error {
// *PipelineSync. If an ErrorResponse is received from the server, results will be nil and err will be a *PgError. If no
// results are available, results and err will both be nil.
func (p *Pipeline) GetResults() (results any, err error) {
+ if p.closed {
+ if p.err != nil {
+ return nil, p.err
+ }
+ return nil, errors.New("pipeline closed")
+ }
+
if p.expectedReadyForQueryCount == 0 {
return nil, nil
}
+ return p.getResults()
+}
+
+func (p *Pipeline) getResults() (results any, err error) {
for {
msg, err := p.conn.receiveMessage()
if err != nil {
@@ -2092,7 +2117,8 @@ func (p *Pipeline) GetResults() (results any, err error) {
case *pgproto3.ParseComplete:
peekedMsg, err := p.conn.peekMessage()
if err != nil {
- return nil, err
+ p.conn.asyncClose()
+ return nil, normalizeTimeoutError(p.ctx, err)
}
if _, ok := peekedMsg.(*pgproto3.ParameterDescription); ok {
return p.getResultsPrepare()
@@ -2152,6 +2178,7 @@ func (p *Pipeline) Close() error {
if p.closed {
return p.err
}
+
p.closed = true
if p.pendingSync {
@@ -2164,7 +2191,7 @@ func (p *Pipeline) Close() error {
}
for p.expectedReadyForQueryCount > 0 {
- _, err := p.GetResults()
+ _, err := p.getResults()
if err != nil {
p.err = err
var pgErr *PgError