summaryrefslogtreecommitdiff
path: root/vendor/google.golang.org/grpc/internal/transport/http2_server.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/google.golang.org/grpc/internal/transport/http2_server.go')
-rw-r--r--vendor/google.golang.org/grpc/internal/transport/http2_server.go135
1 files changed, 66 insertions, 69 deletions
diff --git a/vendor/google.golang.org/grpc/internal/transport/http2_server.go b/vendor/google.golang.org/grpc/internal/transport/http2_server.go
index 6fa1eb419..a206e2eef 100644
--- a/vendor/google.golang.org/grpc/internal/transport/http2_server.go
+++ b/vendor/google.golang.org/grpc/internal/transport/http2_server.go
@@ -68,18 +68,15 @@ var serverConnectionCounter uint64
// http2Server implements the ServerTransport interface with HTTP2.
type http2Server struct {
- lastRead int64 // Keep this field 64-bit aligned. Accessed atomically.
- ctx context.Context
- done chan struct{}
- conn net.Conn
- loopy *loopyWriter
- readerDone chan struct{} // sync point to enable testing.
- writerDone chan struct{} // sync point to enable testing.
- remoteAddr net.Addr
- localAddr net.Addr
- authInfo credentials.AuthInfo // auth info about the connection
- inTapHandle tap.ServerInHandle
- framer *framer
+ lastRead int64 // Keep this field 64-bit aligned. Accessed atomically.
+ done chan struct{}
+ conn net.Conn
+ loopy *loopyWriter
+ readerDone chan struct{} // sync point to enable testing.
+ loopyWriterDone chan struct{}
+ peer peer.Peer
+ inTapHandle tap.ServerInHandle
+ framer *framer
// The max number of concurrent streams.
maxStreams uint32
// controlBuf delivers all the control related tasks (e.g., window
@@ -243,16 +240,18 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
}
done := make(chan struct{})
+ peer := peer.Peer{
+ Addr: conn.RemoteAddr(),
+ LocalAddr: conn.LocalAddr(),
+ AuthInfo: authInfo,
+ }
t := &http2Server{
- ctx: setConnection(context.Background(), rawConn),
done: done,
conn: conn,
- remoteAddr: conn.RemoteAddr(),
- localAddr: conn.LocalAddr(),
- authInfo: authInfo,
+ peer: peer,
framer: framer,
readerDone: make(chan struct{}),
- writerDone: make(chan struct{}),
+ loopyWriterDone: make(chan struct{}),
maxStreams: config.MaxStreams,
inTapHandle: config.InTapHandle,
fc: &trInFlow{limit: uint32(icwz)},
@@ -267,8 +266,6 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
bufferPool: newBufferPool(),
}
t.logger = prefixLoggerForServerTransport(t)
- // Add peer information to the http2server context.
- t.ctx = peer.NewContext(t.ctx, t.getPeer())
t.controlBuf = newControlBuffer(t.done)
if dynamicWindow {
@@ -277,15 +274,7 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
updateFlowControl: t.updateFlowControl,
}
}
- for _, sh := range t.stats {
- t.ctx = sh.TagConn(t.ctx, &stats.ConnTagInfo{
- RemoteAddr: t.remoteAddr,
- LocalAddr: t.localAddr,
- })
- connBegin := &stats.ConnBegin{}
- sh.HandleConn(t.ctx, connBegin)
- }
- t.channelzID, err = channelz.RegisterNormalSocket(t, config.ChannelzParentID, fmt.Sprintf("%s -> %s", t.remoteAddr, t.localAddr))
+ t.channelzID, err = channelz.RegisterNormalSocket(t, config.ChannelzParentID, fmt.Sprintf("%s -> %s", t.peer.Addr, t.peer.LocalAddr))
if err != nil {
return nil, err
}
@@ -333,8 +322,24 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
go func() {
t.loopy = newLoopyWriter(serverSide, t.framer, t.controlBuf, t.bdpEst, t.conn, t.logger)
t.loopy.ssGoAwayHandler = t.outgoingGoAwayHandler
- t.loopy.run()
- close(t.writerDone)
+ err := t.loopy.run()
+ close(t.loopyWriterDone)
+ if !isIOError(err) {
+ // Close the connection if a non-I/O error occurs (for I/O errors
+ // the reader will also encounter the error and close). Wait 1
+ // second before closing the connection, or when the reader is done
+ // (i.e. the client already closed the connection or a connection
+ // error occurred). This avoids the potential problem where there
+ // is unread data on the receive side of the connection, which, if
+ // closed, would lead to a TCP RST instead of FIN, and the client
+ // encountering errors. For more info:
+ // https://github.com/grpc/grpc-go/issues/5358
+ select {
+ case <-t.readerDone:
+ case <-time.After(time.Second):
+ }
+ t.conn.Close()
+ }
}()
go t.keepalive()
return t, nil
@@ -342,7 +347,7 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
// operateHeaders takes action on the decoded headers. Returns an error if fatal
// error encountered and transport needs to close, otherwise returns nil.
-func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(*Stream)) error {
+func (t *http2Server) operateHeaders(ctx context.Context, frame *http2.MetaHeadersFrame, handle func(*Stream)) error {
// Acquire max stream ID lock for entire duration
t.maxStreamMu.Lock()
defer t.maxStreamMu.Unlock()
@@ -369,10 +374,11 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
buf := newRecvBuffer()
s := &Stream{
- id: streamID,
- st: t,
- buf: buf,
- fc: &inFlow{limit: uint32(t.initialWindowSize)},
+ id: streamID,
+ st: t,
+ buf: buf,
+ fc: &inFlow{limit: uint32(t.initialWindowSize)},
+ headerWireLength: int(frame.Header().Length),
}
var (
// if false, content-type was missing or invalid
@@ -511,9 +517,9 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
s.state = streamReadDone
}
if timeoutSet {
- s.ctx, s.cancel = context.WithTimeout(t.ctx, timeout)
+ s.ctx, s.cancel = context.WithTimeout(ctx, timeout)
} else {
- s.ctx, s.cancel = context.WithCancel(t.ctx)
+ s.ctx, s.cancel = context.WithCancel(ctx)
}
// Attach the received metadata to the context.
@@ -592,18 +598,6 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
s.requestRead = func(n int) {
t.adjustWindow(s, uint32(n))
}
- for _, sh := range t.stats {
- s.ctx = sh.TagRPC(s.ctx, &stats.RPCTagInfo{FullMethodName: s.method})
- inHeader := &stats.InHeader{
- FullMethod: s.method,
- RemoteAddr: t.remoteAddr,
- LocalAddr: t.localAddr,
- Compression: s.recvCompress,
- WireLength: int(frame.Header().Length),
- Header: mdata.Copy(),
- }
- sh.HandleRPC(s.ctx, inHeader)
- }
s.ctxDone = s.ctx.Done()
s.wq = newWriteQuota(defaultWriteQuota, s.ctxDone)
s.trReader = &transportReader{
@@ -629,8 +623,11 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
// HandleStreams receives incoming streams using the given handler. This is
// typically run in a separate goroutine.
// traceCtx attaches trace to ctx and returns the new context.
-func (t *http2Server) HandleStreams(handle func(*Stream)) {
- defer close(t.readerDone)
+func (t *http2Server) HandleStreams(ctx context.Context, handle func(*Stream)) {
+ defer func() {
+ close(t.readerDone)
+ <-t.loopyWriterDone
+ }()
for {
t.controlBuf.throttle()
frame, err := t.framer.fr.ReadFrame()
@@ -664,7 +661,7 @@ func (t *http2Server) HandleStreams(handle func(*Stream)) {
}
switch frame := frame.(type) {
case *http2.MetaHeadersFrame:
- if err := t.operateHeaders(frame, handle); err != nil {
+ if err := t.operateHeaders(ctx, frame, handle); err != nil {
t.Close(err)
break
}
@@ -979,7 +976,12 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error {
}
}
if err := t.writeHeaderLocked(s); err != nil {
- return status.Convert(err).Err()
+ switch e := err.(type) {
+ case ConnectionError:
+ return status.Error(codes.Unavailable, e.Desc)
+ default:
+ return status.Convert(err).Err()
+ }
}
return nil
}
@@ -1242,10 +1244,6 @@ func (t *http2Server) Close(err error) {
for _, s := range streams {
s.cancel()
}
- for _, sh := range t.stats {
- connEnd := &stats.ConnEnd{}
- sh.HandleConn(t.ctx, connEnd)
- }
}
// deleteStream deletes the stream s from transport's active streams.
@@ -1311,10 +1309,6 @@ func (t *http2Server) closeStream(s *Stream, rst bool, rstCode http2.ErrCode, eo
})
}
-func (t *http2Server) RemoteAddr() net.Addr {
- return t.remoteAddr
-}
-
func (t *http2Server) Drain(debugData string) {
t.mu.Lock()
defer t.mu.Unlock()
@@ -1351,6 +1345,7 @@ func (t *http2Server) outgoingGoAwayHandler(g *goAway) (bool, error) {
if err := t.framer.fr.WriteGoAway(sid, g.code, g.debugData); err != nil {
return false, err
}
+ t.framer.writer.Flush()
if retErr != nil {
return false, retErr
}
@@ -1371,7 +1366,7 @@ func (t *http2Server) outgoingGoAwayHandler(g *goAway) (bool, error) {
return false, err
}
go func() {
- timer := time.NewTimer(time.Minute)
+ timer := time.NewTimer(5 * time.Second)
defer timer.Stop()
select {
case <-t.drainEvent.Done():
@@ -1397,11 +1392,11 @@ func (t *http2Server) ChannelzMetric() *channelz.SocketInternalMetric {
LastMessageReceivedTimestamp: time.Unix(0, atomic.LoadInt64(&t.czData.lastMsgRecvTime)),
LocalFlowControlWindow: int64(t.fc.getSize()),
SocketOptions: channelz.GetSocketOption(t.conn),
- LocalAddr: t.localAddr,
- RemoteAddr: t.remoteAddr,
+ LocalAddr: t.peer.LocalAddr,
+ RemoteAddr: t.peer.Addr,
// RemoteName :
}
- if au, ok := t.authInfo.(credentials.ChannelzSecurityInfo); ok {
+ if au, ok := t.peer.AuthInfo.(credentials.ChannelzSecurityInfo); ok {
s.Security = au.GetSecurityValue()
}
s.RemoteFlowControlWindow = t.getOutFlowWindow()
@@ -1433,10 +1428,12 @@ func (t *http2Server) getOutFlowWindow() int64 {
}
}
-func (t *http2Server) getPeer() *peer.Peer {
+// Peer returns the peer of the transport.
+func (t *http2Server) Peer() *peer.Peer {
return &peer.Peer{
- Addr: t.remoteAddr,
- AuthInfo: t.authInfo, // Can be nil
+ Addr: t.peer.Addr,
+ LocalAddr: t.peer.LocalAddr,
+ AuthInfo: t.peer.AuthInfo, // Can be nil
}
}
@@ -1461,6 +1458,6 @@ func GetConnection(ctx context.Context) net.Conn {
// SetConnection adds the connection to the context to be able to get
// information about the destination ip and port for an incoming RPC. This also
// allows any unary or streaming interceptors to see the connection.
-func setConnection(ctx context.Context, conn net.Conn) context.Context {
+func SetConnection(ctx context.Context, conn net.Conn) context.Context {
return context.WithValue(ctx, connectionKey{}, conn)
}