diff options
Diffstat (limited to 'vendor/github.com/grpc-ecosystem/grpc-gateway/v2/runtime/context.go')
-rw-r--r-- | vendor/github.com/grpc-ecosystem/grpc-gateway/v2/runtime/context.go | 28 |
1 files changed, 22 insertions, 6 deletions
diff --git a/vendor/github.com/grpc-ecosystem/grpc-gateway/v2/runtime/context.go b/vendor/github.com/grpc-ecosystem/grpc-gateway/v2/runtime/context.go index 31553e784..2f2b34243 100644 --- a/vendor/github.com/grpc-ecosystem/grpc-gateway/v2/runtime/context.go +++ b/vendor/github.com/grpc-ecosystem/grpc-gateway/v2/runtime/context.go @@ -49,6 +49,7 @@ var malformedHTTPHeaders = map[string]struct{}{ type ( rpcMethodKey struct{} httpPathPatternKey struct{} + httpPatternKey struct{} AnnotateContextOption func(ctx context.Context) context.Context ) @@ -148,6 +149,12 @@ func annotateContext(ctx context.Context, mux *ServeMux, req *http.Request, rpcM var pairs []string for key, vals := range req.Header { key = textproto.CanonicalMIMEHeaderKey(key) + switch key { + case xForwardedFor, xForwardedHost: + // Handled separately below + continue + } + for _, val := range vals { // For backwards-compatibility, pass through 'authorization' header with no prefix. if key == "Authorization" { @@ -181,18 +188,17 @@ func annotateContext(ctx context.Context, mux *ServeMux, req *http.Request, rpcM pairs = append(pairs, strings.ToLower(xForwardedHost), req.Host) } + xff := req.Header.Values(xForwardedFor) if addr := req.RemoteAddr; addr != "" { if remoteIP, _, err := net.SplitHostPort(addr); err == nil { - if fwd := req.Header.Get(xForwardedFor); fwd == "" { - pairs = append(pairs, strings.ToLower(xForwardedFor), remoteIP) - } else { - pairs = append(pairs, strings.ToLower(xForwardedFor), fmt.Sprintf("%s, %s", fwd, remoteIP)) - } + xff = append(xff, remoteIP) } } + if len(xff) > 0 { + pairs = append(pairs, strings.ToLower(xForwardedFor), strings.Join(xff, ", ")) + } if timeout != 0 { - //nolint:govet // The context outlives this function ctx, _ = context.WithTimeout(ctx, timeout) } if len(pairs) == 0 { @@ -399,3 +405,13 @@ func HTTPPathPattern(ctx context.Context) (string, bool) { func withHTTPPathPattern(ctx context.Context, httpPathPattern string) context.Context { return context.WithValue(ctx, httpPathPatternKey{}, httpPathPattern) } + +// HTTPPattern returns the HTTP path pattern struct relating to the HTTP handler, if one exists. +func HTTPPattern(ctx context.Context) (Pattern, bool) { + v, ok := ctx.Value(httpPatternKey{}).(Pattern) + return v, ok +} + +func withHTTPPattern(ctx context.Context, httpPattern Pattern) context.Context { + return context.WithValue(ctx, httpPatternKey{}, httpPattern) +} |