summaryrefslogtreecommitdiff
path: root/vendor/google.golang.org/grpc/stream.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/google.golang.org/grpc/stream.go')
-rw-r--r--vendor/google.golang.org/grpc/stream.go36
1 files changed, 31 insertions, 5 deletions
diff --git a/vendor/google.golang.org/grpc/stream.go b/vendor/google.golang.org/grpc/stream.go
index d58bb6471..d9bbd4c57 100644
--- a/vendor/google.golang.org/grpc/stream.go
+++ b/vendor/google.golang.org/grpc/stream.go
@@ -469,8 +469,9 @@ func (cs *clientStream) newAttemptLocked(isTransparent bool) (*csAttempt, error)
func (a *csAttempt) getTransport() error {
cs := a.cs
- var err error
- a.transport, a.pickResult, err = cs.cc.getTransport(a.ctx, cs.callInfo.failFast, cs.callHdr.Method)
+ pickInfo := balancer.PickInfo{Ctx: a.ctx, FullMethodName: cs.callHdr.Method}
+ pick, err := cs.cc.pickerWrapper.pick(a.ctx, cs.callInfo.failFast, pickInfo)
+ a.transport, a.pickResult = pick.transport, pick.result
if err != nil {
if de, ok := err.(dropError); ok {
err = de.error
@@ -481,6 +482,11 @@ func (a *csAttempt) getTransport() error {
if a.trInfo != nil {
a.trInfo.firstLine.SetRemoteAddr(a.transport.RemoteAddr())
}
+ if pick.blocked {
+ for _, sh := range a.statsHandlers {
+ sh.HandleRPC(a.ctx, &stats.DelayedPickComplete{})
+ }
+ }
return nil
}
@@ -1171,7 +1177,7 @@ func (a *csAttempt) recvMsg(m any, payInfo *payloadInfo) (err error) {
} else if err != nil {
return toRPCErr(err)
}
- return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>"))
+ return status.Errorf(codes.Internal, "cardinality violation: expected <EOF> for non server-streaming RPCs, but received another message")
}
func (a *csAttempt) finish(err error) {
@@ -1495,7 +1501,7 @@ func (as *addrConnStream) RecvMsg(m any) (err error) {
} else if err != nil {
return toRPCErr(err)
}
- return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>"))
+ return status.Errorf(codes.Internal, "cardinality violation: expected <EOF> for non server-streaming RPCs, but received another message")
}
func (as *addrConnStream) finish(err error) {
@@ -1580,6 +1586,7 @@ type serverStream struct {
s *transport.ServerStream
p *parser
codec baseCodec
+ desc *StreamDesc
compressorV0 Compressor
compressorV1 encoding.Compressor
@@ -1588,6 +1595,8 @@ type serverStream struct {
sendCompressorName string
+ recvFirstMsg bool // set after the first message is received
+
maxReceiveMessageSize int
maxSendMessageSize int
trInfo *traceInfo
@@ -1774,6 +1783,10 @@ func (ss *serverStream) RecvMsg(m any) (err error) {
binlog.Log(ss.ctx, chc)
}
}
+ // Received no request msg for non-client streaming rpcs.
+ if !ss.desc.ClientStreams && !ss.recvFirstMsg {
+ return status.Error(codes.Internal, "cardinality violation: received no request message from non-client-streaming RPC")
+ }
return err
}
if err == io.ErrUnexpectedEOF {
@@ -1781,6 +1794,7 @@ func (ss *serverStream) RecvMsg(m any) (err error) {
}
return toRPCErr(err)
}
+ ss.recvFirstMsg = true
if len(ss.statsHandler) != 0 {
for _, sh := range ss.statsHandler {
sh.HandleRPC(ss.s.Context(), &stats.InPayload{
@@ -1800,7 +1814,19 @@ func (ss *serverStream) RecvMsg(m any) (err error) {
binlog.Log(ss.ctx, cm)
}
}
- return nil
+
+ if ss.desc.ClientStreams {
+ // Subsequent messages should be received by subsequent RecvMsg calls.
+ return nil
+ }
+ // Special handling for non-client-stream rpcs.
+ // This recv expects EOF or errors, so we don't collect inPayload.
+ if err := recv(ss.p, ss.codec, ss.s, ss.decompressorV0, m, ss.maxReceiveMessageSize, nil, ss.decompressorV1, true); err == io.EOF {
+ return nil
+ } else if err != nil {
+ return err
+ }
+ return status.Error(codes.Internal, "cardinality violation: received multiple request messages for non-client-streaming RPC")
}
// MethodFromServerStream returns the method string for the input stream.