diff options
Diffstat (limited to 'vendor/github.com/jackc/pgx/v5/pgconn')
-rw-r--r-- | vendor/github.com/jackc/pgx/v5/pgconn/config.go | 46 | ||||
-rw-r--r-- | vendor/github.com/jackc/pgx/v5/pgconn/ctxwatch/context_watcher.go (renamed from vendor/github.com/jackc/pgx/v5/pgconn/internal/ctxwatch/context_watcher.go) | 25 | ||||
-rw-r--r-- | vendor/github.com/jackc/pgx/v5/pgconn/doc.go | 16 | ||||
-rw-r--r-- | vendor/github.com/jackc/pgx/v5/pgconn/errors.go | 76 | ||||
-rw-r--r-- | vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go | 399 |
5 files changed, 353 insertions, 209 deletions
diff --git a/vendor/github.com/jackc/pgx/v5/pgconn/config.go b/vendor/github.com/jackc/pgx/v5/pgconn/config.go index 33a722579..598917f55 100644 --- a/vendor/github.com/jackc/pgx/v5/pgconn/config.go +++ b/vendor/github.com/jackc/pgx/v5/pgconn/config.go @@ -19,6 +19,7 @@ import ( "github.com/jackc/pgpassfile" "github.com/jackc/pgservicefile" + "github.com/jackc/pgx/v5/pgconn/ctxwatch" "github.com/jackc/pgx/v5/pgproto3" ) @@ -39,7 +40,12 @@ type Config struct { DialFunc DialFunc // e.g. net.Dialer.DialContext LookupFunc LookupFunc // e.g. net.Resolver.LookupHost BuildFrontend BuildFrontendFunc - RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name) + + // BuildContextWatcherHandler is called to create a ContextWatcherHandler for a connection. The handler is called + // when a context passed to a PgConn method is canceled. + BuildContextWatcherHandler func(*PgConn) ctxwatch.Handler + + RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name) KerberosSrvName string KerberosSpn string @@ -70,7 +76,7 @@ type Config struct { // ParseConfigOptions contains options that control how a config is built such as GetSSLPassword. type ParseConfigOptions struct { - // GetSSLPassword gets the password to decrypt a SSL client certificate. This is analogous to the the libpq function + // GetSSLPassword gets the password to decrypt a SSL client certificate. This is analogous to the libpq function // PQsetSSLKeyPassHook_OpenSSL. GetSSLPassword GetSSLPasswordFunc } @@ -112,6 +118,14 @@ type FallbackConfig struct { TLSConfig *tls.Config // nil disables TLS } +// connectOneConfig is the configuration for a single attempt to connect to a single host. +type connectOneConfig struct { + network string + address string + originalHostname string // original hostname before resolving + tlsConfig *tls.Config // nil disables TLS +} + // isAbsolutePath checks if the provided value is an absolute path either // beginning with a forward slash (as on Linux-based systems) or with a capital // letter A-Z followed by a colon and a backslash, e.g., "C:\", (as on Windows). @@ -146,11 +160,11 @@ func NetworkAddress(host string, port uint16) (network, address string) { // ParseConfig builds a *Config from connString with similar behavior to the PostgreSQL standard C library libpq. It // uses the same defaults as libpq (e.g. port=5432) and understands most PG* environment variables. ParseConfig closely -// matches the parsing behavior of libpq. connString may either be in URL format or keyword = value format (DSN style). -// See https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING for details. connString also may be -// empty to only read from the environment. If a password is not supplied it will attempt to read the .pgpass file. +// matches the parsing behavior of libpq. connString may either be in URL format or keyword = value format. See +// https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING for details. connString also may be empty +// to only read from the environment. If a password is not supplied it will attempt to read the .pgpass file. // -// # Example DSN +// # Example Keyword/Value // user=jack password=secret host=pg.example.com port=5432 dbname=mydb sslmode=verify-ca // // # Example URL @@ -169,7 +183,7 @@ func NetworkAddress(host string, port uint16) (network, address string) { // postgres://jack:secret@foo.example.com:5432,bar.example.com:5432/mydb // // ParseConfig currently recognizes the following environment variable and their parameter key word equivalents passed -// via database URL or DSN: +// via database URL or keyword/value: // // PGHOST // PGPORT @@ -233,16 +247,16 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con connStringSettings := make(map[string]string) if connString != "" { var err error - // connString may be a database URL or a DSN + // connString may be a database URL or in PostgreSQL keyword/value format if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") { connStringSettings, err = parseURLSettings(connString) if err != nil { return nil, &ParseConfigError{ConnString: connString, msg: "failed to parse as URL", err: err} } } else { - connStringSettings, err = parseDSNSettings(connString) + connStringSettings, err = parseKeywordValueSettings(connString) if err != nil { - return nil, &ParseConfigError{ConnString: connString, msg: "failed to parse as DSN", err: err} + return nil, &ParseConfigError{ConnString: connString, msg: "failed to parse as keyword/value", err: err} } } } @@ -266,6 +280,9 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con BuildFrontend: func(r io.Reader, w io.Writer) *pgproto3.Frontend { return pgproto3.NewFrontend(r, w) }, + BuildContextWatcherHandler: func(pgConn *PgConn) ctxwatch.Handler { + return &DeadlineContextWatcherHandler{Conn: pgConn.conn} + }, OnPgError: func(_ *PgConn, pgErr *PgError) bool { // we want to automatically close any fatal errors if strings.EqualFold(pgErr.Severity, "FATAL") { @@ -517,7 +534,7 @@ func isIPOnly(host string) bool { var asciiSpace = [256]uint8{'\t': 1, '\n': 1, '\v': 1, '\f': 1, '\r': 1, ' ': 1} -func parseDSNSettings(s string) (map[string]string, error) { +func parseKeywordValueSettings(s string) (map[string]string, error) { settings := make(map[string]string) nameMap := map[string]string{ @@ -528,7 +545,7 @@ func parseDSNSettings(s string) (map[string]string, error) { var key, val string eqIdx := strings.IndexRune(s, '=') if eqIdx < 0 { - return nil, errors.New("invalid dsn") + return nil, errors.New("invalid keyword/value") } key = strings.Trim(s[:eqIdx], " \t\n\r\v\f") @@ -580,7 +597,7 @@ func parseDSNSettings(s string) (map[string]string, error) { } if key == "" { - return nil, errors.New("invalid dsn") + return nil, errors.New("invalid keyword/value") } settings[key] = val @@ -800,7 +817,8 @@ func parsePort(s string) (uint16, error) { } func makeDefaultDialer() *net.Dialer { - return &net.Dialer{KeepAlive: 5 * time.Minute} + // rely on GOLANG KeepAlive settings + return &net.Dialer{} } func makeDefaultResolver() *net.Resolver { diff --git a/vendor/github.com/jackc/pgx/v5/pgconn/internal/ctxwatch/context_watcher.go b/vendor/github.com/jackc/pgx/v5/pgconn/ctxwatch/context_watcher.go index b39cb3ee5..db8884eb8 100644 --- a/vendor/github.com/jackc/pgx/v5/pgconn/internal/ctxwatch/context_watcher.go +++ b/vendor/github.com/jackc/pgx/v5/pgconn/ctxwatch/context_watcher.go @@ -8,9 +8,8 @@ import ( // ContextWatcher watches a context and performs an action when the context is canceled. It can watch one context at a // time. type ContextWatcher struct { - onCancel func() - onUnwatchAfterCancel func() - unwatchChan chan struct{} + handler Handler + unwatchChan chan struct{} lock sync.Mutex watchInProgress bool @@ -20,11 +19,10 @@ type ContextWatcher struct { // NewContextWatcher returns a ContextWatcher. onCancel will be called when a watched context is canceled. // OnUnwatchAfterCancel will be called when Unwatch is called and the watched context had already been canceled and // onCancel called. -func NewContextWatcher(onCancel func(), onUnwatchAfterCancel func()) *ContextWatcher { +func NewContextWatcher(handler Handler) *ContextWatcher { cw := &ContextWatcher{ - onCancel: onCancel, - onUnwatchAfterCancel: onUnwatchAfterCancel, - unwatchChan: make(chan struct{}), + handler: handler, + unwatchChan: make(chan struct{}), } return cw @@ -46,7 +44,7 @@ func (cw *ContextWatcher) Watch(ctx context.Context) { go func() { select { case <-ctx.Done(): - cw.onCancel() + cw.handler.HandleCancel(ctx) cw.onCancelWasCalled = true <-cw.unwatchChan case <-cw.unwatchChan: @@ -66,8 +64,17 @@ func (cw *ContextWatcher) Unwatch() { if cw.watchInProgress { cw.unwatchChan <- struct{}{} if cw.onCancelWasCalled { - cw.onUnwatchAfterCancel() + cw.handler.HandleUnwatchAfterCancel() } cw.watchInProgress = false } } + +type Handler interface { + // HandleCancel is called when the context that a ContextWatcher is currently watching is canceled. canceledCtx is the + // context that was canceled. + HandleCancel(canceledCtx context.Context) + + // HandleUnwatchAfterCancel is called when a ContextWatcher that called HandleCancel on this Handler is unwatched. + HandleUnwatchAfterCancel() +} diff --git a/vendor/github.com/jackc/pgx/v5/pgconn/doc.go b/vendor/github.com/jackc/pgx/v5/pgconn/doc.go index e3242cf4e..701375019 100644 --- a/vendor/github.com/jackc/pgx/v5/pgconn/doc.go +++ b/vendor/github.com/jackc/pgx/v5/pgconn/doc.go @@ -5,8 +5,8 @@ nearly the same level is the C library libpq. Establishing a Connection -Use Connect to establish a connection. It accepts a connection string in URL or DSN and will read the environment for -libpq style environment variables. +Use Connect to establish a connection. It accepts a connection string in URL or keyword/value format and will read the +environment for libpq style environment variables. Executing a Query @@ -20,13 +20,17 @@ result. The ReadAll method reads all query results into memory. Pipeline Mode -Pipeline mode allows sending queries without having read the results of previously sent queries. It allows -control of exactly how many and when network round trips occur. +Pipeline mode allows sending queries without having read the results of previously sent queries. It allows control of +exactly how many and when network round trips occur. Context Support -All potentially blocking operations take a context.Context. If a context is canceled while the method is in progress the -method immediately returns. In most circumstances, this will close the underlying connection. +All potentially blocking operations take a context.Context. The default behavior when a context is canceled is for the +method to immediately return. In most circumstances, this will also close the underlying connection. This behavior can +be customized by using BuildContextWatcherHandler on the Config to create a ctxwatch.Handler with different behavior. +This can be especially useful when queries that are frequently canceled and the overhead of creating new connections is +a problem. DeadlineContextWatcherHandler and CancelRequestContextWatcherHandler can be used to introduce a delay before +interrupting the query in such a way as to close the connection. The CancelRequest method may be used to request the PostgreSQL server cancel an in-progress query without forcing the client to abort. diff --git a/vendor/github.com/jackc/pgx/v5/pgconn/errors.go b/vendor/github.com/jackc/pgx/v5/pgconn/errors.go index c315739a9..ec4a6d47c 100644 --- a/vendor/github.com/jackc/pgx/v5/pgconn/errors.go +++ b/vendor/github.com/jackc/pgx/v5/pgconn/errors.go @@ -12,13 +12,14 @@ import ( // SafeToRetry checks if the err is guaranteed to have occurred before sending any data to the server. func SafeToRetry(err error) bool { - if e, ok := err.(interface{ SafeToRetry() bool }); ok { - return e.SafeToRetry() + var retryableErr interface{ SafeToRetry() bool } + if errors.As(err, &retryableErr) { + return retryableErr.SafeToRetry() } return false } -// Timeout checks if err was was caused by a timeout. To be specific, it is true if err was caused within pgconn by a +// Timeout checks if err was caused by a timeout. To be specific, it is true if err was caused within pgconn by a // context.DeadlineExceeded or an implementer of net.Error where Timeout() is true. func Timeout(err error) bool { var timeoutErr *errTimeout @@ -29,23 +30,24 @@ func Timeout(err error) bool { // http://www.postgresql.org/docs/11/static/protocol-error-fields.html for // detailed field description. type PgError struct { - Severity string - Code string - Message string - Detail string - Hint string - Position int32 - InternalPosition int32 - InternalQuery string - Where string - SchemaName string - TableName string - ColumnName string - DataTypeName string - ConstraintName string - File string - Line int32 - Routine string + Severity string + SeverityUnlocalized string + Code string + Message string + Detail string + Hint string + Position int32 + InternalPosition int32 + InternalQuery string + Where string + SchemaName string + TableName string + ColumnName string + DataTypeName string + ConstraintName string + File string + Line int32 + Routine string } func (pe *PgError) Error() string { @@ -60,23 +62,37 @@ func (pe *PgError) SQLState() string { // ConnectError is the error returned when a connection attempt fails. type ConnectError struct { Config *Config // The configuration that was used in the connection attempt. - msg string err error } func (e *ConnectError) Error() string { - sb := &strings.Builder{} - fmt.Fprintf(sb, "failed to connect to `host=%s user=%s database=%s`: %s", e.Config.Host, e.Config.User, e.Config.Database, e.msg) - if e.err != nil { - fmt.Fprintf(sb, " (%s)", e.err.Error()) + prefix := fmt.Sprintf("failed to connect to `user=%s database=%s`:", e.Config.User, e.Config.Database) + details := e.err.Error() + if strings.Contains(details, "\n") { + return prefix + "\n\t" + strings.ReplaceAll(details, "\n", "\n\t") + } else { + return prefix + " " + details } - return sb.String() } func (e *ConnectError) Unwrap() error { return e.err } +type perDialConnectError struct { + address string + originalHostname string + err error +} + +func (e *perDialConnectError) Error() string { + return fmt.Sprintf("%s (%s): %s", e.address, e.originalHostname, e.err.Error()) +} + +func (e *perDialConnectError) Unwrap() error { + return e.err +} + type connLockError struct { status string } @@ -195,10 +211,10 @@ func redactPW(connString string) string { return redactURL(u) } } - quotedDSN := regexp.MustCompile(`password='[^']*'`) - connString = quotedDSN.ReplaceAllLiteralString(connString, "password=xxxxx") - plainDSN := regexp.MustCompile(`password=[^ ]*`) - connString = plainDSN.ReplaceAllLiteralString(connString, "password=xxxxx") + quotedKV := regexp.MustCompile(`password='[^']*'`) + connString = quotedKV.ReplaceAllLiteralString(connString, "password=xxxxx") + plainKV := regexp.MustCompile(`password=[^ ]*`) + connString = plainKV.ReplaceAllLiteralString(connString, "password=xxxxx") brokenURL := regexp.MustCompile(`:[^:@]+?@`) connString = brokenURL.ReplaceAllLiteralString(connString, ":xxxxxx@") return connString diff --git a/vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go b/vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go index 0bf03f335..7efb522a4 100644 --- a/vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go +++ b/vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go @@ -18,8 +18,8 @@ import ( "github.com/jackc/pgx/v5/internal/iobufpool" "github.com/jackc/pgx/v5/internal/pgio" + "github.com/jackc/pgx/v5/pgconn/ctxwatch" "github.com/jackc/pgx/v5/pgconn/internal/bgreader" - "github.com/jackc/pgx/v5/pgconn/internal/ctxwatch" "github.com/jackc/pgx/v5/pgproto3" ) @@ -82,6 +82,8 @@ type PgConn struct { slowWriteTimer *time.Timer bgReaderStarted chan struct{} + customData map[string]any + config *Config status byte // One of connStatus* constants @@ -103,8 +105,9 @@ type PgConn struct { cleanupDone chan struct{} } -// Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or DSN format) -// to provide configuration. See documentation for [ParseConfig] for details. ctx can be used to cancel a connect attempt. +// Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or keyword/value +// format) to provide configuration. See documentation for [ParseConfig] for details. ctx can be used to cancel a +// connect attempt. func Connect(ctx context.Context, connString string) (*PgConn, error) { config, err := ParseConfig(connString) if err != nil { @@ -114,9 +117,9 @@ func Connect(ctx context.Context, connString string) (*PgConn, error) { return ConnectConfig(ctx, config) } -// Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or DSN format) -// and ParseConfigOptions to provide additional configuration. See documentation for [ParseConfig] for details. ctx can be -// used to cancel a connect attempt. +// Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or keyword/value +// format) and ParseConfigOptions to provide additional configuration. See documentation for [ParseConfig] for details. +// ctx can be used to cancel a connect attempt. func ConnectWithOptions(ctx context.Context, connString string, parseConfigOptions ParseConfigOptions) (*PgConn, error) { config, err := ParseConfigWithOptions(connString, parseConfigOptions) if err != nil { @@ -131,113 +134,77 @@ func ConnectWithOptions(ctx context.Context, connString string, parseConfigOptio // // If config.Fallbacks are present they will sequentially be tried in case of error establishing network connection. An // authentication error will terminate the chain of attempts (like libpq: -// https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS) and be returned as the error. Otherwise, -// if all attempts fail the last error is returned. -func ConnectConfig(octx context.Context, config *Config) (pgConn *PgConn, err error) { +// https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS) and be returned as the error. +func ConnectConfig(ctx context.Context, config *Config) (*PgConn, error) { // Default values are set in ParseConfig. Enforce initial creation by ParseConfig rather than setting defaults from // zero values. if !config.createdByParseConfig { panic("config must be created by ParseConfig") } - // Simplify usage by treating primary config and fallbacks the same. - fallbackConfigs := []*FallbackConfig{ - { - Host: config.Host, - Port: config.Port, - TLSConfig: config.TLSConfig, - }, - } - fallbackConfigs = append(fallbackConfigs, config.Fallbacks...) - ctx := octx - fallbackConfigs, err = expandWithIPs(ctx, config.LookupFunc, fallbackConfigs) - if err != nil { - return nil, &ConnectError{Config: config, msg: "hostname resolving error", err: err} - } + var allErrors []error - if len(fallbackConfigs) == 0 { - return nil, &ConnectError{Config: config, msg: "hostname resolving error", err: errors.New("ip addr wasn't found")} - } - - foundBestServer := false - var fallbackConfig *FallbackConfig - for i, fc := range fallbackConfigs { - // ConnectTimeout restricts the whole connection process. - if config.ConnectTimeout != 0 { - // create new context first time or when previous host was different - if i == 0 || (fallbackConfigs[i].Host != fallbackConfigs[i-1].Host) { - var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(octx, config.ConnectTimeout) - defer cancel() - } - } else { - ctx = octx - } - pgConn, err = connect(ctx, config, fc, false) - if err == nil { - foundBestServer = true - break - } else if pgerr, ok := err.(*PgError); ok { - err = &ConnectError{Config: config, msg: "server error", err: pgerr} - const ERRCODE_INVALID_PASSWORD = "28P01" // wrong password - const ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION = "28000" // wrong password or bad pg_hba.conf settings - const ERRCODE_INVALID_CATALOG_NAME = "3D000" // db does not exist - const ERRCODE_INSUFFICIENT_PRIVILEGE = "42501" // missing connect privilege - if pgerr.Code == ERRCODE_INVALID_PASSWORD || - pgerr.Code == ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION && fc.TLSConfig != nil || - pgerr.Code == ERRCODE_INVALID_CATALOG_NAME || - pgerr.Code == ERRCODE_INSUFFICIENT_PRIVILEGE { - break - } - } else if cerr, ok := err.(*ConnectError); ok { - if _, ok := cerr.err.(*NotPreferredError); ok { - fallbackConfig = fc - } - } + connectConfigs, errs := buildConnectOneConfigs(ctx, config) + if len(errs) > 0 { + allErrors = append(allErrors, errs...) } - if !foundBestServer && fallbackConfig != nil { - pgConn, err = connect(ctx, config, fallbackConfig, true) - if pgerr, ok := err.(*PgError); ok { - err = &ConnectError{Config: config, msg: "server error", err: pgerr} - } + if len(connectConfigs) == 0 { + return nil, &ConnectError{Config: config, err: fmt.Errorf("hostname resolving error: %w", errors.Join(allErrors...))} } - if err != nil { - return nil, err // no need to wrap in connectError because it will already be wrapped in all cases except PgError + pgConn, errs := connectPreferred(ctx, config, connectConfigs) + if len(errs) > 0 { + allErrors = append(allErrors, errs...) + return nil, &ConnectError{Config: config, err: errors.Join(allErrors...)} } if config.AfterConnect != nil { err := config.AfterConnect(ctx, pgConn) if err != nil { pgConn.conn.Close() - return nil, &ConnectError{Config: config, msg: "AfterConnect error", err: err} + return nil, &ConnectError{Config: config, err: fmt.Errorf("AfterConnect error: %w", err)} } } return pgConn, nil } -func expandWithIPs(ctx context.Context, lookupFn LookupFunc, fallbacks []*FallbackConfig) ([]*FallbackConfig, error) { - var configs []*FallbackConfig +// buildConnectOneConfigs resolves hostnames and builds a list of connectOneConfigs to try connecting to. It returns a +// slice of successfully resolved connectOneConfigs and a slice of errors. It is possible for both slices to contain +// values if some hosts were successfully resolved and others were not. +func buildConnectOneConfigs(ctx context.Context, config *Config) ([]*connectOneConfig, []error) { + // Simplify usage by treating primary config and fallbacks the same. + fallbackConfigs := []*FallbackConfig{ + { + Host: config.Host, + Port: config.Port, + TLSConfig: config.TLSConfig, + }, + } + fallbackConfigs = append(fallbackConfigs, config.Fallbacks...) + + var configs []*connectOneConfig - var lookupErrors []error + var allErrors []error - for _, fb := range fallbacks { + for _, fb := range fallbackConfigs { // skip resolve for unix sockets if isAbsolutePath(fb.Host) { - configs = append(configs, &FallbackConfig{ - Host: fb.Host, - Port: fb.Port, - TLSConfig: fb.TLSConfig, + network, address := NetworkAddress(fb.Host, fb.Port) + configs = append(configs, &connectOneConfig{ + network: network, + address: address, + originalHostname: fb.Host, + tlsConfig: fb.TLSConfig, }) continue } - ips, err := lookupFn(ctx, fb.Host) + ips, err := config.LookupFunc(ctx, fb.Host) if err != nil { - lookupErrors = append(lookupErrors, err) + allErrors = append(allErrors, err) continue } @@ -246,63 +213,126 @@ func expandWithIPs(ctx context.Context, lookupFn LookupFunc, fallbacks []*Fallba if err == nil { port, err := strconv.ParseUint(splitPort, 10, 16) if err != nil { - return nil, fmt.Errorf("error parsing port (%s) from lookup: %w", splitPort, err) + return nil, []error{fmt.Errorf("error parsing port (%s) from lookup: %w", splitPort, err)} } - configs = append(configs, &FallbackConfig{ - Host: splitIP, - Port: uint16(port), - TLSConfig: fb.TLSConfig, + network, address := NetworkAddress(splitIP, uint16(port)) + configs = append(configs, &connectOneConfig{ + network: network, + address: address, + originalHostname: fb.Host, + tlsConfig: fb.TLSConfig, }) } else { - configs = append(configs, &FallbackConfig{ - Host: ip, - Port: fb.Port, - TLSConfig: fb.TLSConfig, + network, address := NetworkAddress(ip, fb.Port) + configs = append(configs, &connectOneConfig{ + network: network, + address: address, + originalHostname: fb.Host, + tlsConfig: fb.TLSConfig, }) } } } - // See https://github.com/jackc/pgx/issues/1464. When Go 1.20 can be used in pgx consider using errors.Join so all - // errors are reported. - if len(configs) == 0 && len(lookupErrors) > 0 { - return nil, lookupErrors[0] + return configs, allErrors +} + +// connectPreferred attempts to connect to the preferred host from connectOneConfigs. The connections are attempted in +// order. If a connection is successful it is returned. If no connection is successful then all errors are returned. If +// a connection attempt returns a [NotPreferredError], then that host will be used if no other hosts are successful. +func connectPreferred(ctx context.Context, config *Config, connectOneConfigs []*connectOneConfig) (*PgConn, []error) { + octx := ctx + var allErrors []error + + var fallbackConnectOneConfig *connectOneConfig + for i, c := range connectOneConfigs { + // ConnectTimeout restricts the whole connection process. + if config.ConnectTimeout != 0 { + // create new context first time or when previous host was different + if i == 0 || (connectOneConfigs[i].address != connectOneConfigs[i-1].address) { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(octx, config.ConnectTimeout) + defer cancel() + } + } else { + ctx = octx + } + + pgConn, err := connectOne(ctx, config, c, false) + if pgConn != nil { + return pgConn, nil + } + + allErrors = append(allErrors, err) + + var pgErr *PgError + if errors.As(err, &pgErr) { + const ERRCODE_INVALID_PASSWORD = "28P01" // wrong password + const ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION = "28000" // wrong password or bad pg_hba.conf settings + const ERRCODE_INVALID_CATALOG_NAME = "3D000" // db does not exist + const ERRCODE_INSUFFICIENT_PRIVILEGE = "42501" // missing connect privilege + if pgErr.Code == ERRCODE_INVALID_PASSWORD || + pgErr.Code == ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION && c.tlsConfig != nil || + pgErr.Code == ERRCODE_INVALID_CATALOG_NAME || + pgErr.Code == ERRCODE_INSUFFICIENT_PRIVILEGE { + return nil, allErrors + } + } + + var npErr *NotPreferredError + if errors.As(err, &npErr) { + fallbackConnectOneConfig = c + } + } + + if fallbackConnectOneConfig != nil { + pgConn, err := connectOne(ctx, config, fallbackConnectOneConfig, true) + if err == nil { + return pgConn, nil + } + allErrors = append(allErrors, err) } - return configs, nil + return nil, allErrors } -func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig, +// connectOne makes one connection attempt to a single host. +func connectOne(ctx context.Context, config *Config, connectConfig *connectOneConfig, ignoreNotPreferredErr bool, ) (*PgConn, error) { pgConn := new(PgConn) pgConn.config = config pgConn.cleanupDone = make(chan struct{}) + pgConn.customData = make(map[string]any) var err error - network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port) - netConn, err := config.DialFunc(ctx, network, address) - if err != nil { - return nil, &ConnectError{Config: config, msg: "dial error", err: normalizeTimeoutError(ctx, err)} + + newPerDialConnectError := func(msg string, err error) *perDialConnectError { + err = normalizeTimeoutError(ctx, err) + e := &perDialConnectError{address: connectConfig.address, originalHostname: connectConfig.originalHostname, err: fmt.Errorf("%s: %w", msg, err)} + return e } - pgConn.conn = netConn - pgConn.contextWatcher = newContextWatcher(netConn) - pgConn.contextWatcher.Watch(ctx) + pgConn.conn, err = config.DialFunc(ctx, connectConfig.network, connectConfig.address) + if err != nil { + return nil, newPerDialConnectError("dial error", err) + } - if fallbackConfig.TLSConfig != nil { - nbTLSConn, err := startTLS(netConn, fallbackConfig.TLSConfig) + if connectConfig.tlsConfig != nil { + pgConn.contextWatcher = ctxwatch.NewContextWatcher(&DeadlineContextWatcherHandler{Conn: pgConn.conn}) + pgConn.contextWatcher.Watch(ctx) + tlsConn, err := startTLS(pgConn.conn, connectConfig.tlsConfig) pgConn.contextWatcher.Unwatch() // Always unwatch `netConn` after TLS. if err != nil { - netConn.Close() - return nil, &ConnectError{Config: config, msg: "tls error", err: normalizeTimeoutError(ctx, err)} + pgConn.conn.Close() + return nil, newPerDialConnectError("tls error", err) } - pgConn.conn = nbTLSConn - pgConn.contextWatcher = newContextWatcher(nbTLSConn) - pgConn.contextWatcher.Watch(ctx) + pgConn.conn = tlsConn } + pgConn.contextWatcher = ctxwatch.NewContextWatcher(config.BuildContextWatcherHandler(pgConn)) + pgConn.contextWatcher.Watch(ctx) defer pgConn.contextWatcher.Unwatch() pgConn.parameterStatuses = make(map[string]string) @@ -336,7 +366,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig pgConn.frontend.Send(&startupMsg) if err := pgConn.flushWithPotentialWriteReadDeadlock(); err != nil { pgConn.conn.Close() - return nil, &ConnectError{Config: config, msg: "failed to write startup message", err: normalizeTimeoutError(ctx, err)} + return nil, newPerDialConnectError("failed to write startup message", err) } for { @@ -344,9 +374,9 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig if err != nil { pgConn.conn.Close() if err, ok := err.(*PgError); ok { - return nil, err + return nil, newPerDialConnectError("server error", err) } - return nil, &ConnectError{Config: config, msg: "failed to receive message", err: normalizeTimeoutError(ctx, err)} + return nil, newPerDialConnectError("failed to receive message", err) } switch msg := msg.(type) { @@ -359,26 +389,26 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig err = pgConn.txPasswordMessage(pgConn.config.Password) if err != nil { pgConn.conn.Close() - return nil, &ConnectError{Config: config, msg: "failed to write password message", err: err} + return nil, newPerDialConnectError("failed to write password message", err) } case *pgproto3.AuthenticationMD5Password: digestedPassword := "md5" + hexMD5(hexMD5(pgConn.config.Password+pgConn.config.User)+string(msg.Salt[:])) err = pgConn.txPasswordMessage(digestedPassword) if err != nil { pgConn.conn.Close() - return nil, &ConnectError{Config: config, msg: "failed to write password message", err: err} + return nil, newPerDialConnectError("failed to write password message", err) } case *pgproto3.AuthenticationSASL: err = pgConn.scramAuth(msg.AuthMechanisms) if err != nil { pgConn.conn.Close() - return nil, &ConnectError{Config: config, msg: "failed SASL auth", err: err} + return nil, newPerDialConnectError("failed SASL auth", err) } case *pgproto3.AuthenticationGSS: err = pgConn.gssAuth() if err != nil { pgConn.conn.Close() - return nil, &ConnectError{Config: config, msg: "failed GSS auth", err: err} + return nil, newPerDialConnectError("failed GSS auth", err) } case *pgproto3.ReadyForQuery: pgConn.status = connStatusIdle @@ -396,7 +426,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig return pgConn, nil } pgConn.conn.Close() - return nil, &ConnectError{Config: config, msg: "ValidateConnect failed", err: err} + return nil, newPerDialConnectError("ValidateConnect failed", err) } } return pgConn, nil @@ -404,21 +434,14 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig // handled by ReceiveMessage case *pgproto3.ErrorResponse: pgConn.conn.Close() - return nil, ErrorResponseToPgError(msg) + return nil, newPerDialConnectError("server error", ErrorResponseToPgError(msg)) default: pgConn.conn.Close() - return nil, &ConnectError{Config: config, msg: "received unexpected message", err: err} + return nil, newPerDialConnectError("received unexpected message", err) } } } -func newContextWatcher(conn net.Conn) *ctxwatch.ContextWatcher { - return ctxwatch.NewContextWatcher( - func() { conn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) }, - func() { conn.SetDeadline(time.Time{}) }, - ) -} - func startTLS(conn net.Conn, tlsConfig *tls.Config) (net.Conn, error) { err := binary.Write(conn, binary.BigEndian, []int32{8, 80877103}) if err != nil { @@ -928,23 +951,24 @@ func (pgConn *PgConn) Deallocate(ctx context.Context, name string) error { // ErrorResponseToPgError converts a wire protocol error message to a *PgError. func ErrorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError { return &PgError{ - Severity: msg.Severity, - Code: string(msg.Code), - Message: string(msg.Message), - Detail: string(msg.Detail), - Hint: msg.Hint, - Position: msg.Position, - InternalPosition: msg.InternalPosition, - InternalQuery: string(msg.InternalQuery), - Where: string(msg.Where), - SchemaName: string(msg.SchemaName), - TableName: string(msg.TableName), - ColumnName: string(msg.ColumnName), - DataTypeName: string(msg.DataTypeName), - ConstraintName: msg.ConstraintName, - File: string(msg.File), - Line: msg.Line, - Routine: string(msg.Routine), + Severity: msg.Severity, + SeverityUnlocalized: msg.SeverityUnlocalized, + Code: string(msg.Code), + Message: string(msg.Message), + Detail: string(msg.Detail), + Hint: msg.Hint, + Position: msg.Position, + InternalPosition: msg.InternalPosition, + InternalQuery: string(msg.InternalQuery), + Where: string(msg.Where), + SchemaName: string(msg.SchemaName), + TableName: string(msg.TableName), + ColumnName: string(msg.ColumnName), + DataTypeName: string(msg.DataTypeName), + ConstraintName: msg.ConstraintName, + File: string(msg.File), + Line: msg.Line, + Routine: string(msg.Routine), } } @@ -987,10 +1011,7 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error { defer cancelConn.Close() if ctx != context.Background() { - contextWatcher := ctxwatch.NewContextWatcher( - func() { cancelConn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) }, - func() { cancelConn.SetDeadline(time.Time{}) }, - ) + contextWatcher := ctxwatch.NewContextWatcher(&DeadlineContextWatcherHandler{Conn: cancelConn}) contextWatcher.Watch(ctx) defer contextWatcher.Unwatch() } @@ -1523,8 +1544,10 @@ func (rr *ResultReader) Read() *Result { values := rr.Values() row := make([][]byte, len(values)) for i := range row { - row[i] = make([]byte, len(values[i])) - copy(row[i], values[i]) + if values[i] != nil { + row[i] = make([]byte, len(values[i])) + copy(row[i], values[i]) + } } br.Rows = append(br.Rows, row) } @@ -1879,6 +1902,11 @@ func (pgConn *PgConn) SyncConn(ctx context.Context) error { return errors.New("SyncConn: conn never synchronized") } +// CustomData returns a map that can be used to associate custom data with the connection. +func (pgConn *PgConn) CustomData() map[string]any { + return pgConn.customData +} + // HijackedConn is the result of hijacking a connection. // // Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning @@ -1891,6 +1919,7 @@ type HijackedConn struct { TxStatus byte Frontend *pgproto3.Frontend Config *Config + CustomData map[string]any } // Hijack extracts the internal connection data. pgConn must be in an idle state. SyncConn should be called immediately @@ -1913,6 +1942,7 @@ func (pgConn *PgConn) Hijack() (*HijackedConn, error) { TxStatus: pgConn.txStatus, Frontend: pgConn.frontend, Config: pgConn.config, + CustomData: pgConn.customData, }, nil } @@ -1932,13 +1962,14 @@ func Construct(hc *HijackedConn) (*PgConn, error) { txStatus: hc.TxStatus, frontend: hc.Frontend, config: hc.Config, + customData: hc.CustomData, status: connStatusIdle, cleanupDone: make(chan struct{}), } - pgConn.contextWatcher = newContextWatcher(pgConn.conn) + pgConn.contextWatcher = ctxwatch.NewContextWatcher(hc.Config.BuildContextWatcherHandler(pgConn)) pgConn.bgReader = bgreader.New(pgConn.conn) pgConn.slowWriteTimer = time.AfterFunc(time.Duration(math.MaxInt64), func() { @@ -2245,3 +2276,71 @@ func (p *Pipeline) Close() error { return p.err } + +// DeadlineContextWatcherHandler handles canceled contexts by setting a deadline on a net.Conn. +type DeadlineContextWatcherHandler struct { + Conn net.Conn + + // DeadlineDelay is the delay to set on the deadline set on net.Conn when the context is canceled. + DeadlineDelay time.Duration +} + +func (h *DeadlineContextWatcherHandler) HandleCancel(ctx context.Context) { + h.Conn.SetDeadline(time.Now().Add(h.DeadlineDelay)) +} + +func (h *DeadlineContextWatcherHandler) HandleUnwatchAfterCancel() { + h.Conn.SetDeadline(time.Time{}) +} + +// CancelRequestContextWatcherHandler handles canceled contexts by sending a cancel request to the server. It also sets +// a deadline on a net.Conn as a fallback. +type CancelRequestContextWatcherHandler struct { + Conn *PgConn + + // CancelRequestDelay is the delay before sending the cancel request to the server. + CancelRequestDelay time.Duration + + // DeadlineDelay is the delay to set on the deadline set on net.Conn when the context is canceled. + DeadlineDelay time.Duration + + cancelFinishedChan chan struct{} + handleUnwatchAfterCancelCalled func() +} + +func (h *CancelRequestContextWatcherHandler) HandleCancel(context.Context) { + h.cancelFinishedChan = make(chan struct{}) + var handleUnwatchedAfterCancelCalledCtx context.Context + handleUnwatchedAfterCancelCalledCtx, h.handleUnwatchAfterCancelCalled = context.WithCancel(context.Background()) + + deadline := time.Now().Add(h.DeadlineDelay) + h.Conn.conn.SetDeadline(deadline) + + go func() { + defer close(h.cancelFinishedChan) + + select { + case <-handleUnwatchedAfterCancelCalledCtx.Done(): + return + case <-time.After(h.CancelRequestDelay): + } + + cancelRequestCtx, cancel := context.WithDeadline(handleUnwatchedAfterCancelCalledCtx, deadline) + defer cancel() + h.Conn.CancelRequest(cancelRequestCtx) + + // CancelRequest is inherently racy. Even though the cancel request has been received by the server at this point, + // it hasn't necessarily been delivered to the other connection. If we immediately return and the connection is + // immediately used then it is possible the CancelRequest will actually cancel our next query. The + // TestCancelRequestContextWatcherHandler Stress test can produce this error without the sleep below. The sleep time + // is arbitrary, but should be sufficient to prevent this error case. + time.Sleep(100 * time.Millisecond) + }() +} + +func (h *CancelRequestContextWatcherHandler) HandleUnwatchAfterCancel() { + h.handleUnwatchAfterCancelCalled() + <-h.cancelFinishedChan + + h.Conn.conn.SetDeadline(time.Time{}) +} |