diff options
Diffstat (limited to 'vendor/github.com/jackc/pgconn/pgconn.go')
-rw-r--r-- | vendor/github.com/jackc/pgconn/pgconn.go | 82 |
1 files changed, 51 insertions, 31 deletions
diff --git a/vendor/github.com/jackc/pgconn/pgconn.go b/vendor/github.com/jackc/pgconn/pgconn.go index 382ad33c0..7bf2f20ef 100644 --- a/vendor/github.com/jackc/pgconn/pgconn.go +++ b/vendor/github.com/jackc/pgconn/pgconn.go @@ -11,6 +11,7 @@ import ( "io" "math" "net" + "strconv" "strings" "sync" "time" @@ -44,7 +45,8 @@ type Notification struct { // DialFunc is a function that can be used to connect to a PostgreSQL server. type DialFunc func(ctx context.Context, network, addr string) (net.Conn, error) -// LookupFunc is a function that can be used to lookup IPs addrs from host. +// LookupFunc is a function that can be used to lookup IPs addrs from host. Optionally an ip:port combination can be +// returned in order to override the connection string's port. type LookupFunc func(ctx context.Context, host string) (addrs []string, err error) // BuildFrontendFunc is a function that can be used to create Frontend implementation for connection. @@ -196,11 +198,24 @@ func expandWithIPs(ctx context.Context, lookupFn LookupFunc, fallbacks []*Fallba } for _, ip := range ips { - configs = append(configs, &FallbackConfig{ - Host: ip, - Port: fb.Port, - TLSConfig: fb.TLSConfig, - }) + splitIP, splitPort, err := net.SplitHostPort(ip) + if err == nil { + port, err := strconv.ParseUint(splitPort, 10, 16) + if err != nil { + return nil, fmt.Errorf("error parsing port (%s) from lookup: %w", splitPort, err) + } + configs = append(configs, &FallbackConfig{ + Host: splitIP, + Port: uint16(port), + TLSConfig: fb.TLSConfig, + }) + } else { + configs = append(configs, &FallbackConfig{ + Host: ip, + Port: fb.Port, + TLSConfig: fb.TLSConfig, + }) + } } } @@ -215,7 +230,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig var err error network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port) - pgConn.conn, err = config.DialFunc(ctx, network, address) + netConn, err := config.DialFunc(ctx, network, address) if err != nil { var netErr net.Error if errors.As(err, &netErr) && netErr.Timeout() { @@ -224,24 +239,27 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig return nil, &connectError{config: config, msg: "dial error", err: err} } - pgConn.parameterStatuses = make(map[string]string) + pgConn.conn = netConn + pgConn.contextWatcher = newContextWatcher(netConn) + pgConn.contextWatcher.Watch(ctx) if fallbackConfig.TLSConfig != nil { - if err := pgConn.startTLS(fallbackConfig.TLSConfig); err != nil { - pgConn.conn.Close() + tlsConn, err := startTLS(netConn, fallbackConfig.TLSConfig) + pgConn.contextWatcher.Unwatch() // Always unwatch `netConn` after TLS. + if err != nil { + netConn.Close() return nil, &connectError{config: config, msg: "tls error", err: err} } - } - pgConn.status = connStatusConnecting - pgConn.contextWatcher = ctxwatch.NewContextWatcher( - func() { pgConn.conn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) }, - func() { pgConn.conn.SetDeadline(time.Time{}) }, - ) + pgConn.conn = tlsConn + pgConn.contextWatcher = newContextWatcher(tlsConn) + pgConn.contextWatcher.Watch(ctx) + } - pgConn.contextWatcher.Watch(ctx) defer pgConn.contextWatcher.Unwatch() + pgConn.parameterStatuses = make(map[string]string) + pgConn.status = connStatusConnecting pgConn.frontend = config.BuildFrontend(pgConn.conn, pgConn.conn) startupMsg := pgproto3.StartupMessage{ @@ -317,7 +335,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig } } return pgConn, nil - case *pgproto3.ParameterStatus: + case *pgproto3.ParameterStatus, *pgproto3.NoticeResponse: // handled by ReceiveMessage case *pgproto3.ErrorResponse: pgConn.conn.Close() @@ -329,24 +347,29 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig } } -func (pgConn *PgConn) startTLS(tlsConfig *tls.Config) (err error) { - err = binary.Write(pgConn.conn, binary.BigEndian, []int32{8, 80877103}) +func newContextWatcher(conn net.Conn) *ctxwatch.ContextWatcher { + return ctxwatch.NewContextWatcher( + func() { conn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) }, + func() { conn.SetDeadline(time.Time{}) }, + ) +} + +func startTLS(conn net.Conn, tlsConfig *tls.Config) (net.Conn, error) { + err := binary.Write(conn, binary.BigEndian, []int32{8, 80877103}) if err != nil { - return + return nil, err } response := make([]byte, 1) - if _, err = io.ReadFull(pgConn.conn, response); err != nil { - return + if _, err = io.ReadFull(conn, response); err != nil { + return nil, err } if response[0] != 'S' { - return errors.New("server refused TLS connection") + return nil, errors.New("server refused TLS connection") } - pgConn.conn = tls.Client(pgConn.conn, tlsConfig) - - return nil + return tls.Client(conn, tlsConfig), nil } func (pgConn *PgConn) txPasswordMessage(password string) (err error) { @@ -1694,10 +1717,7 @@ func Construct(hc *HijackedConn) (*PgConn, error) { cleanupDone: make(chan struct{}), } - pgConn.contextWatcher = ctxwatch.NewContextWatcher( - func() { pgConn.conn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) }, - func() { pgConn.conn.SetDeadline(time.Time{}) }, - ) + pgConn.contextWatcher = newContextWatcher(pgConn.conn) return pgConn, nil } |