diff options
Diffstat (limited to 'vendor/github.com/jackc/pgx/v5/pgconn')
-rw-r--r-- | vendor/github.com/jackc/pgx/v5/pgconn/config.go | 26 | ||||
-rw-r--r-- | vendor/github.com/jackc/pgx/v5/pgconn/errors.go | 30 | ||||
-rw-r--r-- | vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go | 67 |
3 files changed, 84 insertions, 39 deletions
diff --git a/vendor/github.com/jackc/pgx/v5/pgconn/config.go b/vendor/github.com/jackc/pgx/v5/pgconn/config.go index db0170e02..ddde89bd5 100644 --- a/vendor/github.com/jackc/pgx/v5/pgconn/config.go +++ b/vendor/github.com/jackc/pgx/v5/pgconn/config.go @@ -60,6 +60,11 @@ type Config struct { // OnNotification is a callback function called when a notification from the LISTEN/NOTIFY system is received. OnNotification NotificationHandler + // OnPgError is a callback function called when a Postgres error is received by the server. The default handler will close + // the connection on any FATAL errors. If you override this handler you should call the previously set handler or ensure + // that you close on FATAL errors by returning false. + OnPgError PgErrorHandler + createdByParseConfig bool // Used to enforce created by ParseConfig rule. } @@ -232,12 +237,12 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") { connStringSettings, err = parseURLSettings(connString) if err != nil { - return nil, &parseConfigError{connString: connString, msg: "failed to parse as URL", err: err} + return nil, &ParseConfigError{ConnString: connString, msg: "failed to parse as URL", err: err} } } else { connStringSettings, err = parseDSNSettings(connString) if err != nil { - return nil, &parseConfigError{connString: connString, msg: "failed to parse as DSN", err: err} + return nil, &ParseConfigError{ConnString: connString, msg: "failed to parse as DSN", err: err} } } } @@ -246,7 +251,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con if service, present := settings["service"]; present { serviceSettings, err := parseServiceSettings(settings["servicefile"], service) if err != nil { - return nil, &parseConfigError{connString: connString, msg: "failed to read service", err: err} + return nil, &ParseConfigError{ConnString: connString, msg: "failed to read service", err: err} } settings = mergeSettings(defaultSettings, envSettings, serviceSettings, connStringSettings) @@ -261,12 +266,19 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con BuildFrontend: func(r io.Reader, w io.Writer) *pgproto3.Frontend { return pgproto3.NewFrontend(r, w) }, + OnPgError: func(_ *PgConn, pgErr *PgError) bool { + // we want to automatically close any fatal errors + if strings.EqualFold(pgErr.Severity, "FATAL") { + return false + } + return true + }, } if connectTimeoutSetting, present := settings["connect_timeout"]; present { connectTimeout, err := parseConnectTimeoutSetting(connectTimeoutSetting) if err != nil { - return nil, &parseConfigError{connString: connString, msg: "invalid connect_timeout", err: err} + return nil, &ParseConfigError{ConnString: connString, msg: "invalid connect_timeout", err: err} } config.ConnectTimeout = connectTimeout config.DialFunc = makeConnectTimeoutDialFunc(connectTimeout) @@ -328,7 +340,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con port, err := parsePort(portStr) if err != nil { - return nil, &parseConfigError{connString: connString, msg: "invalid port", err: err} + return nil, &ParseConfigError{ConnString: connString, msg: "invalid port", err: err} } var tlsConfigs []*tls.Config @@ -340,7 +352,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con var err error tlsConfigs, err = configTLS(settings, host, options) if err != nil { - return nil, &parseConfigError{connString: connString, msg: "failed to configure TLS", err: err} + return nil, &ParseConfigError{ConnString: connString, msg: "failed to configure TLS", err: err} } } @@ -384,7 +396,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con case "any": // do nothing default: - return nil, &parseConfigError{connString: connString, msg: fmt.Sprintf("unknown target_session_attrs value: %v", tsa)} + return nil, &ParseConfigError{ConnString: connString, msg: fmt.Sprintf("unknown target_session_attrs value: %v", tsa)} } return config, nil diff --git a/vendor/github.com/jackc/pgx/v5/pgconn/errors.go b/vendor/github.com/jackc/pgx/v5/pgconn/errors.go index 3c54bbec0..c315739a9 100644 --- a/vendor/github.com/jackc/pgx/v5/pgconn/errors.go +++ b/vendor/github.com/jackc/pgx/v5/pgconn/errors.go @@ -57,22 +57,23 @@ func (pe *PgError) SQLState() string { return pe.Code } -type connectError struct { - config *Config +// ConnectError is the error returned when a connection attempt fails. +type ConnectError struct { + Config *Config // The configuration that was used in the connection attempt. msg string err error } -func (e *connectError) Error() string { +func (e *ConnectError) Error() string { sb := &strings.Builder{} - fmt.Fprintf(sb, "failed to connect to `host=%s user=%s database=%s`: %s", e.config.Host, e.config.User, e.config.Database, e.msg) + fmt.Fprintf(sb, "failed to connect to `host=%s user=%s database=%s`: %s", e.Config.Host, e.Config.User, e.Config.Database, e.msg) if e.err != nil { fmt.Fprintf(sb, " (%s)", e.err.Error()) } return sb.String() } -func (e *connectError) Unwrap() error { +func (e *ConnectError) Unwrap() error { return e.err } @@ -88,33 +89,38 @@ func (e *connLockError) Error() string { return e.status } -type parseConfigError struct { - connString string +// ParseConfigError is the error returned when a connection string cannot be parsed. +type ParseConfigError struct { + ConnString string // The connection string that could not be parsed. msg string err error } -func (e *parseConfigError) Error() string { - connString := redactPW(e.connString) +func (e *ParseConfigError) Error() string { + // Now that ParseConfigError is public and ConnString is available to the developer, perhaps it would be better only + // return a static string. That would ensure that the error message cannot leak a password. The ConnString field would + // allow access to the original string if desired and Unwrap would allow access to the underlying error. + connString := redactPW(e.ConnString) if e.err == nil { return fmt.Sprintf("cannot parse `%s`: %s", connString, e.msg) } return fmt.Sprintf("cannot parse `%s`: %s (%s)", connString, e.msg, e.err.Error()) } -func (e *parseConfigError) Unwrap() error { +func (e *ParseConfigError) Unwrap() error { return e.err } func normalizeTimeoutError(ctx context.Context, err error) error { - if err, ok := err.(net.Error); ok && err.Timeout() { + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { if ctx.Err() == context.Canceled { // Since the timeout was caused by a context cancellation, the actual error is context.Canceled not the timeout error. return context.Canceled } else if ctx.Err() == context.DeadlineExceeded { return &errTimeout{err: ctx.Err()} } else { - return &errTimeout{err: err} + return &errTimeout{err: netErr} } } return err diff --git a/vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go b/vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go index 1ccdc4db9..b287e0205 100644 --- a/vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go +++ b/vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go @@ -52,6 +52,12 @@ type LookupFunc func(ctx context.Context, host string) (addrs []string, err erro // BuildFrontendFunc is a function that can be used to create Frontend implementation for connection. type BuildFrontendFunc func(r io.Reader, w io.Writer) *pgproto3.Frontend +// PgErrorHandler is a function that handles errors returned from Postgres. This function must return true to keep +// the connection open. Returning false will cause the connection to be closed immediately. You should return +// false on any FATAL-severity errors. This will not receive network errors. The *PgConn is provided so the handler is +// aware of the origin of the error, but it must not invoke any query method. +type PgErrorHandler func(*PgConn, *PgError) bool + // NoticeHandler is a function that can handle notices received from the PostgreSQL server. Notices can be received at // any time, usually during handling of a query response. The *PgConn is provided so the handler is aware of the origin // of the notice, but it must not invoke any query method. Be aware that this is distinct from LISTEN/NOTIFY @@ -146,11 +152,11 @@ func ConnectConfig(octx context.Context, config *Config) (pgConn *PgConn, err er ctx := octx fallbackConfigs, err = expandWithIPs(ctx, config.LookupFunc, fallbackConfigs) if err != nil { - return nil, &connectError{config: config, msg: "hostname resolving error", err: err} + return nil, &ConnectError{Config: config, msg: "hostname resolving error", err: err} } if len(fallbackConfigs) == 0 { - return nil, &connectError{config: config, msg: "hostname resolving error", err: errors.New("ip addr wasn't found")} + return nil, &ConnectError{Config: config, msg: "hostname resolving error", err: errors.New("ip addr wasn't found")} } foundBestServer := false @@ -172,7 +178,7 @@ func ConnectConfig(octx context.Context, config *Config) (pgConn *PgConn, err er foundBestServer = true break } else if pgerr, ok := err.(*PgError); ok { - err = &connectError{config: config, msg: "server error", err: pgerr} + err = &ConnectError{Config: config, msg: "server error", err: pgerr} const ERRCODE_INVALID_PASSWORD = "28P01" // wrong password const ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION = "28000" // wrong password or bad pg_hba.conf settings const ERRCODE_INVALID_CATALOG_NAME = "3D000" // db does not exist @@ -183,7 +189,7 @@ func ConnectConfig(octx context.Context, config *Config) (pgConn *PgConn, err er pgerr.Code == ERRCODE_INSUFFICIENT_PRIVILEGE { break } - } else if cerr, ok := err.(*connectError); ok { + } else if cerr, ok := err.(*ConnectError); ok { if _, ok := cerr.err.(*NotPreferredError); ok { fallbackConfig = fc } @@ -193,7 +199,7 @@ func ConnectConfig(octx context.Context, config *Config) (pgConn *PgConn, err er if !foundBestServer && fallbackConfig != nil { pgConn, err = connect(ctx, config, fallbackConfig, true) if pgerr, ok := err.(*PgError); ok { - err = &connectError{config: config, msg: "server error", err: pgerr} + err = &ConnectError{Config: config, msg: "server error", err: pgerr} } } @@ -205,7 +211,7 @@ func ConnectConfig(octx context.Context, config *Config) (pgConn *PgConn, err er err := config.AfterConnect(ctx, pgConn) if err != nil { pgConn.conn.Close() - return nil, &connectError{config: config, msg: "AfterConnect error", err: err} + return nil, &ConnectError{Config: config, msg: "AfterConnect error", err: err} } } @@ -277,7 +283,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port) netConn, err := config.DialFunc(ctx, network, address) if err != nil { - return nil, &connectError{config: config, msg: "dial error", err: normalizeTimeoutError(ctx, err)} + return nil, &ConnectError{Config: config, msg: "dial error", err: normalizeTimeoutError(ctx, err)} } pgConn.conn = netConn @@ -289,7 +295,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig pgConn.contextWatcher.Unwatch() // Always unwatch `netConn` after TLS. if err != nil { netConn.Close() - return nil, &connectError{config: config, msg: "tls error", err: err} + return nil, &ConnectError{Config: config, msg: "tls error", err: normalizeTimeoutError(ctx, err)} } pgConn.conn = nbTLSConn @@ -330,7 +336,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig pgConn.frontend.Send(&startupMsg) if err := pgConn.flushWithPotentialWriteReadDeadlock(); err != nil { pgConn.conn.Close() - return nil, &connectError{config: config, msg: "failed to write startup message", err: normalizeTimeoutError(ctx, err)} + return nil, &ConnectError{Config: config, msg: "failed to write startup message", err: normalizeTimeoutError(ctx, err)} } for { @@ -340,7 +346,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig if err, ok := err.(*PgError); ok { return nil, err } - return nil, &connectError{config: config, msg: "failed to receive message", err: normalizeTimeoutError(ctx, err)} + return nil, &ConnectError{Config: config, msg: "failed to receive message", err: normalizeTimeoutError(ctx, err)} } switch msg := msg.(type) { @@ -353,26 +359,26 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig err = pgConn.txPasswordMessage(pgConn.config.Password) if err != nil { pgConn.conn.Close() - return nil, &connectError{config: config, msg: "failed to write password message", err: err} + return nil, &ConnectError{Config: config, msg: "failed to write password message", err: err} } case *pgproto3.AuthenticationMD5Password: digestedPassword := "md5" + hexMD5(hexMD5(pgConn.config.Password+pgConn.config.User)+string(msg.Salt[:])) err = pgConn.txPasswordMessage(digestedPassword) if err != nil { pgConn.conn.Close() - return nil, &connectError{config: config, msg: "failed to write password message", err: err} + return nil, &ConnectError{Config: config, msg: "failed to write password message", err: err} } case *pgproto3.AuthenticationSASL: err = pgConn.scramAuth(msg.AuthMechanisms) if err != nil { pgConn.conn.Close() - return nil, &connectError{config: config, msg: "failed SASL auth", err: err} + return nil, &ConnectError{Config: config, msg: "failed SASL auth", err: err} } case *pgproto3.AuthenticationGSS: err = pgConn.gssAuth() if err != nil { pgConn.conn.Close() - return nil, &connectError{config: config, msg: "failed GSS auth", err: err} + return nil, &ConnectError{Config: config, msg: "failed GSS auth", err: err} } case *pgproto3.ReadyForQuery: pgConn.status = connStatusIdle @@ -390,7 +396,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig return pgConn, nil } pgConn.conn.Close() - return nil, &connectError{config: config, msg: "ValidateConnect failed", err: err} + return nil, &ConnectError{Config: config, msg: "ValidateConnect failed", err: err} } } return pgConn, nil @@ -401,7 +407,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig return nil, ErrorResponseToPgError(msg) default: pgConn.conn.Close() - return nil, &connectError{config: config, msg: "received unexpected message", err: err} + return nil, &ConnectError{Config: config, msg: "received unexpected message", err: err} } } } @@ -547,11 +553,12 @@ func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) { case *pgproto3.ParameterStatus: pgConn.parameterStatuses[msg.Name] = msg.Value case *pgproto3.ErrorResponse: - if msg.Severity == "FATAL" { + err := ErrorResponseToPgError(msg) + if pgConn.config.OnPgError != nil && !pgConn.config.OnPgError(pgConn, err) { pgConn.status = connStatusClosed pgConn.conn.Close() // Ignore error as the connection is already broken and there is already an error to return. close(pgConn.cleanupDone) - return nil, ErrorResponseToPgError(msg) + return nil, err } case *pgproto3.NoticeResponse: if pgConn.config.OnNotice != nil { @@ -2046,6 +2053,13 @@ func (p *Pipeline) Flush() error { // Sync establishes a synchronization point and flushes the queued requests. func (p *Pipeline) Sync() error { + if p.closed { + if p.err != nil { + return p.err + } + return errors.New("pipeline closed") + } + p.conn.frontend.SendSync(&pgproto3.Sync{}) err := p.Flush() if err != nil { @@ -2062,10 +2076,21 @@ func (p *Pipeline) Sync() error { // *PipelineSync. If an ErrorResponse is received from the server, results will be nil and err will be a *PgError. If no // results are available, results and err will both be nil. func (p *Pipeline) GetResults() (results any, err error) { + if p.closed { + if p.err != nil { + return nil, p.err + } + return nil, errors.New("pipeline closed") + } + if p.expectedReadyForQueryCount == 0 { return nil, nil } + return p.getResults() +} + +func (p *Pipeline) getResults() (results any, err error) { for { msg, err := p.conn.receiveMessage() if err != nil { @@ -2092,7 +2117,8 @@ func (p *Pipeline) GetResults() (results any, err error) { case *pgproto3.ParseComplete: peekedMsg, err := p.conn.peekMessage() if err != nil { - return nil, err + p.conn.asyncClose() + return nil, normalizeTimeoutError(p.ctx, err) } if _, ok := peekedMsg.(*pgproto3.ParameterDescription); ok { return p.getResultsPrepare() @@ -2152,6 +2178,7 @@ func (p *Pipeline) Close() error { if p.closed { return p.err } + p.closed = true if p.pendingSync { @@ -2164,7 +2191,7 @@ func (p *Pipeline) Close() error { } for p.expectedReadyForQueryCount > 0 { - _, err := p.GetResults() + _, err := p.getResults() if err != nil { p.err = err var pgErr *PgError |