diff options
Diffstat (limited to 'vendor/google.golang.org/grpc/server.go')
-rw-r--r-- | vendor/google.golang.org/grpc/server.go | 264 |
1 files changed, 182 insertions, 82 deletions
diff --git a/vendor/google.golang.org/grpc/server.go b/vendor/google.golang.org/grpc/server.go index 8f60d4214..e89c5ac61 100644 --- a/vendor/google.golang.org/grpc/server.go +++ b/vendor/google.golang.org/grpc/server.go @@ -70,9 +70,10 @@ func init() { internal.GetServerCredentials = func(srv *Server) credentials.TransportCredentials { return srv.opts.creds } - internal.DrainServerTransports = func(srv *Server, addr string) { - srv.drainServerTransports(addr) + internal.IsRegisteredMethod = func(srv *Server, method string) bool { + return srv.isRegisteredMethod(method) } + internal.ServerFromContext = serverFromContext internal.AddGlobalServerOptions = func(opt ...ServerOption) { globalServerOptions = append(globalServerOptions, opt...) } @@ -81,6 +82,7 @@ func init() { } internal.BinaryLogger = binaryLogger internal.JoinServerOptions = newJoinServerOption + internal.RecvBufferPool = recvBufferPool } var statusOK = status.New(codes.OK, "") @@ -134,12 +136,14 @@ type Server struct { quit *grpcsync.Event done *grpcsync.Event channelzRemoveOnce sync.Once - serveWG sync.WaitGroup // counts active Serve goroutines for GracefulStop + serveWG sync.WaitGroup // counts active Serve goroutines for Stop/GracefulStop + handlersWG sync.WaitGroup // counts active method handler goroutines channelzID *channelz.Identifier czData *channelzData - serverWorkerChannel chan func() + serverWorkerChannel chan func() + serverWorkerChannelClose func() } type serverOptions struct { @@ -170,6 +174,7 @@ type serverOptions struct { headerTableSize *uint32 numServerWorkers uint32 recvBufferPool SharedBufferPool + waitForHandlers bool } var defaultServerOptions = serverOptions{ @@ -567,6 +572,21 @@ func NumStreamWorkers(numServerWorkers uint32) ServerOption { }) } +// WaitForHandlers cause Stop to wait until all outstanding method handlers have +// exited before returning. If false, Stop will return as soon as all +// connections have closed, but method handlers may still be running. By +// default, Stop does not wait for method handlers to return. +// +// # Experimental +// +// Notice: This API is EXPERIMENTAL and may be changed or removed in a +// later release. +func WaitForHandlers(w bool) ServerOption { + return newFuncServerOption(func(o *serverOptions) { + o.waitForHandlers = w + }) +} + // RecvBufferPool returns a ServerOption that configures the server // to use the provided shared buffer pool for parsing incoming messages. Depending // on the application's workload, this could result in reduced memory allocation. @@ -578,11 +598,13 @@ func NumStreamWorkers(numServerWorkers uint32) ServerOption { // options are used: StatsHandler, EnableTracing, or binary logging. In such // cases, the shared buffer pool will be ignored. // -// # Experimental -// -// Notice: This API is EXPERIMENTAL and may be changed or removed in a -// later release. +// Deprecated: use experimental.WithRecvBufferPool instead. Will be deleted in +// v1.60.0 or later. func RecvBufferPool(bufferPool SharedBufferPool) ServerOption { + return recvBufferPool(bufferPool) +} + +func recvBufferPool(bufferPool SharedBufferPool) ServerOption { return newFuncServerOption(func(o *serverOptions) { o.recvBufferPool = bufferPool }) @@ -616,15 +638,14 @@ func (s *Server) serverWorker() { // connections to reduce the time spent overall on runtime.morestack. func (s *Server) initServerWorkers() { s.serverWorkerChannel = make(chan func()) + s.serverWorkerChannelClose = grpcsync.OnceFunc(func() { + close(s.serverWorkerChannel) + }) for i := uint32(0); i < s.opts.numServerWorkers; i++ { go s.serverWorker() } } -func (s *Server) stopServerWorkers() { - close(s.serverWorkerChannel) -} - // NewServer creates a gRPC server which has no service registered and has not // started to accept requests yet. func NewServer(opt ...ServerOption) *Server { @@ -806,6 +827,18 @@ func (l *listenSocket) Close() error { // Serve returns when lis.Accept fails with fatal errors. lis will be closed when // this method returns. // Serve will return a non-nil error unless Stop or GracefulStop is called. +// +// Note: All supported releases of Go (as of December 2023) override the OS +// defaults for TCP keepalive time and interval to 15s. To enable TCP keepalive +// with OS defaults for keepalive time and interval, callers need to do the +// following two things: +// - pass a net.Listener created by calling the Listen method on a +// net.ListenConfig with the `KeepAlive` field set to a negative value. This +// will result in the Go standard library not overriding OS defaults for TCP +// keepalive interval and time. But this will also result in the Go standard +// library not enabling TCP keepalives by default. +// - override the Accept method on the passed in net.Listener and set the +// SO_KEEPALIVE socket option to enable TCP keepalives, with OS defaults. func (s *Server) Serve(lis net.Listener) error { s.mu.Lock() s.printf("serving") @@ -913,24 +946,21 @@ func (s *Server) handleRawConn(lisAddr string, rawConn net.Conn) { return } + if cc, ok := rawConn.(interface { + PassServerTransport(transport.ServerTransport) + }); ok { + cc.PassServerTransport(st) + } + if !s.addConn(lisAddr, st) { return } go func() { - s.serveStreams(st) + s.serveStreams(context.Background(), st, rawConn) s.removeConn(lisAddr, st) }() } -func (s *Server) drainServerTransports(addr string) { - s.mu.Lock() - conns := s.conns[addr] - for st := range conns { - st.Drain("") - } - s.mu.Unlock() -} - // newHTTP2Transport sets up a http/2 transport (using the // gRPC http2 server transport in transport/http2_server.go). func (s *Server) newHTTP2Transport(c net.Conn) transport.ServerTransport { @@ -971,18 +1001,31 @@ func (s *Server) newHTTP2Transport(c net.Conn) transport.ServerTransport { return st } -func (s *Server) serveStreams(st transport.ServerTransport) { - defer st.Close(errors.New("finished serving streams for the server transport")) - var wg sync.WaitGroup +func (s *Server) serveStreams(ctx context.Context, st transport.ServerTransport, rawConn net.Conn) { + ctx = transport.SetConnection(ctx, rawConn) + ctx = peer.NewContext(ctx, st.Peer()) + for _, sh := range s.opts.statsHandlers { + ctx = sh.TagConn(ctx, &stats.ConnTagInfo{ + RemoteAddr: st.Peer().Addr, + LocalAddr: st.Peer().LocalAddr, + }) + sh.HandleConn(ctx, &stats.ConnBegin{}) + } - streamQuota := newHandlerQuota(s.opts.maxConcurrentStreams) - st.HandleStreams(func(stream *transport.Stream) { - wg.Add(1) + defer func() { + st.Close(errors.New("finished serving streams for the server transport")) + for _, sh := range s.opts.statsHandlers { + sh.HandleConn(ctx, &stats.ConnEnd{}) + } + }() + streamQuota := newHandlerQuota(s.opts.maxConcurrentStreams) + st.HandleStreams(ctx, func(stream *transport.Stream) { + s.handlersWG.Add(1) streamQuota.acquire() f := func() { defer streamQuota.release() - defer wg.Done() + defer s.handlersWG.Done() s.handleStream(st, stream) } @@ -996,7 +1039,6 @@ func (s *Server) serveStreams(st transport.ServerTransport) { } go f() }) - wg.Wait() } var _ http.Handler = (*Server)(nil) @@ -1040,7 +1082,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } defer s.removeConn(listenerAddressForServeHTTP, st) - s.serveStreams(st) + s.serveStreams(r.Context(), st, nil) } func (s *Server) addConn(addr string, st transport.ServerTransport) bool { @@ -1689,6 +1731,7 @@ func (s *Server) processStreamingRPC(ctx context.Context, t transport.ServerTran func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Stream) { ctx := stream.Context() + ctx = contextWithServer(ctx, s) var ti *traceInfo if EnableTracing { tr := trace.New("grpc.Recv."+methodFamily(stream.Method()), stream.Method()) @@ -1697,7 +1740,7 @@ func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Str tr: tr, firstLine: firstLine{ client: false, - remoteAddr: t.RemoteAddr(), + remoteAddr: t.Peer().Addr, }, } if dl, ok := ctx.Deadline(); ok { @@ -1731,6 +1774,22 @@ func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Str service := sm[:pos] method := sm[pos+1:] + md, _ := metadata.FromIncomingContext(ctx) + for _, sh := range s.opts.statsHandlers { + ctx = sh.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: stream.Method()}) + sh.HandleRPC(ctx, &stats.InHeader{ + FullMethod: stream.Method(), + RemoteAddr: t.Peer().Addr, + LocalAddr: t.Peer().LocalAddr, + Compression: stream.RecvCompress(), + WireLength: stream.HeaderWireLength(), + Header: md, + }) + } + // To have calls in stream callouts work. Will delete once all stats handler + // calls come from the gRPC layer. + stream.SetContext(ctx) + srv, knownService := s.services[service] if knownService { if md, ok := srv.methods[method]; ok { @@ -1820,62 +1879,72 @@ func ServerTransportStreamFromContext(ctx context.Context) ServerTransportStream // pending RPCs on the client side will get notified by connection // errors. func (s *Server) Stop() { - s.quit.Fire() + s.stop(false) +} - defer func() { - s.serveWG.Wait() - s.done.Fire() - }() +// GracefulStop stops the gRPC server gracefully. It stops the server from +// accepting new connections and RPCs and blocks until all the pending RPCs are +// finished. +func (s *Server) GracefulStop() { + s.stop(true) +} + +func (s *Server) stop(graceful bool) { + s.quit.Fire() + defer s.done.Fire() s.channelzRemoveOnce.Do(func() { channelz.RemoveEntry(s.channelzID) }) s.mu.Lock() - listeners := s.lis - s.lis = nil - conns := s.conns - s.conns = nil - // interrupt GracefulStop if Stop and GracefulStop are called concurrently. - s.cv.Broadcast() + s.closeListenersLocked() + // Wait for serving threads to be ready to exit. Only then can we be sure no + // new conns will be created. s.mu.Unlock() + s.serveWG.Wait() - for lis := range listeners { - lis.Close() + s.mu.Lock() + defer s.mu.Unlock() + + if graceful { + s.drainAllServerTransportsLocked() + } else { + s.closeServerTransportsLocked() } - for _, cs := range conns { - for st := range cs { - st.Close(errors.New("Server.Stop called")) - } + + for len(s.conns) != 0 { + s.cv.Wait() } + s.conns = nil + if s.opts.numServerWorkers > 0 { - s.stopServerWorkers() + // Closing the channel (only once, via grpcsync.OnceFunc) after all the + // connections have been closed above ensures that there are no + // goroutines executing the callback passed to st.HandleStreams (where + // the channel is written to). + s.serverWorkerChannelClose() + } + + if graceful || s.opts.waitForHandlers { + s.handlersWG.Wait() } - s.mu.Lock() if s.events != nil { s.events.Finish() s.events = nil } - s.mu.Unlock() } -// GracefulStop stops the gRPC server gracefully. It stops the server from -// accepting new connections and RPCs and blocks until all the pending RPCs are -// finished. -func (s *Server) GracefulStop() { - s.quit.Fire() - defer s.done.Fire() - - s.channelzRemoveOnce.Do(func() { channelz.RemoveEntry(s.channelzID) }) - s.mu.Lock() - if s.conns == nil { - s.mu.Unlock() - return +// s.mu must be held by the caller. +func (s *Server) closeServerTransportsLocked() { + for _, conns := range s.conns { + for st := range conns { + st.Close(errors.New("Server.Stop called")) + } } +} - for lis := range s.lis { - lis.Close() - } - s.lis = nil +// s.mu must be held by the caller. +func (s *Server) drainAllServerTransportsLocked() { if !s.drain { for _, conns := range s.conns { for st := range conns { @@ -1884,22 +1953,14 @@ func (s *Server) GracefulStop() { } s.drain = true } +} - // Wait for serving threads to be ready to exit. Only then can we be sure no - // new conns will be created. - s.mu.Unlock() - s.serveWG.Wait() - s.mu.Lock() - - for len(s.conns) != 0 { - s.cv.Wait() - } - s.conns = nil - if s.events != nil { - s.events.Finish() - s.events = nil +// s.mu must be held by the caller. +func (s *Server) closeListenersLocked() { + for lis := range s.lis { + lis.Close() } - s.mu.Unlock() + s.lis = nil } // contentSubtype must be lowercase @@ -1913,11 +1974,50 @@ func (s *Server) getCodec(contentSubtype string) baseCodec { } codec := encoding.GetCodec(contentSubtype) if codec == nil { + logger.Warningf("Unsupported codec %q. Defaulting to %q for now. This will start to fail in future releases.", contentSubtype, proto.Name) return encoding.GetCodec(proto.Name) } return codec } +type serverKey struct{} + +// serverFromContext gets the Server from the context. +func serverFromContext(ctx context.Context) *Server { + s, _ := ctx.Value(serverKey{}).(*Server) + return s +} + +// contextWithServer sets the Server in the context. +func contextWithServer(ctx context.Context, server *Server) context.Context { + return context.WithValue(ctx, serverKey{}, server) +} + +// isRegisteredMethod returns whether the passed in method is registered as a +// method on the server. /service/method and service/method will match if the +// service and method are registered on the server. +func (s *Server) isRegisteredMethod(serviceMethod string) bool { + if serviceMethod != "" && serviceMethod[0] == '/' { + serviceMethod = serviceMethod[1:] + } + pos := strings.LastIndex(serviceMethod, "/") + if pos == -1 { // Invalid method name syntax. + return false + } + service := serviceMethod[:pos] + method := serviceMethod[pos+1:] + srv, knownService := s.services[service] + if knownService { + if _, ok := srv.methods[method]; ok { + return true + } + if _, ok := srv.streams[method]; ok { + return true + } + } + return false +} + // SetHeader sets the header metadata to be sent from the server to the client. // The context provided must be the context passed to the server's handler. // |