summaryrefslogtreecommitdiff
path: root/vendor/github.com/grpc-ecosystem/grpc-gateway/v2/runtime/mux.go
diff options
context:
space:
mode:
authorLibravatar tobi <31960611+tsmethurst@users.noreply.github.com>2024-08-26 18:05:54 +0200
committerLibravatar GitHub <noreply@github.com>2024-08-26 18:05:54 +0200
commit28d57d1f13ee61a6ff83ce4beaf238139d20bbac (patch)
tree7946abb44b21f9e2f4267146711760a911072c9b /vendor/github.com/grpc-ecosystem/grpc-gateway/v2/runtime/mux.go
parent[chore]: Bump github.com/prometheus/client_golang from 1.20.0 to 1.20.2 (#3239) (diff)
downloadgotosocial-28d57d1f13ee61a6ff83ce4beaf238139d20bbac.tar.xz
[chore] Bump all otel deps (#3241)
Diffstat (limited to 'vendor/github.com/grpc-ecosystem/grpc-gateway/v2/runtime/mux.go')
-rw-r--r--vendor/github.com/grpc-ecosystem/grpc-gateway/v2/runtime/mux.go69
1 files changed, 60 insertions, 9 deletions
diff --git a/vendor/github.com/grpc-ecosystem/grpc-gateway/v2/runtime/mux.go b/vendor/github.com/grpc-ecosystem/grpc-gateway/v2/runtime/mux.go
index ed9a7e438..60c2065dd 100644
--- a/vendor/github.com/grpc-ecosystem/grpc-gateway/v2/runtime/mux.go
+++ b/vendor/github.com/grpc-ecosystem/grpc-gateway/v2/runtime/mux.go
@@ -48,12 +48,19 @@ var encodedPathSplitter = regexp.MustCompile("(/|%2F)")
// A HandlerFunc handles a specific pair of path pattern and HTTP method.
type HandlerFunc func(w http.ResponseWriter, r *http.Request, pathParams map[string]string)
+// A Middleware handler wraps another HandlerFunc to do some pre- and/or post-processing of the request. This is used as an alternative to gRPC interceptors when using the direct-to-implementation
+// registration methods. It is generally recommended to use gRPC client or server interceptors instead
+// where possible.
+type Middleware func(HandlerFunc) HandlerFunc
+
// ServeMux is a request multiplexer for grpc-gateway.
// It matches http requests to patterns and invokes the corresponding handler.
type ServeMux struct {
// handlers maps HTTP method to a list of handlers.
handlers map[string][]handler
+ middlewares []Middleware
forwardResponseOptions []func(context.Context, http.ResponseWriter, proto.Message) error
+ forwardResponseRewriter ForwardResponseRewriter
marshalers marshalerRegistry
incomingHeaderMatcher HeaderMatcherFunc
outgoingHeaderMatcher HeaderMatcherFunc
@@ -69,6 +76,24 @@ type ServeMux struct {
// ServeMuxOption is an option that can be given to a ServeMux on construction.
type ServeMuxOption func(*ServeMux)
+// ForwardResponseRewriter is the signature of a function that is capable of rewriting messages
+// before they are forwarded in a unary, stream, or error response.
+type ForwardResponseRewriter func(ctx context.Context, response proto.Message) (any, error)
+
+// WithForwardResponseRewriter returns a ServeMuxOption that allows for implementers to insert logic
+// that can rewrite the final response before it is forwarded.
+//
+// The response rewriter function is called during unary message forwarding, stream message
+// forwarding and when errors are being forwarded.
+//
+// NOTE: Using this option will likely make what is generated by `protoc-gen-openapiv2` incorrect.
+// Since this option involves making runtime changes to the response shape or type.
+func WithForwardResponseRewriter(fwdResponseRewriter ForwardResponseRewriter) ServeMuxOption {
+ return func(sm *ServeMux) {
+ sm.forwardResponseRewriter = fwdResponseRewriter
+ }
+}
+
// WithForwardResponseOption returns a ServeMuxOption representing the forwardResponseOption.
//
// forwardResponseOption is an option that will be called on the relevant context.Context,
@@ -89,6 +114,15 @@ func WithUnescapingMode(mode UnescapingMode) ServeMuxOption {
}
}
+// WithMiddlewares sets server middleware for all handlers. This is useful as an alternative to gRPC
+// interceptors when using the direct-to-implementation registration methods and cannot rely
+// on gRPC interceptors. It's recommended to use gRPC interceptors instead if possible.
+func WithMiddlewares(middlewares ...Middleware) ServeMuxOption {
+ return func(serveMux *ServeMux) {
+ serveMux.middlewares = append(serveMux.middlewares, middlewares...)
+ }
+}
+
// SetQueryParameterParser sets the query parameter parser, used to populate message from query parameters.
// Configuring this will mean the generated OpenAPI output is no longer correct, and it should be
// done with careful consideration.
@@ -277,13 +311,14 @@ func WithHealthzEndpoint(healthCheckClient grpc_health_v1.HealthClient) ServeMux
// NewServeMux returns a new ServeMux whose internal mapping is empty.
func NewServeMux(opts ...ServeMuxOption) *ServeMux {
serveMux := &ServeMux{
- handlers: make(map[string][]handler),
- forwardResponseOptions: make([]func(context.Context, http.ResponseWriter, proto.Message) error, 0),
- marshalers: makeMarshalerMIMERegistry(),
- errorHandler: DefaultHTTPErrorHandler,
- streamErrorHandler: DefaultStreamErrorHandler,
- routingErrorHandler: DefaultRoutingErrorHandler,
- unescapingMode: UnescapingModeDefault,
+ handlers: make(map[string][]handler),
+ forwardResponseOptions: make([]func(context.Context, http.ResponseWriter, proto.Message) error, 0),
+ forwardResponseRewriter: func(ctx context.Context, response proto.Message) (any, error) { return response, nil },
+ marshalers: makeMarshalerMIMERegistry(),
+ errorHandler: DefaultHTTPErrorHandler,
+ streamErrorHandler: DefaultStreamErrorHandler,
+ routingErrorHandler: DefaultRoutingErrorHandler,
+ unescapingMode: UnescapingModeDefault,
}
for _, opt := range opts {
@@ -305,6 +340,9 @@ func NewServeMux(opts ...ServeMuxOption) *ServeMux {
// Handle associates "h" to the pair of HTTP method and path pattern.
func (s *ServeMux) Handle(meth string, pat Pattern, h HandlerFunc) {
+ if len(s.middlewares) > 0 {
+ h = chainMiddlewares(s.middlewares)(h)
+ }
s.handlers[meth] = append([]handler{{pat: pat, h: h}}, s.handlers[meth]...)
}
@@ -405,7 +443,7 @@ func (s *ServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
continue
}
- h.h(w, r, pathParams)
+ s.handleHandler(h, w, r, pathParams)
return
}
@@ -458,7 +496,7 @@ func (s *ServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
s.errorHandler(ctx, s, outboundMarshaler, w, r, sterr)
return
}
- h.h(w, r, pathParams)
+ s.handleHandler(h, w, r, pathParams)
return
}
_, outboundMarshaler := MarshalerForRequest(s, r)
@@ -484,3 +522,16 @@ type handler struct {
pat Pattern
h HandlerFunc
}
+
+func (s *ServeMux) handleHandler(h handler, w http.ResponseWriter, r *http.Request, pathParams map[string]string) {
+ h.h(w, r.WithContext(withHTTPPattern(r.Context(), h.pat)), pathParams)
+}
+
+func chainMiddlewares(mws []Middleware) Middleware {
+ return func(next HandlerFunc) HandlerFunc {
+ for i := len(mws); i > 0; i-- {
+ next = mws[i-1](next)
+ }
+ return next
+ }
+}