summaryrefslogtreecommitdiff
path: root/vendor/google.golang.org/grpc/server.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/google.golang.org/grpc/server.go')
-rw-r--r--vendor/google.golang.org/grpc/server.go264
1 files changed, 182 insertions, 82 deletions
diff --git a/vendor/google.golang.org/grpc/server.go b/vendor/google.golang.org/grpc/server.go
index 8f60d4214..e89c5ac61 100644
--- a/vendor/google.golang.org/grpc/server.go
+++ b/vendor/google.golang.org/grpc/server.go
@@ -70,9 +70,10 @@ func init() {
internal.GetServerCredentials = func(srv *Server) credentials.TransportCredentials {
return srv.opts.creds
}
- internal.DrainServerTransports = func(srv *Server, addr string) {
- srv.drainServerTransports(addr)
+ internal.IsRegisteredMethod = func(srv *Server, method string) bool {
+ return srv.isRegisteredMethod(method)
}
+ internal.ServerFromContext = serverFromContext
internal.AddGlobalServerOptions = func(opt ...ServerOption) {
globalServerOptions = append(globalServerOptions, opt...)
}
@@ -81,6 +82,7 @@ func init() {
}
internal.BinaryLogger = binaryLogger
internal.JoinServerOptions = newJoinServerOption
+ internal.RecvBufferPool = recvBufferPool
}
var statusOK = status.New(codes.OK, "")
@@ -134,12 +136,14 @@ type Server struct {
quit *grpcsync.Event
done *grpcsync.Event
channelzRemoveOnce sync.Once
- serveWG sync.WaitGroup // counts active Serve goroutines for GracefulStop
+ serveWG sync.WaitGroup // counts active Serve goroutines for Stop/GracefulStop
+ handlersWG sync.WaitGroup // counts active method handler goroutines
channelzID *channelz.Identifier
czData *channelzData
- serverWorkerChannel chan func()
+ serverWorkerChannel chan func()
+ serverWorkerChannelClose func()
}
type serverOptions struct {
@@ -170,6 +174,7 @@ type serverOptions struct {
headerTableSize *uint32
numServerWorkers uint32
recvBufferPool SharedBufferPool
+ waitForHandlers bool
}
var defaultServerOptions = serverOptions{
@@ -567,6 +572,21 @@ func NumStreamWorkers(numServerWorkers uint32) ServerOption {
})
}
+// WaitForHandlers cause Stop to wait until all outstanding method handlers have
+// exited before returning. If false, Stop will return as soon as all
+// connections have closed, but method handlers may still be running. By
+// default, Stop does not wait for method handlers to return.
+//
+// # Experimental
+//
+// Notice: This API is EXPERIMENTAL and may be changed or removed in a
+// later release.
+func WaitForHandlers(w bool) ServerOption {
+ return newFuncServerOption(func(o *serverOptions) {
+ o.waitForHandlers = w
+ })
+}
+
// RecvBufferPool returns a ServerOption that configures the server
// to use the provided shared buffer pool for parsing incoming messages. Depending
// on the application's workload, this could result in reduced memory allocation.
@@ -578,11 +598,13 @@ func NumStreamWorkers(numServerWorkers uint32) ServerOption {
// options are used: StatsHandler, EnableTracing, or binary logging. In such
// cases, the shared buffer pool will be ignored.
//
-// # Experimental
-//
-// Notice: This API is EXPERIMENTAL and may be changed or removed in a
-// later release.
+// Deprecated: use experimental.WithRecvBufferPool instead. Will be deleted in
+// v1.60.0 or later.
func RecvBufferPool(bufferPool SharedBufferPool) ServerOption {
+ return recvBufferPool(bufferPool)
+}
+
+func recvBufferPool(bufferPool SharedBufferPool) ServerOption {
return newFuncServerOption(func(o *serverOptions) {
o.recvBufferPool = bufferPool
})
@@ -616,15 +638,14 @@ func (s *Server) serverWorker() {
// connections to reduce the time spent overall on runtime.morestack.
func (s *Server) initServerWorkers() {
s.serverWorkerChannel = make(chan func())
+ s.serverWorkerChannelClose = grpcsync.OnceFunc(func() {
+ close(s.serverWorkerChannel)
+ })
for i := uint32(0); i < s.opts.numServerWorkers; i++ {
go s.serverWorker()
}
}
-func (s *Server) stopServerWorkers() {
- close(s.serverWorkerChannel)
-}
-
// NewServer creates a gRPC server which has no service registered and has not
// started to accept requests yet.
func NewServer(opt ...ServerOption) *Server {
@@ -806,6 +827,18 @@ func (l *listenSocket) Close() error {
// Serve returns when lis.Accept fails with fatal errors. lis will be closed when
// this method returns.
// Serve will return a non-nil error unless Stop or GracefulStop is called.
+//
+// Note: All supported releases of Go (as of December 2023) override the OS
+// defaults for TCP keepalive time and interval to 15s. To enable TCP keepalive
+// with OS defaults for keepalive time and interval, callers need to do the
+// following two things:
+// - pass a net.Listener created by calling the Listen method on a
+// net.ListenConfig with the `KeepAlive` field set to a negative value. This
+// will result in the Go standard library not overriding OS defaults for TCP
+// keepalive interval and time. But this will also result in the Go standard
+// library not enabling TCP keepalives by default.
+// - override the Accept method on the passed in net.Listener and set the
+// SO_KEEPALIVE socket option to enable TCP keepalives, with OS defaults.
func (s *Server) Serve(lis net.Listener) error {
s.mu.Lock()
s.printf("serving")
@@ -913,24 +946,21 @@ func (s *Server) handleRawConn(lisAddr string, rawConn net.Conn) {
return
}
+ if cc, ok := rawConn.(interface {
+ PassServerTransport(transport.ServerTransport)
+ }); ok {
+ cc.PassServerTransport(st)
+ }
+
if !s.addConn(lisAddr, st) {
return
}
go func() {
- s.serveStreams(st)
+ s.serveStreams(context.Background(), st, rawConn)
s.removeConn(lisAddr, st)
}()
}
-func (s *Server) drainServerTransports(addr string) {
- s.mu.Lock()
- conns := s.conns[addr]
- for st := range conns {
- st.Drain("")
- }
- s.mu.Unlock()
-}
-
// newHTTP2Transport sets up a http/2 transport (using the
// gRPC http2 server transport in transport/http2_server.go).
func (s *Server) newHTTP2Transport(c net.Conn) transport.ServerTransport {
@@ -971,18 +1001,31 @@ func (s *Server) newHTTP2Transport(c net.Conn) transport.ServerTransport {
return st
}
-func (s *Server) serveStreams(st transport.ServerTransport) {
- defer st.Close(errors.New("finished serving streams for the server transport"))
- var wg sync.WaitGroup
+func (s *Server) serveStreams(ctx context.Context, st transport.ServerTransport, rawConn net.Conn) {
+ ctx = transport.SetConnection(ctx, rawConn)
+ ctx = peer.NewContext(ctx, st.Peer())
+ for _, sh := range s.opts.statsHandlers {
+ ctx = sh.TagConn(ctx, &stats.ConnTagInfo{
+ RemoteAddr: st.Peer().Addr,
+ LocalAddr: st.Peer().LocalAddr,
+ })
+ sh.HandleConn(ctx, &stats.ConnBegin{})
+ }
- streamQuota := newHandlerQuota(s.opts.maxConcurrentStreams)
- st.HandleStreams(func(stream *transport.Stream) {
- wg.Add(1)
+ defer func() {
+ st.Close(errors.New("finished serving streams for the server transport"))
+ for _, sh := range s.opts.statsHandlers {
+ sh.HandleConn(ctx, &stats.ConnEnd{})
+ }
+ }()
+ streamQuota := newHandlerQuota(s.opts.maxConcurrentStreams)
+ st.HandleStreams(ctx, func(stream *transport.Stream) {
+ s.handlersWG.Add(1)
streamQuota.acquire()
f := func() {
defer streamQuota.release()
- defer wg.Done()
+ defer s.handlersWG.Done()
s.handleStream(st, stream)
}
@@ -996,7 +1039,6 @@ func (s *Server) serveStreams(st transport.ServerTransport) {
}
go f()
})
- wg.Wait()
}
var _ http.Handler = (*Server)(nil)
@@ -1040,7 +1082,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
defer s.removeConn(listenerAddressForServeHTTP, st)
- s.serveStreams(st)
+ s.serveStreams(r.Context(), st, nil)
}
func (s *Server) addConn(addr string, st transport.ServerTransport) bool {
@@ -1689,6 +1731,7 @@ func (s *Server) processStreamingRPC(ctx context.Context, t transport.ServerTran
func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Stream) {
ctx := stream.Context()
+ ctx = contextWithServer(ctx, s)
var ti *traceInfo
if EnableTracing {
tr := trace.New("grpc.Recv."+methodFamily(stream.Method()), stream.Method())
@@ -1697,7 +1740,7 @@ func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Str
tr: tr,
firstLine: firstLine{
client: false,
- remoteAddr: t.RemoteAddr(),
+ remoteAddr: t.Peer().Addr,
},
}
if dl, ok := ctx.Deadline(); ok {
@@ -1731,6 +1774,22 @@ func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Str
service := sm[:pos]
method := sm[pos+1:]
+ md, _ := metadata.FromIncomingContext(ctx)
+ for _, sh := range s.opts.statsHandlers {
+ ctx = sh.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: stream.Method()})
+ sh.HandleRPC(ctx, &stats.InHeader{
+ FullMethod: stream.Method(),
+ RemoteAddr: t.Peer().Addr,
+ LocalAddr: t.Peer().LocalAddr,
+ Compression: stream.RecvCompress(),
+ WireLength: stream.HeaderWireLength(),
+ Header: md,
+ })
+ }
+ // To have calls in stream callouts work. Will delete once all stats handler
+ // calls come from the gRPC layer.
+ stream.SetContext(ctx)
+
srv, knownService := s.services[service]
if knownService {
if md, ok := srv.methods[method]; ok {
@@ -1820,62 +1879,72 @@ func ServerTransportStreamFromContext(ctx context.Context) ServerTransportStream
// pending RPCs on the client side will get notified by connection
// errors.
func (s *Server) Stop() {
- s.quit.Fire()
+ s.stop(false)
+}
- defer func() {
- s.serveWG.Wait()
- s.done.Fire()
- }()
+// GracefulStop stops the gRPC server gracefully. It stops the server from
+// accepting new connections and RPCs and blocks until all the pending RPCs are
+// finished.
+func (s *Server) GracefulStop() {
+ s.stop(true)
+}
+
+func (s *Server) stop(graceful bool) {
+ s.quit.Fire()
+ defer s.done.Fire()
s.channelzRemoveOnce.Do(func() { channelz.RemoveEntry(s.channelzID) })
s.mu.Lock()
- listeners := s.lis
- s.lis = nil
- conns := s.conns
- s.conns = nil
- // interrupt GracefulStop if Stop and GracefulStop are called concurrently.
- s.cv.Broadcast()
+ s.closeListenersLocked()
+ // Wait for serving threads to be ready to exit. Only then can we be sure no
+ // new conns will be created.
s.mu.Unlock()
+ s.serveWG.Wait()
- for lis := range listeners {
- lis.Close()
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ if graceful {
+ s.drainAllServerTransportsLocked()
+ } else {
+ s.closeServerTransportsLocked()
}
- for _, cs := range conns {
- for st := range cs {
- st.Close(errors.New("Server.Stop called"))
- }
+
+ for len(s.conns) != 0 {
+ s.cv.Wait()
}
+ s.conns = nil
+
if s.opts.numServerWorkers > 0 {
- s.stopServerWorkers()
+ // Closing the channel (only once, via grpcsync.OnceFunc) after all the
+ // connections have been closed above ensures that there are no
+ // goroutines executing the callback passed to st.HandleStreams (where
+ // the channel is written to).
+ s.serverWorkerChannelClose()
+ }
+
+ if graceful || s.opts.waitForHandlers {
+ s.handlersWG.Wait()
}
- s.mu.Lock()
if s.events != nil {
s.events.Finish()
s.events = nil
}
- s.mu.Unlock()
}
-// GracefulStop stops the gRPC server gracefully. It stops the server from
-// accepting new connections and RPCs and blocks until all the pending RPCs are
-// finished.
-func (s *Server) GracefulStop() {
- s.quit.Fire()
- defer s.done.Fire()
-
- s.channelzRemoveOnce.Do(func() { channelz.RemoveEntry(s.channelzID) })
- s.mu.Lock()
- if s.conns == nil {
- s.mu.Unlock()
- return
+// s.mu must be held by the caller.
+func (s *Server) closeServerTransportsLocked() {
+ for _, conns := range s.conns {
+ for st := range conns {
+ st.Close(errors.New("Server.Stop called"))
+ }
}
+}
- for lis := range s.lis {
- lis.Close()
- }
- s.lis = nil
+// s.mu must be held by the caller.
+func (s *Server) drainAllServerTransportsLocked() {
if !s.drain {
for _, conns := range s.conns {
for st := range conns {
@@ -1884,22 +1953,14 @@ func (s *Server) GracefulStop() {
}
s.drain = true
}
+}
- // Wait for serving threads to be ready to exit. Only then can we be sure no
- // new conns will be created.
- s.mu.Unlock()
- s.serveWG.Wait()
- s.mu.Lock()
-
- for len(s.conns) != 0 {
- s.cv.Wait()
- }
- s.conns = nil
- if s.events != nil {
- s.events.Finish()
- s.events = nil
+// s.mu must be held by the caller.
+func (s *Server) closeListenersLocked() {
+ for lis := range s.lis {
+ lis.Close()
}
- s.mu.Unlock()
+ s.lis = nil
}
// contentSubtype must be lowercase
@@ -1913,11 +1974,50 @@ func (s *Server) getCodec(contentSubtype string) baseCodec {
}
codec := encoding.GetCodec(contentSubtype)
if codec == nil {
+ logger.Warningf("Unsupported codec %q. Defaulting to %q for now. This will start to fail in future releases.", contentSubtype, proto.Name)
return encoding.GetCodec(proto.Name)
}
return codec
}
+type serverKey struct{}
+
+// serverFromContext gets the Server from the context.
+func serverFromContext(ctx context.Context) *Server {
+ s, _ := ctx.Value(serverKey{}).(*Server)
+ return s
+}
+
+// contextWithServer sets the Server in the context.
+func contextWithServer(ctx context.Context, server *Server) context.Context {
+ return context.WithValue(ctx, serverKey{}, server)
+}
+
+// isRegisteredMethod returns whether the passed in method is registered as a
+// method on the server. /service/method and service/method will match if the
+// service and method are registered on the server.
+func (s *Server) isRegisteredMethod(serviceMethod string) bool {
+ if serviceMethod != "" && serviceMethod[0] == '/' {
+ serviceMethod = serviceMethod[1:]
+ }
+ pos := strings.LastIndex(serviceMethod, "/")
+ if pos == -1 { // Invalid method name syntax.
+ return false
+ }
+ service := serviceMethod[:pos]
+ method := serviceMethod[pos+1:]
+ srv, knownService := s.services[service]
+ if knownService {
+ if _, ok := srv.methods[method]; ok {
+ return true
+ }
+ if _, ok := srv.streams[method]; ok {
+ return true
+ }
+ }
+ return false
+}
+
// SetHeader sets the header metadata to be sent from the server to the client.
// The context provided must be the context passed to the server's handler.
//