summaryrefslogtreecommitdiff
path: root/vendor/github.com/grpc-ecosystem/grpc-gateway/v2/runtime/context.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/grpc-ecosystem/grpc-gateway/v2/runtime/context.go')
-rw-r--r--vendor/github.com/grpc-ecosystem/grpc-gateway/v2/runtime/context.go28
1 files changed, 22 insertions, 6 deletions
diff --git a/vendor/github.com/grpc-ecosystem/grpc-gateway/v2/runtime/context.go b/vendor/github.com/grpc-ecosystem/grpc-gateway/v2/runtime/context.go
index 31553e784..2f2b34243 100644
--- a/vendor/github.com/grpc-ecosystem/grpc-gateway/v2/runtime/context.go
+++ b/vendor/github.com/grpc-ecosystem/grpc-gateway/v2/runtime/context.go
@@ -49,6 +49,7 @@ var malformedHTTPHeaders = map[string]struct{}{
type (
rpcMethodKey struct{}
httpPathPatternKey struct{}
+ httpPatternKey struct{}
AnnotateContextOption func(ctx context.Context) context.Context
)
@@ -148,6 +149,12 @@ func annotateContext(ctx context.Context, mux *ServeMux, req *http.Request, rpcM
var pairs []string
for key, vals := range req.Header {
key = textproto.CanonicalMIMEHeaderKey(key)
+ switch key {
+ case xForwardedFor, xForwardedHost:
+ // Handled separately below
+ continue
+ }
+
for _, val := range vals {
// For backwards-compatibility, pass through 'authorization' header with no prefix.
if key == "Authorization" {
@@ -181,18 +188,17 @@ func annotateContext(ctx context.Context, mux *ServeMux, req *http.Request, rpcM
pairs = append(pairs, strings.ToLower(xForwardedHost), req.Host)
}
+ xff := req.Header.Values(xForwardedFor)
if addr := req.RemoteAddr; addr != "" {
if remoteIP, _, err := net.SplitHostPort(addr); err == nil {
- if fwd := req.Header.Get(xForwardedFor); fwd == "" {
- pairs = append(pairs, strings.ToLower(xForwardedFor), remoteIP)
- } else {
- pairs = append(pairs, strings.ToLower(xForwardedFor), fmt.Sprintf("%s, %s", fwd, remoteIP))
- }
+ xff = append(xff, remoteIP)
}
}
+ if len(xff) > 0 {
+ pairs = append(pairs, strings.ToLower(xForwardedFor), strings.Join(xff, ", "))
+ }
if timeout != 0 {
- //nolint:govet // The context outlives this function
ctx, _ = context.WithTimeout(ctx, timeout)
}
if len(pairs) == 0 {
@@ -399,3 +405,13 @@ func HTTPPathPattern(ctx context.Context) (string, bool) {
func withHTTPPathPattern(ctx context.Context, httpPathPattern string) context.Context {
return context.WithValue(ctx, httpPathPatternKey{}, httpPathPattern)
}
+
+// HTTPPattern returns the HTTP path pattern struct relating to the HTTP handler, if one exists.
+func HTTPPattern(ctx context.Context) (Pattern, bool) {
+ v, ok := ctx.Value(httpPatternKey{}).(Pattern)
+ return v, ok
+}
+
+func withHTTPPattern(ctx context.Context, httpPattern Pattern) context.Context {
+ return context.WithValue(ctx, httpPatternKey{}, httpPattern)
+}