diff options
Diffstat (limited to 'vendor/google.golang.org/grpc/internal/resolver/dns/dns_resolver.go')
-rw-r--r-- | vendor/google.golang.org/grpc/internal/resolver/dns/dns_resolver.go | 69 |
1 files changed, 20 insertions, 49 deletions
diff --git a/vendor/google.golang.org/grpc/internal/resolver/dns/dns_resolver.go b/vendor/google.golang.org/grpc/internal/resolver/dns/dns_resolver.go index 99e1e5b36..b66dcb213 100644 --- a/vendor/google.golang.org/grpc/internal/resolver/dns/dns_resolver.go +++ b/vendor/google.golang.org/grpc/internal/resolver/dns/dns_resolver.go @@ -23,7 +23,6 @@ package dns import ( "context" "encoding/json" - "errors" "fmt" "net" "os" @@ -37,6 +36,7 @@ import ( "google.golang.org/grpc/internal/backoff" "google.golang.org/grpc/internal/envconfig" "google.golang.org/grpc/internal/grpcrand" + "google.golang.org/grpc/internal/resolver/dns/internal" "google.golang.org/grpc/resolver" "google.golang.org/grpc/serviceconfig" ) @@ -47,15 +47,11 @@ var EnableSRVLookups = false var logger = grpclog.Component("dns") -// Globals to stub out in tests. TODO: Perhaps these two can be combined into a -// single variable for testing the resolver? -var ( - newTimer = time.NewTimer - newTimerDNSResRate = time.NewTimer -) - func init() { resolver.Register(NewBuilder()) + internal.TimeAfterFunc = time.After + internal.NewNetResolver = newNetResolver + internal.AddressDialer = addressDialer } const ( @@ -70,23 +66,6 @@ const ( txtAttribute = "grpc_config=" ) -var ( - errMissingAddr = errors.New("dns resolver: missing address") - - // Addresses ending with a colon that is supposed to be the separator - // between host and port is not allowed. E.g. "::" is a valid address as - // it is an IPv6 address (host only) and "[::]:" is invalid as it ends with - // a colon as the host and port separator - errEndsWithColon = errors.New("dns resolver: missing port after port-separator colon") -) - -var ( - defaultResolver netResolver = net.DefaultResolver - // To prevent excessive re-resolution, we enforce a rate limit on DNS - // resolution requests. - minDNSResRate = 30 * time.Second -) - var addressDialer = func(address string) func(context.Context, string, string) (net.Conn, error) { return func(ctx context.Context, network, _ string) (net.Conn, error) { var dialer net.Dialer @@ -94,7 +73,11 @@ var addressDialer = func(address string) func(context.Context, string, string) ( } } -var newNetResolver = func(authority string) (netResolver, error) { +var newNetResolver = func(authority string) (internal.NetResolver, error) { + if authority == "" { + return net.DefaultResolver, nil + } + host, port, err := parseTarget(authority, defaultDNSSvrPort) if err != nil { return nil, err @@ -104,7 +87,7 @@ var newNetResolver = func(authority string) (netResolver, error) { return &net.Resolver{ PreferGo: true, - Dial: addressDialer(authorityWithPort), + Dial: internal.AddressDialer(authorityWithPort), }, nil } @@ -142,13 +125,9 @@ func (b *dnsBuilder) Build(target resolver.Target, cc resolver.ClientConn, opts disableServiceConfig: opts.DisableServiceConfig, } - if target.URL.Host == "" { - d.resolver = defaultResolver - } else { - d.resolver, err = newNetResolver(target.URL.Host) - if err != nil { - return nil, err - } + d.resolver, err = internal.NewNetResolver(target.URL.Host) + if err != nil { + return nil, err } d.wg.Add(1) @@ -161,12 +140,6 @@ func (b *dnsBuilder) Scheme() string { return "dns" } -type netResolver interface { - LookupHost(ctx context.Context, host string) (addrs []string, err error) - LookupSRV(ctx context.Context, service, proto, name string) (cname string, addrs []*net.SRV, err error) - LookupTXT(ctx context.Context, name string) (txts []string, err error) -} - // deadResolver is a resolver that does nothing. type deadResolver struct{} @@ -178,7 +151,7 @@ func (deadResolver) Close() {} type dnsResolver struct { host string port string - resolver netResolver + resolver internal.NetResolver ctx context.Context cancel context.CancelFunc cc resolver.ClientConn @@ -223,29 +196,27 @@ func (d *dnsResolver) watcher() { err = d.cc.UpdateState(*state) } - var timer *time.Timer + var waitTime time.Duration if err == nil { // Success resolving, wait for the next ResolveNow. However, also wait 30 // seconds at the very least to prevent constantly re-resolving. backoffIndex = 1 - timer = newTimerDNSResRate(minDNSResRate) + waitTime = internal.MinResolutionRate select { case <-d.ctx.Done(): - timer.Stop() return case <-d.rn: } } else { // Poll on an error found in DNS Resolver or an error received from // ClientConn. - timer = newTimer(backoff.DefaultExponential.Backoff(backoffIndex)) + waitTime = backoff.DefaultExponential.Backoff(backoffIndex) backoffIndex++ } select { case <-d.ctx.Done(): - timer.Stop() return - case <-timer.C: + case <-internal.TimeAfterFunc(waitTime): } } } @@ -387,7 +358,7 @@ func formatIP(addr string) (addrIP string, ok bool) { // target: ":80" defaultPort: "443" returns host: "localhost", port: "80" func parseTarget(target, defaultPort string) (host, port string, err error) { if target == "" { - return "", "", errMissingAddr + return "", "", internal.ErrMissingAddr } if ip := net.ParseIP(target); ip != nil { // target is an IPv4 or IPv6(without brackets) address @@ -397,7 +368,7 @@ func parseTarget(target, defaultPort string) (host, port string, err error) { if port == "" { // If the port field is empty (target ends with colon), e.g. "[::1]:", // this is an error. - return "", "", errEndsWithColon + return "", "", internal.ErrEndsWithColon } // target has port, i.e ipv4-host:port, [ipv6-host]:port, host-name:port if host == "" { |